ndtypes 0.2.0dev5 → 0.2.0dev6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +12 -0
- data/Rakefile +8 -0
- data/ext/ruby_ndtypes/GPATH +0 -0
- data/ext/ruby_ndtypes/GRTAGS +0 -0
- data/ext/ruby_ndtypes/GTAGS +0 -0
- data/ext/ruby_ndtypes/extconf.rb +1 -1
- data/ext/ruby_ndtypes/include/ndtypes.h +231 -122
- data/ext/ruby_ndtypes/include/ruby_ndtypes.h +1 -1
- data/ext/ruby_ndtypes/lib/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/lib/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/Makefile +87 -0
- data/ext/ruby_ndtypes/ndtypes/config.h +68 -0
- data/ext/ruby_ndtypes/ndtypes/config.log +477 -0
- data/ext/ruby_ndtypes/ndtypes/config.status +1027 -0
- data/ext/ruby_ndtypes/ndtypes/doc/_static/style.css +7 -0
- data/ext/ruby_ndtypes/ndtypes/doc/_templates/layout.html +2 -0
- data/ext/ruby_ndtypes/ndtypes/doc/conf.py +40 -4
- data/ext/ruby_ndtypes/ndtypes/doc/images/xndlogo.png +0 -0
- data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +1 -1
- data/ext/ruby_ndtypes/ndtypes/doc/requirements.txt +2 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile +287 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +20 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +22 -3
- data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile +73 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +246 -229
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +15 -11
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +38 -28
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +91 -91
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +4 -3
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +8 -7
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/context.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +263 -182
- data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +67 -7
- data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +1112 -1000
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +69 -58
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +150 -99
- data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +185 -15
- data/ext/ruby_ndtypes/ndtypes/libndtypes/io.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +301 -276
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +9 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so +1 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0 +1 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +729 -228
- data/ext/ruby_ndtypes/ndtypes/libndtypes/match.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +768 -403
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h +1002 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +231 -122
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +176 -84
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +26 -14
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +57 -35
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.c +420 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +8 -8
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile +48 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +200 -116
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +46 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +58 -27
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +3 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +12 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile +55 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +8 -8
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +5 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +274 -172
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +24 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +14 -14
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +32 -30
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +37 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +36 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +16 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +5 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +706 -253
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_unify.c +132 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.c +703 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +335 -127
- data/ext/ruby_ndtypes/ndtypes/libndtypes/util.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/values.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +88 -71
- data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +0 -1
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +10 -13
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +395 -314
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so +1 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0 +1 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/ndtypes.h +1002 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +15 -33
- data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +340 -132
- data/ext/ruby_ndtypes/ndtypes/setup.py +11 -2
- data/ext/ruby_ndtypes/ruby_ndtypes.c +364 -241
- data/ext/ruby_ndtypes/ruby_ndtypes.h +1 -1
- data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +0 -1
- data/lib/ndtypes.rb +11 -0
- data/lib/ndtypes/version.rb +2 -2
- data/lib/ruby_ndtypes.so +0 -0
- data/ndtypes.gemspec +3 -0
- data/spec/ndtypes_spec.rb +6 -0
- metadata +98 -4
- data/ext/ruby_ndtypes/gc_guard.c +0 -36
- data/ext/ruby_ndtypes/gc_guard.h +0 -12
@@ -144,24 +144,26 @@ yycolumn = 1;
|
|
144
144
|
"void" { return VOID; }
|
145
145
|
"bool" { return BOOL; }
|
146
146
|
|
147
|
-
"
|
147
|
+
"signed" { return SIGNED_KIND; }
|
148
148
|
"int8" { return INT8; }
|
149
149
|
"int16" { return INT16; }
|
150
150
|
"int32" { return INT32; }
|
151
151
|
"int64" { return INT64; }
|
152
152
|
|
153
|
-
"
|
153
|
+
"unsigned" { return UNSIGNED_KIND; }
|
154
154
|
"uint8" { return UINT8; }
|
155
155
|
"uint16" { return UINT16; }
|
156
156
|
"uint32" { return UINT32; }
|
157
157
|
"uint64" { return UINT64; }
|
158
158
|
|
159
|
-
"
|
159
|
+
"float" { return FLOAT_KIND; }
|
160
|
+
"bfloat16" { return BFLOAT16; }
|
160
161
|
"float16" { return FLOAT16; }
|
161
162
|
"float32" { return FLOAT32; }
|
162
163
|
"float64" { return FLOAT64; }
|
163
164
|
|
164
|
-
"
|
165
|
+
"complex" { return COMPLEX_KIND; }
|
166
|
+
"bcomplex32" { return BCOMPLEX32; }
|
165
167
|
"complex32" { return COMPLEX32; }
|
166
168
|
"complex64" { return COMPLEX64; }
|
167
169
|
"complex128" { return COMPLEX128; }
|
@@ -186,6 +188,9 @@ yycolumn = 1;
|
|
186
188
|
|
187
189
|
"fixed" { return FIXED; }
|
188
190
|
"var" { return VAR; }
|
191
|
+
"array" { return ARRAY; }
|
192
|
+
|
193
|
+
"of" { return OF; }
|
189
194
|
|
190
195
|
"..." { return ELLIPSIS; }
|
191
196
|
"->" { return RARROW; }
|
Binary file
|
Binary file
|
@@ -0,0 +1 @@
|
|
1
|
+
ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3
|
@@ -0,0 +1 @@
|
|
1
|
+
ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3
|
Binary file
|
@@ -103,7 +103,7 @@ resolve_broadcast(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
|
103
103
|
}
|
104
104
|
|
105
105
|
static int
|
106
|
-
check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
|
106
|
+
check_contig(const ndt_t *ptypes[], const ndt_t *ctypes[], int64_t nargs)
|
107
107
|
{
|
108
108
|
for (int i = 0; i < nargs; i++) {
|
109
109
|
const ndt_t *p = ptypes[i];
|
@@ -130,14 +130,14 @@ check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
|
|
130
130
|
return 1;
|
131
131
|
}
|
132
132
|
|
133
|
-
static ndt_t *
|
134
|
-
to_fortran(const ndt_t *p, ndt_t *c, ndt_context_t *ctx)
|
133
|
+
static const ndt_t *
|
134
|
+
to_fortran(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
|
135
135
|
{
|
136
136
|
if (p->tag == EllipsisDim && p->EllipsisDim.tag == RequireF) {
|
137
|
-
|
138
|
-
return t;
|
137
|
+
return ndt_to_fortran(c, ctx);
|
139
138
|
}
|
140
139
|
else {
|
140
|
+
ndt_incref(c);
|
141
141
|
return c;
|
142
142
|
}
|
143
143
|
}
|
@@ -217,41 +217,124 @@ resolve_typevar(const char *key, symtable_entry_t w, symtable_t *tbl, ndt_contex
|
|
217
217
|
}
|
218
218
|
}
|
219
219
|
|
220
|
+
typedef struct {
|
221
|
+
int64_t index;
|
222
|
+
const ndt_t *type;
|
223
|
+
} indexed_type_t;
|
224
|
+
|
225
|
+
static indexed_type_t indexed_type_error = { 0, NULL };
|
226
|
+
|
227
|
+
static inline int64_t
|
228
|
+
adjust_index(const int64_t i, const int64_t shape, ndt_context_t *ctx)
|
229
|
+
{
|
230
|
+
const int64_t k = i < 0 ? i + shape : i;
|
231
|
+
|
232
|
+
if (k < 0 || k >= shape) {
|
233
|
+
ndt_err_format(ctx, NDT_IndexError,
|
234
|
+
"index with value %" PRIi64 " out of bounds", i);
|
235
|
+
return -1;
|
236
|
+
}
|
237
|
+
|
238
|
+
return k;
|
239
|
+
}
|
240
|
+
|
241
|
+
static inline indexed_type_t
|
242
|
+
var_dim_next(const indexed_type_t *x, const int64_t start, const int64_t step,
|
243
|
+
const int64_t i)
|
244
|
+
{
|
245
|
+
indexed_type_t next;
|
246
|
+
const ndt_t *t = x->type;
|
247
|
+
|
248
|
+
next.index = start + i * step;
|
249
|
+
next.type = t->VarDim.type;
|
250
|
+
|
251
|
+
return next;
|
252
|
+
}
|
253
|
+
|
254
|
+
/* skip stored indices */
|
255
|
+
static indexed_type_t
|
256
|
+
apply_stored_index(const indexed_type_t *x, ndt_context_t *ctx)
|
257
|
+
{
|
258
|
+
const ndt_t * const t = x->type;
|
259
|
+
int64_t start, step, shape;
|
260
|
+
|
261
|
+
if (t->tag != VarDimElem) {
|
262
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
263
|
+
"apply_stored_index: need VarDimElem");
|
264
|
+
return indexed_type_error;
|
265
|
+
}
|
266
|
+
|
267
|
+
shape = ndt_var_indices(&start, &step, t, x->index, ctx);
|
268
|
+
if (shape < 0) {
|
269
|
+
return indexed_type_error;
|
270
|
+
}
|
271
|
+
|
272
|
+
const int64_t i = adjust_index(t->VarDimElem.index, shape, ctx);
|
273
|
+
if (i < 0) {
|
274
|
+
return indexed_type_error;
|
275
|
+
}
|
276
|
+
|
277
|
+
return var_dim_next(x, start, step, i);
|
278
|
+
}
|
279
|
+
|
280
|
+
static indexed_type_t
|
281
|
+
apply_stored_indices(const indexed_type_t *x, ndt_context_t *ctx)
|
282
|
+
{
|
283
|
+
indexed_type_t tl = *x;
|
284
|
+
|
285
|
+
while (tl.type->tag == VarDimElem) {
|
286
|
+
tl = apply_stored_index(&tl, ctx);
|
287
|
+
}
|
288
|
+
|
289
|
+
return tl;
|
290
|
+
}
|
291
|
+
|
220
292
|
static int
|
221
|
-
match_concrete_var_dim(const
|
222
|
-
const ndt_t *u, int64_t uindex,
|
293
|
+
match_concrete_var_dim(const indexed_type_t *a, const indexed_type_t *b,
|
223
294
|
const int outer_dims, ndt_context_t *ctx)
|
224
295
|
{
|
225
|
-
int64_t
|
226
|
-
int64_t
|
296
|
+
int64_t xshape, xstart, xstep;
|
297
|
+
int64_t yshape, ystart, ystep;
|
227
298
|
|
228
299
|
if (outer_dims == 0) {
|
229
300
|
return 1;
|
230
301
|
}
|
302
|
+
|
303
|
+
const indexed_type_t x = apply_stored_indices(a, ctx);
|
304
|
+
if (x.type == NULL) {
|
305
|
+
return -1;
|
306
|
+
}
|
307
|
+
|
308
|
+
const indexed_type_t y = apply_stored_indices(b, ctx);
|
309
|
+
if (x.type == NULL) {
|
310
|
+
return -1;
|
311
|
+
}
|
312
|
+
|
313
|
+
const ndt_t * const t = x.type;
|
314
|
+
const ndt_t * const u = y.type;
|
315
|
+
|
231
316
|
if (t->Concrete.VarDim.itemsize != u->Concrete.VarDim.itemsize) {
|
232
317
|
return 0;
|
233
318
|
}
|
234
319
|
|
235
|
-
|
236
|
-
if (
|
320
|
+
xshape = ndt_var_indices(&xstart, &xstep, t, x.index, ctx);
|
321
|
+
if (xshape < 0) {
|
237
322
|
return -1;
|
238
323
|
}
|
239
324
|
|
240
|
-
|
241
|
-
if (
|
325
|
+
yshape = ndt_var_indices(&ystart, &ystep, u, y.index, ctx);
|
326
|
+
if (yshape < 0) {
|
242
327
|
return -1;
|
243
328
|
}
|
244
329
|
|
245
|
-
if (
|
330
|
+
if (xshape != yshape) {
|
246
331
|
return 0;
|
247
332
|
}
|
248
333
|
|
249
|
-
for (int64_t i = 0; i <
|
250
|
-
|
251
|
-
|
252
|
-
int ret = match_concrete_var_dim(
|
253
|
-
u->VarDim.type, unext,
|
254
|
-
outer_dims-1, ctx);
|
334
|
+
for (int64_t i = 0; i < xshape; i++) {
|
335
|
+
const indexed_type_t xnext = var_dim_next(&x, xstart, xstep, i);
|
336
|
+
const indexed_type_t ynext = var_dim_next(&y, ystart, ystep, i);
|
337
|
+
int ret = match_concrete_var_dim(&xnext, &ynext, outer_dims-1, ctx);
|
255
338
|
if (ret <= 0) {
|
256
339
|
return ret;
|
257
340
|
}
|
@@ -265,6 +348,7 @@ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
|
265
348
|
{
|
266
349
|
const char *key = "var";
|
267
350
|
symtable_entry_t v;
|
351
|
+
int vdims, wdims;
|
268
352
|
|
269
353
|
v = symtable_find(tbl, key);
|
270
354
|
if (v.tag == Unbound) {
|
@@ -274,16 +358,36 @@ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
|
274
358
|
return 1;
|
275
359
|
}
|
276
360
|
|
277
|
-
|
361
|
+
vdims = ndt_logical_ndim(v.VarSeq.dims[0]);
|
362
|
+
wdims = ndt_logical_ndim(w.VarSeq.dims[0]);
|
363
|
+
|
364
|
+
if (wdims != vdims) {
|
278
365
|
return 0;
|
279
366
|
}
|
280
|
-
if (
|
367
|
+
if (vdims == 0) {
|
281
368
|
return 1;
|
282
369
|
}
|
283
370
|
|
284
|
-
|
285
|
-
|
286
|
-
|
371
|
+
const indexed_type_t x = { w.VarSeq.linear_index, w.VarSeq.dims[0] };
|
372
|
+
const indexed_type_t y = { v.VarSeq.linear_index, v.VarSeq.dims[0] };
|
373
|
+
return match_concrete_var_dim(&x, &y, vdims, ctx);
|
374
|
+
}
|
375
|
+
|
376
|
+
static int
|
377
|
+
resolve_array(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
378
|
+
{
|
379
|
+
const char *key = "array";
|
380
|
+
symtable_entry_t v;
|
381
|
+
|
382
|
+
v = symtable_find(tbl, key);
|
383
|
+
if (v.tag == Unbound) {
|
384
|
+
if (symtable_add(tbl, key, w, ctx) < 0) {
|
385
|
+
return -1;
|
386
|
+
}
|
387
|
+
return 1;
|
388
|
+
}
|
389
|
+
|
390
|
+
return v.ArraySeq.size == w.ArraySeq.size;
|
287
391
|
}
|
288
392
|
|
289
393
|
static int
|
@@ -332,8 +436,32 @@ match_record_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
332
436
|
}
|
333
437
|
|
334
438
|
static int
|
335
|
-
|
336
|
-
|
439
|
+
match_union_tags(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
440
|
+
ndt_context_t *ctx)
|
441
|
+
{
|
442
|
+
int64_t i;
|
443
|
+
int n;
|
444
|
+
|
445
|
+
assert(p->tag == Union && c->tag == Union);
|
446
|
+
|
447
|
+
if (p->Union.ntags != c->Union.ntags) {
|
448
|
+
return 0;
|
449
|
+
}
|
450
|
+
|
451
|
+
for (i = 0; i < p->Union.ntags; i++) {
|
452
|
+
n = strcmp(p->Union.tags[i], c->Union.tags[i]);
|
453
|
+
if (n != 0) return 0;
|
454
|
+
|
455
|
+
n = match_datashape(p->Union.types[i], c->Union.types[i], tbl, ctx);
|
456
|
+
if (n <= 0) return n;
|
457
|
+
}
|
458
|
+
|
459
|
+
return 1;
|
460
|
+
}
|
461
|
+
|
462
|
+
static int
|
463
|
+
match_categorical(const ndt_value_t *p, int64_t plen,
|
464
|
+
const ndt_value_t *c, int64_t clen)
|
337
465
|
{
|
338
466
|
int64_t i;
|
339
467
|
|
@@ -378,7 +506,7 @@ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
|
|
378
506
|
}
|
379
507
|
return outer_inner(v, i+1, t->FixedDim.type, ndim);
|
380
508
|
}
|
381
|
-
case VarDim: {
|
509
|
+
case VarDim: case VarDimElem: {
|
382
510
|
switch (v->tag) {
|
383
511
|
case VarSeq:
|
384
512
|
v->VarSeq.size = i+1;
|
@@ -389,6 +517,18 @@ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
|
|
389
517
|
}
|
390
518
|
return outer_inner(v, i+1, t->VarDim.type, ndim);
|
391
519
|
}
|
520
|
+
case Array: {
|
521
|
+
switch (v->tag) {
|
522
|
+
case ArraySeq:
|
523
|
+
v->ArraySeq.size = i+1;
|
524
|
+
v->ArraySeq.dims[i] = t;
|
525
|
+
break;
|
526
|
+
default:
|
527
|
+
return NULL;
|
528
|
+
}
|
529
|
+
return outer_inner(v, i+1, t->Array.type, ndim);
|
530
|
+
}
|
531
|
+
|
392
532
|
default:
|
393
533
|
return NULL;
|
394
534
|
}
|
@@ -400,6 +540,10 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
400
540
|
{
|
401
541
|
int n;
|
402
542
|
|
543
|
+
if (c->tag == VarDimElem) {
|
544
|
+
return match_datashape(p, c->VarDimElem.type, tbl, ctx);
|
545
|
+
}
|
546
|
+
|
403
547
|
if (ndt_is_optional(c) != ndt_is_optional(p)) return 0;
|
404
548
|
|
405
549
|
switch (p->tag) {
|
@@ -445,57 +589,11 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
445
589
|
return match_datashape(p->SymbolicDim.type, c->FixedDim.type, tbl, ctx);
|
446
590
|
}
|
447
591
|
|
448
|
-
case EllipsisDim: {
|
449
|
-
symtable_entry_t outer;
|
450
|
-
const ndt_t *inner;
|
451
|
-
|
452
|
-
if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
|
453
|
-
return 0;
|
454
|
-
}
|
455
|
-
if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
|
456
|
-
return 0;
|
457
|
-
}
|
458
|
-
|
459
|
-
if (p->EllipsisDim.name == NULL) {
|
460
|
-
outer.tag = BroadcastSeq;
|
461
|
-
outer.BroadcastSeq.size = 0;
|
462
|
-
}
|
463
|
-
else if (strcmp(p->EllipsisDim.name, "var") == 0) {
|
464
|
-
outer.tag = VarSeq;
|
465
|
-
outer.VarSeq.size = 0;
|
466
|
-
}
|
467
|
-
else {
|
468
|
-
outer.tag = FixedSeq;
|
469
|
-
outer.FixedSeq.size = 0;
|
470
|
-
}
|
471
|
-
|
472
|
-
inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
|
473
|
-
if (inner == NULL) {
|
474
|
-
return 0;
|
475
|
-
}
|
476
|
-
|
477
|
-
n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
|
478
|
-
if (n <= 0) {
|
479
|
-
return n;
|
480
|
-
}
|
481
|
-
|
482
|
-
switch (outer.tag) {
|
483
|
-
case BroadcastSeq:
|
484
|
-
return resolve_broadcast(outer, tbl, ctx);
|
485
|
-
case FixedSeq:
|
486
|
-
return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
|
487
|
-
case VarSeq:
|
488
|
-
return resolve_var(outer, tbl, ctx);
|
489
|
-
default: /* NOT REACHED */
|
490
|
-
ndt_internal_error("invalid tag");
|
491
|
-
}
|
492
|
-
}
|
493
|
-
|
494
592
|
case Bool:
|
495
593
|
case Int8: case Int16: case Int32: case Int64:
|
496
594
|
case Uint8: case Uint16: case Uint32: case Uint64:
|
497
|
-
case Float16: case Float32: case Float64:
|
498
|
-
case Complex32: case Complex64: case Complex128:
|
595
|
+
case BFloat16: case Float16: case Float32: case Float64:
|
596
|
+
case BComplex32: case Complex32: case Complex64: case Complex128:
|
499
597
|
case String:
|
500
598
|
return p->tag == c->tag;
|
501
599
|
case FixedString:
|
@@ -528,6 +626,10 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
528
626
|
return c->tag == Categorical &&
|
529
627
|
match_categorical(p->Categorical.types, p->Categorical.ntypes,
|
530
628
|
c->Categorical.types, c->Categorical.ntypes);
|
629
|
+
case Array:
|
630
|
+
if (c->tag != Array) return 0;
|
631
|
+
if (c->Array.itemsize != p->Array.itemsize) return 0;
|
632
|
+
return match_datashape(p->Array.type, c->Array.type, tbl, ctx);
|
531
633
|
case Ref:
|
532
634
|
if (c->tag != Ref) return 0;
|
533
635
|
return match_datashape(p->Ref.type, c->Ref.type, tbl, ctx);
|
@@ -536,9 +638,12 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
536
638
|
if (c->tag != Tuple) return 0;
|
537
639
|
return match_tuple_fields(p, c, tbl, ctx);
|
538
640
|
case Record:
|
539
|
-
if (p->
|
641
|
+
if (p->Record.flag == Variadic) return 0;
|
540
642
|
if (c->tag != Record) return 0;
|
541
643
|
return match_record_fields(p, c, tbl, ctx);
|
644
|
+
case Union:
|
645
|
+
if (c->tag != Union) return 0;
|
646
|
+
return match_union_tags(p, c, tbl, ctx);
|
542
647
|
case Function: {
|
543
648
|
int64_t i;
|
544
649
|
if (c->tag != Function ||
|
@@ -576,12 +681,83 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
576
681
|
case Constr:
|
577
682
|
return c->tag == Constr && strcmp(p->Constr.name, c->Constr.name) == 0 &&
|
578
683
|
ndt_equal(p->Constr.type, c->Constr.type);
|
684
|
+
case VarDimElem:
|
685
|
+
ndt_err_format(ctx, NDT_ValueError,
|
686
|
+
"VarDimElem cannot occur in pattern");
|
687
|
+
return -1;
|
688
|
+
case EllipsisDim:
|
689
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
690
|
+
"match_datashape: internal_error: unexpected EllipsisDim");
|
691
|
+
return -1;
|
579
692
|
}
|
580
693
|
|
581
694
|
/* NOT REACHED: tags should be exhaustive. */
|
582
695
|
ndt_internal_error("invalid type");
|
583
696
|
}
|
584
697
|
|
698
|
+
static int
|
699
|
+
match_datashape_top(const ndt_t *p, const ndt_t *c, int64_t linear_index,
|
700
|
+
symtable_t *tbl, ndt_context_t *ctx)
|
701
|
+
{
|
702
|
+
switch (p->tag) {
|
703
|
+
case EllipsisDim: {
|
704
|
+
symtable_entry_t outer;
|
705
|
+
const ndt_t *inner;
|
706
|
+
int n;
|
707
|
+
|
708
|
+
if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
|
709
|
+
return 0;
|
710
|
+
}
|
711
|
+
if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
|
712
|
+
return 0;
|
713
|
+
}
|
714
|
+
|
715
|
+
if (p->EllipsisDim.name == NULL) {
|
716
|
+
outer.tag = BroadcastSeq;
|
717
|
+
outer.BroadcastSeq.size = 0;
|
718
|
+
}
|
719
|
+
else if (strcmp(p->EllipsisDim.name, "var") == 0) {
|
720
|
+
outer.tag = VarSeq;
|
721
|
+
outer.VarSeq.size = 0;
|
722
|
+
outer.VarSeq.linear_index = linear_index;
|
723
|
+
}
|
724
|
+
else if (strcmp(p->EllipsisDim.name, "array") == 0) {
|
725
|
+
outer.tag = ArraySeq;
|
726
|
+
outer.ArraySeq.size = 0;
|
727
|
+
}
|
728
|
+
else {
|
729
|
+
outer.tag = FixedSeq;
|
730
|
+
outer.FixedSeq.size = 0;
|
731
|
+
}
|
732
|
+
|
733
|
+
inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
|
734
|
+
if (inner == NULL) {
|
735
|
+
return 0;
|
736
|
+
}
|
737
|
+
|
738
|
+
n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
|
739
|
+
if (n <= 0) {
|
740
|
+
return n;
|
741
|
+
}
|
742
|
+
|
743
|
+
switch (outer.tag) {
|
744
|
+
case BroadcastSeq:
|
745
|
+
return resolve_broadcast(outer, tbl, ctx);
|
746
|
+
case FixedSeq:
|
747
|
+
return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
|
748
|
+
case VarSeq:
|
749
|
+
return resolve_var(outer, tbl, ctx);
|
750
|
+
case ArraySeq:
|
751
|
+
return resolve_array(outer, tbl, ctx);
|
752
|
+
default: /* NOT REACHED */
|
753
|
+
ndt_internal_error("invalid tag");
|
754
|
+
}
|
755
|
+
}
|
756
|
+
default:
|
757
|
+
return match_datashape(p, c, tbl, ctx);
|
758
|
+
}
|
759
|
+
}
|
760
|
+
|
585
761
|
int
|
586
762
|
ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
|
587
763
|
{
|
@@ -597,19 +773,18 @@ ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
|
|
597
773
|
return -1;
|
598
774
|
}
|
599
775
|
|
600
|
-
ret =
|
776
|
+
ret = match_datashape_top(p, c, 0, tbl, ctx);
|
601
777
|
symtable_del(tbl);
|
602
778
|
return ret;
|
603
779
|
}
|
604
780
|
|
605
|
-
static ndt_t *
|
781
|
+
static const ndt_t *
|
606
782
|
broadcast(const ndt_t *t, const int64_t *shape,
|
607
783
|
int outer_dims, int inner_dims,
|
608
784
|
bool use_max, ndt_context_t *ctx)
|
609
785
|
{
|
610
786
|
ndt_ndarray_t u;
|
611
|
-
const ndt_t *
|
612
|
-
ndt_t *v;
|
787
|
+
const ndt_t *v, *w;
|
613
788
|
int64_t step;
|
614
789
|
int ndim;
|
615
790
|
int i, k;
|
@@ -619,14 +794,12 @@ broadcast(const ndt_t *t, const int64_t *shape,
|
|
619
794
|
return NULL;
|
620
795
|
}
|
621
796
|
|
622
|
-
|
623
|
-
v
|
624
|
-
if (v == NULL) {
|
625
|
-
return NULL;
|
626
|
-
}
|
797
|
+
v = ndt_dtype(t);
|
798
|
+
ndt_incref(v);
|
627
799
|
|
628
800
|
for (i=ndim-1; i>=ndim-inner_dims; i--) {
|
629
|
-
|
801
|
+
w = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
|
802
|
+
ndt_move(&v, w);
|
630
803
|
if (v == NULL) {
|
631
804
|
return NULL;
|
632
805
|
}
|
@@ -634,7 +807,8 @@ broadcast(const ndt_t *t, const int64_t *shape,
|
|
634
807
|
|
635
808
|
for (k=outer_dims-1; i>=0 && k>=0; i--, k--) {
|
636
809
|
step = u.shape[i]<=1 ? 0 : u.steps[i];
|
637
|
-
|
810
|
+
w = ndt_fixed_dim(v, shape[k], step, ctx);
|
811
|
+
ndt_move(&v, w);
|
638
812
|
if (v == NULL) {
|
639
813
|
return NULL;
|
640
814
|
}
|
@@ -642,11 +816,12 @@ broadcast(const ndt_t *t, const int64_t *shape,
|
|
642
816
|
|
643
817
|
for (; k>=0; k--) {
|
644
818
|
if (use_max) {
|
645
|
-
|
819
|
+
w = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
|
646
820
|
}
|
647
821
|
else {
|
648
|
-
|
822
|
+
w = ndt_fixed_dim(v, shape[k], 0, ctx);
|
649
823
|
}
|
824
|
+
ndt_move(&v, w);
|
650
825
|
if (v == NULL) {
|
651
826
|
return NULL;
|
652
827
|
}
|
@@ -656,34 +831,44 @@ broadcast(const ndt_t *t, const int64_t *shape,
|
|
656
831
|
}
|
657
832
|
|
658
833
|
int
|
659
|
-
ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
660
|
-
const
|
661
|
-
const int64_t *shape, int outer_dims,
|
662
|
-
ndt_context_t *ctx)
|
834
|
+
ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig, bool check_broadcast,
|
835
|
+
const int64_t *shape, const int outer_dims, ndt_context_t *ctx)
|
663
836
|
{
|
664
|
-
|
837
|
+
const int nin = spec->nin;
|
838
|
+
const int nout = spec->nout;
|
665
839
|
int inner_dims;
|
666
840
|
int i;
|
667
841
|
|
668
842
|
for (i = 0; i < nin; i++) {
|
669
843
|
inner_dims = sig->Function.types[i]->ndim-1;
|
670
|
-
|
671
|
-
|
672
|
-
if (
|
844
|
+
const ndt_t *t = spec->types[i];
|
845
|
+
const ndt_t *u = broadcast(t, shape, outer_dims, inner_dims, false, ctx);
|
846
|
+
if (u == NULL) {
|
673
847
|
return -1;
|
674
848
|
}
|
675
|
-
|
849
|
+
ndt_decref(t);
|
850
|
+
spec->types[i] = u;
|
676
851
|
}
|
677
852
|
|
678
|
-
for (i = 0; i <
|
853
|
+
for (i = 0; i < nout; i++) {
|
679
854
|
inner_dims = sig->Function.types[nin+i]->ndim-1;
|
680
|
-
|
681
|
-
|
855
|
+
const ndt_t *t = spec->types[nin+i];
|
856
|
+
const ndt_t *u = broadcast(t, shape, outer_dims, inner_dims, true, ctx);
|
682
857
|
if (u == NULL) {
|
683
858
|
return -1;
|
684
859
|
}
|
685
|
-
|
686
|
-
|
860
|
+
|
861
|
+
if (check_broadcast) {
|
862
|
+
if (!ndt_equal(t, u)) {
|
863
|
+
ndt_err_format(ctx, NDT_ValueError,
|
864
|
+
"explicit 'out' type not compatible with input types");
|
865
|
+
ndt_decref(u);
|
866
|
+
return -1;
|
867
|
+
}
|
868
|
+
}
|
869
|
+
|
870
|
+
ndt_decref(t);
|
871
|
+
spec->types[nin+i] = u;
|
687
872
|
}
|
688
873
|
|
689
874
|
spec->outer_dims = outer_dims;
|
@@ -692,8 +877,7 @@ ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
692
877
|
}
|
693
878
|
|
694
879
|
static int
|
695
|
-
broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
696
|
-
const ndt_t *in[], const int nin,
|
880
|
+
broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig, bool check_broadcast,
|
697
881
|
const symtable_t *tbl, ndt_context_t *ctx)
|
698
882
|
{
|
699
883
|
symtable_entry_t v;
|
@@ -705,7 +889,7 @@ broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
705
889
|
return -1;
|
706
890
|
}
|
707
891
|
|
708
|
-
return ndt_broadcast_all(spec, sig,
|
892
|
+
return ndt_broadcast_all(spec, sig, check_broadcast,
|
709
893
|
v.BroadcastSeq.dims, v.BroadcastSeq.size,
|
710
894
|
ctx);
|
711
895
|
}
|
@@ -746,20 +930,23 @@ resolve_constraint(const ndt_constraint_t *c, const void *args, symtable_t *tbl,
|
|
746
930
|
*/
|
747
931
|
int
|
748
932
|
ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
749
|
-
const ndt_t *
|
933
|
+
const ndt_t *types[], const int64_t li[],
|
934
|
+
const int nin, const int nout, bool check_broadcast,
|
750
935
|
const ndt_constraint_t *c, const void *args,
|
751
936
|
ndt_context_t *ctx)
|
752
937
|
{
|
753
938
|
symtable_t *tbl;
|
754
|
-
ndt_t *t;
|
939
|
+
const ndt_t *t;
|
755
940
|
const char *name;
|
756
|
-
int
|
941
|
+
const int nargs = nin + nout;
|
757
942
|
int64_t i;
|
943
|
+
int ret;
|
758
944
|
|
759
945
|
assert(spec->flags == 0);
|
760
|
-
assert(spec->nout == 0);
|
761
|
-
assert(spec->nbroadcast == 0);
|
762
946
|
assert(spec->outer_dims == 0);
|
947
|
+
assert(spec->nin == 0);
|
948
|
+
assert(spec->nout == 0);
|
949
|
+
assert(spec->nargs == 0);
|
763
950
|
|
764
951
|
if (sig->tag != Function) {
|
765
952
|
ndt_err_format(ctx, NDT_ValueError,
|
@@ -773,8 +960,30 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
773
960
|
return -1;
|
774
961
|
}
|
775
962
|
|
776
|
-
|
777
|
-
|
963
|
+
/*
|
964
|
+
* Two configurations are allowed:
|
965
|
+
* 1) nout==0 and all 'out' types are inferred.
|
966
|
+
* 2) nout==sig->Function.nout and all 'out' types are given.
|
967
|
+
*/
|
968
|
+
if (nout && nout != sig->Function.nout) {
|
969
|
+
ndt_err_format(ctx, NDT_ValueError,
|
970
|
+
"expected %" PRIi64 " 'out' arguments, got %d", sig->Function.nout, nout);
|
971
|
+
return -1;
|
972
|
+
}
|
973
|
+
|
974
|
+
/*
|
975
|
+
* Broadcast configurations:
|
976
|
+
* nout==0 -> check_broadcast==false.
|
977
|
+
*/
|
978
|
+
if (!nout && check_broadcast == true) {
|
979
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
980
|
+
"internal error: without explicit 'out' arguments "
|
981
|
+
"'check_broadcast' must be false");
|
982
|
+
return -1;
|
983
|
+
}
|
984
|
+
|
985
|
+
for (i = 0; i < nargs; i++) {
|
986
|
+
if (ndt_is_abstract(types[i])) {
|
778
987
|
ndt_err_format(ctx, NDT_ValueError,
|
779
988
|
"type checking requires concrete argument types");
|
780
989
|
return -1;
|
@@ -786,8 +995,8 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
786
995
|
return -1;
|
787
996
|
}
|
788
997
|
|
789
|
-
for (i = 0; i <
|
790
|
-
ret =
|
998
|
+
for (i = 0; i < nargs; i++) {
|
999
|
+
ret = match_datashape_top(sig->Function.types[i], types[i], li[i], tbl, ctx);
|
791
1000
|
if (ret <= 0) {
|
792
1001
|
symtable_del(tbl);
|
793
1002
|
|
@@ -805,15 +1014,34 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
805
1014
|
return -1;
|
806
1015
|
}
|
807
1016
|
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
1017
|
+
spec->nin = 0;
|
1018
|
+
for (i = 0; i < nin; i++) {
|
1019
|
+
ndt_incref(types[i]);
|
1020
|
+
spec->types[i] = types[i];
|
1021
|
+
spec->nin++;
|
1022
|
+
}
|
1023
|
+
|
1024
|
+
spec->nout = 0;
|
1025
|
+
if (nout == 0) {
|
1026
|
+
/* Infer the return types. */
|
1027
|
+
for (i = 0; i < sig->Function.nout; i++) {
|
1028
|
+
spec->types[nin+i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
|
1029
|
+
if (spec->types[nin+i] == NULL) {
|
1030
|
+
ndt_apply_spec_clear(spec);
|
1031
|
+
symtable_del(tbl);
|
1032
|
+
return -1;
|
1033
|
+
}
|
1034
|
+
spec->nout++;
|
1035
|
+
}
|
1036
|
+
}
|
1037
|
+
else {
|
1038
|
+
for (i = 0; i < nout; i++) {
|
1039
|
+
ndt_incref(types[nin+i]);
|
1040
|
+
spec->types[nin+i] = types[nin+i];
|
1041
|
+
spec->nout++;
|
814
1042
|
}
|
815
|
-
spec->nout++;
|
816
1043
|
}
|
1044
|
+
spec->nargs = spec->nin + spec->nout;
|
817
1045
|
|
818
1046
|
if (sig->flags & NDT_ELLIPSIS) {
|
819
1047
|
if (sig->Function.nargs == 0 || sig->Function.types[0]->tag != EllipsisDim) {
|
@@ -833,8 +1061,17 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
833
1061
|
case FixedSeq:
|
834
1062
|
spec->outer_dims = v.FixedSeq.size;
|
835
1063
|
break;
|
836
|
-
case VarSeq:
|
837
|
-
|
1064
|
+
case VarSeq: {
|
1065
|
+
if (v.VarSeq.size > 0) {
|
1066
|
+
spec->outer_dims = ndt_logical_ndim(v.VarSeq.dims[0]);
|
1067
|
+
}
|
1068
|
+
else {
|
1069
|
+
spec->outer_dims = 0;
|
1070
|
+
}
|
1071
|
+
break;
|
1072
|
+
}
|
1073
|
+
case ArraySeq:
|
1074
|
+
spec->outer_dims = v.ArraySeq.size;
|
838
1075
|
break;
|
839
1076
|
default:
|
840
1077
|
ndt_err_format(ctx, NDT_RuntimeError,
|
@@ -845,7 +1082,7 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
845
1082
|
}
|
846
1083
|
}
|
847
1084
|
else {
|
848
|
-
if (broadcast_all(spec, sig,
|
1085
|
+
if (broadcast_all(spec, sig, check_broadcast, tbl, ctx) < 0) {
|
849
1086
|
ndt_apply_spec_clear(spec);
|
850
1087
|
symtable_del(tbl);
|
851
1088
|
return -1;
|
@@ -855,31 +1092,33 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
855
1092
|
|
856
1093
|
symtable_del(tbl);
|
857
1094
|
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
1095
|
+
if (nout == 0) {
|
1096
|
+
for (i = 0; i < sig->Function.nout; i++) {
|
1097
|
+
const ndt_t *_p = sig->Function.types[nin+i];
|
1098
|
+
const ndt_t *_c = spec->types[nin+i];
|
1099
|
+
const ndt_t *_t = to_fortran(_p, _c, ctx);
|
1100
|
+
|
1101
|
+
if (_t == NULL) {
|
1102
|
+
ndt_apply_spec_clear(spec);
|
1103
|
+
return -1;
|
1104
|
+
}
|
1105
|
+
|
1106
|
+
ndt_decref(_c);
|
1107
|
+
spec->types[nin+i] = _t;
|
868
1108
|
}
|
869
|
-
spec->out[i] = _t;
|
870
1109
|
}
|
871
1110
|
|
872
|
-
if (!check_contig(sig->Function.types,
|
1111
|
+
if (!check_contig(sig->Function.types, types, nargs)) {
|
873
1112
|
ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
|
1113
|
+
ndt_apply_spec_clear(spec);
|
874
1114
|
return -1;
|
875
1115
|
}
|
876
|
-
|
877
|
-
|
1116
|
+
|
1117
|
+
if (ndt_select_kernel_strategy(spec, ctx) < 0) {
|
1118
|
+
ndt_apply_spec_clear(spec);
|
878
1119
|
return -1;
|
879
1120
|
}
|
880
1121
|
|
881
|
-
ndt_select_kernel_strategy(spec, sig, in, nin);
|
882
|
-
|
883
1122
|
return 0;
|
884
1123
|
}
|
885
1124
|
|
@@ -888,11 +1127,11 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
888
1127
|
/* Optimized binary typecheck for fixed input */
|
889
1128
|
/*****************************************************************************/
|
890
1129
|
|
891
|
-
static ndt_t *
|
892
|
-
|
893
|
-
|
1130
|
+
static const ndt_t *
|
1131
|
+
fast_broadcast(const ndt_ndarray_t *t, const ndt_t *dtype,
|
1132
|
+
const int64_t *shape, int size, ndt_context_t *ctx)
|
894
1133
|
{
|
895
|
-
ndt_t *v;
|
1134
|
+
const ndt_t *v;
|
896
1135
|
int64_t step;
|
897
1136
|
int i, k;
|
898
1137
|
|
@@ -903,14 +1142,18 @@ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
|
|
903
1142
|
|
904
1143
|
for (i=t->ndim-1, k=size-1; i>=0 && k>=0; i--, k--) {
|
905
1144
|
step = t->shape[i]<=1 ? 0 : t->steps[i];
|
906
|
-
|
1145
|
+
const ndt_t *w = ndt_fixed_dim(v, shape[k], step, ctx);
|
1146
|
+
ndt_decref(v);
|
1147
|
+
v = w;
|
907
1148
|
if (v == NULL) {
|
908
1149
|
return NULL;
|
909
1150
|
}
|
910
1151
|
}
|
911
1152
|
|
912
1153
|
for (; k>=0; k--) {
|
913
|
-
|
1154
|
+
const ndt_t *w = ndt_fixed_dim(v, shape[k], 0, ctx);
|
1155
|
+
ndt_decref(v);
|
1156
|
+
v = w;
|
914
1157
|
if (v == NULL) {
|
915
1158
|
return NULL;
|
916
1159
|
}
|
@@ -919,15 +1162,19 @@ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
|
|
919
1162
|
return v;
|
920
1163
|
}
|
921
1164
|
|
922
|
-
static ndt_t *
|
923
|
-
fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
|
1165
|
+
static const ndt_t *
|
1166
|
+
fixed_dim_from_shape(const int64_t shape[], int len, const ndt_t *dtype,
|
924
1167
|
ndt_context_t *ctx)
|
925
1168
|
{
|
926
|
-
ndt_t *t;
|
1169
|
+
const ndt_t *t;
|
927
1170
|
int i;
|
928
1171
|
|
1172
|
+
ndt_incref(dtype);
|
1173
|
+
|
929
1174
|
for (i=len-1, t=dtype; i >= 0; i--) {
|
930
|
-
|
1175
|
+
const ndt_t *u = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
|
1176
|
+
ndt_decref(t);
|
1177
|
+
t = u;
|
931
1178
|
if (t == NULL) {
|
932
1179
|
return NULL;
|
933
1180
|
}
|
@@ -936,83 +1183,313 @@ fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
|
|
936
1183
|
return t;
|
937
1184
|
}
|
938
1185
|
|
939
|
-
static
|
940
|
-
|
1186
|
+
static int
|
1187
|
+
broadcast_error(const char *msg, ndt_context_t *ctx)
|
941
1188
|
{
|
942
|
-
|
943
|
-
|
1189
|
+
ndt_err_format(ctx, NDT_TypeError, "%s", msg);
|
1190
|
+
return -1;
|
1191
|
+
}
|
1192
|
+
|
1193
|
+
static int
|
1194
|
+
_ndt_unary_broadcast(ndt_apply_spec_t *spec,
|
1195
|
+
const ndt_ndarray_t *x, const ndt_t *dtype_x,
|
1196
|
+
const ndt_t *out, const ndt_t *dtype_out,
|
1197
|
+
const bool check_broadcast, const int inner,
|
1198
|
+
ndt_context_t *ctx)
|
1199
|
+
{
|
1200
|
+
int64_t shape[NDT_MAX_DIM];
|
1201
|
+
int size = x->ndim;
|
1202
|
+
ndt_ndarray_t y;
|
1203
|
+
const ndt_t *t;
|
1204
|
+
|
1205
|
+
for (int i = 0; i < size; i++) {
|
1206
|
+
shape[i] = x->shape[i];
|
944
1207
|
}
|
945
1208
|
|
946
|
-
|
947
|
-
if (
|
948
|
-
return
|
1209
|
+
if (out != NULL) {
|
1210
|
+
if (ndt_as_ndarray(&y, out, ctx) < 0) {
|
1211
|
+
return -1;
|
949
1212
|
}
|
1213
|
+
|
1214
|
+
size = _resolve_broadcast(shape, size, y.shape, y.ndim);
|
1215
|
+
if (size < 0) {
|
1216
|
+
return broadcast_error("could not broadcast output argument", ctx);
|
1217
|
+
}
|
1218
|
+
}
|
1219
|
+
|
1220
|
+
spec->types[0] = fast_broadcast(x, dtype_x, shape, size, ctx);
|
1221
|
+
if (spec->types[0] == NULL) {
|
1222
|
+
return -1;
|
1223
|
+
}
|
1224
|
+
|
1225
|
+
if (out != NULL) {
|
1226
|
+
if (check_broadcast) {
|
1227
|
+
t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
|
1228
|
+
if (t == NULL) {
|
1229
|
+
ndt_decref(spec->types[0]);
|
1230
|
+
return -1;
|
1231
|
+
}
|
1232
|
+
|
1233
|
+
if (!ndt_equal(t, out)) {
|
1234
|
+
ndt_err_format(ctx, NDT_ValueError,
|
1235
|
+
"explicit 'out' type not compatible with input types");
|
1236
|
+
ndt_decref(spec->types[0]);
|
1237
|
+
ndt_decref(t);
|
1238
|
+
return -1;
|
1239
|
+
}
|
1240
|
+
}
|
1241
|
+
else {
|
1242
|
+
t = fast_broadcast(&y, dtype_out, shape, size, ctx);
|
1243
|
+
if (t == NULL) {
|
1244
|
+
ndt_decref(spec->types[0]);
|
1245
|
+
return -1;
|
1246
|
+
}
|
1247
|
+
}
|
1248
|
+
}
|
1249
|
+
else {
|
1250
|
+
t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
|
1251
|
+
if (t == NULL) {
|
1252
|
+
ndt_decref(spec->types[0]);
|
1253
|
+
return -1;
|
1254
|
+
}
|
1255
|
+
}
|
1256
|
+
spec->types[1] = t;
|
1257
|
+
|
1258
|
+
spec->outer_dims = size-inner;
|
1259
|
+
spec->nin = 1;
|
1260
|
+
spec->nout = 1;
|
1261
|
+
spec->nargs = 2;
|
1262
|
+
|
1263
|
+
if (ndt_select_kernel_strategy(spec, ctx) < 0) {
|
1264
|
+
ndt_apply_spec_clear(spec);
|
1265
|
+
return -1;
|
1266
|
+
}
|
1267
|
+
|
1268
|
+
return 0;
|
1269
|
+
}
|
1270
|
+
|
1271
|
+
static bool
|
1272
|
+
unary_all_ellipses(const ndt_t *t0, const ndt_t *t1, ndt_context_t *ctx)
|
1273
|
+
{
|
1274
|
+
if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
|
1275
|
+
(t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL)) {
|
1276
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1277
|
+
"fast unary typecheck expects leading ellipsis dimensions");
|
1278
|
+
return false;
|
950
1279
|
}
|
951
1280
|
|
952
1281
|
return true;
|
953
1282
|
}
|
954
1283
|
|
955
|
-
static
|
956
|
-
|
957
|
-
const ndt_ndarray_t *x, const ndt_ndarray_t *y,
|
958
|
-
const ndt_t *in[], const int nin, ndt_t *dtype,
|
959
|
-
int inner, ndt_context_t *ctx)
|
1284
|
+
static bool
|
1285
|
+
unary_all_same_symbol(const ndt_t *t0, const ndt_t *t1)
|
960
1286
|
{
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
if (shape_equal(x, y)) {
|
965
|
-
spec->nout = 1;
|
966
|
-
spec->nbroadcast = 0;
|
967
|
-
spec->outer_dims = x->ndim-inner;
|
968
|
-
spec->out[0] = fixed_dim_from_shape(x->shape, x->ndim, dtype, ctx);
|
969
|
-
if (spec->out[0] == NULL) {
|
970
|
-
return -1;
|
971
|
-
}
|
1287
|
+
if (t0->tag != SymbolicDim || t1->tag != SymbolicDim) {
|
1288
|
+
return false;
|
972
1289
|
}
|
973
|
-
|
974
|
-
|
975
|
-
|
1290
|
+
|
1291
|
+
return strcmp(t0->SymbolicDim.name, t1->SymbolicDim.name) == 0;
|
1292
|
+
}
|
1293
|
+
|
1294
|
+
static bool
|
1295
|
+
unary_all_ndim0(const ndt_t *t0, const ndt_t *t1)
|
1296
|
+
{
|
1297
|
+
return t0->ndim == 0 && t1->ndim == 0;
|
1298
|
+
}
|
1299
|
+
|
1300
|
+
/*
|
1301
|
+
* Optimized type checking for very specific signatures. The caller must
|
1302
|
+
* have identified the kernel location and the signature. For performance
|
1303
|
+
* reasons, no substitution is performed on the dtype, so the dtype must be
|
1304
|
+
* concrete.
|
1305
|
+
*
|
1306
|
+
* Supported signature: 1) ... * T0 -> ... * T1
|
1307
|
+
*/
|
1308
|
+
int
|
1309
|
+
ndt_fast_unary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
1310
|
+
const ndt_t *types[], const int nin, const int nout,
|
1311
|
+
const bool check_broadcast, ndt_context_t *ctx)
|
1312
|
+
{
|
1313
|
+
const ndt_t *p0, *p1;
|
1314
|
+
ndt_ndarray_t x;
|
1315
|
+
const ndt_t *out = NULL;
|
1316
|
+
const ndt_t *dtype = NULL;
|
1317
|
+
|
1318
|
+
assert(spec->flags == 0);
|
1319
|
+
assert(spec->outer_dims == 0);
|
1320
|
+
assert(spec->nin == 0);
|
1321
|
+
assert(spec->nout == 0);
|
1322
|
+
assert(spec->nargs == 0);
|
1323
|
+
|
1324
|
+
if (sig->tag != Function ||
|
1325
|
+
sig->Function.nin != 1) {
|
1326
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1327
|
+
"fast unary typecheck expects a signature with one input and "
|
1328
|
+
"one output");
|
1329
|
+
return -1;
|
1330
|
+
}
|
1331
|
+
|
1332
|
+
if (nin != 1) {
|
1333
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1334
|
+
"fast unary typecheck expects one input argument");
|
1335
|
+
return -1;
|
1336
|
+
}
|
1337
|
+
|
1338
|
+
p0 = sig->Function.types[0];
|
1339
|
+
p1 = sig->Function.types[1];
|
1340
|
+
|
1341
|
+
if (nout) {
|
1342
|
+
if (nout != 1) {
|
1343
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1344
|
+
"fast unary typecheck expects at most one explicit out argument");
|
1345
|
+
return -1;
|
976
1346
|
}
|
1347
|
+
out = types[1];
|
1348
|
+
dtype = ndt_dtype(types[1]);
|
977
1349
|
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
ndt_del(dtype);
|
1350
|
+
if (!ndt_equal(dtype, ndt_dtype(p1))) {
|
1351
|
+
ndt_err_format(ctx, NDT_ValueError,
|
1352
|
+
"dtype of the out argument does not match");
|
982
1353
|
return -1;
|
983
1354
|
}
|
1355
|
+
}
|
1356
|
+
else {
|
1357
|
+
dtype = ndt_dtype(sig->Function.types[1]);
|
1358
|
+
}
|
1359
|
+
|
1360
|
+
if (ndt_is_abstract(dtype)) {
|
1361
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1362
|
+
"fast unary typecheck expects a concrete dtype");
|
1363
|
+
return -1;
|
1364
|
+
}
|
1365
|
+
|
1366
|
+
if (!unary_all_ellipses(p0, p1, ctx)) {
|
1367
|
+
return -1;
|
1368
|
+
}
|
1369
|
+
|
1370
|
+
if (ndt_as_ndarray(&x, types[0], ctx) < 0) {
|
1371
|
+
return -1;
|
1372
|
+
}
|
1373
|
+
|
1374
|
+
p0 = p0->EllipsisDim.type;
|
1375
|
+
p1 = p1->EllipsisDim.type;
|
1376
|
+
|
1377
|
+
if (unary_all_same_symbol(p0, p1)) {
|
1378
|
+
return _ndt_unary_broadcast(spec,
|
1379
|
+
&x, ndt_dtype(types[0]),
|
1380
|
+
out, dtype, check_broadcast,
|
1381
|
+
1, ctx);
|
1382
|
+
}
|
1383
|
+
else if (unary_all_ndim0(p0, p1)) {
|
1384
|
+
return _ndt_unary_broadcast(spec,
|
1385
|
+
&x, ndt_dtype(types[0]),
|
1386
|
+
out, dtype, check_broadcast,
|
1387
|
+
0, ctx);
|
1388
|
+
}
|
1389
|
+
else {
|
1390
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1391
|
+
"unsupported signature in fast unary typecheck");
|
1392
|
+
return -1;
|
1393
|
+
}
|
1394
|
+
}
|
1395
|
+
|
1396
|
+
static int
|
1397
|
+
_ndt_binary_broadcast(ndt_apply_spec_t *spec,
|
1398
|
+
const ndt_ndarray_t *x, const ndt_t *dtype_x,
|
1399
|
+
const ndt_ndarray_t *y, const ndt_t *dtype_y,
|
1400
|
+
const ndt_t *out, const ndt_t *dtype_out,
|
1401
|
+
const bool check_broadcast, const int inner,
|
1402
|
+
ndt_context_t *ctx)
|
1403
|
+
{
|
1404
|
+
int64_t shape[NDT_MAX_DIM];
|
1405
|
+
int size = x->ndim;
|
1406
|
+
ndt_ndarray_t z;
|
1407
|
+
const ndt_t *t;
|
984
1408
|
|
985
|
-
|
986
|
-
|
987
|
-
|
1409
|
+
for (int i = 0; i < size; i++) {
|
1410
|
+
shape[i] = x->shape[i];
|
1411
|
+
}
|
1412
|
+
|
1413
|
+
size = _resolve_broadcast(shape, size, y->shape, y->ndim);
|
1414
|
+
if (size < 0) {
|
1415
|
+
return broadcast_error("could not broadcast input arguments", ctx);
|
1416
|
+
}
|
988
1417
|
|
989
|
-
|
990
|
-
if (
|
1418
|
+
if (out != NULL) {
|
1419
|
+
if (ndt_as_ndarray(&z, out, ctx) < 0) {
|
991
1420
|
return -1;
|
992
1421
|
}
|
993
1422
|
|
994
|
-
|
995
|
-
if (
|
996
|
-
|
997
|
-
return -1;
|
1423
|
+
size = _resolve_broadcast(shape, size, z.shape, z.ndim);
|
1424
|
+
if (size < 0) {
|
1425
|
+
return broadcast_error("could not broadcast output argument", ctx);
|
998
1426
|
}
|
1427
|
+
}
|
1428
|
+
|
1429
|
+
spec->types[0] = fast_broadcast(x, dtype_x, shape, size, ctx);
|
1430
|
+
if (spec->types[0] == NULL) {
|
1431
|
+
return -1;
|
1432
|
+
}
|
1433
|
+
|
1434
|
+
spec->types[1] = fast_broadcast(y, dtype_y, shape, size, ctx);
|
1435
|
+
if (spec->types[1] == NULL) {
|
1436
|
+
ndt_decref(spec->types[0]);
|
1437
|
+
return -1;
|
1438
|
+
}
|
1439
|
+
|
1440
|
+
if (out != NULL) {
|
1441
|
+
if (check_broadcast) {
|
1442
|
+
t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
|
1443
|
+
if (t == NULL) {
|
1444
|
+
ndt_decref(spec->types[0]);
|
1445
|
+
ndt_decref(spec->types[1]);
|
1446
|
+
return -1;
|
1447
|
+
}
|
999
1448
|
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1449
|
+
if (!ndt_equal(t, out)) {
|
1450
|
+
ndt_err_format(ctx, NDT_ValueError,
|
1451
|
+
"explicit 'out' type not compatible with input types");
|
1452
|
+
ndt_decref(spec->types[0]);
|
1453
|
+
ndt_decref(spec->types[1]);
|
1454
|
+
ndt_decref(t);
|
1455
|
+
return -1;
|
1456
|
+
}
|
1457
|
+
}
|
1458
|
+
else {
|
1459
|
+
t = fast_broadcast(&z, dtype_out, shape, size, ctx);
|
1460
|
+
if (t == NULL) {
|
1461
|
+
ndt_decref(spec->types[0]);
|
1462
|
+
ndt_decref(spec->types[1]);
|
1463
|
+
return -1;
|
1464
|
+
}
|
1465
|
+
}
|
1466
|
+
}
|
1467
|
+
else {
|
1468
|
+
t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
|
1469
|
+
if (t == NULL) {
|
1470
|
+
ndt_decref(spec->types[0]);
|
1471
|
+
ndt_decref(spec->types[1]);
|
1004
1472
|
return -1;
|
1005
1473
|
}
|
1006
1474
|
}
|
1475
|
+
spec->types[2] = t;
|
1007
1476
|
|
1008
|
-
|
1477
|
+
spec->outer_dims = size-inner;
|
1478
|
+
spec->nin = 2;
|
1479
|
+
spec->nout = 1;
|
1480
|
+
spec->nargs = 3;
|
1481
|
+
|
1482
|
+
if (ndt_select_kernel_strategy(spec, ctx) < 0) {
|
1483
|
+
ndt_apply_spec_clear(spec);
|
1484
|
+
return -1;
|
1485
|
+
}
|
1009
1486
|
|
1010
1487
|
return 0;
|
1011
1488
|
}
|
1012
1489
|
|
1013
1490
|
static bool
|
1014
|
-
|
1015
|
-
|
1491
|
+
binary_all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
|
1492
|
+
ndt_context_t *ctx)
|
1016
1493
|
{
|
1017
1494
|
if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
|
1018
1495
|
(t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL) ||
|
@@ -1026,7 +1503,7 @@ all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
|
|
1026
1503
|
}
|
1027
1504
|
|
1028
1505
|
static bool
|
1029
|
-
|
1506
|
+
binary_all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
|
1030
1507
|
{
|
1031
1508
|
if (t0->tag != SymbolicDim || t1->tag != SymbolicDim ||
|
1032
1509
|
t2->tag != SymbolicDim) {
|
@@ -1038,37 +1515,37 @@ all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
|
|
1038
1515
|
}
|
1039
1516
|
|
1040
1517
|
static bool
|
1041
|
-
|
1518
|
+
binary_all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
|
1042
1519
|
{
|
1043
1520
|
return t0->ndim == 0 && t1->ndim == 0 && t2->ndim == 0;
|
1044
1521
|
}
|
1045
1522
|
|
1046
1523
|
/*
|
1047
1524
|
* Optimized type checking for very specific signatures. The caller must
|
1048
|
-
* have identified the kernel location
|
1049
|
-
*
|
1050
|
-
*
|
1525
|
+
* have identified the kernel location and the signature. For performance
|
1526
|
+
* reasons, no substitution is performed on the dtype, so the dtype must be
|
1527
|
+
* concrete.
|
1051
1528
|
*
|
1052
|
-
* Supported
|
1053
|
-
* 1) ... * N * T0, ... * N * T1 -> N * T2
|
1054
|
-
* 2) ... * T0, ... * T1 -> ... * T2
|
1529
|
+
* Supported signature: 1) ... * T0, ... * T1 -> ... * T2
|
1055
1530
|
*/
|
1056
1531
|
int
|
1057
1532
|
ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
1058
|
-
const ndt_t *
|
1059
|
-
ndt_context_t *ctx)
|
1533
|
+
const ndt_t *types[], const int nin, const int nout,
|
1534
|
+
const bool check_broadcast, ndt_context_t *ctx)
|
1060
1535
|
{
|
1061
|
-
ndt_t *p0, *p1, *p2;
|
1536
|
+
const ndt_t *p0, *p1, *p2;
|
1062
1537
|
ndt_ndarray_t x, y;
|
1538
|
+
const ndt_t *out = NULL;
|
1539
|
+
const ndt_t *dtype = NULL;
|
1063
1540
|
|
1064
1541
|
assert(spec->flags == 0);
|
1065
|
-
assert(spec->nout == 0);
|
1066
|
-
assert(spec->nbroadcast == 0);
|
1067
1542
|
assert(spec->outer_dims == 0);
|
1543
|
+
assert(spec->nin == 0);
|
1544
|
+
assert(spec->nout == 0);
|
1545
|
+
assert(spec->nargs == 0);
|
1068
1546
|
|
1069
1547
|
if (sig->tag != Function ||
|
1070
|
-
sig->Function.nin != 2
|
1071
|
-
sig->Function.nout != 1) {
|
1548
|
+
sig->Function.nin != 2) {
|
1072
1549
|
ndt_err_format(ctx, NDT_RuntimeError,
|
1073
1550
|
"fast binary typecheck expects a signature with two inputs and "
|
1074
1551
|
"one output");
|
@@ -1081,27 +1558,44 @@ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
1081
1558
|
return -1;
|
1082
1559
|
}
|
1083
1560
|
|
1561
|
+
p0 = sig->Function.types[0];
|
1562
|
+
p1 = sig->Function.types[1];
|
1563
|
+
p2 = sig->Function.types[2];
|
1564
|
+
|
1565
|
+
if (nout) {
|
1566
|
+
if (nout != 1) {
|
1567
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1568
|
+
"fast binary typecheck expects at most one explicit out argument");
|
1569
|
+
return -1;
|
1570
|
+
}
|
1571
|
+
out = types[2];
|
1572
|
+
dtype = ndt_dtype(types[2]);
|
1573
|
+
|
1574
|
+
if (!ndt_equal(dtype, ndt_dtype(p2))) {
|
1575
|
+
ndt_err_format(ctx, NDT_ValueError,
|
1576
|
+
"dtype of the out argument does not match");
|
1577
|
+
return -1;
|
1578
|
+
}
|
1579
|
+
}
|
1580
|
+
else {
|
1581
|
+
dtype = ndt_dtype(sig->Function.types[2]);
|
1582
|
+
}
|
1583
|
+
|
1084
1584
|
if (ndt_is_abstract(dtype)) {
|
1085
1585
|
ndt_err_format(ctx, NDT_RuntimeError,
|
1086
1586
|
"fast binary typecheck expects a concrete dtype");
|
1087
1587
|
return -1;
|
1088
1588
|
}
|
1089
1589
|
|
1090
|
-
p0
|
1091
|
-
p1 = sig->Function.types[1];
|
1092
|
-
p2 = sig->Function.types[2];
|
1093
|
-
|
1094
|
-
if (!all_ellipses(p0, p1, p2, ctx)) {
|
1590
|
+
if (!binary_all_ellipses(p0, p1, p2, ctx)) {
|
1095
1591
|
return -1;
|
1096
1592
|
}
|
1097
1593
|
|
1098
|
-
if (ndt_as_ndarray(&x,
|
1099
|
-
ndt_del(dtype);
|
1594
|
+
if (ndt_as_ndarray(&x, types[0], ctx) < 0) {
|
1100
1595
|
return -1;
|
1101
1596
|
}
|
1102
1597
|
|
1103
|
-
if (ndt_as_ndarray(&y,
|
1104
|
-
ndt_del(dtype);
|
1598
|
+
if (ndt_as_ndarray(&y, types[1], ctx) < 0) {
|
1105
1599
|
return -1;
|
1106
1600
|
}
|
1107
1601
|
|
@@ -1109,20 +1603,27 @@ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
1109
1603
|
p1 = p1->EllipsisDim.type;
|
1110
1604
|
p2 = p2->EllipsisDim.type;
|
1111
1605
|
|
1112
|
-
if (
|
1606
|
+
if (binary_all_same_symbol(p0, p1, p2)) {
|
1113
1607
|
if (x.ndim > 0 && y.ndim > 0) {
|
1114
1608
|
const int64_t xshape = x.shape[x.ndim-1];
|
1115
1609
|
const int64_t yshape = y.shape[y.ndim-1];
|
1116
1610
|
if (xshape != 1 && yshape != 1 && xshape != yshape) {
|
1117
1611
|
ndt_err_format(ctx, NDT_TypeError, "mismatch in inner dimensions");
|
1118
|
-
ndt_del(dtype);
|
1119
1612
|
return -1;
|
1120
1613
|
}
|
1121
1614
|
}
|
1122
|
-
return _ndt_binary_broadcast(spec,
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1615
|
+
return _ndt_binary_broadcast(spec,
|
1616
|
+
&x, ndt_dtype(types[0]),
|
1617
|
+
&y, ndt_dtype(types[1]),
|
1618
|
+
out, dtype, check_broadcast,
|
1619
|
+
1, ctx);
|
1620
|
+
}
|
1621
|
+
else if (binary_all_ndim0(p0, p1, p2)) {
|
1622
|
+
return _ndt_binary_broadcast(spec,
|
1623
|
+
&x, ndt_dtype(types[0]),
|
1624
|
+
&y, ndt_dtype(types[1]),
|
1625
|
+
out, dtype, check_broadcast,
|
1626
|
+
0, ctx);
|
1126
1627
|
}
|
1127
1628
|
else {
|
1128
1629
|
ndt_err_format(ctx, NDT_RuntimeError,
|