ndtypes 0.2.0dev5 → 0.2.0dev6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (130) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +12 -0
  3. data/Rakefile +8 -0
  4. data/ext/ruby_ndtypes/GPATH +0 -0
  5. data/ext/ruby_ndtypes/GRTAGS +0 -0
  6. data/ext/ruby_ndtypes/GTAGS +0 -0
  7. data/ext/ruby_ndtypes/extconf.rb +1 -1
  8. data/ext/ruby_ndtypes/include/ndtypes.h +231 -122
  9. data/ext/ruby_ndtypes/include/ruby_ndtypes.h +1 -1
  10. data/ext/ruby_ndtypes/lib/libndtypes.a +0 -0
  11. data/ext/ruby_ndtypes/lib/libndtypes.so.0.2.0dev3 +0 -0
  12. data/ext/ruby_ndtypes/ndtypes/Makefile +87 -0
  13. data/ext/ruby_ndtypes/ndtypes/config.h +68 -0
  14. data/ext/ruby_ndtypes/ndtypes/config.log +477 -0
  15. data/ext/ruby_ndtypes/ndtypes/config.status +1027 -0
  16. data/ext/ruby_ndtypes/ndtypes/doc/_static/style.css +7 -0
  17. data/ext/ruby_ndtypes/ndtypes/doc/_templates/layout.html +2 -0
  18. data/ext/ruby_ndtypes/ndtypes/doc/conf.py +40 -4
  19. data/ext/ruby_ndtypes/ndtypes/doc/images/xndlogo.png +0 -0
  20. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +1 -1
  21. data/ext/ruby_ndtypes/ndtypes/doc/requirements.txt +2 -0
  22. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile +287 -0
  23. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +20 -4
  24. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +22 -3
  25. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +1 -1
  26. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.o +0 -0
  27. data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.o +0 -0
  28. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile +73 -0
  29. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +246 -229
  30. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +15 -11
  31. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.o +0 -0
  32. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +38 -28
  33. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +91 -91
  34. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +1 -1
  35. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +4 -3
  36. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.o +0 -0
  37. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +8 -7
  38. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.o +0 -0
  39. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +2 -2
  40. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.o +0 -0
  41. data/ext/ruby_ndtypes/ndtypes/libndtypes/context.o +0 -0
  42. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +263 -182
  43. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.o +0 -0
  44. data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.o +0 -0
  45. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +67 -7
  46. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.o +0 -0
  47. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +1112 -1000
  48. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +69 -58
  49. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.o +0 -0
  50. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +150 -99
  51. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +185 -15
  52. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.o +0 -0
  53. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +301 -276
  54. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +1 -1
  55. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +9 -4
  56. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.o +0 -0
  57. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.a +0 -0
  58. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so +1 -0
  59. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0 +1 -0
  60. data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3 +0 -0
  61. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +729 -228
  62. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.o +0 -0
  63. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +768 -403
  64. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h +1002 -0
  65. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +231 -122
  66. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.o +0 -0
  67. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +176 -84
  68. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +26 -14
  69. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.o +0 -0
  70. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +57 -35
  71. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.o +0 -0
  72. data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.c +420 -0
  73. data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.o +0 -0
  74. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +8 -8
  75. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +1 -1
  76. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.o +0 -0
  77. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile +48 -0
  78. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +200 -116
  79. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.o +0 -0
  80. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +46 -4
  81. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.o +0 -0
  82. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +58 -27
  83. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +1 -1
  84. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.o +0 -0
  85. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +3 -5
  86. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +12 -4
  87. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.o +0 -0
  88. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile +55 -0
  89. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +8 -8
  90. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +5 -5
  91. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +274 -172
  92. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +24 -4
  93. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +2 -2
  94. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +14 -14
  95. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +32 -30
  96. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +37 -0
  97. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +36 -0
  98. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +16 -0
  99. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +5 -5
  100. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +706 -253
  101. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_unify.c +132 -0
  102. data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.c +703 -0
  103. data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.o +0 -0
  104. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +335 -127
  105. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.o +0 -0
  106. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +2 -2
  107. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.o +0 -0
  108. data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +88 -71
  109. data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +0 -1
  110. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +10 -13
  111. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +395 -314
  112. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.a +0 -0
  113. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so +1 -0
  114. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0 +1 -0
  115. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0.2.0dev3 +0 -0
  116. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/ndtypes.h +1002 -0
  117. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +15 -33
  118. data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +340 -132
  119. data/ext/ruby_ndtypes/ndtypes/setup.py +11 -2
  120. data/ext/ruby_ndtypes/ruby_ndtypes.c +364 -241
  121. data/ext/ruby_ndtypes/ruby_ndtypes.h +1 -1
  122. data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +0 -1
  123. data/lib/ndtypes.rb +11 -0
  124. data/lib/ndtypes/version.rb +2 -2
  125. data/lib/ruby_ndtypes.so +0 -0
  126. data/ndtypes.gemspec +3 -0
  127. data/spec/ndtypes_spec.rb +6 -0
  128. metadata +98 -4
  129. data/ext/ruby_ndtypes/gc_guard.c +0 -36
  130. data/ext/ruby_ndtypes/gc_guard.h +0 -12
@@ -726,7 +726,7 @@ extern int yylex \
726
726
  #undef yyTABLES_NAME
727
727
  #endif
728
728
 
729
- #line 222 "lexer.l"
729
+ #line 227 "lexer.l"
730
730
 
731
731
 
732
732
  #line 732 "lexer.h"
@@ -144,24 +144,26 @@ yycolumn = 1;
144
144
  "void" { return VOID; }
145
145
  "bool" { return BOOL; }
146
146
 
147
- "Signed" { return SIGNED_KIND; }
147
+ "signed" { return SIGNED_KIND; }
148
148
  "int8" { return INT8; }
149
149
  "int16" { return INT16; }
150
150
  "int32" { return INT32; }
151
151
  "int64" { return INT64; }
152
152
 
153
- "Unsigned" { return UNSIGNED_KIND; }
153
+ "unsigned" { return UNSIGNED_KIND; }
154
154
  "uint8" { return UINT8; }
155
155
  "uint16" { return UINT16; }
156
156
  "uint32" { return UINT32; }
157
157
  "uint64" { return UINT64; }
158
158
 
159
- "Float" { return FLOAT_KIND; }
159
+ "float" { return FLOAT_KIND; }
160
+ "bfloat16" { return BFLOAT16; }
160
161
  "float16" { return FLOAT16; }
161
162
  "float32" { return FLOAT32; }
162
163
  "float64" { return FLOAT64; }
163
164
 
164
- "Complex" { return COMPLEX_KIND; }
165
+ "complex" { return COMPLEX_KIND; }
166
+ "bcomplex32" { return BCOMPLEX32; }
165
167
  "complex32" { return COMPLEX32; }
166
168
  "complex64" { return COMPLEX64; }
167
169
  "complex128" { return COMPLEX128; }
@@ -186,6 +188,9 @@ yycolumn = 1;
186
188
 
187
189
  "fixed" { return FIXED; }
188
190
  "var" { return VAR; }
191
+ "array" { return ARRAY; }
192
+
193
+ "of" { return OF; }
189
194
 
190
195
  "..." { return ELLIPSIS; }
191
196
  "->" { return RARROW; }
@@ -0,0 +1 @@
1
+ ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3
@@ -0,0 +1 @@
1
+ ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3
@@ -103,7 +103,7 @@ resolve_broadcast(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
103
103
  }
104
104
 
105
105
  static int
106
- check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
106
+ check_contig(const ndt_t *ptypes[], const ndt_t *ctypes[], int64_t nargs)
107
107
  {
108
108
  for (int i = 0; i < nargs; i++) {
109
109
  const ndt_t *p = ptypes[i];
@@ -130,14 +130,14 @@ check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
130
130
  return 1;
131
131
  }
132
132
 
133
- static ndt_t *
134
- to_fortran(const ndt_t *p, ndt_t *c, ndt_context_t *ctx)
133
+ static const ndt_t *
134
+ to_fortran(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
135
135
  {
136
136
  if (p->tag == EllipsisDim && p->EllipsisDim.tag == RequireF) {
137
- ndt_t *t = ndt_to_fortran(c, ctx);
138
- return t;
137
+ return ndt_to_fortran(c, ctx);
139
138
  }
140
139
  else {
140
+ ndt_incref(c);
141
141
  return c;
142
142
  }
143
143
  }
@@ -217,41 +217,124 @@ resolve_typevar(const char *key, symtable_entry_t w, symtable_t *tbl, ndt_contex
217
217
  }
218
218
  }
219
219
 
220
+ typedef struct {
221
+ int64_t index;
222
+ const ndt_t *type;
223
+ } indexed_type_t;
224
+
225
+ static indexed_type_t indexed_type_error = { 0, NULL };
226
+
227
+ static inline int64_t
228
+ adjust_index(const int64_t i, const int64_t shape, ndt_context_t *ctx)
229
+ {
230
+ const int64_t k = i < 0 ? i + shape : i;
231
+
232
+ if (k < 0 || k >= shape) {
233
+ ndt_err_format(ctx, NDT_IndexError,
234
+ "index with value %" PRIi64 " out of bounds", i);
235
+ return -1;
236
+ }
237
+
238
+ return k;
239
+ }
240
+
241
+ static inline indexed_type_t
242
+ var_dim_next(const indexed_type_t *x, const int64_t start, const int64_t step,
243
+ const int64_t i)
244
+ {
245
+ indexed_type_t next;
246
+ const ndt_t *t = x->type;
247
+
248
+ next.index = start + i * step;
249
+ next.type = t->VarDim.type;
250
+
251
+ return next;
252
+ }
253
+
254
+ /* skip stored indices */
255
+ static indexed_type_t
256
+ apply_stored_index(const indexed_type_t *x, ndt_context_t *ctx)
257
+ {
258
+ const ndt_t * const t = x->type;
259
+ int64_t start, step, shape;
260
+
261
+ if (t->tag != VarDimElem) {
262
+ ndt_err_format(ctx, NDT_RuntimeError,
263
+ "apply_stored_index: need VarDimElem");
264
+ return indexed_type_error;
265
+ }
266
+
267
+ shape = ndt_var_indices(&start, &step, t, x->index, ctx);
268
+ if (shape < 0) {
269
+ return indexed_type_error;
270
+ }
271
+
272
+ const int64_t i = adjust_index(t->VarDimElem.index, shape, ctx);
273
+ if (i < 0) {
274
+ return indexed_type_error;
275
+ }
276
+
277
+ return var_dim_next(x, start, step, i);
278
+ }
279
+
280
+ static indexed_type_t
281
+ apply_stored_indices(const indexed_type_t *x, ndt_context_t *ctx)
282
+ {
283
+ indexed_type_t tl = *x;
284
+
285
+ while (tl.type->tag == VarDimElem) {
286
+ tl = apply_stored_index(&tl, ctx);
287
+ }
288
+
289
+ return tl;
290
+ }
291
+
220
292
  static int
221
- match_concrete_var_dim(const ndt_t *t, int64_t tindex,
222
- const ndt_t *u, int64_t uindex,
293
+ match_concrete_var_dim(const indexed_type_t *a, const indexed_type_t *b,
223
294
  const int outer_dims, ndt_context_t *ctx)
224
295
  {
225
- int64_t tshape, tstart, tstep;
226
- int64_t ushape, ustart, ustep;
296
+ int64_t xshape, xstart, xstep;
297
+ int64_t yshape, ystart, ystep;
227
298
 
228
299
  if (outer_dims == 0) {
229
300
  return 1;
230
301
  }
302
+
303
+ const indexed_type_t x = apply_stored_indices(a, ctx);
304
+ if (x.type == NULL) {
305
+ return -1;
306
+ }
307
+
308
+ const indexed_type_t y = apply_stored_indices(b, ctx);
309
+ if (x.type == NULL) {
310
+ return -1;
311
+ }
312
+
313
+ const ndt_t * const t = x.type;
314
+ const ndt_t * const u = y.type;
315
+
231
316
  if (t->Concrete.VarDim.itemsize != u->Concrete.VarDim.itemsize) {
232
317
  return 0;
233
318
  }
234
319
 
235
- tshape = ndt_var_indices(&tstart, &tstep, t, tindex, ctx);
236
- if (tshape < 0) {
320
+ xshape = ndt_var_indices(&xstart, &xstep, t, x.index, ctx);
321
+ if (xshape < 0) {
237
322
  return -1;
238
323
  }
239
324
 
240
- ushape = ndt_var_indices(&ustart, &ustep, u, uindex, ctx);
241
- if (ushape < 0) {
325
+ yshape = ndt_var_indices(&ystart, &ystep, u, y.index, ctx);
326
+ if (yshape < 0) {
242
327
  return -1;
243
328
  }
244
329
 
245
- if (ushape != tshape) {
330
+ if (xshape != yshape) {
246
331
  return 0;
247
332
  }
248
333
 
249
- for (int64_t i = 0; i < tshape; i++) {
250
- int64_t tnext = tstart + i * tstep;
251
- int64_t unext = ustart + i * ustep;
252
- int ret = match_concrete_var_dim(t->VarDim.type, tnext,
253
- u->VarDim.type, unext,
254
- outer_dims-1, ctx);
334
+ for (int64_t i = 0; i < xshape; i++) {
335
+ const indexed_type_t xnext = var_dim_next(&x, xstart, xstep, i);
336
+ const indexed_type_t ynext = var_dim_next(&y, ystart, ystep, i);
337
+ int ret = match_concrete_var_dim(&xnext, &ynext, outer_dims-1, ctx);
255
338
  if (ret <= 0) {
256
339
  return ret;
257
340
  }
@@ -265,6 +348,7 @@ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
265
348
  {
266
349
  const char *key = "var";
267
350
  symtable_entry_t v;
351
+ int vdims, wdims;
268
352
 
269
353
  v = symtable_find(tbl, key);
270
354
  if (v.tag == Unbound) {
@@ -274,16 +358,36 @@ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
274
358
  return 1;
275
359
  }
276
360
 
277
- if (w.VarSeq.size != v.VarSeq.size) {
361
+ vdims = ndt_logical_ndim(v.VarSeq.dims[0]);
362
+ wdims = ndt_logical_ndim(w.VarSeq.dims[0]);
363
+
364
+ if (wdims != vdims) {
278
365
  return 0;
279
366
  }
280
- if (v.VarSeq.size == 0) {
367
+ if (vdims == 0) {
281
368
  return 1;
282
369
  }
283
370
 
284
- return match_concrete_var_dim(w.VarSeq.dims[0], 0,
285
- v.VarSeq.dims[0], 0,
286
- v.VarSeq.size, ctx);
371
+ const indexed_type_t x = { w.VarSeq.linear_index, w.VarSeq.dims[0] };
372
+ const indexed_type_t y = { v.VarSeq.linear_index, v.VarSeq.dims[0] };
373
+ return match_concrete_var_dim(&x, &y, vdims, ctx);
374
+ }
375
+
376
+ static int
377
+ resolve_array(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
378
+ {
379
+ const char *key = "array";
380
+ symtable_entry_t v;
381
+
382
+ v = symtable_find(tbl, key);
383
+ if (v.tag == Unbound) {
384
+ if (symtable_add(tbl, key, w, ctx) < 0) {
385
+ return -1;
386
+ }
387
+ return 1;
388
+ }
389
+
390
+ return v.ArraySeq.size == w.ArraySeq.size;
287
391
  }
288
392
 
289
393
  static int
@@ -332,8 +436,32 @@ match_record_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
332
436
  }
333
437
 
334
438
  static int
335
- match_categorical(ndt_value_t *p, int64_t plen,
336
- ndt_value_t *c, int64_t clen)
439
+ match_union_tags(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
440
+ ndt_context_t *ctx)
441
+ {
442
+ int64_t i;
443
+ int n;
444
+
445
+ assert(p->tag == Union && c->tag == Union);
446
+
447
+ if (p->Union.ntags != c->Union.ntags) {
448
+ return 0;
449
+ }
450
+
451
+ for (i = 0; i < p->Union.ntags; i++) {
452
+ n = strcmp(p->Union.tags[i], c->Union.tags[i]);
453
+ if (n != 0) return 0;
454
+
455
+ n = match_datashape(p->Union.types[i], c->Union.types[i], tbl, ctx);
456
+ if (n <= 0) return n;
457
+ }
458
+
459
+ return 1;
460
+ }
461
+
462
+ static int
463
+ match_categorical(const ndt_value_t *p, int64_t plen,
464
+ const ndt_value_t *c, int64_t clen)
337
465
  {
338
466
  int64_t i;
339
467
 
@@ -378,7 +506,7 @@ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
378
506
  }
379
507
  return outer_inner(v, i+1, t->FixedDim.type, ndim);
380
508
  }
381
- case VarDim: {
509
+ case VarDim: case VarDimElem: {
382
510
  switch (v->tag) {
383
511
  case VarSeq:
384
512
  v->VarSeq.size = i+1;
@@ -389,6 +517,18 @@ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
389
517
  }
390
518
  return outer_inner(v, i+1, t->VarDim.type, ndim);
391
519
  }
520
+ case Array: {
521
+ switch (v->tag) {
522
+ case ArraySeq:
523
+ v->ArraySeq.size = i+1;
524
+ v->ArraySeq.dims[i] = t;
525
+ break;
526
+ default:
527
+ return NULL;
528
+ }
529
+ return outer_inner(v, i+1, t->Array.type, ndim);
530
+ }
531
+
392
532
  default:
393
533
  return NULL;
394
534
  }
@@ -400,6 +540,10 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
400
540
  {
401
541
  int n;
402
542
 
543
+ if (c->tag == VarDimElem) {
544
+ return match_datashape(p, c->VarDimElem.type, tbl, ctx);
545
+ }
546
+
403
547
  if (ndt_is_optional(c) != ndt_is_optional(p)) return 0;
404
548
 
405
549
  switch (p->tag) {
@@ -445,57 +589,11 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
445
589
  return match_datashape(p->SymbolicDim.type, c->FixedDim.type, tbl, ctx);
446
590
  }
447
591
 
448
- case EllipsisDim: {
449
- symtable_entry_t outer;
450
- const ndt_t *inner;
451
-
452
- if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
453
- return 0;
454
- }
455
- if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
456
- return 0;
457
- }
458
-
459
- if (p->EllipsisDim.name == NULL) {
460
- outer.tag = BroadcastSeq;
461
- outer.BroadcastSeq.size = 0;
462
- }
463
- else if (strcmp(p->EllipsisDim.name, "var") == 0) {
464
- outer.tag = VarSeq;
465
- outer.VarSeq.size = 0;
466
- }
467
- else {
468
- outer.tag = FixedSeq;
469
- outer.FixedSeq.size = 0;
470
- }
471
-
472
- inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
473
- if (inner == NULL) {
474
- return 0;
475
- }
476
-
477
- n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
478
- if (n <= 0) {
479
- return n;
480
- }
481
-
482
- switch (outer.tag) {
483
- case BroadcastSeq:
484
- return resolve_broadcast(outer, tbl, ctx);
485
- case FixedSeq:
486
- return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
487
- case VarSeq:
488
- return resolve_var(outer, tbl, ctx);
489
- default: /* NOT REACHED */
490
- ndt_internal_error("invalid tag");
491
- }
492
- }
493
-
494
592
  case Bool:
495
593
  case Int8: case Int16: case Int32: case Int64:
496
594
  case Uint8: case Uint16: case Uint32: case Uint64:
497
- case Float16: case Float32: case Float64:
498
- case Complex32: case Complex64: case Complex128:
595
+ case BFloat16: case Float16: case Float32: case Float64:
596
+ case BComplex32: case Complex32: case Complex64: case Complex128:
499
597
  case String:
500
598
  return p->tag == c->tag;
501
599
  case FixedString:
@@ -528,6 +626,10 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
528
626
  return c->tag == Categorical &&
529
627
  match_categorical(p->Categorical.types, p->Categorical.ntypes,
530
628
  c->Categorical.types, c->Categorical.ntypes);
629
+ case Array:
630
+ if (c->tag != Array) return 0;
631
+ if (c->Array.itemsize != p->Array.itemsize) return 0;
632
+ return match_datashape(p->Array.type, c->Array.type, tbl, ctx);
531
633
  case Ref:
532
634
  if (c->tag != Ref) return 0;
533
635
  return match_datashape(p->Ref.type, c->Ref.type, tbl, ctx);
@@ -536,9 +638,12 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
536
638
  if (c->tag != Tuple) return 0;
537
639
  return match_tuple_fields(p, c, tbl, ctx);
538
640
  case Record:
539
- if (p->Tuple.flag == Variadic) return 0;
641
+ if (p->Record.flag == Variadic) return 0;
540
642
  if (c->tag != Record) return 0;
541
643
  return match_record_fields(p, c, tbl, ctx);
644
+ case Union:
645
+ if (c->tag != Union) return 0;
646
+ return match_union_tags(p, c, tbl, ctx);
542
647
  case Function: {
543
648
  int64_t i;
544
649
  if (c->tag != Function ||
@@ -576,12 +681,83 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
576
681
  case Constr:
577
682
  return c->tag == Constr && strcmp(p->Constr.name, c->Constr.name) == 0 &&
578
683
  ndt_equal(p->Constr.type, c->Constr.type);
684
+ case VarDimElem:
685
+ ndt_err_format(ctx, NDT_ValueError,
686
+ "VarDimElem cannot occur in pattern");
687
+ return -1;
688
+ case EllipsisDim:
689
+ ndt_err_format(ctx, NDT_RuntimeError,
690
+ "match_datashape: internal_error: unexpected EllipsisDim");
691
+ return -1;
579
692
  }
580
693
 
581
694
  /* NOT REACHED: tags should be exhaustive. */
582
695
  ndt_internal_error("invalid type");
583
696
  }
584
697
 
698
+ static int
699
+ match_datashape_top(const ndt_t *p, const ndt_t *c, int64_t linear_index,
700
+ symtable_t *tbl, ndt_context_t *ctx)
701
+ {
702
+ switch (p->tag) {
703
+ case EllipsisDim: {
704
+ symtable_entry_t outer;
705
+ const ndt_t *inner;
706
+ int n;
707
+
708
+ if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
709
+ return 0;
710
+ }
711
+ if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
712
+ return 0;
713
+ }
714
+
715
+ if (p->EllipsisDim.name == NULL) {
716
+ outer.tag = BroadcastSeq;
717
+ outer.BroadcastSeq.size = 0;
718
+ }
719
+ else if (strcmp(p->EllipsisDim.name, "var") == 0) {
720
+ outer.tag = VarSeq;
721
+ outer.VarSeq.size = 0;
722
+ outer.VarSeq.linear_index = linear_index;
723
+ }
724
+ else if (strcmp(p->EllipsisDim.name, "array") == 0) {
725
+ outer.tag = ArraySeq;
726
+ outer.ArraySeq.size = 0;
727
+ }
728
+ else {
729
+ outer.tag = FixedSeq;
730
+ outer.FixedSeq.size = 0;
731
+ }
732
+
733
+ inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
734
+ if (inner == NULL) {
735
+ return 0;
736
+ }
737
+
738
+ n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
739
+ if (n <= 0) {
740
+ return n;
741
+ }
742
+
743
+ switch (outer.tag) {
744
+ case BroadcastSeq:
745
+ return resolve_broadcast(outer, tbl, ctx);
746
+ case FixedSeq:
747
+ return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
748
+ case VarSeq:
749
+ return resolve_var(outer, tbl, ctx);
750
+ case ArraySeq:
751
+ return resolve_array(outer, tbl, ctx);
752
+ default: /* NOT REACHED */
753
+ ndt_internal_error("invalid tag");
754
+ }
755
+ }
756
+ default:
757
+ return match_datashape(p, c, tbl, ctx);
758
+ }
759
+ }
760
+
585
761
  int
586
762
  ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
587
763
  {
@@ -597,19 +773,18 @@ ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
597
773
  return -1;
598
774
  }
599
775
 
600
- ret = match_datashape(p, c, tbl, ctx);
776
+ ret = match_datashape_top(p, c, 0, tbl, ctx);
601
777
  symtable_del(tbl);
602
778
  return ret;
603
779
  }
604
780
 
605
- static ndt_t *
781
+ static const ndt_t *
606
782
  broadcast(const ndt_t *t, const int64_t *shape,
607
783
  int outer_dims, int inner_dims,
608
784
  bool use_max, ndt_context_t *ctx)
609
785
  {
610
786
  ndt_ndarray_t u;
611
- const ndt_t *dtype;
612
- ndt_t *v;
787
+ const ndt_t *v, *w;
613
788
  int64_t step;
614
789
  int ndim;
615
790
  int i, k;
@@ -619,14 +794,12 @@ broadcast(const ndt_t *t, const int64_t *shape,
619
794
  return NULL;
620
795
  }
621
796
 
622
- dtype = ndt_dtype(t);
623
- v = ndt_copy(dtype, ctx);
624
- if (v == NULL) {
625
- return NULL;
626
- }
797
+ v = ndt_dtype(t);
798
+ ndt_incref(v);
627
799
 
628
800
  for (i=ndim-1; i>=ndim-inner_dims; i--) {
629
- v = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
801
+ w = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
802
+ ndt_move(&v, w);
630
803
  if (v == NULL) {
631
804
  return NULL;
632
805
  }
@@ -634,7 +807,8 @@ broadcast(const ndt_t *t, const int64_t *shape,
634
807
 
635
808
  for (k=outer_dims-1; i>=0 && k>=0; i--, k--) {
636
809
  step = u.shape[i]<=1 ? 0 : u.steps[i];
637
- v = ndt_fixed_dim(v, shape[k], step, ctx);
810
+ w = ndt_fixed_dim(v, shape[k], step, ctx);
811
+ ndt_move(&v, w);
638
812
  if (v == NULL) {
639
813
  return NULL;
640
814
  }
@@ -642,11 +816,12 @@ broadcast(const ndt_t *t, const int64_t *shape,
642
816
 
643
817
  for (; k>=0; k--) {
644
818
  if (use_max) {
645
- v = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
819
+ w = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
646
820
  }
647
821
  else {
648
- v = ndt_fixed_dim(v, shape[k], 0, ctx);
822
+ w = ndt_fixed_dim(v, shape[k], 0, ctx);
649
823
  }
824
+ ndt_move(&v, w);
650
825
  if (v == NULL) {
651
826
  return NULL;
652
827
  }
@@ -656,34 +831,44 @@ broadcast(const ndt_t *t, const int64_t *shape,
656
831
  }
657
832
 
658
833
  int
659
- ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
660
- const ndt_t *in[], const int nin,
661
- const int64_t *shape, int outer_dims,
662
- ndt_context_t *ctx)
834
+ ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig, bool check_broadcast,
835
+ const int64_t *shape, const int outer_dims, ndt_context_t *ctx)
663
836
  {
664
- ndt_t *u;
837
+ const int nin = spec->nin;
838
+ const int nout = spec->nout;
665
839
  int inner_dims;
666
840
  int i;
667
841
 
668
842
  for (i = 0; i < nin; i++) {
669
843
  inner_dims = sig->Function.types[i]->ndim-1;
670
- spec->broadcast[i] = broadcast(in[i], shape,
671
- outer_dims, inner_dims, false, ctx);
672
- if (spec->broadcast[i] == NULL) {
844
+ const ndt_t *t = spec->types[i];
845
+ const ndt_t *u = broadcast(t, shape, outer_dims, inner_dims, false, ctx);
846
+ if (u == NULL) {
673
847
  return -1;
674
848
  }
675
- spec->nbroadcast++;
849
+ ndt_decref(t);
850
+ spec->types[i] = u;
676
851
  }
677
852
 
678
- for (i = 0; i < spec->nout; i++) {
853
+ for (i = 0; i < nout; i++) {
679
854
  inner_dims = sig->Function.types[nin+i]->ndim-1;
680
- u = broadcast(spec->out[i], shape,
681
- outer_dims, inner_dims, true, ctx);
855
+ const ndt_t *t = spec->types[nin+i];
856
+ const ndt_t *u = broadcast(t, shape, outer_dims, inner_dims, true, ctx);
682
857
  if (u == NULL) {
683
858
  return -1;
684
859
  }
685
- ndt_del(spec->out[i]);
686
- spec->out[i] = u;
860
+
861
+ if (check_broadcast) {
862
+ if (!ndt_equal(t, u)) {
863
+ ndt_err_format(ctx, NDT_ValueError,
864
+ "explicit 'out' type not compatible with input types");
865
+ ndt_decref(u);
866
+ return -1;
867
+ }
868
+ }
869
+
870
+ ndt_decref(t);
871
+ spec->types[nin+i] = u;
687
872
  }
688
873
 
689
874
  spec->outer_dims = outer_dims;
@@ -692,8 +877,7 @@ ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
692
877
  }
693
878
 
694
879
  static int
695
- broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
696
- const ndt_t *in[], const int nin,
880
+ broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig, bool check_broadcast,
697
881
  const symtable_t *tbl, ndt_context_t *ctx)
698
882
  {
699
883
  symtable_entry_t v;
@@ -705,7 +889,7 @@ broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
705
889
  return -1;
706
890
  }
707
891
 
708
- return ndt_broadcast_all(spec, sig, in, nin,
892
+ return ndt_broadcast_all(spec, sig, check_broadcast,
709
893
  v.BroadcastSeq.dims, v.BroadcastSeq.size,
710
894
  ctx);
711
895
  }
@@ -746,20 +930,23 @@ resolve_constraint(const ndt_constraint_t *c, const void *args, symtable_t *tbl,
746
930
  */
747
931
  int
748
932
  ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
749
- const ndt_t *in[], const int nin,
933
+ const ndt_t *types[], const int64_t li[],
934
+ const int nin, const int nout, bool check_broadcast,
750
935
  const ndt_constraint_t *c, const void *args,
751
936
  ndt_context_t *ctx)
752
937
  {
753
938
  symtable_t *tbl;
754
- ndt_t *t;
939
+ const ndt_t *t;
755
940
  const char *name;
756
- int ret;
941
+ const int nargs = nin + nout;
757
942
  int64_t i;
943
+ int ret;
758
944
 
759
945
  assert(spec->flags == 0);
760
- assert(spec->nout == 0);
761
- assert(spec->nbroadcast == 0);
762
946
  assert(spec->outer_dims == 0);
947
+ assert(spec->nin == 0);
948
+ assert(spec->nout == 0);
949
+ assert(spec->nargs == 0);
763
950
 
764
951
  if (sig->tag != Function) {
765
952
  ndt_err_format(ctx, NDT_ValueError,
@@ -773,8 +960,30 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
773
960
  return -1;
774
961
  }
775
962
 
776
- for (i = 0; i < nin; i++) {
777
- if (ndt_is_abstract(in[i])) {
963
+ /*
964
+ * Two configurations are allowed:
965
+ * 1) nout==0 and all 'out' types are inferred.
966
+ * 2) nout==sig->Function.nout and all 'out' types are given.
967
+ */
968
+ if (nout && nout != sig->Function.nout) {
969
+ ndt_err_format(ctx, NDT_ValueError,
970
+ "expected %" PRIi64 " 'out' arguments, got %d", sig->Function.nout, nout);
971
+ return -1;
972
+ }
973
+
974
+ /*
975
+ * Broadcast configurations:
976
+ * nout==0 -> check_broadcast==false.
977
+ */
978
+ if (!nout && check_broadcast == true) {
979
+ ndt_err_format(ctx, NDT_RuntimeError,
980
+ "internal error: without explicit 'out' arguments "
981
+ "'check_broadcast' must be false");
982
+ return -1;
983
+ }
984
+
985
+ for (i = 0; i < nargs; i++) {
986
+ if (ndt_is_abstract(types[i])) {
778
987
  ndt_err_format(ctx, NDT_ValueError,
779
988
  "type checking requires concrete argument types");
780
989
  return -1;
@@ -786,8 +995,8 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
786
995
  return -1;
787
996
  }
788
997
 
789
- for (i = 0; i < nin; i++) {
790
- ret = match_datashape(sig->Function.types[i], in[i], tbl, ctx);
998
+ for (i = 0; i < nargs; i++) {
999
+ ret = match_datashape_top(sig->Function.types[i], types[i], li[i], tbl, ctx);
791
1000
  if (ret <= 0) {
792
1001
  symtable_del(tbl);
793
1002
 
@@ -805,15 +1014,34 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
805
1014
  return -1;
806
1015
  }
807
1016
 
808
- for (i = 0; i < sig->Function.nout; i++) {
809
- spec->out[i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
810
- if (spec->out[i] == NULL) {
811
- ndt_apply_spec_clear(spec);
812
- symtable_del(tbl);
813
- return -1;
1017
+ spec->nin = 0;
1018
+ for (i = 0; i < nin; i++) {
1019
+ ndt_incref(types[i]);
1020
+ spec->types[i] = types[i];
1021
+ spec->nin++;
1022
+ }
1023
+
1024
+ spec->nout = 0;
1025
+ if (nout == 0) {
1026
+ /* Infer the return types. */
1027
+ for (i = 0; i < sig->Function.nout; i++) {
1028
+ spec->types[nin+i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
1029
+ if (spec->types[nin+i] == NULL) {
1030
+ ndt_apply_spec_clear(spec);
1031
+ symtable_del(tbl);
1032
+ return -1;
1033
+ }
1034
+ spec->nout++;
1035
+ }
1036
+ }
1037
+ else {
1038
+ for (i = 0; i < nout; i++) {
1039
+ ndt_incref(types[nin+i]);
1040
+ spec->types[nin+i] = types[nin+i];
1041
+ spec->nout++;
814
1042
  }
815
- spec->nout++;
816
1043
  }
1044
+ spec->nargs = spec->nin + spec->nout;
817
1045
 
818
1046
  if (sig->flags & NDT_ELLIPSIS) {
819
1047
  if (sig->Function.nargs == 0 || sig->Function.types[0]->tag != EllipsisDim) {
@@ -833,8 +1061,17 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
833
1061
  case FixedSeq:
834
1062
  spec->outer_dims = v.FixedSeq.size;
835
1063
  break;
836
- case VarSeq:
837
- spec->outer_dims = v.VarSeq.size;
1064
+ case VarSeq: {
1065
+ if (v.VarSeq.size > 0) {
1066
+ spec->outer_dims = ndt_logical_ndim(v.VarSeq.dims[0]);
1067
+ }
1068
+ else {
1069
+ spec->outer_dims = 0;
1070
+ }
1071
+ break;
1072
+ }
1073
+ case ArraySeq:
1074
+ spec->outer_dims = v.ArraySeq.size;
838
1075
  break;
839
1076
  default:
840
1077
  ndt_err_format(ctx, NDT_RuntimeError,
@@ -845,7 +1082,7 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
845
1082
  }
846
1083
  }
847
1084
  else {
848
- if (broadcast_all(spec, sig, in, nin, tbl, ctx) < 0) {
1085
+ if (broadcast_all(spec, sig, check_broadcast, tbl, ctx) < 0) {
849
1086
  ndt_apply_spec_clear(spec);
850
1087
  symtable_del(tbl);
851
1088
  return -1;
@@ -855,31 +1092,33 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
855
1092
 
856
1093
  symtable_del(tbl);
857
1094
 
858
- for (i = 0; i < sig->Function.nout; i++) {
859
- ndt_t *_p = sig->Function.types[nin+i];
860
- ndt_t *_c = spec->out[i];
861
- ndt_t *_t = to_fortran(_p, _c, ctx);
862
- if (_t == NULL) {
863
- ndt_apply_spec_clear(spec);
864
- return -1;
865
- }
866
- if (_t != _c) {
867
- ndt_del(_c);
1095
+ if (nout == 0) {
1096
+ for (i = 0; i < sig->Function.nout; i++) {
1097
+ const ndt_t *_p = sig->Function.types[nin+i];
1098
+ const ndt_t *_c = spec->types[nin+i];
1099
+ const ndt_t *_t = to_fortran(_p, _c, ctx);
1100
+
1101
+ if (_t == NULL) {
1102
+ ndt_apply_spec_clear(spec);
1103
+ return -1;
1104
+ }
1105
+
1106
+ ndt_decref(_c);
1107
+ spec->types[nin+i] = _t;
868
1108
  }
869
- spec->out[i] = _t;
870
1109
  }
871
1110
 
872
- if (!check_contig(sig->Function.types, (ndt_t **)in, nin)) {
1111
+ if (!check_contig(sig->Function.types, types, nargs)) {
873
1112
  ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
1113
+ ndt_apply_spec_clear(spec);
874
1114
  return -1;
875
1115
  }
876
- if (!check_contig(sig->Function.types+nin, spec->out, spec->nout)) {
877
- ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
1116
+
1117
+ if (ndt_select_kernel_strategy(spec, ctx) < 0) {
1118
+ ndt_apply_spec_clear(spec);
878
1119
  return -1;
879
1120
  }
880
1121
 
881
- ndt_select_kernel_strategy(spec, sig, in, nin);
882
-
883
1122
  return 0;
884
1123
  }
885
1124
 
@@ -888,11 +1127,11 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
888
1127
  /* Optimized binary typecheck for fixed input */
889
1128
  /*****************************************************************************/
890
1129
 
891
- static ndt_t *
892
- binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
893
- const int64_t *shape, int size, ndt_context_t *ctx)
1130
+ static const ndt_t *
1131
+ fast_broadcast(const ndt_ndarray_t *t, const ndt_t *dtype,
1132
+ const int64_t *shape, int size, ndt_context_t *ctx)
894
1133
  {
895
- ndt_t *v;
1134
+ const ndt_t *v;
896
1135
  int64_t step;
897
1136
  int i, k;
898
1137
 
@@ -903,14 +1142,18 @@ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
903
1142
 
904
1143
  for (i=t->ndim-1, k=size-1; i>=0 && k>=0; i--, k--) {
905
1144
  step = t->shape[i]<=1 ? 0 : t->steps[i];
906
- v = ndt_fixed_dim(v, shape[k], step, ctx);
1145
+ const ndt_t *w = ndt_fixed_dim(v, shape[k], step, ctx);
1146
+ ndt_decref(v);
1147
+ v = w;
907
1148
  if (v == NULL) {
908
1149
  return NULL;
909
1150
  }
910
1151
  }
911
1152
 
912
1153
  for (; k>=0; k--) {
913
- v = ndt_fixed_dim(v, shape[k], 0, ctx);
1154
+ const ndt_t *w = ndt_fixed_dim(v, shape[k], 0, ctx);
1155
+ ndt_decref(v);
1156
+ v = w;
914
1157
  if (v == NULL) {
915
1158
  return NULL;
916
1159
  }
@@ -919,15 +1162,19 @@ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
919
1162
  return v;
920
1163
  }
921
1164
 
922
- static ndt_t *
923
- fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
1165
+ static const ndt_t *
1166
+ fixed_dim_from_shape(const int64_t shape[], int len, const ndt_t *dtype,
924
1167
  ndt_context_t *ctx)
925
1168
  {
926
- ndt_t *t;
1169
+ const ndt_t *t;
927
1170
  int i;
928
1171
 
1172
+ ndt_incref(dtype);
1173
+
929
1174
  for (i=len-1, t=dtype; i >= 0; i--) {
930
- t = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
1175
+ const ndt_t *u = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
1176
+ ndt_decref(t);
1177
+ t = u;
931
1178
  if (t == NULL) {
932
1179
  return NULL;
933
1180
  }
@@ -936,83 +1183,313 @@ fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
936
1183
  return t;
937
1184
  }
938
1185
 
939
- static bool
940
- shape_equal(const ndt_ndarray_t *a, const ndt_ndarray_t *b)
1186
+ static int
1187
+ broadcast_error(const char *msg, ndt_context_t *ctx)
941
1188
  {
942
- if (b->ndim != a->ndim) {
943
- return false;
1189
+ ndt_err_format(ctx, NDT_TypeError, "%s", msg);
1190
+ return -1;
1191
+ }
1192
+
1193
+ static int
1194
+ _ndt_unary_broadcast(ndt_apply_spec_t *spec,
1195
+ const ndt_ndarray_t *x, const ndt_t *dtype_x,
1196
+ const ndt_t *out, const ndt_t *dtype_out,
1197
+ const bool check_broadcast, const int inner,
1198
+ ndt_context_t *ctx)
1199
+ {
1200
+ int64_t shape[NDT_MAX_DIM];
1201
+ int size = x->ndim;
1202
+ ndt_ndarray_t y;
1203
+ const ndt_t *t;
1204
+
1205
+ for (int i = 0; i < size; i++) {
1206
+ shape[i] = x->shape[i];
944
1207
  }
945
1208
 
946
- for (int i = 0; i < a->ndim; i++) {
947
- if (b->shape[i] != a->shape[i]) {
948
- return false;
1209
+ if (out != NULL) {
1210
+ if (ndt_as_ndarray(&y, out, ctx) < 0) {
1211
+ return -1;
949
1212
  }
1213
+
1214
+ size = _resolve_broadcast(shape, size, y.shape, y.ndim);
1215
+ if (size < 0) {
1216
+ return broadcast_error("could not broadcast output argument", ctx);
1217
+ }
1218
+ }
1219
+
1220
+ spec->types[0] = fast_broadcast(x, dtype_x, shape, size, ctx);
1221
+ if (spec->types[0] == NULL) {
1222
+ return -1;
1223
+ }
1224
+
1225
+ if (out != NULL) {
1226
+ if (check_broadcast) {
1227
+ t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
1228
+ if (t == NULL) {
1229
+ ndt_decref(spec->types[0]);
1230
+ return -1;
1231
+ }
1232
+
1233
+ if (!ndt_equal(t, out)) {
1234
+ ndt_err_format(ctx, NDT_ValueError,
1235
+ "explicit 'out' type not compatible with input types");
1236
+ ndt_decref(spec->types[0]);
1237
+ ndt_decref(t);
1238
+ return -1;
1239
+ }
1240
+ }
1241
+ else {
1242
+ t = fast_broadcast(&y, dtype_out, shape, size, ctx);
1243
+ if (t == NULL) {
1244
+ ndt_decref(spec->types[0]);
1245
+ return -1;
1246
+ }
1247
+ }
1248
+ }
1249
+ else {
1250
+ t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
1251
+ if (t == NULL) {
1252
+ ndt_decref(spec->types[0]);
1253
+ return -1;
1254
+ }
1255
+ }
1256
+ spec->types[1] = t;
1257
+
1258
+ spec->outer_dims = size-inner;
1259
+ spec->nin = 1;
1260
+ spec->nout = 1;
1261
+ spec->nargs = 2;
1262
+
1263
+ if (ndt_select_kernel_strategy(spec, ctx) < 0) {
1264
+ ndt_apply_spec_clear(spec);
1265
+ return -1;
1266
+ }
1267
+
1268
+ return 0;
1269
+ }
1270
+
1271
+ static bool
1272
+ unary_all_ellipses(const ndt_t *t0, const ndt_t *t1, ndt_context_t *ctx)
1273
+ {
1274
+ if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
1275
+ (t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL)) {
1276
+ ndt_err_format(ctx, NDT_RuntimeError,
1277
+ "fast unary typecheck expects leading ellipsis dimensions");
1278
+ return false;
950
1279
  }
951
1280
 
952
1281
  return true;
953
1282
  }
954
1283
 
955
- static int
956
- _ndt_binary_broadcast(ndt_apply_spec_t *spec, const ndt_t *sig,
957
- const ndt_ndarray_t *x, const ndt_ndarray_t *y,
958
- const ndt_t *in[], const int nin, ndt_t *dtype,
959
- int inner, ndt_context_t *ctx)
1284
+ static bool
1285
+ unary_all_same_symbol(const ndt_t *t0, const ndt_t *t1)
960
1286
  {
961
- int64_t shape[NDT_MAX_DIM];
962
- int size;
963
-
964
- if (shape_equal(x, y)) {
965
- spec->nout = 1;
966
- spec->nbroadcast = 0;
967
- spec->outer_dims = x->ndim-inner;
968
- spec->out[0] = fixed_dim_from_shape(x->shape, x->ndim, dtype, ctx);
969
- if (spec->out[0] == NULL) {
970
- return -1;
971
- }
1287
+ if (t0->tag != SymbolicDim || t1->tag != SymbolicDim) {
1288
+ return false;
972
1289
  }
973
- else {
974
- for (int i = 0; i < x->ndim; i++) {
975
- shape[i] = x->shape[i];
1290
+
1291
+ return strcmp(t0->SymbolicDim.name, t1->SymbolicDim.name) == 0;
1292
+ }
1293
+
1294
+ static bool
1295
+ unary_all_ndim0(const ndt_t *t0, const ndt_t *t1)
1296
+ {
1297
+ return t0->ndim == 0 && t1->ndim == 0;
1298
+ }
1299
+
1300
+ /*
1301
+ * Optimized type checking for very specific signatures. The caller must
1302
+ * have identified the kernel location and the signature. For performance
1303
+ * reasons, no substitution is performed on the dtype, so the dtype must be
1304
+ * concrete.
1305
+ *
1306
+ * Supported signature: 1) ... * T0 -> ... * T1
1307
+ */
1308
+ int
1309
+ ndt_fast_unary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1310
+ const ndt_t *types[], const int nin, const int nout,
1311
+ const bool check_broadcast, ndt_context_t *ctx)
1312
+ {
1313
+ const ndt_t *p0, *p1;
1314
+ ndt_ndarray_t x;
1315
+ const ndt_t *out = NULL;
1316
+ const ndt_t *dtype = NULL;
1317
+
1318
+ assert(spec->flags == 0);
1319
+ assert(spec->outer_dims == 0);
1320
+ assert(spec->nin == 0);
1321
+ assert(spec->nout == 0);
1322
+ assert(spec->nargs == 0);
1323
+
1324
+ if (sig->tag != Function ||
1325
+ sig->Function.nin != 1) {
1326
+ ndt_err_format(ctx, NDT_RuntimeError,
1327
+ "fast unary typecheck expects a signature with one input and "
1328
+ "one output");
1329
+ return -1;
1330
+ }
1331
+
1332
+ if (nin != 1) {
1333
+ ndt_err_format(ctx, NDT_RuntimeError,
1334
+ "fast unary typecheck expects one input argument");
1335
+ return -1;
1336
+ }
1337
+
1338
+ p0 = sig->Function.types[0];
1339
+ p1 = sig->Function.types[1];
1340
+
1341
+ if (nout) {
1342
+ if (nout != 1) {
1343
+ ndt_err_format(ctx, NDT_RuntimeError,
1344
+ "fast unary typecheck expects at most one explicit out argument");
1345
+ return -1;
976
1346
  }
1347
+ out = types[1];
1348
+ dtype = ndt_dtype(types[1]);
977
1349
 
978
- size = _resolve_broadcast(shape, x->ndim, y->shape, y->ndim);
979
- if (size < 0) {
980
- ndt_err_format(ctx, NDT_TypeError, "broadcast error");
981
- ndt_del(dtype);
1350
+ if (!ndt_equal(dtype, ndt_dtype(p1))) {
1351
+ ndt_err_format(ctx, NDT_ValueError,
1352
+ "dtype of the out argument does not match");
982
1353
  return -1;
983
1354
  }
1355
+ }
1356
+ else {
1357
+ dtype = ndt_dtype(sig->Function.types[1]);
1358
+ }
1359
+
1360
+ if (ndt_is_abstract(dtype)) {
1361
+ ndt_err_format(ctx, NDT_RuntimeError,
1362
+ "fast unary typecheck expects a concrete dtype");
1363
+ return -1;
1364
+ }
1365
+
1366
+ if (!unary_all_ellipses(p0, p1, ctx)) {
1367
+ return -1;
1368
+ }
1369
+
1370
+ if (ndt_as_ndarray(&x, types[0], ctx) < 0) {
1371
+ return -1;
1372
+ }
1373
+
1374
+ p0 = p0->EllipsisDim.type;
1375
+ p1 = p1->EllipsisDim.type;
1376
+
1377
+ if (unary_all_same_symbol(p0, p1)) {
1378
+ return _ndt_unary_broadcast(spec,
1379
+ &x, ndt_dtype(types[0]),
1380
+ out, dtype, check_broadcast,
1381
+ 1, ctx);
1382
+ }
1383
+ else if (unary_all_ndim0(p0, p1)) {
1384
+ return _ndt_unary_broadcast(spec,
1385
+ &x, ndt_dtype(types[0]),
1386
+ out, dtype, check_broadcast,
1387
+ 0, ctx);
1388
+ }
1389
+ else {
1390
+ ndt_err_format(ctx, NDT_RuntimeError,
1391
+ "unsupported signature in fast unary typecheck");
1392
+ return -1;
1393
+ }
1394
+ }
1395
+
1396
+ static int
1397
+ _ndt_binary_broadcast(ndt_apply_spec_t *spec,
1398
+ const ndt_ndarray_t *x, const ndt_t *dtype_x,
1399
+ const ndt_ndarray_t *y, const ndt_t *dtype_y,
1400
+ const ndt_t *out, const ndt_t *dtype_out,
1401
+ const bool check_broadcast, const int inner,
1402
+ ndt_context_t *ctx)
1403
+ {
1404
+ int64_t shape[NDT_MAX_DIM];
1405
+ int size = x->ndim;
1406
+ ndt_ndarray_t z;
1407
+ const ndt_t *t;
984
1408
 
985
- spec->nout = 1;
986
- spec->nbroadcast = 2;
987
- spec->outer_dims = size-inner;
1409
+ for (int i = 0; i < size; i++) {
1410
+ shape[i] = x->shape[i];
1411
+ }
1412
+
1413
+ size = _resolve_broadcast(shape, size, y->shape, y->ndim);
1414
+ if (size < 0) {
1415
+ return broadcast_error("could not broadcast input arguments", ctx);
1416
+ }
988
1417
 
989
- spec->out[0] = fixed_dim_from_shape(shape, size, dtype, ctx);
990
- if (spec->out[0] == NULL) {
1418
+ if (out != NULL) {
1419
+ if (ndt_as_ndarray(&z, out, ctx) < 0) {
991
1420
  return -1;
992
1421
  }
993
1422
 
994
- spec->broadcast[0] = binary_broadcast_1D(x, ndt_dtype(in[0]), shape, size, ctx);
995
- if (spec->broadcast[0] == NULL) {
996
- ndt_del(spec->out[0]);
997
- return -1;
1423
+ size = _resolve_broadcast(shape, size, z.shape, z.ndim);
1424
+ if (size < 0) {
1425
+ return broadcast_error("could not broadcast output argument", ctx);
998
1426
  }
1427
+ }
1428
+
1429
+ spec->types[0] = fast_broadcast(x, dtype_x, shape, size, ctx);
1430
+ if (spec->types[0] == NULL) {
1431
+ return -1;
1432
+ }
1433
+
1434
+ spec->types[1] = fast_broadcast(y, dtype_y, shape, size, ctx);
1435
+ if (spec->types[1] == NULL) {
1436
+ ndt_decref(spec->types[0]);
1437
+ return -1;
1438
+ }
1439
+
1440
+ if (out != NULL) {
1441
+ if (check_broadcast) {
1442
+ t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
1443
+ if (t == NULL) {
1444
+ ndt_decref(spec->types[0]);
1445
+ ndt_decref(spec->types[1]);
1446
+ return -1;
1447
+ }
999
1448
 
1000
- spec->broadcast[1] = binary_broadcast_1D(y, ndt_dtype(in[1]), shape, size, ctx);
1001
- if (spec->broadcast[1] == NULL) {
1002
- ndt_del(spec->out[0]);
1003
- ndt_del(spec->broadcast[0]);
1449
+ if (!ndt_equal(t, out)) {
1450
+ ndt_err_format(ctx, NDT_ValueError,
1451
+ "explicit 'out' type not compatible with input types");
1452
+ ndt_decref(spec->types[0]);
1453
+ ndt_decref(spec->types[1]);
1454
+ ndt_decref(t);
1455
+ return -1;
1456
+ }
1457
+ }
1458
+ else {
1459
+ t = fast_broadcast(&z, dtype_out, shape, size, ctx);
1460
+ if (t == NULL) {
1461
+ ndt_decref(spec->types[0]);
1462
+ ndt_decref(spec->types[1]);
1463
+ return -1;
1464
+ }
1465
+ }
1466
+ }
1467
+ else {
1468
+ t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
1469
+ if (t == NULL) {
1470
+ ndt_decref(spec->types[0]);
1471
+ ndt_decref(spec->types[1]);
1004
1472
  return -1;
1005
1473
  }
1006
1474
  }
1475
+ spec->types[2] = t;
1007
1476
 
1008
- ndt_select_kernel_strategy(spec, sig, in, nin);
1477
+ spec->outer_dims = size-inner;
1478
+ spec->nin = 2;
1479
+ spec->nout = 1;
1480
+ spec->nargs = 3;
1481
+
1482
+ if (ndt_select_kernel_strategy(spec, ctx) < 0) {
1483
+ ndt_apply_spec_clear(spec);
1484
+ return -1;
1485
+ }
1009
1486
 
1010
1487
  return 0;
1011
1488
  }
1012
1489
 
1013
1490
  static bool
1014
- all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
1015
- ndt_context_t *ctx)
1491
+ binary_all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
1492
+ ndt_context_t *ctx)
1016
1493
  {
1017
1494
  if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
1018
1495
  (t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL) ||
@@ -1026,7 +1503,7 @@ all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
1026
1503
  }
1027
1504
 
1028
1505
  static bool
1029
- all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1506
+ binary_all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1030
1507
  {
1031
1508
  if (t0->tag != SymbolicDim || t1->tag != SymbolicDim ||
1032
1509
  t2->tag != SymbolicDim) {
@@ -1038,37 +1515,37 @@ all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1038
1515
  }
1039
1516
 
1040
1517
  static bool
1041
- all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1518
+ binary_all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1042
1519
  {
1043
1520
  return t0->ndim == 0 && t1->ndim == 0 && t2->ndim == 0;
1044
1521
  }
1045
1522
 
1046
1523
  /*
1047
1524
  * Optimized type checking for very specific signatures. The caller must
1048
- * have identified the kernel location, signature and the dtype. For
1049
- * performance reasons, no substitution is performed on the dtype, so
1050
- * the dtype must be concrete.
1525
+ * have identified the kernel location and the signature. For performance
1526
+ * reasons, no substitution is performed on the dtype, so the dtype must be
1527
+ * concrete.
1051
1528
  *
1052
- * Supported signatures:
1053
- * 1) ... * N * T0, ... * N * T1 -> N * T2
1054
- * 2) ... * T0, ... * T1 -> ... * T2
1529
+ * Supported signature: 1) ... * T0, ... * T1 -> ... * T2
1055
1530
  */
1056
1531
  int
1057
1532
  ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1058
- const ndt_t *in[], const int nin, ndt_t *dtype,
1059
- ndt_context_t *ctx)
1533
+ const ndt_t *types[], const int nin, const int nout,
1534
+ const bool check_broadcast, ndt_context_t *ctx)
1060
1535
  {
1061
- ndt_t *p0, *p1, *p2;
1536
+ const ndt_t *p0, *p1, *p2;
1062
1537
  ndt_ndarray_t x, y;
1538
+ const ndt_t *out = NULL;
1539
+ const ndt_t *dtype = NULL;
1063
1540
 
1064
1541
  assert(spec->flags == 0);
1065
- assert(spec->nout == 0);
1066
- assert(spec->nbroadcast == 0);
1067
1542
  assert(spec->outer_dims == 0);
1543
+ assert(spec->nin == 0);
1544
+ assert(spec->nout == 0);
1545
+ assert(spec->nargs == 0);
1068
1546
 
1069
1547
  if (sig->tag != Function ||
1070
- sig->Function.nin != 2 ||
1071
- sig->Function.nout != 1) {
1548
+ sig->Function.nin != 2) {
1072
1549
  ndt_err_format(ctx, NDT_RuntimeError,
1073
1550
  "fast binary typecheck expects a signature with two inputs and "
1074
1551
  "one output");
@@ -1081,27 +1558,44 @@ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1081
1558
  return -1;
1082
1559
  }
1083
1560
 
1561
+ p0 = sig->Function.types[0];
1562
+ p1 = sig->Function.types[1];
1563
+ p2 = sig->Function.types[2];
1564
+
1565
+ if (nout) {
1566
+ if (nout != 1) {
1567
+ ndt_err_format(ctx, NDT_RuntimeError,
1568
+ "fast binary typecheck expects at most one explicit out argument");
1569
+ return -1;
1570
+ }
1571
+ out = types[2];
1572
+ dtype = ndt_dtype(types[2]);
1573
+
1574
+ if (!ndt_equal(dtype, ndt_dtype(p2))) {
1575
+ ndt_err_format(ctx, NDT_ValueError,
1576
+ "dtype of the out argument does not match");
1577
+ return -1;
1578
+ }
1579
+ }
1580
+ else {
1581
+ dtype = ndt_dtype(sig->Function.types[2]);
1582
+ }
1583
+
1084
1584
  if (ndt_is_abstract(dtype)) {
1085
1585
  ndt_err_format(ctx, NDT_RuntimeError,
1086
1586
  "fast binary typecheck expects a concrete dtype");
1087
1587
  return -1;
1088
1588
  }
1089
1589
 
1090
- p0 = sig->Function.types[0];
1091
- p1 = sig->Function.types[1];
1092
- p2 = sig->Function.types[2];
1093
-
1094
- if (!all_ellipses(p0, p1, p2, ctx)) {
1590
+ if (!binary_all_ellipses(p0, p1, p2, ctx)) {
1095
1591
  return -1;
1096
1592
  }
1097
1593
 
1098
- if (ndt_as_ndarray(&x, in[0], ctx) < 0) {
1099
- ndt_del(dtype);
1594
+ if (ndt_as_ndarray(&x, types[0], ctx) < 0) {
1100
1595
  return -1;
1101
1596
  }
1102
1597
 
1103
- if (ndt_as_ndarray(&y, in[1], ctx) < 0) {
1104
- ndt_del(dtype);
1598
+ if (ndt_as_ndarray(&y, types[1], ctx) < 0) {
1105
1599
  return -1;
1106
1600
  }
1107
1601
 
@@ -1109,20 +1603,27 @@ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1109
1603
  p1 = p1->EllipsisDim.type;
1110
1604
  p2 = p2->EllipsisDim.type;
1111
1605
 
1112
- if (all_same_symbol(p0, p1, p2)) {
1606
+ if (binary_all_same_symbol(p0, p1, p2)) {
1113
1607
  if (x.ndim > 0 && y.ndim > 0) {
1114
1608
  const int64_t xshape = x.shape[x.ndim-1];
1115
1609
  const int64_t yshape = y.shape[y.ndim-1];
1116
1610
  if (xshape != 1 && yshape != 1 && xshape != yshape) {
1117
1611
  ndt_err_format(ctx, NDT_TypeError, "mismatch in inner dimensions");
1118
- ndt_del(dtype);
1119
1612
  return -1;
1120
1613
  }
1121
1614
  }
1122
- return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 1, ctx);
1123
- }
1124
- else if (all_ndim0(p0, p1, p2)) {
1125
- return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 0, ctx);
1615
+ return _ndt_binary_broadcast(spec,
1616
+ &x, ndt_dtype(types[0]),
1617
+ &y, ndt_dtype(types[1]),
1618
+ out, dtype, check_broadcast,
1619
+ 1, ctx);
1620
+ }
1621
+ else if (binary_all_ndim0(p0, p1, p2)) {
1622
+ return _ndt_binary_broadcast(spec,
1623
+ &x, ndt_dtype(types[0]),
1624
+ &y, ndt_dtype(types[1]),
1625
+ out, dtype, check_broadcast,
1626
+ 0, ctx);
1126
1627
  }
1127
1628
  else {
1128
1629
  ndt_err_format(ctx, NDT_RuntimeError,