ndtypes 0.2.0dev5 → 0.2.0dev6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +12 -0
  3. data/Rakefile +8 -0
  4. data/ext/ruby_ndtypes/GPATH +0 -0
  5. data/ext/ruby_ndtypes/GRTAGS +0 -0
  6. data/ext/ruby_ndtypes/GTAGS +0 -0
  7. data/ext/ruby_ndtypes/extconf.rb +1 -1
  8. data/ext/ruby_ndtypes/include/ndtypes.h +231 -122
  9. data/ext/ruby_ndtypes/include/ruby_ndtypes.h +1 -1
  10. data/ext/ruby_ndtypes/lib/libndtypes.a +0 -0
  11. data/ext/ruby_ndtypes/lib/libndtypes.so.0.2.0dev3 +0 -0
  12. data/ext/ruby_ndtypes/ndtypes/Makefile +87 -0
  13. data/ext/ruby_ndtypes/ndtypes/config.h +68 -0
  14. data/ext/ruby_ndtypes/ndtypes/config.log +477 -0
  15. data/ext/ruby_ndtypes/ndtypes/config.status +1027 -0
  16. data/ext/ruby_ndtypes/ndtypes/doc/_static/style.css +7 -0
  17. data/ext/ruby_ndtypes/ndtypes/doc/_templates/layout.html +2 -0
  18. data/ext/ruby_ndtypes/ndtypes/doc/conf.py +40 -4
  19. data/ext/ruby_ndtypes/ndtypes/doc/images/xndlogo.png +0 -0
  20. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +1 -1
  21. data/ext/ruby_ndtypes/ndtypes/doc/requirements.txt +2 -0
  22. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile +287 -0
  23. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +20 -4
  24. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +22 -3
  25. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +1 -1
  26. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.o +0 -0
  27. data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.o +0 -0
  28. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile +73 -0
  29. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +246 -229
  30. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +15 -11
  31. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.o +0 -0
  32. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +38 -28
  33. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +91 -91
  34. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +1 -1
  35. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +4 -3
  36. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.o +0 -0
  37. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +8 -7
  38. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.o +0 -0
  39. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +2 -2
  40. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.o +0 -0
  41. data/ext/ruby_ndtypes/ndtypes/libndtypes/context.o +0 -0
  42. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +263 -182
  43. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.o +0 -0
  44. data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.o +0 -0
  45. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +67 -7
  46. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.o +0 -0
  47. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +1112 -1000
  48. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +69 -58
  49. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.o +0 -0
  50. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +150 -99
  51. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +185 -15
  52. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.o +0 -0
  53. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +301 -276
  54. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +1 -1
  55. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +9 -4
  56. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.o +0 -0
  57. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.a +0 -0
  58. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so +1 -0
  59. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0 +1 -0
  60. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3 +0 -0
  61. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +729 -228
  62. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.o +0 -0
  63. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +768 -403
  64. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h +1002 -0
  65. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +231 -122
  66. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.o +0 -0
  67. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +176 -84
  68. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +26 -14
  69. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.o +0 -0
  70. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +57 -35
  71. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.o +0 -0
  72. data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.c +420 -0
  73. data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.o +0 -0
  74. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +8 -8
  75. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +1 -1
  76. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.o +0 -0
  77. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile +48 -0
  78. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +200 -116
  79. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.o +0 -0
  80. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +46 -4
  81. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.o +0 -0
  82. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +58 -27
  83. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +1 -1
  84. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.o +0 -0
  85. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +3 -5
  86. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +12 -4
  87. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.o +0 -0
  88. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile +55 -0
  89. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +8 -8
  90. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +5 -5
  91. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +274 -172
  92. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +24 -4
  93. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +2 -2
  94. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +14 -14
  95. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +32 -30
  96. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +37 -0
  97. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +36 -0
  98. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +16 -0
  99. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +5 -5
  100. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +706 -253
  101. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_unify.c +132 -0
  102. data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.c +703 -0
  103. data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.o +0 -0
  104. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +335 -127
  105. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.o +0 -0
  106. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +2 -2
  107. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.o +0 -0
  108. data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +88 -71
  109. data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +0 -1
  110. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +10 -13
  111. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +395 -314
  112. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.a +0 -0
  113. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so +1 -0
  114. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0 +1 -0
  115. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0.2.0dev3 +0 -0
  116. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/ndtypes.h +1002 -0
  117. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +15 -33
  118. data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +340 -132
  119. data/ext/ruby_ndtypes/ndtypes/setup.py +11 -2
  120. data/ext/ruby_ndtypes/ruby_ndtypes.c +364 -241
  121. data/ext/ruby_ndtypes/ruby_ndtypes.h +1 -1
  122. data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +0 -1
  123. data/lib/ndtypes.rb +11 -0
  124. data/lib/ndtypes/version.rb +2 -2
  125. data/lib/ruby_ndtypes.so +0 -0
  126. data/ndtypes.gemspec +3 -0
  127. data/spec/ndtypes_spec.rb +6 -0
  128. metadata +98 -4
  129. data/ext/ruby_ndtypes/gc_guard.c +0 -36
  130. data/ext/ruby_ndtypes/gc_guard.h +0 -12
@@ -95,21 +95,46 @@ ndt_asprintf(ndt_context_t *ctx, const char *fmt, ...)
95
95
  /* Type functions (unstable API) */
96
96
  /*****************************************************************************/
97
97
 
98
+ /* Number of elements, currently only for ndarrays. Return -1 on error. */
99
+ int64_t
100
+ ndt_nelem(const ndt_t *t)
101
+ {
102
+ NDT_STATIC_CONTEXT(ctx);
103
+ ndt_ndarray_t x;
104
+ int64_t n;
105
+
106
+ if (ndt_as_ndarray(&x, t, &ctx) < 0) {
107
+ ndt_err_clear(&ctx);
108
+ return -1;
109
+ }
110
+
111
+ n = 1;
112
+ for (int i = 0; i < x.ndim; i++) {
113
+ n *= x.shape[i];
114
+ }
115
+
116
+ return n;
117
+ }
118
+
98
119
  /* Return the next type in a dimension chain. Undefined for non-dimensions. */
99
120
  static const ndt_t *
100
- next_dim(const ndt_t *a)
121
+ next_dim(const ndt_t *t)
101
122
  {
102
- assert(a->ndim > 0);
123
+ assert(t->ndim > 0);
103
124
 
104
- switch (a->tag) {
125
+ switch (t->tag) {
105
126
  case FixedDim:
106
- return a->FixedDim.type;
127
+ return t->FixedDim.type;
107
128
  case VarDim:
108
- return a->VarDim.type;
129
+ return t->VarDim.type;
130
+ case VarDimElem:
131
+ return t->VarDimElem.type;
109
132
  case SymbolicDim:
110
- return a->SymbolicDim.type;
133
+ return t->SymbolicDim.type;
111
134
  case EllipsisDim:
112
- return a->EllipsisDim.type;
135
+ return t->EllipsisDim.type;
136
+ case Array:
137
+ return t->Array.type;
113
138
  default:
114
139
  /* NOT REACHED: tags should be exhaustive. */
115
140
  ndt_internal_error("invalid value");
@@ -126,16 +151,19 @@ ndt_dtype(const ndt_t *t)
126
151
  return t;
127
152
  }
128
153
 
129
- const ndt_t *
130
- ndt_dim_at(const ndt_t *t, int n)
154
+ int
155
+ ndt_logical_ndim(const ndt_t *t)
131
156
  {
132
- assert(n <= t->ndim);
157
+ int ndim = t->ndim;
133
158
 
134
- for (int i = 0; i < n; i++) {
135
- t = next_dim(t);
159
+ while (t->tag == VarDim || t->tag == VarDimElem) {
160
+ if (t->tag == VarDimElem) {
161
+ --ndim;
162
+ }
163
+ t = t->VarDim.type;
136
164
  }
137
165
 
138
- return t;
166
+ return ndim;
139
167
  }
140
168
 
141
169
  const ndt_t *
@@ -170,6 +198,75 @@ ndt_dims_dtype(const ndt_t *dims[NDT_MAX_DIM], const ndt_t **dtype, const ndt_t
170
198
  return n;
171
199
  }
172
200
 
201
+ static const ndt_t *
202
+ compress(const ndt_t *t)
203
+ {
204
+ while (t->tag == VarDimElem) {
205
+ t = t->VarDimElem.type;
206
+ }
207
+
208
+ return t;
209
+ }
210
+
211
+ /*
212
+ * Return the next logical type in a dimension chain. Undefined for
213
+ * non-dimensions.
214
+ */
215
+ static const ndt_t *
216
+ next_logical_dim(const ndt_t *t)
217
+ {
218
+ switch (t->tag) {
219
+ case FixedDim:
220
+ return t->FixedDim.type;
221
+ case VarDim:
222
+ return compress(t->VarDim.type);
223
+ case VarDimElem:
224
+ return next_logical_dim(compress(t));
225
+ case SymbolicDim:
226
+ return t->SymbolicDim.type;
227
+ case EllipsisDim:
228
+ return t->EllipsisDim.type;
229
+ default:
230
+ return t;
231
+ }
232
+ }
233
+
234
+ const ndt_t *
235
+ ndt_logical_dim_at(const ndt_t *t, int n)
236
+ {
237
+ for (int i = 0; i < n; i++) {
238
+ t = next_logical_dim(t);
239
+ }
240
+
241
+ return t;
242
+ }
243
+
244
+ const ndt_t *
245
+ ndt_copy_contiguous_at(const ndt_t *t, int n, const ndt_t *dtype, ndt_context_t *ctx)
246
+ {
247
+ const ndt_t *u;
248
+
249
+ if (!ndt_is_ndarray(t)) {
250
+ ndt_err_format(ctx, NDT_NotImplementedError,
251
+ "partial copies are currently restricted to fixed dimensions");
252
+ return NULL;
253
+ }
254
+
255
+ if (n < 0 || n > t->ndim) {
256
+ ndt_err_format(ctx, NDT_ValueError, "n out of bounds");
257
+ return NULL;
258
+ }
259
+
260
+ u = ndt_logical_dim_at(t, n);
261
+
262
+ if (dtype != NULL) {
263
+ return ndt_copy_contiguous_dtype(u, dtype, 0, ctx);
264
+ }
265
+ else {
266
+ return ndt_copy_contiguous(u, 0, ctx);
267
+ }
268
+ }
269
+
173
270
  int
174
271
  ndt_as_ndarray(ndt_ndarray_t *a, const ndt_t *t, ndt_context_t *ctx)
175
272
  {
@@ -205,9 +302,89 @@ ndt_as_ndarray(ndt_ndarray_t *a, const ndt_t *t, ndt_context_t *ctx)
205
302
  return n;
206
303
  }
207
304
 
305
+ static const ndt_t *
306
+ _ndt_transpose(const ndt_ndarray_t *a, const int p[], const ndt_t *dtype,
307
+ ndt_context_t *ctx)
308
+ {
309
+ const ndt_t *t;
310
+
311
+ t = dtype;
312
+ ndt_incref(t);
313
+
314
+ for (int i = a->ndim-1; i >= 0; i--) {
315
+ const ndt_t *u = ndt_fixed_dim(t, a->shape[p[i]], a->steps[p[i]], ctx);
316
+ ndt_decref(t);
317
+ if (u == NULL) {
318
+ return NULL;
319
+ }
320
+ t = u;
321
+ }
322
+
323
+ return t;
324
+ }
325
+
326
+ const ndt_t *
327
+ ndt_transpose(const ndt_t *t, const int *p, int ndim, ndt_context_t *ctx)
328
+ {
329
+ ndt_ndarray_t a;
330
+ int permute[NDT_MAX_DIM];
331
+
332
+ if (ndt_as_ndarray(&a, t, ctx) < 0) {
333
+ return NULL;
334
+ }
335
+
336
+ if (p == NULL) {
337
+ if (ndim != 0) {
338
+ ndt_err_format(ctx, NDT_RuntimeError,
339
+ "internal error: ndim != 0 but no permutations given");
340
+ return NULL;
341
+ }
342
+
343
+ ndim = a.ndim;
344
+ if (ndim <= 1) {
345
+ ndt_incref(t);
346
+ return t;
347
+ }
348
+
349
+ for (int i = 0; i < ndim; i++) {
350
+ permute[i] = ndim-i-1;
351
+ }
352
+ p = permute;
353
+ }
354
+ else {
355
+ if (ndim != a.ndim) {
356
+ ndt_err_format(ctx, NDT_ValueError,
357
+ "number of permutations != ndim");
358
+ return NULL;
359
+ }
360
+
361
+ for (int i = 0; i < ndim; i++) {
362
+ permute[i] = 0;
363
+ }
364
+
365
+ for (int i = 0; i < ndim; i++) {
366
+ if (p[i] < 0 || p[i] >= ndim) {
367
+ ndt_err_format(ctx, NDT_ValueError,
368
+ "permutation with value %d out of bounds", p[i]);
369
+ return NULL;
370
+ }
371
+
372
+ if (permute[p[i]] != 0) {
373
+ ndt_err_format(ctx, NDT_ValueError,
374
+ "duplicate permutation index=%d", p[i]);
375
+ return NULL;
376
+ }
377
+
378
+ permute[p[i]] = 1;
379
+ }
380
+ }
381
+
382
+ return _ndt_transpose(&a, p, ndt_dtype(t), ctx);
383
+ }
384
+
208
385
  /* Unoptimized hash function for experimenting. */
209
386
  ndt_ssize_t
210
- ndt_hash(ndt_t *t, ndt_context_t *ctx)
387
+ ndt_hash(const ndt_t *t, ndt_context_t *ctx)
211
388
  {
212
389
  unsigned char *s, *cp;
213
390
  size_t len;
@@ -242,11 +419,11 @@ ndt_hash(ndt_t *t, ndt_context_t *ctx)
242
419
 
243
420
  const ndt_apply_spec_t ndt_apply_spec_empty = {
244
421
  .flags = 0U,
245
- .nout = 0,
246
- .nbroadcast = 0,
247
422
  .outer_dims = 0,
248
- .out = {NULL},
249
- .broadcast = {NULL}
423
+ .nin = 0,
424
+ .nout = 0,
425
+ .nargs = 0,
426
+ .types = {NULL}
250
427
  };
251
428
 
252
429
  ndt_apply_spec_t *
@@ -259,9 +436,10 @@ ndt_apply_spec_new(ndt_context_t *ctx)
259
436
  return ndt_memory_error(ctx);
260
437
  }
261
438
  spec->flags = 0U;
262
- spec->nout = 0;
263
- spec->nbroadcast = 0;
264
439
  spec->outer_dims = 0;
440
+ spec->nin = 0;
441
+ spec->nout = 0;
442
+ spec->nargs = 0;
265
443
 
266
444
  return spec;
267
445
  }
@@ -269,24 +447,20 @@ ndt_apply_spec_new(ndt_context_t *ctx)
269
447
  void
270
448
  ndt_apply_spec_clear(ndt_apply_spec_t *spec)
271
449
  {
272
- int i;
273
-
274
450
  if (spec == NULL) {
275
451
  return;
276
452
  }
277
453
 
278
- for (i = 0; i < spec->nbroadcast; i++) {
279
- ndt_del(spec->broadcast[i]);
280
- }
281
-
282
- for (i = 0; i < spec->nout; i++) {
283
- ndt_del(spec->out[i]);
454
+ for (int i = 0; i < spec->nin+spec->nout; i++) {
455
+ ndt_decref(spec->types[i]);
456
+ spec->types[i] = NULL;
284
457
  }
285
458
 
286
459
  spec->flags = 0U;
287
- spec->nout = 0;
288
- spec->nbroadcast = 0;
289
460
  spec->outer_dims = 0;
461
+ spec->nin = 0;
462
+ spec->nout = 0;
463
+ spec->nargs = 0;
290
464
  }
291
465
 
292
466
  void
@@ -300,6 +474,15 @@ ndt_apply_spec_del(ndt_apply_spec_t *spec)
300
474
  ndt_free(spec);
301
475
  }
302
476
 
477
+ #define X (NDT_INNER_XND)
478
+ #define S (NDT_INNER_STRIDED)
479
+ #define C (NDT_INNER_C)
480
+ #define F (NDT_INNER_F)
481
+
482
+ #define ES (NDT_EXT_STRIDED)
483
+ #define EC (NDT_EXT_C)
484
+ #define EZ (NDT_EXT_ZERO)
485
+
303
486
  /* This function is used in places where it is _really_ convenient not to
304
487
  have to deal with deallocating an error message. */
305
488
  const char *
@@ -307,38 +490,35 @@ ndt_apply_flags_as_string(const ndt_apply_spec_t *spec)
307
490
  {
308
491
  switch (spec->flags) {
309
492
  case 0: return "None";
310
- case NDT_C: return "C";
311
- case NDT_FORTRAN: return "Fortran";
312
- case NDT_FORTRAN|NDT_C: return "C|Fortran";
313
- case NDT_STRIDED: return "Strided";
314
- case NDT_STRIDED|NDT_C: return "C|Strided";
315
- case NDT_STRIDED|NDT_FORTRAN: return "Fortran|Strided";
316
- case NDT_STRIDED|NDT_FORTRAN|NDT_C: return "C|Fortran|Strided";
317
- case NDT_XND: return "Xnd";
318
- case NDT_XND|NDT_C: return "C|Xnd";
319
- case NDT_XND|NDT_FORTRAN: return "Fortran|Xnd";
320
- case NDT_XND|NDT_FORTRAN|NDT_C: return "C|Fortran|Xnd";
321
- case NDT_XND|NDT_STRIDED: return "Strided|Xnd";
322
- case NDT_XND|NDT_STRIDED|NDT_C: return "C|Strided|Xnd";
323
- case NDT_XND|NDT_STRIDED|NDT_FORTRAN: return "Fortran|Strided|Xnd";
324
- case NDT_XND|NDT_STRIDED|NDT_FORTRAN|NDT_C: return "C|Fortran|Strided|Xnd";
325
- case NDT_ELEMWISE_1D: return "Elemwise1D";
326
- case NDT_ELEMWISE_1D|NDT_C: return "Elemwise1D|C";
327
- case NDT_ELEMWISE_1D|NDT_FORTRAN: return "Elemwise1D|Fortran";
328
- case NDT_ELEMWISE_1D|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran";
329
- case NDT_ELEMWISE_1D|NDT_STRIDED: return "Elemwise1D|Strided";
330
- case NDT_ELEMWISE_1D|NDT_STRIDED|NDT_C: return "Elemwise1D|C|Strided";
331
- case NDT_ELEMWISE_1D|NDT_STRIDED|NDT_FORTRAN: return "Elemwise1D|Fortran|Strided";
332
- case NDT_ELEMWISE_1D|NDT_STRIDED|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran|Strided";
333
- case NDT_ELEMWISE_1D|NDT_XND: return "Elemwise1D|Xnd";
334
- case NDT_ELEMWISE_1D|NDT_XND|NDT_C: return "Elemwise1D|C|Xnd";
335
- case NDT_ELEMWISE_1D|NDT_XND|NDT_FORTRAN: return "Elemwise1D|Fortran|Xnd";
336
- case NDT_ELEMWISE_1D|NDT_XND|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran|Xnd";
337
- case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED: return "Elemwise1D|Strided|Xnd";
338
- case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED|NDT_C: return "Elemwise1D|C|Strided|Xnd";
339
- case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED|NDT_FORTRAN: return "Elemwise1D|Fortran|Strided|Xnd";
340
- case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran|Strided|Xnd";
341
- default: return "error: invalid combination of spec->flags";
493
+ case X: return "Xnd";
494
+
495
+ case S|X: return "Strided|Xnd";
496
+ case C|S|X: return "C|Strided|Xnd";
497
+ case F|S|X: return "Fortran|Strided|Xnd";
498
+ case C|F|S|X: return "C|Fortran|Strided|Xnd";
499
+
500
+ case ES|S|X: return "OptS|Strided|Xnd";
501
+ case ES|C|S|X: return "OptS|C|Strided|Xnd";
502
+ case ES|F|S|X: return "OptS|Fortran|Strided|Xnd";
503
+ case ES|C|F|S|X: return "OptS|C|Fortran|Strided|Xnd";
504
+ case EC|ES|C|S|X: return "OptC|OptS|C|Strided|Xnd";
505
+ case EC|ES|C|F|S|X: return "OptC|OptS|C|Fortran|Strided|Xnd";
506
+ case EZ|EC|ES|C|F|S|X: return "OptZ|OptC|OptS|C|Fortran|Strided|Xnd";
507
+ case EZ|EC|ES|C|S|X: return "OptZ|OptC|OptS|C|Strided|Xnd";
508
+ case EZ|ES|C|S|X: return "OptZ|OptS|C|Strided|Xnd";
509
+ case EZ|ES|C|F|S|X: return "OptZ|OptS|C|Fortran|Strided|Xnd";
510
+
511
+ default:
512
+ if (spec->flags & NDT_EXT_ZERO) fprintf(stderr, "EZ ");
513
+ if (spec->flags & NDT_EXT_C) fprintf(stderr, "EC ");
514
+ if (spec->flags & NDT_EXT_STRIDED) fprintf(stderr, "ES ");
515
+ if (spec->flags & NDT_INNER_C) fprintf(stderr, "C ");
516
+ if (spec->flags & NDT_INNER_F) fprintf(stderr, "F ");
517
+ if (spec->flags & NDT_INNER_STRIDED) fprintf(stderr, "S ");
518
+ if (spec->flags & NDT_INNER_XND) fprintf(stderr, "X ");
519
+ fprintf(stderr, "\n");
520
+
521
+ return "unknown flags";
342
522
  }
343
523
  }
344
524
 
@@ -362,16 +542,21 @@ ndt_meta_new(ndt_context_t *ctx)
362
542
  }
363
543
 
364
544
  void
365
- ndt_meta_del(ndt_meta_t *m)
545
+ ndt_meta_clear(ndt_meta_t *m)
366
546
  {
367
547
  if (m == NULL) {
368
548
  return;
369
549
  }
370
550
 
371
551
  for (int i = 0; i < m->ndims; i++) {
372
- ndt_free(m->offsets[i]);
552
+ ndt_decref_offsets(m->offsets[i]);
373
553
  }
554
+ }
374
555
 
556
+ void
557
+ ndt_meta_del(ndt_meta_t *m)
558
+ {
559
+ ndt_meta_clear(m);
375
560
  ndt_free(m);
376
561
  }
377
562
 
@@ -380,95 +565,118 @@ ndt_meta_del(ndt_meta_t *m)
380
565
  /* Optimized kernel strategy (unstable API) */
381
566
  /*****************************************************************************/
382
567
 
383
- static bool
384
- all_inner_1D_contiguous(const ndt_t *types[], int n)
568
+ static uint32_t
569
+ check_c(uint32_t flags, ndt_ndarray_t *x, int outer)
385
570
  {
386
- for (int i = 0; i < n; i++) {
387
- if (types[i]->ndim == 0) {
388
- return false;
571
+ int inner = x->ndim-outer;
572
+ int64_t *shape = x->shape+outer;
573
+ int64_t *steps = x->steps+outer;
574
+ int64_t step = 1;
575
+
576
+ for (int i = inner-1; i >= 0; i--) {
577
+ if (shape[i] > 1 && steps[i] != step) {
578
+ return flags & ~(NDT_INNER_C|NDT_EXT_C|NDT_EXT_ZERO);
389
579
  }
390
- if (!ndt_is_c_contiguous(ndt_dim_at(types[i], types[i]->ndim-1))) {
391
- return false;
580
+ step *= shape[i];
581
+ }
582
+
583
+ if (x->ndim == inner) {
584
+ return flags & ~(NDT_EXT_C|NDT_EXT_ZERO);
585
+ }
586
+ else if (x->ndim > inner) {
587
+ if (shape[-1] > 1 && steps[-1] != step) {
588
+ flags &= ~NDT_EXT_C;
589
+ if (steps[-1] != 0) {
590
+ flags &= ~NDT_EXT_ZERO;
591
+ }
392
592
  }
393
593
  }
394
594
 
395
- return true;
595
+ return flags;
396
596
  }
397
597
 
398
- static bool
399
- all_inner_c_contiguous(const ndt_t *types[], int n, int outer)
598
+ static uint32_t
599
+ check_f(uint32_t flags, ndt_ndarray_t *x, int outer)
400
600
  {
401
- for (int i = 0; i < n; i++) {
402
- if (!ndt_is_c_contiguous(ndt_dim_at(types[i], outer))) {
403
- return false;
601
+ int inner = x->ndim-outer;
602
+ int64_t *shape = x->shape+outer;
603
+ int64_t *steps = x->steps+outer;
604
+ int64_t step = 1;
605
+
606
+ for (int i = 0; i < inner; i++) {
607
+ if (shape[i] > 1 && steps[i] != step) {
608
+ return flags & ~NDT_INNER_F;
404
609
  }
610
+ step *= shape[i];
405
611
  }
406
612
 
407
- return true;
613
+ return flags;
408
614
  }
409
615
 
410
- static bool
411
- all_inner_f_contiguous(const ndt_t *types[], int n, int outer)
616
+ static uint32_t
617
+ check_strided(uint32_t flags, int outer)
412
618
  {
413
- for (int i = 0; i < n; i++) {
414
- if (!ndt_is_f_contiguous(ndt_dim_at(types[i], outer))) {
415
- return false;
416
- }
619
+ if (outer == 0) {
620
+ return flags & ~NDT_EXT_STRIDED;
417
621
  }
418
622
 
419
- return true;
623
+ return flags;
420
624
  }
421
625
 
422
- static bool
423
- all_ndarray(const ndt_t *types[], int n)
626
+ static uint32_t
627
+ check_var(uint32_t flags, const ndt_t *t, int outer)
424
628
  {
425
- for (int i = 0; i < n; i++) {
426
- if (!ndt_is_ndarray(types[i])) {
427
- return false;
428
- }
429
- if (ndt_subtree_is_optional(types[i])) {
430
- return false;
431
- }
432
- }
433
-
434
- return true;
629
+ #if 0
630
+ /*
631
+ * This currently does not handle a corner case with zeros in a shape.
632
+ * The loop in gumath is safe, however, since it stops at shape==0.
633
+ */
634
+ if (ndt_logical_ndim(t) == outer) {
635
+ return flags & NDT_INNER_XND;
636
+ }
637
+
638
+ return 0U;
639
+ #endif
640
+ (void)t;
641
+ (void)outer;
642
+ return flags & NDT_INNER_XND;
435
643
  }
436
644
 
437
- void
438
- ndt_select_kernel_strategy(ndt_apply_spec_t *spec, const ndt_t *sig,
439
- const ndt_t *in[], int nin)
645
+ static uint32_t
646
+ select_flags(const ndt_t *types[], int n, int outer, ndt_context_t *ctx)
440
647
  {
441
- const ndt_t **out = (const ndt_t **)spec->out;
442
- bool in_inner_c, in_inner_f, out_inner_c;
648
+ uint32_t flags = NDT_SPEC_FLAGS_ALL;
649
+ ndt_ndarray_t x;
443
650
 
444
- assert(sig->tag == Function);
445
- assert(spec->flags == 0);
651
+ for (int i = 0; i < n; i++) {
652
+ const ndt_t *t = types[i];
446
653
 
447
- if (spec->nbroadcast > 0) {
448
- in = (const ndt_t **)spec->broadcast;
449
- nin = spec->nbroadcast;
654
+ if (ndt_as_ndarray(&x, t, ctx) < 0) { /* var dimension */
655
+ ndt_err_clear(ctx);
656
+ if (t->tag == VarDim || t->tag == VarDimElem) {
657
+ flags = check_var(flags, t, outer);
658
+ }
659
+ }
660
+ else {
661
+ if (outer > t->ndim) {
662
+ ndt_err_format(ctx, NDT_RuntimeError,
663
+ "number of outer dimensions greater than ndim");
664
+ return UINT32_MAX;
665
+ }
666
+
667
+ flags = check_strided(flags, outer);
668
+ flags = check_c(flags, &x, outer);
669
+ flags = check_f(flags, &x, outer);
670
+ }
450
671
  }
451
672
 
452
- in_inner_c = all_inner_c_contiguous(in, nin, spec->outer_dims);
453
- in_inner_f = all_inner_f_contiguous(in, nin, spec->outer_dims);
454
- out_inner_c = all_inner_c_contiguous(out, spec->nout, spec->outer_dims);
455
-
456
- spec->flags = NDT_XND;
673
+ return flags;
674
+ }
457
675
 
458
- if (sig->Function.elemwise) {
459
- if (all_inner_1D_contiguous(in, nin) &&
460
- all_inner_1D_contiguous(out, spec->nout)) {
461
- spec->flags |= NDT_ELEMWISE_1D;
462
- }
463
- }
464
- if (in_inner_c && out_inner_c) {
465
- spec->flags |= NDT_C;
466
- }
467
- if (in_inner_f && out_inner_c) {
468
- spec->flags |= NDT_FORTRAN;
469
- }
676
+ int
677
+ ndt_select_kernel_strategy(ndt_apply_spec_t *spec, ndt_context_t *ctx)
678
+ {
679
+ spec->flags = select_flags(spec->types, spec->nargs, spec->outer_dims, ctx);
470
680
 
471
- if (all_ndarray(in, nin) && all_ndarray(out, spec->nout)) {
472
- spec->flags |= NDT_STRIDED;
473
- }
681
+ return spec->flags == UINT32_MAX ? -1 : 0;
474
682
  }