ndtypes 0.2.0dev5 → 0.2.0dev6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +12 -0
  3. data/Rakefile +8 -0
  4. data/ext/ruby_ndtypes/GPATH +0 -0
  5. data/ext/ruby_ndtypes/GRTAGS +0 -0
  6. data/ext/ruby_ndtypes/GTAGS +0 -0
  7. data/ext/ruby_ndtypes/extconf.rb +1 -1
  8. data/ext/ruby_ndtypes/include/ndtypes.h +231 -122
  9. data/ext/ruby_ndtypes/include/ruby_ndtypes.h +1 -1
  10. data/ext/ruby_ndtypes/lib/libndtypes.a +0 -0
  11. data/ext/ruby_ndtypes/lib/libndtypes.so.0.2.0dev3 +0 -0
  12. data/ext/ruby_ndtypes/ndtypes/Makefile +87 -0
  13. data/ext/ruby_ndtypes/ndtypes/config.h +68 -0
  14. data/ext/ruby_ndtypes/ndtypes/config.log +477 -0
  15. data/ext/ruby_ndtypes/ndtypes/config.status +1027 -0
  16. data/ext/ruby_ndtypes/ndtypes/doc/_static/style.css +7 -0
  17. data/ext/ruby_ndtypes/ndtypes/doc/_templates/layout.html +2 -0
  18. data/ext/ruby_ndtypes/ndtypes/doc/conf.py +40 -4
  19. data/ext/ruby_ndtypes/ndtypes/doc/images/xndlogo.png +0 -0
  20. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +1 -1
  21. data/ext/ruby_ndtypes/ndtypes/doc/requirements.txt +2 -0
  22. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile +287 -0
  23. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +20 -4
  24. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +22 -3
  25. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +1 -1
  26. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.o +0 -0
  27. data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.o +0 -0
  28. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile +73 -0
  29. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +246 -229
  30. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +15 -11
  31. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.o +0 -0
  32. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +38 -28
  33. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +91 -91
  34. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +1 -1
  35. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +4 -3
  36. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.o +0 -0
  37. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +8 -7
  38. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.o +0 -0
  39. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +2 -2
  40. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.o +0 -0
  41. data/ext/ruby_ndtypes/ndtypes/libndtypes/context.o +0 -0
  42. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +263 -182
  43. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.o +0 -0
  44. data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.o +0 -0
  45. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +67 -7
  46. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.o +0 -0
  47. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +1112 -1000
  48. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +69 -58
  49. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.o +0 -0
  50. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +150 -99
  51. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +185 -15
  52. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.o +0 -0
  53. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +301 -276
  54. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +1 -1
  55. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +9 -4
  56. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.o +0 -0
  57. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.a +0 -0
  58. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so +1 -0
  59. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0 +1 -0
  60. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3 +0 -0
  61. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +729 -228
  62. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.o +0 -0
  63. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +768 -403
  64. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h +1002 -0
  65. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +231 -122
  66. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.o +0 -0
  67. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +176 -84
  68. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +26 -14
  69. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.o +0 -0
  70. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +57 -35
  71. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.o +0 -0
  72. data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.c +420 -0
  73. data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.o +0 -0
  74. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +8 -8
  75. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +1 -1
  76. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.o +0 -0
  77. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile +48 -0
  78. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +200 -116
  79. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.o +0 -0
  80. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +46 -4
  81. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.o +0 -0
  82. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +58 -27
  83. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +1 -1
  84. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.o +0 -0
  85. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +3 -5
  86. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +12 -4
  87. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.o +0 -0
  88. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile +55 -0
  89. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +8 -8
  90. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +5 -5
  91. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +274 -172
  92. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +24 -4
  93. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +2 -2
  94. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +14 -14
  95. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +32 -30
  96. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +37 -0
  97. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +36 -0
  98. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +16 -0
  99. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +5 -5
  100. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +706 -253
  101. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_unify.c +132 -0
  102. data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.c +703 -0
  103. data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.o +0 -0
  104. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +335 -127
  105. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.o +0 -0
  106. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +2 -2
  107. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.o +0 -0
  108. data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +88 -71
  109. data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +0 -1
  110. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +10 -13
  111. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +395 -314
  112. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.a +0 -0
  113. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so +1 -0
  114. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0 +1 -0
  115. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0.2.0dev3 +0 -0
  116. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/ndtypes.h +1002 -0
  117. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +15 -33
  118. data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +340 -132
  119. data/ext/ruby_ndtypes/ndtypes/setup.py +11 -2
  120. data/ext/ruby_ndtypes/ruby_ndtypes.c +364 -241
  121. data/ext/ruby_ndtypes/ruby_ndtypes.h +1 -1
  122. data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +0 -1
  123. data/lib/ndtypes.rb +11 -0
  124. data/lib/ndtypes/version.rb +2 -2
  125. data/lib/ruby_ndtypes.so +0 -0
  126. data/ndtypes.gemspec +3 -0
  127. data/spec/ndtypes_spec.rb +6 -0
  128. metadata +98 -4
  129. data/ext/ruby_ndtypes/gc_guard.c +0 -36
  130. data/ext/ruby_ndtypes/gc_guard.h +0 -12
@@ -726,7 +726,7 @@ extern int yylex \
726
726
  #undef yyTABLES_NAME
727
727
  #endif
728
728
 
729
- #line 222 "lexer.l"
729
+ #line 227 "lexer.l"
730
730
 
731
731
 
732
732
  #line 732 "lexer.h"
@@ -144,24 +144,26 @@ yycolumn = 1;
144
144
  "void" { return VOID; }
145
145
  "bool" { return BOOL; }
146
146
 
147
- "Signed" { return SIGNED_KIND; }
147
+ "signed" { return SIGNED_KIND; }
148
148
  "int8" { return INT8; }
149
149
  "int16" { return INT16; }
150
150
  "int32" { return INT32; }
151
151
  "int64" { return INT64; }
152
152
 
153
- "Unsigned" { return UNSIGNED_KIND; }
153
+ "unsigned" { return UNSIGNED_KIND; }
154
154
  "uint8" { return UINT8; }
155
155
  "uint16" { return UINT16; }
156
156
  "uint32" { return UINT32; }
157
157
  "uint64" { return UINT64; }
158
158
 
159
- "Float" { return FLOAT_KIND; }
159
+ "float" { return FLOAT_KIND; }
160
+ "bfloat16" { return BFLOAT16; }
160
161
  "float16" { return FLOAT16; }
161
162
  "float32" { return FLOAT32; }
162
163
  "float64" { return FLOAT64; }
163
164
 
164
- "Complex" { return COMPLEX_KIND; }
165
+ "complex" { return COMPLEX_KIND; }
166
+ "bcomplex32" { return BCOMPLEX32; }
165
167
  "complex32" { return COMPLEX32; }
166
168
  "complex64" { return COMPLEX64; }
167
169
  "complex128" { return COMPLEX128; }
@@ -186,6 +188,9 @@ yycolumn = 1;
186
188
 
187
189
  "fixed" { return FIXED; }
188
190
  "var" { return VAR; }
191
+ "array" { return ARRAY; }
192
+
193
+ "of" { return OF; }
189
194
 
190
195
  "..." { return ELLIPSIS; }
191
196
  "->" { return RARROW; }
@@ -0,0 +1 @@
1
+ ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3
@@ -0,0 +1 @@
1
+ ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3
@@ -103,7 +103,7 @@ resolve_broadcast(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
103
103
  }
104
104
 
105
105
  static int
106
- check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
106
+ check_contig(const ndt_t *ptypes[], const ndt_t *ctypes[], int64_t nargs)
107
107
  {
108
108
  for (int i = 0; i < nargs; i++) {
109
109
  const ndt_t *p = ptypes[i];
@@ -130,14 +130,14 @@ check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
130
130
  return 1;
131
131
  }
132
132
 
133
- static ndt_t *
134
- to_fortran(const ndt_t *p, ndt_t *c, ndt_context_t *ctx)
133
+ static const ndt_t *
134
+ to_fortran(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
135
135
  {
136
136
  if (p->tag == EllipsisDim && p->EllipsisDim.tag == RequireF) {
137
- ndt_t *t = ndt_to_fortran(c, ctx);
138
- return t;
137
+ return ndt_to_fortran(c, ctx);
139
138
  }
140
139
  else {
140
+ ndt_incref(c);
141
141
  return c;
142
142
  }
143
143
  }
@@ -217,41 +217,124 @@ resolve_typevar(const char *key, symtable_entry_t w, symtable_t *tbl, ndt_contex
217
217
  }
218
218
  }
219
219
 
220
+ typedef struct {
221
+ int64_t index;
222
+ const ndt_t *type;
223
+ } indexed_type_t;
224
+
225
+ static indexed_type_t indexed_type_error = { 0, NULL };
226
+
227
+ static inline int64_t
228
+ adjust_index(const int64_t i, const int64_t shape, ndt_context_t *ctx)
229
+ {
230
+ const int64_t k = i < 0 ? i + shape : i;
231
+
232
+ if (k < 0 || k >= shape) {
233
+ ndt_err_format(ctx, NDT_IndexError,
234
+ "index with value %" PRIi64 " out of bounds", i);
235
+ return -1;
236
+ }
237
+
238
+ return k;
239
+ }
240
+
241
+ static inline indexed_type_t
242
+ var_dim_next(const indexed_type_t *x, const int64_t start, const int64_t step,
243
+ const int64_t i)
244
+ {
245
+ indexed_type_t next;
246
+ const ndt_t *t = x->type;
247
+
248
+ next.index = start + i * step;
249
+ next.type = t->VarDim.type;
250
+
251
+ return next;
252
+ }
253
+
254
+ /* skip stored indices */
255
+ static indexed_type_t
256
+ apply_stored_index(const indexed_type_t *x, ndt_context_t *ctx)
257
+ {
258
+ const ndt_t * const t = x->type;
259
+ int64_t start, step, shape;
260
+
261
+ if (t->tag != VarDimElem) {
262
+ ndt_err_format(ctx, NDT_RuntimeError,
263
+ "apply_stored_index: need VarDimElem");
264
+ return indexed_type_error;
265
+ }
266
+
267
+ shape = ndt_var_indices(&start, &step, t, x->index, ctx);
268
+ if (shape < 0) {
269
+ return indexed_type_error;
270
+ }
271
+
272
+ const int64_t i = adjust_index(t->VarDimElem.index, shape, ctx);
273
+ if (i < 0) {
274
+ return indexed_type_error;
275
+ }
276
+
277
+ return var_dim_next(x, start, step, i);
278
+ }
279
+
280
+ static indexed_type_t
281
+ apply_stored_indices(const indexed_type_t *x, ndt_context_t *ctx)
282
+ {
283
+ indexed_type_t tl = *x;
284
+
285
+ while (tl.type->tag == VarDimElem) {
286
+ tl = apply_stored_index(&tl, ctx);
287
+ }
288
+
289
+ return tl;
290
+ }
291
+
220
292
  static int
221
- match_concrete_var_dim(const ndt_t *t, int64_t tindex,
222
- const ndt_t *u, int64_t uindex,
293
+ match_concrete_var_dim(const indexed_type_t *a, const indexed_type_t *b,
223
294
  const int outer_dims, ndt_context_t *ctx)
224
295
  {
225
- int64_t tshape, tstart, tstep;
226
- int64_t ushape, ustart, ustep;
296
+ int64_t xshape, xstart, xstep;
297
+ int64_t yshape, ystart, ystep;
227
298
 
228
299
  if (outer_dims == 0) {
229
300
  return 1;
230
301
  }
302
+
303
+ const indexed_type_t x = apply_stored_indices(a, ctx);
304
+ if (x.type == NULL) {
305
+ return -1;
306
+ }
307
+
308
+ const indexed_type_t y = apply_stored_indices(b, ctx);
309
+ if (x.type == NULL) {
310
+ return -1;
311
+ }
312
+
313
+ const ndt_t * const t = x.type;
314
+ const ndt_t * const u = y.type;
315
+
231
316
  if (t->Concrete.VarDim.itemsize != u->Concrete.VarDim.itemsize) {
232
317
  return 0;
233
318
  }
234
319
 
235
- tshape = ndt_var_indices(&tstart, &tstep, t, tindex, ctx);
236
- if (tshape < 0) {
320
+ xshape = ndt_var_indices(&xstart, &xstep, t, x.index, ctx);
321
+ if (xshape < 0) {
237
322
  return -1;
238
323
  }
239
324
 
240
- ushape = ndt_var_indices(&ustart, &ustep, u, uindex, ctx);
241
- if (ushape < 0) {
325
+ yshape = ndt_var_indices(&ystart, &ystep, u, y.index, ctx);
326
+ if (yshape < 0) {
242
327
  return -1;
243
328
  }
244
329
 
245
- if (ushape != tshape) {
330
+ if (xshape != yshape) {
246
331
  return 0;
247
332
  }
248
333
 
249
- for (int64_t i = 0; i < tshape; i++) {
250
- int64_t tnext = tstart + i * tstep;
251
- int64_t unext = ustart + i * ustep;
252
- int ret = match_concrete_var_dim(t->VarDim.type, tnext,
253
- u->VarDim.type, unext,
254
- outer_dims-1, ctx);
334
+ for (int64_t i = 0; i < xshape; i++) {
335
+ const indexed_type_t xnext = var_dim_next(&x, xstart, xstep, i);
336
+ const indexed_type_t ynext = var_dim_next(&y, ystart, ystep, i);
337
+ int ret = match_concrete_var_dim(&xnext, &ynext, outer_dims-1, ctx);
255
338
  if (ret <= 0) {
256
339
  return ret;
257
340
  }
@@ -265,6 +348,7 @@ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
265
348
  {
266
349
  const char *key = "var";
267
350
  symtable_entry_t v;
351
+ int vdims, wdims;
268
352
 
269
353
  v = symtable_find(tbl, key);
270
354
  if (v.tag == Unbound) {
@@ -274,16 +358,36 @@ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
274
358
  return 1;
275
359
  }
276
360
 
277
- if (w.VarSeq.size != v.VarSeq.size) {
361
+ vdims = ndt_logical_ndim(v.VarSeq.dims[0]);
362
+ wdims = ndt_logical_ndim(w.VarSeq.dims[0]);
363
+
364
+ if (wdims != vdims) {
278
365
  return 0;
279
366
  }
280
- if (v.VarSeq.size == 0) {
367
+ if (vdims == 0) {
281
368
  return 1;
282
369
  }
283
370
 
284
- return match_concrete_var_dim(w.VarSeq.dims[0], 0,
285
- v.VarSeq.dims[0], 0,
286
- v.VarSeq.size, ctx);
371
+ const indexed_type_t x = { w.VarSeq.linear_index, w.VarSeq.dims[0] };
372
+ const indexed_type_t y = { v.VarSeq.linear_index, v.VarSeq.dims[0] };
373
+ return match_concrete_var_dim(&x, &y, vdims, ctx);
374
+ }
375
+
376
+ static int
377
+ resolve_array(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
378
+ {
379
+ const char *key = "array";
380
+ symtable_entry_t v;
381
+
382
+ v = symtable_find(tbl, key);
383
+ if (v.tag == Unbound) {
384
+ if (symtable_add(tbl, key, w, ctx) < 0) {
385
+ return -1;
386
+ }
387
+ return 1;
388
+ }
389
+
390
+ return v.ArraySeq.size == w.ArraySeq.size;
287
391
  }
288
392
 
289
393
  static int
@@ -332,8 +436,32 @@ match_record_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
332
436
  }
333
437
 
334
438
  static int
335
- match_categorical(ndt_value_t *p, int64_t plen,
336
- ndt_value_t *c, int64_t clen)
439
+ match_union_tags(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
440
+ ndt_context_t *ctx)
441
+ {
442
+ int64_t i;
443
+ int n;
444
+
445
+ assert(p->tag == Union && c->tag == Union);
446
+
447
+ if (p->Union.ntags != c->Union.ntags) {
448
+ return 0;
449
+ }
450
+
451
+ for (i = 0; i < p->Union.ntags; i++) {
452
+ n = strcmp(p->Union.tags[i], c->Union.tags[i]);
453
+ if (n != 0) return 0;
454
+
455
+ n = match_datashape(p->Union.types[i], c->Union.types[i], tbl, ctx);
456
+ if (n <= 0) return n;
457
+ }
458
+
459
+ return 1;
460
+ }
461
+
462
+ static int
463
+ match_categorical(const ndt_value_t *p, int64_t plen,
464
+ const ndt_value_t *c, int64_t clen)
337
465
  {
338
466
  int64_t i;
339
467
 
@@ -378,7 +506,7 @@ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
378
506
  }
379
507
  return outer_inner(v, i+1, t->FixedDim.type, ndim);
380
508
  }
381
- case VarDim: {
509
+ case VarDim: case VarDimElem: {
382
510
  switch (v->tag) {
383
511
  case VarSeq:
384
512
  v->VarSeq.size = i+1;
@@ -389,6 +517,18 @@ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
389
517
  }
390
518
  return outer_inner(v, i+1, t->VarDim.type, ndim);
391
519
  }
520
+ case Array: {
521
+ switch (v->tag) {
522
+ case ArraySeq:
523
+ v->ArraySeq.size = i+1;
524
+ v->ArraySeq.dims[i] = t;
525
+ break;
526
+ default:
527
+ return NULL;
528
+ }
529
+ return outer_inner(v, i+1, t->Array.type, ndim);
530
+ }
531
+
392
532
  default:
393
533
  return NULL;
394
534
  }
@@ -400,6 +540,10 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
400
540
  {
401
541
  int n;
402
542
 
543
+ if (c->tag == VarDimElem) {
544
+ return match_datashape(p, c->VarDimElem.type, tbl, ctx);
545
+ }
546
+
403
547
  if (ndt_is_optional(c) != ndt_is_optional(p)) return 0;
404
548
 
405
549
  switch (p->tag) {
@@ -445,57 +589,11 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
445
589
  return match_datashape(p->SymbolicDim.type, c->FixedDim.type, tbl, ctx);
446
590
  }
447
591
 
448
- case EllipsisDim: {
449
- symtable_entry_t outer;
450
- const ndt_t *inner;
451
-
452
- if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
453
- return 0;
454
- }
455
- if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
456
- return 0;
457
- }
458
-
459
- if (p->EllipsisDim.name == NULL) {
460
- outer.tag = BroadcastSeq;
461
- outer.BroadcastSeq.size = 0;
462
- }
463
- else if (strcmp(p->EllipsisDim.name, "var") == 0) {
464
- outer.tag = VarSeq;
465
- outer.VarSeq.size = 0;
466
- }
467
- else {
468
- outer.tag = FixedSeq;
469
- outer.FixedSeq.size = 0;
470
- }
471
-
472
- inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
473
- if (inner == NULL) {
474
- return 0;
475
- }
476
-
477
- n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
478
- if (n <= 0) {
479
- return n;
480
- }
481
-
482
- switch (outer.tag) {
483
- case BroadcastSeq:
484
- return resolve_broadcast(outer, tbl, ctx);
485
- case FixedSeq:
486
- return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
487
- case VarSeq:
488
- return resolve_var(outer, tbl, ctx);
489
- default: /* NOT REACHED */
490
- ndt_internal_error("invalid tag");
491
- }
492
- }
493
-
494
592
  case Bool:
495
593
  case Int8: case Int16: case Int32: case Int64:
496
594
  case Uint8: case Uint16: case Uint32: case Uint64:
497
- case Float16: case Float32: case Float64:
498
- case Complex32: case Complex64: case Complex128:
595
+ case BFloat16: case Float16: case Float32: case Float64:
596
+ case BComplex32: case Complex32: case Complex64: case Complex128:
499
597
  case String:
500
598
  return p->tag == c->tag;
501
599
  case FixedString:
@@ -528,6 +626,10 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
528
626
  return c->tag == Categorical &&
529
627
  match_categorical(p->Categorical.types, p->Categorical.ntypes,
530
628
  c->Categorical.types, c->Categorical.ntypes);
629
+ case Array:
630
+ if (c->tag != Array) return 0;
631
+ if (c->Array.itemsize != p->Array.itemsize) return 0;
632
+ return match_datashape(p->Array.type, c->Array.type, tbl, ctx);
531
633
  case Ref:
532
634
  if (c->tag != Ref) return 0;
533
635
  return match_datashape(p->Ref.type, c->Ref.type, tbl, ctx);
@@ -536,9 +638,12 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
536
638
  if (c->tag != Tuple) return 0;
537
639
  return match_tuple_fields(p, c, tbl, ctx);
538
640
  case Record:
539
- if (p->Tuple.flag == Variadic) return 0;
641
+ if (p->Record.flag == Variadic) return 0;
540
642
  if (c->tag != Record) return 0;
541
643
  return match_record_fields(p, c, tbl, ctx);
644
+ case Union:
645
+ if (c->tag != Union) return 0;
646
+ return match_union_tags(p, c, tbl, ctx);
542
647
  case Function: {
543
648
  int64_t i;
544
649
  if (c->tag != Function ||
@@ -576,12 +681,83 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
576
681
  case Constr:
577
682
  return c->tag == Constr && strcmp(p->Constr.name, c->Constr.name) == 0 &&
578
683
  ndt_equal(p->Constr.type, c->Constr.type);
684
+ case VarDimElem:
685
+ ndt_err_format(ctx, NDT_ValueError,
686
+ "VarDimElem cannot occur in pattern");
687
+ return -1;
688
+ case EllipsisDim:
689
+ ndt_err_format(ctx, NDT_RuntimeError,
690
+ "match_datashape: internal_error: unexpected EllipsisDim");
691
+ return -1;
579
692
  }
580
693
 
581
694
  /* NOT REACHED: tags should be exhaustive. */
582
695
  ndt_internal_error("invalid type");
583
696
  }
584
697
 
698
+ static int
699
+ match_datashape_top(const ndt_t *p, const ndt_t *c, int64_t linear_index,
700
+ symtable_t *tbl, ndt_context_t *ctx)
701
+ {
702
+ switch (p->tag) {
703
+ case EllipsisDim: {
704
+ symtable_entry_t outer;
705
+ const ndt_t *inner;
706
+ int n;
707
+
708
+ if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
709
+ return 0;
710
+ }
711
+ if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
712
+ return 0;
713
+ }
714
+
715
+ if (p->EllipsisDim.name == NULL) {
716
+ outer.tag = BroadcastSeq;
717
+ outer.BroadcastSeq.size = 0;
718
+ }
719
+ else if (strcmp(p->EllipsisDim.name, "var") == 0) {
720
+ outer.tag = VarSeq;
721
+ outer.VarSeq.size = 0;
722
+ outer.VarSeq.linear_index = linear_index;
723
+ }
724
+ else if (strcmp(p->EllipsisDim.name, "array") == 0) {
725
+ outer.tag = ArraySeq;
726
+ outer.ArraySeq.size = 0;
727
+ }
728
+ else {
729
+ outer.tag = FixedSeq;
730
+ outer.FixedSeq.size = 0;
731
+ }
732
+
733
+ inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
734
+ if (inner == NULL) {
735
+ return 0;
736
+ }
737
+
738
+ n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
739
+ if (n <= 0) {
740
+ return n;
741
+ }
742
+
743
+ switch (outer.tag) {
744
+ case BroadcastSeq:
745
+ return resolve_broadcast(outer, tbl, ctx);
746
+ case FixedSeq:
747
+ return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
748
+ case VarSeq:
749
+ return resolve_var(outer, tbl, ctx);
750
+ case ArraySeq:
751
+ return resolve_array(outer, tbl, ctx);
752
+ default: /* NOT REACHED */
753
+ ndt_internal_error("invalid tag");
754
+ }
755
+ }
756
+ default:
757
+ return match_datashape(p, c, tbl, ctx);
758
+ }
759
+ }
760
+
585
761
  int
586
762
  ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
587
763
  {
@@ -597,19 +773,18 @@ ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
597
773
  return -1;
598
774
  }
599
775
 
600
- ret = match_datashape(p, c, tbl, ctx);
776
+ ret = match_datashape_top(p, c, 0, tbl, ctx);
601
777
  symtable_del(tbl);
602
778
  return ret;
603
779
  }
604
780
 
605
- static ndt_t *
781
+ static const ndt_t *
606
782
  broadcast(const ndt_t *t, const int64_t *shape,
607
783
  int outer_dims, int inner_dims,
608
784
  bool use_max, ndt_context_t *ctx)
609
785
  {
610
786
  ndt_ndarray_t u;
611
- const ndt_t *dtype;
612
- ndt_t *v;
787
+ const ndt_t *v, *w;
613
788
  int64_t step;
614
789
  int ndim;
615
790
  int i, k;
@@ -619,14 +794,12 @@ broadcast(const ndt_t *t, const int64_t *shape,
619
794
  return NULL;
620
795
  }
621
796
 
622
- dtype = ndt_dtype(t);
623
- v = ndt_copy(dtype, ctx);
624
- if (v == NULL) {
625
- return NULL;
626
- }
797
+ v = ndt_dtype(t);
798
+ ndt_incref(v);
627
799
 
628
800
  for (i=ndim-1; i>=ndim-inner_dims; i--) {
629
- v = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
801
+ w = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
802
+ ndt_move(&v, w);
630
803
  if (v == NULL) {
631
804
  return NULL;
632
805
  }
@@ -634,7 +807,8 @@ broadcast(const ndt_t *t, const int64_t *shape,
634
807
 
635
808
  for (k=outer_dims-1; i>=0 && k>=0; i--, k--) {
636
809
  step = u.shape[i]<=1 ? 0 : u.steps[i];
637
- v = ndt_fixed_dim(v, shape[k], step, ctx);
810
+ w = ndt_fixed_dim(v, shape[k], step, ctx);
811
+ ndt_move(&v, w);
638
812
  if (v == NULL) {
639
813
  return NULL;
640
814
  }
@@ -642,11 +816,12 @@ broadcast(const ndt_t *t, const int64_t *shape,
642
816
 
643
817
  for (; k>=0; k--) {
644
818
  if (use_max) {
645
- v = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
819
+ w = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
646
820
  }
647
821
  else {
648
- v = ndt_fixed_dim(v, shape[k], 0, ctx);
822
+ w = ndt_fixed_dim(v, shape[k], 0, ctx);
649
823
  }
824
+ ndt_move(&v, w);
650
825
  if (v == NULL) {
651
826
  return NULL;
652
827
  }
@@ -656,34 +831,44 @@ broadcast(const ndt_t *t, const int64_t *shape,
656
831
  }
657
832
 
658
833
  int
659
- ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
660
- const ndt_t *in[], const int nin,
661
- const int64_t *shape, int outer_dims,
662
- ndt_context_t *ctx)
834
+ ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig, bool check_broadcast,
835
+ const int64_t *shape, const int outer_dims, ndt_context_t *ctx)
663
836
  {
664
- ndt_t *u;
837
+ const int nin = spec->nin;
838
+ const int nout = spec->nout;
665
839
  int inner_dims;
666
840
  int i;
667
841
 
668
842
  for (i = 0; i < nin; i++) {
669
843
  inner_dims = sig->Function.types[i]->ndim-1;
670
- spec->broadcast[i] = broadcast(in[i], shape,
671
- outer_dims, inner_dims, false, ctx);
672
- if (spec->broadcast[i] == NULL) {
844
+ const ndt_t *t = spec->types[i];
845
+ const ndt_t *u = broadcast(t, shape, outer_dims, inner_dims, false, ctx);
846
+ if (u == NULL) {
673
847
  return -1;
674
848
  }
675
- spec->nbroadcast++;
849
+ ndt_decref(t);
850
+ spec->types[i] = u;
676
851
  }
677
852
 
678
- for (i = 0; i < spec->nout; i++) {
853
+ for (i = 0; i < nout; i++) {
679
854
  inner_dims = sig->Function.types[nin+i]->ndim-1;
680
- u = broadcast(spec->out[i], shape,
681
- outer_dims, inner_dims, true, ctx);
855
+ const ndt_t *t = spec->types[nin+i];
856
+ const ndt_t *u = broadcast(t, shape, outer_dims, inner_dims, true, ctx);
682
857
  if (u == NULL) {
683
858
  return -1;
684
859
  }
685
- ndt_del(spec->out[i]);
686
- spec->out[i] = u;
860
+
861
+ if (check_broadcast) {
862
+ if (!ndt_equal(t, u)) {
863
+ ndt_err_format(ctx, NDT_ValueError,
864
+ "explicit 'out' type not compatible with input types");
865
+ ndt_decref(u);
866
+ return -1;
867
+ }
868
+ }
869
+
870
+ ndt_decref(t);
871
+ spec->types[nin+i] = u;
687
872
  }
688
873
 
689
874
  spec->outer_dims = outer_dims;
@@ -692,8 +877,7 @@ ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
692
877
  }
693
878
 
694
879
  static int
695
- broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
696
- const ndt_t *in[], const int nin,
880
+ broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig, bool check_broadcast,
697
881
  const symtable_t *tbl, ndt_context_t *ctx)
698
882
  {
699
883
  symtable_entry_t v;
@@ -705,7 +889,7 @@ broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
705
889
  return -1;
706
890
  }
707
891
 
708
- return ndt_broadcast_all(spec, sig, in, nin,
892
+ return ndt_broadcast_all(spec, sig, check_broadcast,
709
893
  v.BroadcastSeq.dims, v.BroadcastSeq.size,
710
894
  ctx);
711
895
  }
@@ -746,20 +930,23 @@ resolve_constraint(const ndt_constraint_t *c, const void *args, symtable_t *tbl,
746
930
  */
747
931
  int
748
932
  ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
749
- const ndt_t *in[], const int nin,
933
+ const ndt_t *types[], const int64_t li[],
934
+ const int nin, const int nout, bool check_broadcast,
750
935
  const ndt_constraint_t *c, const void *args,
751
936
  ndt_context_t *ctx)
752
937
  {
753
938
  symtable_t *tbl;
754
- ndt_t *t;
939
+ const ndt_t *t;
755
940
  const char *name;
756
- int ret;
941
+ const int nargs = nin + nout;
757
942
  int64_t i;
943
+ int ret;
758
944
 
759
945
  assert(spec->flags == 0);
760
- assert(spec->nout == 0);
761
- assert(spec->nbroadcast == 0);
762
946
  assert(spec->outer_dims == 0);
947
+ assert(spec->nin == 0);
948
+ assert(spec->nout == 0);
949
+ assert(spec->nargs == 0);
763
950
 
764
951
  if (sig->tag != Function) {
765
952
  ndt_err_format(ctx, NDT_ValueError,
@@ -773,8 +960,30 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
773
960
  return -1;
774
961
  }
775
962
 
776
- for (i = 0; i < nin; i++) {
777
- if (ndt_is_abstract(in[i])) {
963
+ /*
964
+ * Two configurations are allowed:
965
+ * 1) nout==0 and all 'out' types are inferred.
966
+ * 2) nout==sig->Function.nout and all 'out' types are given.
967
+ */
968
+ if (nout && nout != sig->Function.nout) {
969
+ ndt_err_format(ctx, NDT_ValueError,
970
+ "expected %" PRIi64 " 'out' arguments, got %d", sig->Function.nout, nout);
971
+ return -1;
972
+ }
973
+
974
+ /*
975
+ * Broadcast configurations:
976
+ * nout==0 -> check_broadcast==false.
977
+ */
978
+ if (!nout && check_broadcast == true) {
979
+ ndt_err_format(ctx, NDT_RuntimeError,
980
+ "internal error: without explicit 'out' arguments "
981
+ "'check_broadcast' must be false");
982
+ return -1;
983
+ }
984
+
985
+ for (i = 0; i < nargs; i++) {
986
+ if (ndt_is_abstract(types[i])) {
778
987
  ndt_err_format(ctx, NDT_ValueError,
779
988
  "type checking requires concrete argument types");
780
989
  return -1;
@@ -786,8 +995,8 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
786
995
  return -1;
787
996
  }
788
997
 
789
- for (i = 0; i < nin; i++) {
790
- ret = match_datashape(sig->Function.types[i], in[i], tbl, ctx);
998
+ for (i = 0; i < nargs; i++) {
999
+ ret = match_datashape_top(sig->Function.types[i], types[i], li[i], tbl, ctx);
791
1000
  if (ret <= 0) {
792
1001
  symtable_del(tbl);
793
1002
 
@@ -805,15 +1014,34 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
805
1014
  return -1;
806
1015
  }
807
1016
 
808
- for (i = 0; i < sig->Function.nout; i++) {
809
- spec->out[i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
810
- if (spec->out[i] == NULL) {
811
- ndt_apply_spec_clear(spec);
812
- symtable_del(tbl);
813
- return -1;
1017
+ spec->nin = 0;
1018
+ for (i = 0; i < nin; i++) {
1019
+ ndt_incref(types[i]);
1020
+ spec->types[i] = types[i];
1021
+ spec->nin++;
1022
+ }
1023
+
1024
+ spec->nout = 0;
1025
+ if (nout == 0) {
1026
+ /* Infer the return types. */
1027
+ for (i = 0; i < sig->Function.nout; i++) {
1028
+ spec->types[nin+i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
1029
+ if (spec->types[nin+i] == NULL) {
1030
+ ndt_apply_spec_clear(spec);
1031
+ symtable_del(tbl);
1032
+ return -1;
1033
+ }
1034
+ spec->nout++;
1035
+ }
1036
+ }
1037
+ else {
1038
+ for (i = 0; i < nout; i++) {
1039
+ ndt_incref(types[nin+i]);
1040
+ spec->types[nin+i] = types[nin+i];
1041
+ spec->nout++;
814
1042
  }
815
- spec->nout++;
816
1043
  }
1044
+ spec->nargs = spec->nin + spec->nout;
817
1045
 
818
1046
  if (sig->flags & NDT_ELLIPSIS) {
819
1047
  if (sig->Function.nargs == 0 || sig->Function.types[0]->tag != EllipsisDim) {
@@ -833,8 +1061,17 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
833
1061
  case FixedSeq:
834
1062
  spec->outer_dims = v.FixedSeq.size;
835
1063
  break;
836
- case VarSeq:
837
- spec->outer_dims = v.VarSeq.size;
1064
+ case VarSeq: {
1065
+ if (v.VarSeq.size > 0) {
1066
+ spec->outer_dims = ndt_logical_ndim(v.VarSeq.dims[0]);
1067
+ }
1068
+ else {
1069
+ spec->outer_dims = 0;
1070
+ }
1071
+ break;
1072
+ }
1073
+ case ArraySeq:
1074
+ spec->outer_dims = v.ArraySeq.size;
838
1075
  break;
839
1076
  default:
840
1077
  ndt_err_format(ctx, NDT_RuntimeError,
@@ -845,7 +1082,7 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
845
1082
  }
846
1083
  }
847
1084
  else {
848
- if (broadcast_all(spec, sig, in, nin, tbl, ctx) < 0) {
1085
+ if (broadcast_all(spec, sig, check_broadcast, tbl, ctx) < 0) {
849
1086
  ndt_apply_spec_clear(spec);
850
1087
  symtable_del(tbl);
851
1088
  return -1;
@@ -855,31 +1092,33 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
855
1092
 
856
1093
  symtable_del(tbl);
857
1094
 
858
- for (i = 0; i < sig->Function.nout; i++) {
859
- ndt_t *_p = sig->Function.types[nin+i];
860
- ndt_t *_c = spec->out[i];
861
- ndt_t *_t = to_fortran(_p, _c, ctx);
862
- if (_t == NULL) {
863
- ndt_apply_spec_clear(spec);
864
- return -1;
865
- }
866
- if (_t != _c) {
867
- ndt_del(_c);
1095
+ if (nout == 0) {
1096
+ for (i = 0; i < sig->Function.nout; i++) {
1097
+ const ndt_t *_p = sig->Function.types[nin+i];
1098
+ const ndt_t *_c = spec->types[nin+i];
1099
+ const ndt_t *_t = to_fortran(_p, _c, ctx);
1100
+
1101
+ if (_t == NULL) {
1102
+ ndt_apply_spec_clear(spec);
1103
+ return -1;
1104
+ }
1105
+
1106
+ ndt_decref(_c);
1107
+ spec->types[nin+i] = _t;
868
1108
  }
869
- spec->out[i] = _t;
870
1109
  }
871
1110
 
872
- if (!check_contig(sig->Function.types, (ndt_t **)in, nin)) {
1111
+ if (!check_contig(sig->Function.types, types, nargs)) {
873
1112
  ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
1113
+ ndt_apply_spec_clear(spec);
874
1114
  return -1;
875
1115
  }
876
- if (!check_contig(sig->Function.types+nin, spec->out, spec->nout)) {
877
- ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
1116
+
1117
+ if (ndt_select_kernel_strategy(spec, ctx) < 0) {
1118
+ ndt_apply_spec_clear(spec);
878
1119
  return -1;
879
1120
  }
880
1121
 
881
- ndt_select_kernel_strategy(spec, sig, in, nin);
882
-
883
1122
  return 0;
884
1123
  }
885
1124
 
@@ -888,11 +1127,11 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
888
1127
  /* Optimized binary typecheck for fixed input */
889
1128
  /*****************************************************************************/
890
1129
 
891
- static ndt_t *
892
- binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
893
- const int64_t *shape, int size, ndt_context_t *ctx)
1130
+ static const ndt_t *
1131
+ fast_broadcast(const ndt_ndarray_t *t, const ndt_t *dtype,
1132
+ const int64_t *shape, int size, ndt_context_t *ctx)
894
1133
  {
895
- ndt_t *v;
1134
+ const ndt_t *v;
896
1135
  int64_t step;
897
1136
  int i, k;
898
1137
 
@@ -903,14 +1142,18 @@ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
903
1142
 
904
1143
  for (i=t->ndim-1, k=size-1; i>=0 && k>=0; i--, k--) {
905
1144
  step = t->shape[i]<=1 ? 0 : t->steps[i];
906
- v = ndt_fixed_dim(v, shape[k], step, ctx);
1145
+ const ndt_t *w = ndt_fixed_dim(v, shape[k], step, ctx);
1146
+ ndt_decref(v);
1147
+ v = w;
907
1148
  if (v == NULL) {
908
1149
  return NULL;
909
1150
  }
910
1151
  }
911
1152
 
912
1153
  for (; k>=0; k--) {
913
- v = ndt_fixed_dim(v, shape[k], 0, ctx);
1154
+ const ndt_t *w = ndt_fixed_dim(v, shape[k], 0, ctx);
1155
+ ndt_decref(v);
1156
+ v = w;
914
1157
  if (v == NULL) {
915
1158
  return NULL;
916
1159
  }
@@ -919,15 +1162,19 @@ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
919
1162
  return v;
920
1163
  }
921
1164
 
922
- static ndt_t *
923
- fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
1165
+ static const ndt_t *
1166
+ fixed_dim_from_shape(const int64_t shape[], int len, const ndt_t *dtype,
924
1167
  ndt_context_t *ctx)
925
1168
  {
926
- ndt_t *t;
1169
+ const ndt_t *t;
927
1170
  int i;
928
1171
 
1172
+ ndt_incref(dtype);
1173
+
929
1174
  for (i=len-1, t=dtype; i >= 0; i--) {
930
- t = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
1175
+ const ndt_t *u = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
1176
+ ndt_decref(t);
1177
+ t = u;
931
1178
  if (t == NULL) {
932
1179
  return NULL;
933
1180
  }
@@ -936,83 +1183,313 @@ fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
936
1183
  return t;
937
1184
  }
938
1185
 
939
- static bool
940
- shape_equal(const ndt_ndarray_t *a, const ndt_ndarray_t *b)
1186
+ static int
1187
+ broadcast_error(const char *msg, ndt_context_t *ctx)
941
1188
  {
942
- if (b->ndim != a->ndim) {
943
- return false;
1189
+ ndt_err_format(ctx, NDT_TypeError, "%s", msg);
1190
+ return -1;
1191
+ }
1192
+
1193
+ static int
1194
+ _ndt_unary_broadcast(ndt_apply_spec_t *spec,
1195
+ const ndt_ndarray_t *x, const ndt_t *dtype_x,
1196
+ const ndt_t *out, const ndt_t *dtype_out,
1197
+ const bool check_broadcast, const int inner,
1198
+ ndt_context_t *ctx)
1199
+ {
1200
+ int64_t shape[NDT_MAX_DIM];
1201
+ int size = x->ndim;
1202
+ ndt_ndarray_t y;
1203
+ const ndt_t *t;
1204
+
1205
+ for (int i = 0; i < size; i++) {
1206
+ shape[i] = x->shape[i];
944
1207
  }
945
1208
 
946
- for (int i = 0; i < a->ndim; i++) {
947
- if (b->shape[i] != a->shape[i]) {
948
- return false;
1209
+ if (out != NULL) {
1210
+ if (ndt_as_ndarray(&y, out, ctx) < 0) {
1211
+ return -1;
949
1212
  }
1213
+
1214
+ size = _resolve_broadcast(shape, size, y.shape, y.ndim);
1215
+ if (size < 0) {
1216
+ return broadcast_error("could not broadcast output argument", ctx);
1217
+ }
1218
+ }
1219
+
1220
+ spec->types[0] = fast_broadcast(x, dtype_x, shape, size, ctx);
1221
+ if (spec->types[0] == NULL) {
1222
+ return -1;
1223
+ }
1224
+
1225
+ if (out != NULL) {
1226
+ if (check_broadcast) {
1227
+ t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
1228
+ if (t == NULL) {
1229
+ ndt_decref(spec->types[0]);
1230
+ return -1;
1231
+ }
1232
+
1233
+ if (!ndt_equal(t, out)) {
1234
+ ndt_err_format(ctx, NDT_ValueError,
1235
+ "explicit 'out' type not compatible with input types");
1236
+ ndt_decref(spec->types[0]);
1237
+ ndt_decref(t);
1238
+ return -1;
1239
+ }
1240
+ }
1241
+ else {
1242
+ t = fast_broadcast(&y, dtype_out, shape, size, ctx);
1243
+ if (t == NULL) {
1244
+ ndt_decref(spec->types[0]);
1245
+ return -1;
1246
+ }
1247
+ }
1248
+ }
1249
+ else {
1250
+ t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
1251
+ if (t == NULL) {
1252
+ ndt_decref(spec->types[0]);
1253
+ return -1;
1254
+ }
1255
+ }
1256
+ spec->types[1] = t;
1257
+
1258
+ spec->outer_dims = size-inner;
1259
+ spec->nin = 1;
1260
+ spec->nout = 1;
1261
+ spec->nargs = 2;
1262
+
1263
+ if (ndt_select_kernel_strategy(spec, ctx) < 0) {
1264
+ ndt_apply_spec_clear(spec);
1265
+ return -1;
1266
+ }
1267
+
1268
+ return 0;
1269
+ }
1270
+
1271
+ static bool
1272
+ unary_all_ellipses(const ndt_t *t0, const ndt_t *t1, ndt_context_t *ctx)
1273
+ {
1274
+ if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
1275
+ (t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL)) {
1276
+ ndt_err_format(ctx, NDT_RuntimeError,
1277
+ "fast unary typecheck expects leading ellipsis dimensions");
1278
+ return false;
950
1279
  }
951
1280
 
952
1281
  return true;
953
1282
  }
954
1283
 
955
- static int
956
- _ndt_binary_broadcast(ndt_apply_spec_t *spec, const ndt_t *sig,
957
- const ndt_ndarray_t *x, const ndt_ndarray_t *y,
958
- const ndt_t *in[], const int nin, ndt_t *dtype,
959
- int inner, ndt_context_t *ctx)
1284
+ static bool
1285
+ unary_all_same_symbol(const ndt_t *t0, const ndt_t *t1)
960
1286
  {
961
- int64_t shape[NDT_MAX_DIM];
962
- int size;
963
-
964
- if (shape_equal(x, y)) {
965
- spec->nout = 1;
966
- spec->nbroadcast = 0;
967
- spec->outer_dims = x->ndim-inner;
968
- spec->out[0] = fixed_dim_from_shape(x->shape, x->ndim, dtype, ctx);
969
- if (spec->out[0] == NULL) {
970
- return -1;
971
- }
1287
+ if (t0->tag != SymbolicDim || t1->tag != SymbolicDim) {
1288
+ return false;
972
1289
  }
973
- else {
974
- for (int i = 0; i < x->ndim; i++) {
975
- shape[i] = x->shape[i];
1290
+
1291
+ return strcmp(t0->SymbolicDim.name, t1->SymbolicDim.name) == 0;
1292
+ }
1293
+
1294
+ static bool
1295
+ unary_all_ndim0(const ndt_t *t0, const ndt_t *t1)
1296
+ {
1297
+ return t0->ndim == 0 && t1->ndim == 0;
1298
+ }
1299
+
1300
+ /*
1301
+ * Optimized type checking for very specific signatures. The caller must
1302
+ * have identified the kernel location and the signature. For performance
1303
+ * reasons, no substitution is performed on the dtype, so the dtype must be
1304
+ * concrete.
1305
+ *
1306
+ * Supported signature: 1) ... * T0 -> ... * T1
1307
+ */
1308
+ int
1309
+ ndt_fast_unary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1310
+ const ndt_t *types[], const int nin, const int nout,
1311
+ const bool check_broadcast, ndt_context_t *ctx)
1312
+ {
1313
+ const ndt_t *p0, *p1;
1314
+ ndt_ndarray_t x;
1315
+ const ndt_t *out = NULL;
1316
+ const ndt_t *dtype = NULL;
1317
+
1318
+ assert(spec->flags == 0);
1319
+ assert(spec->outer_dims == 0);
1320
+ assert(spec->nin == 0);
1321
+ assert(spec->nout == 0);
1322
+ assert(spec->nargs == 0);
1323
+
1324
+ if (sig->tag != Function ||
1325
+ sig->Function.nin != 1) {
1326
+ ndt_err_format(ctx, NDT_RuntimeError,
1327
+ "fast unary typecheck expects a signature with one input and "
1328
+ "one output");
1329
+ return -1;
1330
+ }
1331
+
1332
+ if (nin != 1) {
1333
+ ndt_err_format(ctx, NDT_RuntimeError,
1334
+ "fast unary typecheck expects one input argument");
1335
+ return -1;
1336
+ }
1337
+
1338
+ p0 = sig->Function.types[0];
1339
+ p1 = sig->Function.types[1];
1340
+
1341
+ if (nout) {
1342
+ if (nout != 1) {
1343
+ ndt_err_format(ctx, NDT_RuntimeError,
1344
+ "fast unary typecheck expects at most one explicit out argument");
1345
+ return -1;
976
1346
  }
1347
+ out = types[1];
1348
+ dtype = ndt_dtype(types[1]);
977
1349
 
978
- size = _resolve_broadcast(shape, x->ndim, y->shape, y->ndim);
979
- if (size < 0) {
980
- ndt_err_format(ctx, NDT_TypeError, "broadcast error");
981
- ndt_del(dtype);
1350
+ if (!ndt_equal(dtype, ndt_dtype(p1))) {
1351
+ ndt_err_format(ctx, NDT_ValueError,
1352
+ "dtype of the out argument does not match");
982
1353
  return -1;
983
1354
  }
1355
+ }
1356
+ else {
1357
+ dtype = ndt_dtype(sig->Function.types[1]);
1358
+ }
1359
+
1360
+ if (ndt_is_abstract(dtype)) {
1361
+ ndt_err_format(ctx, NDT_RuntimeError,
1362
+ "fast unary typecheck expects a concrete dtype");
1363
+ return -1;
1364
+ }
1365
+
1366
+ if (!unary_all_ellipses(p0, p1, ctx)) {
1367
+ return -1;
1368
+ }
1369
+
1370
+ if (ndt_as_ndarray(&x, types[0], ctx) < 0) {
1371
+ return -1;
1372
+ }
1373
+
1374
+ p0 = p0->EllipsisDim.type;
1375
+ p1 = p1->EllipsisDim.type;
1376
+
1377
+ if (unary_all_same_symbol(p0, p1)) {
1378
+ return _ndt_unary_broadcast(spec,
1379
+ &x, ndt_dtype(types[0]),
1380
+ out, dtype, check_broadcast,
1381
+ 1, ctx);
1382
+ }
1383
+ else if (unary_all_ndim0(p0, p1)) {
1384
+ return _ndt_unary_broadcast(spec,
1385
+ &x, ndt_dtype(types[0]),
1386
+ out, dtype, check_broadcast,
1387
+ 0, ctx);
1388
+ }
1389
+ else {
1390
+ ndt_err_format(ctx, NDT_RuntimeError,
1391
+ "unsupported signature in fast unary typecheck");
1392
+ return -1;
1393
+ }
1394
+ }
1395
+
1396
+ static int
1397
+ _ndt_binary_broadcast(ndt_apply_spec_t *spec,
1398
+ const ndt_ndarray_t *x, const ndt_t *dtype_x,
1399
+ const ndt_ndarray_t *y, const ndt_t *dtype_y,
1400
+ const ndt_t *out, const ndt_t *dtype_out,
1401
+ const bool check_broadcast, const int inner,
1402
+ ndt_context_t *ctx)
1403
+ {
1404
+ int64_t shape[NDT_MAX_DIM];
1405
+ int size = x->ndim;
1406
+ ndt_ndarray_t z;
1407
+ const ndt_t *t;
984
1408
 
985
- spec->nout = 1;
986
- spec->nbroadcast = 2;
987
- spec->outer_dims = size-inner;
1409
+ for (int i = 0; i < size; i++) {
1410
+ shape[i] = x->shape[i];
1411
+ }
1412
+
1413
+ size = _resolve_broadcast(shape, size, y->shape, y->ndim);
1414
+ if (size < 0) {
1415
+ return broadcast_error("could not broadcast input arguments", ctx);
1416
+ }
988
1417
 
989
- spec->out[0] = fixed_dim_from_shape(shape, size, dtype, ctx);
990
- if (spec->out[0] == NULL) {
1418
+ if (out != NULL) {
1419
+ if (ndt_as_ndarray(&z, out, ctx) < 0) {
991
1420
  return -1;
992
1421
  }
993
1422
 
994
- spec->broadcast[0] = binary_broadcast_1D(x, ndt_dtype(in[0]), shape, size, ctx);
995
- if (spec->broadcast[0] == NULL) {
996
- ndt_del(spec->out[0]);
997
- return -1;
1423
+ size = _resolve_broadcast(shape, size, z.shape, z.ndim);
1424
+ if (size < 0) {
1425
+ return broadcast_error("could not broadcast output argument", ctx);
998
1426
  }
1427
+ }
1428
+
1429
+ spec->types[0] = fast_broadcast(x, dtype_x, shape, size, ctx);
1430
+ if (spec->types[0] == NULL) {
1431
+ return -1;
1432
+ }
1433
+
1434
+ spec->types[1] = fast_broadcast(y, dtype_y, shape, size, ctx);
1435
+ if (spec->types[1] == NULL) {
1436
+ ndt_decref(spec->types[0]);
1437
+ return -1;
1438
+ }
1439
+
1440
+ if (out != NULL) {
1441
+ if (check_broadcast) {
1442
+ t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
1443
+ if (t == NULL) {
1444
+ ndt_decref(spec->types[0]);
1445
+ ndt_decref(spec->types[1]);
1446
+ return -1;
1447
+ }
999
1448
 
1000
- spec->broadcast[1] = binary_broadcast_1D(y, ndt_dtype(in[1]), shape, size, ctx);
1001
- if (spec->broadcast[1] == NULL) {
1002
- ndt_del(spec->out[0]);
1003
- ndt_del(spec->broadcast[0]);
1449
+ if (!ndt_equal(t, out)) {
1450
+ ndt_err_format(ctx, NDT_ValueError,
1451
+ "explicit 'out' type not compatible with input types");
1452
+ ndt_decref(spec->types[0]);
1453
+ ndt_decref(spec->types[1]);
1454
+ ndt_decref(t);
1455
+ return -1;
1456
+ }
1457
+ }
1458
+ else {
1459
+ t = fast_broadcast(&z, dtype_out, shape, size, ctx);
1460
+ if (t == NULL) {
1461
+ ndt_decref(spec->types[0]);
1462
+ ndt_decref(spec->types[1]);
1463
+ return -1;
1464
+ }
1465
+ }
1466
+ }
1467
+ else {
1468
+ t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
1469
+ if (t == NULL) {
1470
+ ndt_decref(spec->types[0]);
1471
+ ndt_decref(spec->types[1]);
1004
1472
  return -1;
1005
1473
  }
1006
1474
  }
1475
+ spec->types[2] = t;
1007
1476
 
1008
- ndt_select_kernel_strategy(spec, sig, in, nin);
1477
+ spec->outer_dims = size-inner;
1478
+ spec->nin = 2;
1479
+ spec->nout = 1;
1480
+ spec->nargs = 3;
1481
+
1482
+ if (ndt_select_kernel_strategy(spec, ctx) < 0) {
1483
+ ndt_apply_spec_clear(spec);
1484
+ return -1;
1485
+ }
1009
1486
 
1010
1487
  return 0;
1011
1488
  }
1012
1489
 
1013
1490
  static bool
1014
- all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
1015
- ndt_context_t *ctx)
1491
+ binary_all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
1492
+ ndt_context_t *ctx)
1016
1493
  {
1017
1494
  if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
1018
1495
  (t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL) ||
@@ -1026,7 +1503,7 @@ all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
1026
1503
  }
1027
1504
 
1028
1505
  static bool
1029
- all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1506
+ binary_all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1030
1507
  {
1031
1508
  if (t0->tag != SymbolicDim || t1->tag != SymbolicDim ||
1032
1509
  t2->tag != SymbolicDim) {
@@ -1038,37 +1515,37 @@ all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1038
1515
  }
1039
1516
 
1040
1517
  static bool
1041
- all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1518
+ binary_all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1042
1519
  {
1043
1520
  return t0->ndim == 0 && t1->ndim == 0 && t2->ndim == 0;
1044
1521
  }
1045
1522
 
1046
1523
  /*
1047
1524
  * Optimized type checking for very specific signatures. The caller must
1048
- * have identified the kernel location, signature and the dtype. For
1049
- * performance reasons, no substitution is performed on the dtype, so
1050
- * the dtype must be concrete.
1525
+ * have identified the kernel location and the signature. For performance
1526
+ * reasons, no substitution is performed on the dtype, so the dtype must be
1527
+ * concrete.
1051
1528
  *
1052
- * Supported signatures:
1053
- * 1) ... * N * T0, ... * N * T1 -> N * T2
1054
- * 2) ... * T0, ... * T1 -> ... * T2
1529
+ * Supported signature: 1) ... * T0, ... * T1 -> ... * T2
1055
1530
  */
1056
1531
  int
1057
1532
  ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1058
- const ndt_t *in[], const int nin, ndt_t *dtype,
1059
- ndt_context_t *ctx)
1533
+ const ndt_t *types[], const int nin, const int nout,
1534
+ const bool check_broadcast, ndt_context_t *ctx)
1060
1535
  {
1061
- ndt_t *p0, *p1, *p2;
1536
+ const ndt_t *p0, *p1, *p2;
1062
1537
  ndt_ndarray_t x, y;
1538
+ const ndt_t *out = NULL;
1539
+ const ndt_t *dtype = NULL;
1063
1540
 
1064
1541
  assert(spec->flags == 0);
1065
- assert(spec->nout == 0);
1066
- assert(spec->nbroadcast == 0);
1067
1542
  assert(spec->outer_dims == 0);
1543
+ assert(spec->nin == 0);
1544
+ assert(spec->nout == 0);
1545
+ assert(spec->nargs == 0);
1068
1546
 
1069
1547
  if (sig->tag != Function ||
1070
- sig->Function.nin != 2 ||
1071
- sig->Function.nout != 1) {
1548
+ sig->Function.nin != 2) {
1072
1549
  ndt_err_format(ctx, NDT_RuntimeError,
1073
1550
  "fast binary typecheck expects a signature with two inputs and "
1074
1551
  "one output");
@@ -1081,27 +1558,44 @@ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1081
1558
  return -1;
1082
1559
  }
1083
1560
 
1561
+ p0 = sig->Function.types[0];
1562
+ p1 = sig->Function.types[1];
1563
+ p2 = sig->Function.types[2];
1564
+
1565
+ if (nout) {
1566
+ if (nout != 1) {
1567
+ ndt_err_format(ctx, NDT_RuntimeError,
1568
+ "fast binary typecheck expects at most one explicit out argument");
1569
+ return -1;
1570
+ }
1571
+ out = types[2];
1572
+ dtype = ndt_dtype(types[2]);
1573
+
1574
+ if (!ndt_equal(dtype, ndt_dtype(p2))) {
1575
+ ndt_err_format(ctx, NDT_ValueError,
1576
+ "dtype of the out argument does not match");
1577
+ return -1;
1578
+ }
1579
+ }
1580
+ else {
1581
+ dtype = ndt_dtype(sig->Function.types[2]);
1582
+ }
1583
+
1084
1584
  if (ndt_is_abstract(dtype)) {
1085
1585
  ndt_err_format(ctx, NDT_RuntimeError,
1086
1586
  "fast binary typecheck expects a concrete dtype");
1087
1587
  return -1;
1088
1588
  }
1089
1589
 
1090
- p0 = sig->Function.types[0];
1091
- p1 = sig->Function.types[1];
1092
- p2 = sig->Function.types[2];
1093
-
1094
- if (!all_ellipses(p0, p1, p2, ctx)) {
1590
+ if (!binary_all_ellipses(p0, p1, p2, ctx)) {
1095
1591
  return -1;
1096
1592
  }
1097
1593
 
1098
- if (ndt_as_ndarray(&x, in[0], ctx) < 0) {
1099
- ndt_del(dtype);
1594
+ if (ndt_as_ndarray(&x, types[0], ctx) < 0) {
1100
1595
  return -1;
1101
1596
  }
1102
1597
 
1103
- if (ndt_as_ndarray(&y, in[1], ctx) < 0) {
1104
- ndt_del(dtype);
1598
+ if (ndt_as_ndarray(&y, types[1], ctx) < 0) {
1105
1599
  return -1;
1106
1600
  }
1107
1601
 
@@ -1109,20 +1603,27 @@ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1109
1603
  p1 = p1->EllipsisDim.type;
1110
1604
  p2 = p2->EllipsisDim.type;
1111
1605
 
1112
- if (all_same_symbol(p0, p1, p2)) {
1606
+ if (binary_all_same_symbol(p0, p1, p2)) {
1113
1607
  if (x.ndim > 0 && y.ndim > 0) {
1114
1608
  const int64_t xshape = x.shape[x.ndim-1];
1115
1609
  const int64_t yshape = y.shape[y.ndim-1];
1116
1610
  if (xshape != 1 && yshape != 1 && xshape != yshape) {
1117
1611
  ndt_err_format(ctx, NDT_TypeError, "mismatch in inner dimensions");
1118
- ndt_del(dtype);
1119
1612
  return -1;
1120
1613
  }
1121
1614
  }
1122
- return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 1, ctx);
1123
- }
1124
- else if (all_ndim0(p0, p1, p2)) {
1125
- return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 0, ctx);
1615
+ return _ndt_binary_broadcast(spec,
1616
+ &x, ndt_dtype(types[0]),
1617
+ &y, ndt_dtype(types[1]),
1618
+ out, dtype, check_broadcast,
1619
+ 1, ctx);
1620
+ }
1621
+ else if (binary_all_ndim0(p0, p1, p2)) {
1622
+ return _ndt_binary_broadcast(spec,
1623
+ &x, ndt_dtype(types[0]),
1624
+ &y, ndt_dtype(types[1]),
1625
+ out, dtype, check_broadcast,
1626
+ 0, ctx);
1126
1627
  }
1127
1628
  else {
1128
1629
  ndt_err_format(ctx, NDT_RuntimeError,