ndtypes 0.2.0dev5 → 0.2.0dev6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (130) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +12 -0
  3. data/Rakefile +8 -0
  4. data/ext/ruby_ndtypes/GPATH +0 -0
  5. data/ext/ruby_ndtypes/GRTAGS +0 -0
  6. data/ext/ruby_ndtypes/GTAGS +0 -0
  7. data/ext/ruby_ndtypes/extconf.rb +1 -1
  8. data/ext/ruby_ndtypes/include/ndtypes.h +231 -122
  9. data/ext/ruby_ndtypes/include/ruby_ndtypes.h +1 -1
  10. data/ext/ruby_ndtypes/lib/libndtypes.a +0 -0
  11. data/ext/ruby_ndtypes/lib/libndtypes.so.0.2.0dev3 +0 -0
  12. data/ext/ruby_ndtypes/ndtypes/Makefile +87 -0
  13. data/ext/ruby_ndtypes/ndtypes/config.h +68 -0
  14. data/ext/ruby_ndtypes/ndtypes/config.log +477 -0
  15. data/ext/ruby_ndtypes/ndtypes/config.status +1027 -0
  16. data/ext/ruby_ndtypes/ndtypes/doc/_static/style.css +7 -0
  17. data/ext/ruby_ndtypes/ndtypes/doc/_templates/layout.html +2 -0
  18. data/ext/ruby_ndtypes/ndtypes/doc/conf.py +40 -4
  19. data/ext/ruby_ndtypes/ndtypes/doc/images/xndlogo.png +0 -0
  20. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +1 -1
  21. data/ext/ruby_ndtypes/ndtypes/doc/requirements.txt +2 -0
  22. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile +287 -0
  23. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +20 -4
  24. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +22 -3
  25. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +1 -1
  26. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.o +0 -0
  27. data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.o +0 -0
  28. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile +73 -0
  29. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +246 -229
  30. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +15 -11
  31. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.o +0 -0
  32. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +38 -28
  33. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +91 -91
  34. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +1 -1
  35. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +4 -3
  36. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.o +0 -0
  37. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +8 -7
  38. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.o +0 -0
  39. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +2 -2
  40. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.o +0 -0
  41. data/ext/ruby_ndtypes/ndtypes/libndtypes/context.o +0 -0
  42. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +263 -182
  43. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.o +0 -0
  44. data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.o +0 -0
  45. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +67 -7
  46. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.o +0 -0
  47. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +1112 -1000
  48. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +69 -58
  49. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.o +0 -0
  50. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +150 -99
  51. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +185 -15
  52. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.o +0 -0
  53. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +301 -276
  54. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +1 -1
  55. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +9 -4
  56. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.o +0 -0
  57. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.a +0 -0
  58. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so +1 -0
  59. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0 +1 -0
  60. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3 +0 -0
  61. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +729 -228
  62. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.o +0 -0
  63. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +768 -403
  64. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h +1002 -0
  65. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +231 -122
  66. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.o +0 -0
  67. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +176 -84
  68. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +26 -14
  69. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.o +0 -0
  70. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +57 -35
  71. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.o +0 -0
  72. data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.c +420 -0
  73. data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.o +0 -0
  74. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +8 -8
  75. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +1 -1
  76. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.o +0 -0
  77. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile +48 -0
  78. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +200 -116
  79. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.o +0 -0
  80. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +46 -4
  81. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.o +0 -0
  82. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +58 -27
  83. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +1 -1
  84. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.o +0 -0
  85. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +3 -5
  86. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +12 -4
  87. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.o +0 -0
  88. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile +55 -0
  89. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +8 -8
  90. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +5 -5
  91. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +274 -172
  92. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +24 -4
  93. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +2 -2
  94. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +14 -14
  95. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +32 -30
  96. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +37 -0
  97. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +36 -0
  98. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +16 -0
  99. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +5 -5
  100. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +706 -253
  101. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_unify.c +132 -0
  102. data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.c +703 -0
  103. data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.o +0 -0
  104. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +335 -127
  105. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.o +0 -0
  106. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +2 -2
  107. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.o +0 -0
  108. data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +88 -71
  109. data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +0 -1
  110. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +10 -13
  111. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +395 -314
  112. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.a +0 -0
  113. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so +1 -0
  114. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0 +1 -0
  115. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0.2.0dev3 +0 -0
  116. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/ndtypes.h +1002 -0
  117. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +15 -33
  118. data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +340 -132
  119. data/ext/ruby_ndtypes/ndtypes/setup.py +11 -2
  120. data/ext/ruby_ndtypes/ruby_ndtypes.c +364 -241
  121. data/ext/ruby_ndtypes/ruby_ndtypes.h +1 -1
  122. data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +0 -1
  123. data/lib/ndtypes.rb +11 -0
  124. data/lib/ndtypes/version.rb +2 -2
  125. data/lib/ruby_ndtypes.so +0 -0
  126. data/ndtypes.gemspec +3 -0
  127. data/spec/ndtypes_spec.rb +6 -0
  128. metadata +98 -4
  129. data/ext/ruby_ndtypes/gc_guard.c +0 -36
  130. data/ext/ruby_ndtypes/gc_guard.h +0 -12
@@ -95,21 +95,46 @@ ndt_asprintf(ndt_context_t *ctx, const char *fmt, ...)
95
95
  /* Type functions (unstable API) */
96
96
  /*****************************************************************************/
97
97
 
98
+ /* Number of elements, currently only for ndarrays. Return -1 on error. */
99
+ int64_t
100
+ ndt_nelem(const ndt_t *t)
101
+ {
102
+ NDT_STATIC_CONTEXT(ctx);
103
+ ndt_ndarray_t x;
104
+ int64_t n;
105
+
106
+ if (ndt_as_ndarray(&x, t, &ctx) < 0) {
107
+ ndt_err_clear(&ctx);
108
+ return -1;
109
+ }
110
+
111
+ n = 1;
112
+ for (int i = 0; i < x.ndim; i++) {
113
+ n *= x.shape[i];
114
+ }
115
+
116
+ return n;
117
+ }
118
+
98
119
  /* Return the next type in a dimension chain. Undefined for non-dimensions. */
99
120
  static const ndt_t *
100
- next_dim(const ndt_t *a)
121
+ next_dim(const ndt_t *t)
101
122
  {
102
- assert(a->ndim > 0);
123
+ assert(t->ndim > 0);
103
124
 
104
- switch (a->tag) {
125
+ switch (t->tag) {
105
126
  case FixedDim:
106
- return a->FixedDim.type;
127
+ return t->FixedDim.type;
107
128
  case VarDim:
108
- return a->VarDim.type;
129
+ return t->VarDim.type;
130
+ case VarDimElem:
131
+ return t->VarDimElem.type;
109
132
  case SymbolicDim:
110
- return a->SymbolicDim.type;
133
+ return t->SymbolicDim.type;
111
134
  case EllipsisDim:
112
- return a->EllipsisDim.type;
135
+ return t->EllipsisDim.type;
136
+ case Array:
137
+ return t->Array.type;
113
138
  default:
114
139
  /* NOT REACHED: tags should be exhaustive. */
115
140
  ndt_internal_error("invalid value");
@@ -126,16 +151,19 @@ ndt_dtype(const ndt_t *t)
126
151
  return t;
127
152
  }
128
153
 
129
- const ndt_t *
130
- ndt_dim_at(const ndt_t *t, int n)
154
+ int
155
+ ndt_logical_ndim(const ndt_t *t)
131
156
  {
132
- assert(n <= t->ndim);
157
+ int ndim = t->ndim;
133
158
 
134
- for (int i = 0; i < n; i++) {
135
- t = next_dim(t);
159
+ while (t->tag == VarDim || t->tag == VarDimElem) {
160
+ if (t->tag == VarDimElem) {
161
+ --ndim;
162
+ }
163
+ t = t->VarDim.type;
136
164
  }
137
165
 
138
- return t;
166
+ return ndim;
139
167
  }
140
168
 
141
169
  const ndt_t *
@@ -170,6 +198,75 @@ ndt_dims_dtype(const ndt_t *dims[NDT_MAX_DIM], const ndt_t **dtype, const ndt_t
170
198
  return n;
171
199
  }
172
200
 
201
+ static const ndt_t *
202
+ compress(const ndt_t *t)
203
+ {
204
+ while (t->tag == VarDimElem) {
205
+ t = t->VarDimElem.type;
206
+ }
207
+
208
+ return t;
209
+ }
210
+
211
+ /*
212
+ * Return the next logical type in a dimension chain. Undefined for
213
+ * non-dimensions.
214
+ */
215
+ static const ndt_t *
216
+ next_logical_dim(const ndt_t *t)
217
+ {
218
+ switch (t->tag) {
219
+ case FixedDim:
220
+ return t->FixedDim.type;
221
+ case VarDim:
222
+ return compress(t->VarDim.type);
223
+ case VarDimElem:
224
+ return next_logical_dim(compress(t));
225
+ case SymbolicDim:
226
+ return t->SymbolicDim.type;
227
+ case EllipsisDim:
228
+ return t->EllipsisDim.type;
229
+ default:
230
+ return t;
231
+ }
232
+ }
233
+
234
+ const ndt_t *
235
+ ndt_logical_dim_at(const ndt_t *t, int n)
236
+ {
237
+ for (int i = 0; i < n; i++) {
238
+ t = next_logical_dim(t);
239
+ }
240
+
241
+ return t;
242
+ }
243
+
244
+ const ndt_t *
245
+ ndt_copy_contiguous_at(const ndt_t *t, int n, const ndt_t *dtype, ndt_context_t *ctx)
246
+ {
247
+ const ndt_t *u;
248
+
249
+ if (!ndt_is_ndarray(t)) {
250
+ ndt_err_format(ctx, NDT_NotImplementedError,
251
+ "partial copies are currently restricted to fixed dimensions");
252
+ return NULL;
253
+ }
254
+
255
+ if (n < 0 || n > t->ndim) {
256
+ ndt_err_format(ctx, NDT_ValueError, "n out of bounds");
257
+ return NULL;
258
+ }
259
+
260
+ u = ndt_logical_dim_at(t, n);
261
+
262
+ if (dtype != NULL) {
263
+ return ndt_copy_contiguous_dtype(u, dtype, 0, ctx);
264
+ }
265
+ else {
266
+ return ndt_copy_contiguous(u, 0, ctx);
267
+ }
268
+ }
269
+
173
270
  int
174
271
  ndt_as_ndarray(ndt_ndarray_t *a, const ndt_t *t, ndt_context_t *ctx)
175
272
  {
@@ -205,9 +302,89 @@ ndt_as_ndarray(ndt_ndarray_t *a, const ndt_t *t, ndt_context_t *ctx)
205
302
  return n;
206
303
  }
207
304
 
305
+ static const ndt_t *
306
+ _ndt_transpose(const ndt_ndarray_t *a, const int p[], const ndt_t *dtype,
307
+ ndt_context_t *ctx)
308
+ {
309
+ const ndt_t *t;
310
+
311
+ t = dtype;
312
+ ndt_incref(t);
313
+
314
+ for (int i = a->ndim-1; i >= 0; i--) {
315
+ const ndt_t *u = ndt_fixed_dim(t, a->shape[p[i]], a->steps[p[i]], ctx);
316
+ ndt_decref(t);
317
+ if (u == NULL) {
318
+ return NULL;
319
+ }
320
+ t = u;
321
+ }
322
+
323
+ return t;
324
+ }
325
+
326
+ const ndt_t *
327
+ ndt_transpose(const ndt_t *t, const int *p, int ndim, ndt_context_t *ctx)
328
+ {
329
+ ndt_ndarray_t a;
330
+ int permute[NDT_MAX_DIM];
331
+
332
+ if (ndt_as_ndarray(&a, t, ctx) < 0) {
333
+ return NULL;
334
+ }
335
+
336
+ if (p == NULL) {
337
+ if (ndim != 0) {
338
+ ndt_err_format(ctx, NDT_RuntimeError,
339
+ "internal error: ndim != 0 but no permutations given");
340
+ return NULL;
341
+ }
342
+
343
+ ndim = a.ndim;
344
+ if (ndim <= 1) {
345
+ ndt_incref(t);
346
+ return t;
347
+ }
348
+
349
+ for (int i = 0; i < ndim; i++) {
350
+ permute[i] = ndim-i-1;
351
+ }
352
+ p = permute;
353
+ }
354
+ else {
355
+ if (ndim != a.ndim) {
356
+ ndt_err_format(ctx, NDT_ValueError,
357
+ "number of permutations != ndim");
358
+ return NULL;
359
+ }
360
+
361
+ for (int i = 0; i < ndim; i++) {
362
+ permute[i] = 0;
363
+ }
364
+
365
+ for (int i = 0; i < ndim; i++) {
366
+ if (p[i] < 0 || p[i] >= ndim) {
367
+ ndt_err_format(ctx, NDT_ValueError,
368
+ "permutation with value %d out of bounds", p[i]);
369
+ return NULL;
370
+ }
371
+
372
+ if (permute[p[i]] != 0) {
373
+ ndt_err_format(ctx, NDT_ValueError,
374
+ "duplicate permutation index=%d", p[i]);
375
+ return NULL;
376
+ }
377
+
378
+ permute[p[i]] = 1;
379
+ }
380
+ }
381
+
382
+ return _ndt_transpose(&a, p, ndt_dtype(t), ctx);
383
+ }
384
+
208
385
  /* Unoptimized hash function for experimenting. */
209
386
  ndt_ssize_t
210
- ndt_hash(ndt_t *t, ndt_context_t *ctx)
387
+ ndt_hash(const ndt_t *t, ndt_context_t *ctx)
211
388
  {
212
389
  unsigned char *s, *cp;
213
390
  size_t len;
@@ -242,11 +419,11 @@ ndt_hash(ndt_t *t, ndt_context_t *ctx)
242
419
 
243
420
  const ndt_apply_spec_t ndt_apply_spec_empty = {
244
421
  .flags = 0U,
245
- .nout = 0,
246
- .nbroadcast = 0,
247
422
  .outer_dims = 0,
248
- .out = {NULL},
249
- .broadcast = {NULL}
423
+ .nin = 0,
424
+ .nout = 0,
425
+ .nargs = 0,
426
+ .types = {NULL}
250
427
  };
251
428
 
252
429
  ndt_apply_spec_t *
@@ -259,9 +436,10 @@ ndt_apply_spec_new(ndt_context_t *ctx)
259
436
  return ndt_memory_error(ctx);
260
437
  }
261
438
  spec->flags = 0U;
262
- spec->nout = 0;
263
- spec->nbroadcast = 0;
264
439
  spec->outer_dims = 0;
440
+ spec->nin = 0;
441
+ spec->nout = 0;
442
+ spec->nargs = 0;
265
443
 
266
444
  return spec;
267
445
  }
@@ -269,24 +447,20 @@ ndt_apply_spec_new(ndt_context_t *ctx)
269
447
  void
270
448
  ndt_apply_spec_clear(ndt_apply_spec_t *spec)
271
449
  {
272
- int i;
273
-
274
450
  if (spec == NULL) {
275
451
  return;
276
452
  }
277
453
 
278
- for (i = 0; i < spec->nbroadcast; i++) {
279
- ndt_del(spec->broadcast[i]);
280
- }
281
-
282
- for (i = 0; i < spec->nout; i++) {
283
- ndt_del(spec->out[i]);
454
+ for (int i = 0; i < spec->nin+spec->nout; i++) {
455
+ ndt_decref(spec->types[i]);
456
+ spec->types[i] = NULL;
284
457
  }
285
458
 
286
459
  spec->flags = 0U;
287
- spec->nout = 0;
288
- spec->nbroadcast = 0;
289
460
  spec->outer_dims = 0;
461
+ spec->nin = 0;
462
+ spec->nout = 0;
463
+ spec->nargs = 0;
290
464
  }
291
465
 
292
466
  void
@@ -300,6 +474,15 @@ ndt_apply_spec_del(ndt_apply_spec_t *spec)
300
474
  ndt_free(spec);
301
475
  }
302
476
 
477
+ #define X (NDT_INNER_XND)
478
+ #define S (NDT_INNER_STRIDED)
479
+ #define C (NDT_INNER_C)
480
+ #define F (NDT_INNER_F)
481
+
482
+ #define ES (NDT_EXT_STRIDED)
483
+ #define EC (NDT_EXT_C)
484
+ #define EZ (NDT_EXT_ZERO)
485
+
303
486
  /* This function is used in places where it is _really_ convenient not to
304
487
  have to deal with deallocating an error message. */
305
488
  const char *
@@ -307,38 +490,35 @@ ndt_apply_flags_as_string(const ndt_apply_spec_t *spec)
307
490
  {
308
491
  switch (spec->flags) {
309
492
  case 0: return "None";
310
- case NDT_C: return "C";
311
- case NDT_FORTRAN: return "Fortran";
312
- case NDT_FORTRAN|NDT_C: return "C|Fortran";
313
- case NDT_STRIDED: return "Strided";
314
- case NDT_STRIDED|NDT_C: return "C|Strided";
315
- case NDT_STRIDED|NDT_FORTRAN: return "Fortran|Strided";
316
- case NDT_STRIDED|NDT_FORTRAN|NDT_C: return "C|Fortran|Strided";
317
- case NDT_XND: return "Xnd";
318
- case NDT_XND|NDT_C: return "C|Xnd";
319
- case NDT_XND|NDT_FORTRAN: return "Fortran|Xnd";
320
- case NDT_XND|NDT_FORTRAN|NDT_C: return "C|Fortran|Xnd";
321
- case NDT_XND|NDT_STRIDED: return "Strided|Xnd";
322
- case NDT_XND|NDT_STRIDED|NDT_C: return "C|Strided|Xnd";
323
- case NDT_XND|NDT_STRIDED|NDT_FORTRAN: return "Fortran|Strided|Xnd";
324
- case NDT_XND|NDT_STRIDED|NDT_FORTRAN|NDT_C: return "C|Fortran|Strided|Xnd";
325
- case NDT_ELEMWISE_1D: return "Elemwise1D";
326
- case NDT_ELEMWISE_1D|NDT_C: return "Elemwise1D|C";
327
- case NDT_ELEMWISE_1D|NDT_FORTRAN: return "Elemwise1D|Fortran";
328
- case NDT_ELEMWISE_1D|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran";
329
- case NDT_ELEMWISE_1D|NDT_STRIDED: return "Elemwise1D|Strided";
330
- case NDT_ELEMWISE_1D|NDT_STRIDED|NDT_C: return "Elemwise1D|C|Strided";
331
- case NDT_ELEMWISE_1D|NDT_STRIDED|NDT_FORTRAN: return "Elemwise1D|Fortran|Strided";
332
- case NDT_ELEMWISE_1D|NDT_STRIDED|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran|Strided";
333
- case NDT_ELEMWISE_1D|NDT_XND: return "Elemwise1D|Xnd";
334
- case NDT_ELEMWISE_1D|NDT_XND|NDT_C: return "Elemwise1D|C|Xnd";
335
- case NDT_ELEMWISE_1D|NDT_XND|NDT_FORTRAN: return "Elemwise1D|Fortran|Xnd";
336
- case NDT_ELEMWISE_1D|NDT_XND|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran|Xnd";
337
- case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED: return "Elemwise1D|Strided|Xnd";
338
- case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED|NDT_C: return "Elemwise1D|C|Strided|Xnd";
339
- case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED|NDT_FORTRAN: return "Elemwise1D|Fortran|Strided|Xnd";
340
- case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran|Strided|Xnd";
341
- default: return "error: invalid combination of spec->flags";
493
+ case X: return "Xnd";
494
+
495
+ case S|X: return "Strided|Xnd";
496
+ case C|S|X: return "C|Strided|Xnd";
497
+ case F|S|X: return "Fortran|Strided|Xnd";
498
+ case C|F|S|X: return "C|Fortran|Strided|Xnd";
499
+
500
+ case ES|S|X: return "OptS|Strided|Xnd";
501
+ case ES|C|S|X: return "OptS|C|Strided|Xnd";
502
+ case ES|F|S|X: return "OptS|Fortran|Strided|Xnd";
503
+ case ES|C|F|S|X: return "OptS|C|Fortran|Strided|Xnd";
504
+ case EC|ES|C|S|X: return "OptC|OptS|C|Strided|Xnd";
505
+ case EC|ES|C|F|S|X: return "OptC|OptS|C|Fortran|Strided|Xnd";
506
+ case EZ|EC|ES|C|F|S|X: return "OptZ|OptC|OptS|C|Fortran|Strided|Xnd";
507
+ case EZ|EC|ES|C|S|X: return "OptZ|OptC|OptS|C|Strided|Xnd";
508
+ case EZ|ES|C|S|X: return "OptZ|OptS|C|Strided|Xnd";
509
+ case EZ|ES|C|F|S|X: return "OptZ|OptS|C|Fortran|Strided|Xnd";
510
+
511
+ default:
512
+ if (spec->flags & NDT_EXT_ZERO) fprintf(stderr, "EZ ");
513
+ if (spec->flags & NDT_EXT_C) fprintf(stderr, "EC ");
514
+ if (spec->flags & NDT_EXT_STRIDED) fprintf(stderr, "ES ");
515
+ if (spec->flags & NDT_INNER_C) fprintf(stderr, "C ");
516
+ if (spec->flags & NDT_INNER_F) fprintf(stderr, "F ");
517
+ if (spec->flags & NDT_INNER_STRIDED) fprintf(stderr, "S ");
518
+ if (spec->flags & NDT_INNER_XND) fprintf(stderr, "X ");
519
+ fprintf(stderr, "\n");
520
+
521
+ return "unknown flags";
342
522
  }
343
523
  }
344
524
 
@@ -362,16 +542,21 @@ ndt_meta_new(ndt_context_t *ctx)
362
542
  }
363
543
 
364
544
  void
365
- ndt_meta_del(ndt_meta_t *m)
545
+ ndt_meta_clear(ndt_meta_t *m)
366
546
  {
367
547
  if (m == NULL) {
368
548
  return;
369
549
  }
370
550
 
371
551
  for (int i = 0; i < m->ndims; i++) {
372
- ndt_free(m->offsets[i]);
552
+ ndt_decref_offsets(m->offsets[i]);
373
553
  }
554
+ }
374
555
 
556
+ void
557
+ ndt_meta_del(ndt_meta_t *m)
558
+ {
559
+ ndt_meta_clear(m);
375
560
  ndt_free(m);
376
561
  }
377
562
 
@@ -380,95 +565,118 @@ ndt_meta_del(ndt_meta_t *m)
380
565
  /* Optimized kernel strategy (unstable API) */
381
566
  /*****************************************************************************/
382
567
 
383
- static bool
384
- all_inner_1D_contiguous(const ndt_t *types[], int n)
568
+ static uint32_t
569
+ check_c(uint32_t flags, ndt_ndarray_t *x, int outer)
385
570
  {
386
- for (int i = 0; i < n; i++) {
387
- if (types[i]->ndim == 0) {
388
- return false;
571
+ int inner = x->ndim-outer;
572
+ int64_t *shape = x->shape+outer;
573
+ int64_t *steps = x->steps+outer;
574
+ int64_t step = 1;
575
+
576
+ for (int i = inner-1; i >= 0; i--) {
577
+ if (shape[i] > 1 && steps[i] != step) {
578
+ return flags & ~(NDT_INNER_C|NDT_EXT_C|NDT_EXT_ZERO);
389
579
  }
390
- if (!ndt_is_c_contiguous(ndt_dim_at(types[i], types[i]->ndim-1))) {
391
- return false;
580
+ step *= shape[i];
581
+ }
582
+
583
+ if (x->ndim == inner) {
584
+ return flags & ~(NDT_EXT_C|NDT_EXT_ZERO);
585
+ }
586
+ else if (x->ndim > inner) {
587
+ if (shape[-1] > 1 && steps[-1] != step) {
588
+ flags &= ~NDT_EXT_C;
589
+ if (steps[-1] != 0) {
590
+ flags &= ~NDT_EXT_ZERO;
591
+ }
392
592
  }
393
593
  }
394
594
 
395
- return true;
595
+ return flags;
396
596
  }
397
597
 
398
- static bool
399
- all_inner_c_contiguous(const ndt_t *types[], int n, int outer)
598
+ static uint32_t
599
+ check_f(uint32_t flags, ndt_ndarray_t *x, int outer)
400
600
  {
401
- for (int i = 0; i < n; i++) {
402
- if (!ndt_is_c_contiguous(ndt_dim_at(types[i], outer))) {
403
- return false;
601
+ int inner = x->ndim-outer;
602
+ int64_t *shape = x->shape+outer;
603
+ int64_t *steps = x->steps+outer;
604
+ int64_t step = 1;
605
+
606
+ for (int i = 0; i < inner; i++) {
607
+ if (shape[i] > 1 && steps[i] != step) {
608
+ return flags & ~NDT_INNER_F;
404
609
  }
610
+ step *= shape[i];
405
611
  }
406
612
 
407
- return true;
613
+ return flags;
408
614
  }
409
615
 
410
- static bool
411
- all_inner_f_contiguous(const ndt_t *types[], int n, int outer)
616
+ static uint32_t
617
+ check_strided(uint32_t flags, int outer)
412
618
  {
413
- for (int i = 0; i < n; i++) {
414
- if (!ndt_is_f_contiguous(ndt_dim_at(types[i], outer))) {
415
- return false;
416
- }
619
+ if (outer == 0) {
620
+ return flags & ~NDT_EXT_STRIDED;
417
621
  }
418
622
 
419
- return true;
623
+ return flags;
420
624
  }
421
625
 
422
- static bool
423
- all_ndarray(const ndt_t *types[], int n)
626
+ static uint32_t
627
+ check_var(uint32_t flags, const ndt_t *t, int outer)
424
628
  {
425
- for (int i = 0; i < n; i++) {
426
- if (!ndt_is_ndarray(types[i])) {
427
- return false;
428
- }
429
- if (ndt_subtree_is_optional(types[i])) {
430
- return false;
431
- }
432
- }
433
-
434
- return true;
629
+ #if 0
630
+ /*
631
+ * This currently does not handle a corner case with zeros in a shape.
632
+ * The loop in gumath is safe, however, since it stops at shape==0.
633
+ */
634
+ if (ndt_logical_ndim(t) == outer) {
635
+ return flags & NDT_INNER_XND;
636
+ }
637
+
638
+ return 0U;
639
+ #endif
640
+ (void)t;
641
+ (void)outer;
642
+ return flags & NDT_INNER_XND;
435
643
  }
436
644
 
437
- void
438
- ndt_select_kernel_strategy(ndt_apply_spec_t *spec, const ndt_t *sig,
439
- const ndt_t *in[], int nin)
645
+ static uint32_t
646
+ select_flags(const ndt_t *types[], int n, int outer, ndt_context_t *ctx)
440
647
  {
441
- const ndt_t **out = (const ndt_t **)spec->out;
442
- bool in_inner_c, in_inner_f, out_inner_c;
648
+ uint32_t flags = NDT_SPEC_FLAGS_ALL;
649
+ ndt_ndarray_t x;
443
650
 
444
- assert(sig->tag == Function);
445
- assert(spec->flags == 0);
651
+ for (int i = 0; i < n; i++) {
652
+ const ndt_t *t = types[i];
446
653
 
447
- if (spec->nbroadcast > 0) {
448
- in = (const ndt_t **)spec->broadcast;
449
- nin = spec->nbroadcast;
654
+ if (ndt_as_ndarray(&x, t, ctx) < 0) { /* var dimension */
655
+ ndt_err_clear(ctx);
656
+ if (t->tag == VarDim || t->tag == VarDimElem) {
657
+ flags = check_var(flags, t, outer);
658
+ }
659
+ }
660
+ else {
661
+ if (outer > t->ndim) {
662
+ ndt_err_format(ctx, NDT_RuntimeError,
663
+ "number of outer dimensions greater than ndim");
664
+ return UINT32_MAX;
665
+ }
666
+
667
+ flags = check_strided(flags, outer);
668
+ flags = check_c(flags, &x, outer);
669
+ flags = check_f(flags, &x, outer);
670
+ }
450
671
  }
451
672
 
452
- in_inner_c = all_inner_c_contiguous(in, nin, spec->outer_dims);
453
- in_inner_f = all_inner_f_contiguous(in, nin, spec->outer_dims);
454
- out_inner_c = all_inner_c_contiguous(out, spec->nout, spec->outer_dims);
455
-
456
- spec->flags = NDT_XND;
673
+ return flags;
674
+ }
457
675
 
458
- if (sig->Function.elemwise) {
459
- if (all_inner_1D_contiguous(in, nin) &&
460
- all_inner_1D_contiguous(out, spec->nout)) {
461
- spec->flags |= NDT_ELEMWISE_1D;
462
- }
463
- }
464
- if (in_inner_c && out_inner_c) {
465
- spec->flags |= NDT_C;
466
- }
467
- if (in_inner_f && out_inner_c) {
468
- spec->flags |= NDT_FORTRAN;
469
- }
676
+ int
677
+ ndt_select_kernel_strategy(ndt_apply_spec_t *spec, ndt_context_t *ctx)
678
+ {
679
+ spec->flags = select_flags(spec->types, spec->nargs, spec->outer_dims, ctx);
470
680
 
471
- if (all_ndarray(in, nin) && all_ndarray(out, spec->nout)) {
472
- spec->flags |= NDT_STRIDED;
473
- }
681
+ return spec->flags == UINT32_MAX ? -1 : 0;
474
682
  }