ndtypes 0.2.0dev5 → 0.2.0dev6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +12 -0
- data/Rakefile +8 -0
- data/ext/ruby_ndtypes/GPATH +0 -0
- data/ext/ruby_ndtypes/GRTAGS +0 -0
- data/ext/ruby_ndtypes/GTAGS +0 -0
- data/ext/ruby_ndtypes/extconf.rb +1 -1
- data/ext/ruby_ndtypes/include/ndtypes.h +231 -122
- data/ext/ruby_ndtypes/include/ruby_ndtypes.h +1 -1
- data/ext/ruby_ndtypes/lib/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/lib/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/Makefile +87 -0
- data/ext/ruby_ndtypes/ndtypes/config.h +68 -0
- data/ext/ruby_ndtypes/ndtypes/config.log +477 -0
- data/ext/ruby_ndtypes/ndtypes/config.status +1027 -0
- data/ext/ruby_ndtypes/ndtypes/doc/_static/style.css +7 -0
- data/ext/ruby_ndtypes/ndtypes/doc/_templates/layout.html +2 -0
- data/ext/ruby_ndtypes/ndtypes/doc/conf.py +40 -4
- data/ext/ruby_ndtypes/ndtypes/doc/images/xndlogo.png +0 -0
- data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +1 -1
- data/ext/ruby_ndtypes/ndtypes/doc/requirements.txt +2 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile +287 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +20 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +22 -3
- data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile +73 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +246 -229
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +15 -11
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +38 -28
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +91 -91
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +4 -3
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +8 -7
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/context.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +263 -182
- data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +67 -7
- data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +1112 -1000
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +69 -58
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +150 -99
- data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +185 -15
- data/ext/ruby_ndtypes/ndtypes/libndtypes/io.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +301 -276
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +9 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so +1 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0 +1 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +729 -228
- data/ext/ruby_ndtypes/ndtypes/libndtypes/match.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +768 -403
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h +1002 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +231 -122
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +176 -84
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +26 -14
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +57 -35
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.c +420 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +8 -8
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile +48 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +200 -116
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +46 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +58 -27
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +3 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +12 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile +55 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +8 -8
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +5 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +274 -172
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +24 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +14 -14
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +32 -30
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +37 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +36 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +16 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +5 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +706 -253
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_unify.c +132 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.c +703 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +335 -127
- data/ext/ruby_ndtypes/ndtypes/libndtypes/util.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/values.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +88 -71
- data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +0 -1
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +10 -13
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +395 -314
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so +1 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0 +1 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/ndtypes.h +1002 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +15 -33
- data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +340 -132
- data/ext/ruby_ndtypes/ndtypes/setup.py +11 -2
- data/ext/ruby_ndtypes/ruby_ndtypes.c +364 -241
- data/ext/ruby_ndtypes/ruby_ndtypes.h +1 -1
- data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +0 -1
- data/lib/ndtypes.rb +11 -0
- data/lib/ndtypes/version.rb +2 -2
- data/lib/ruby_ndtypes.so +0 -0
- data/ndtypes.gemspec +3 -0
- data/spec/ndtypes_spec.rb +6 -0
- metadata +98 -4
- data/ext/ruby_ndtypes/gc_guard.c +0 -36
- data/ext/ruby_ndtypes/gc_guard.h +0 -12
|
@@ -144,24 +144,26 @@ yycolumn = 1;
|
|
|
144
144
|
"void" { return VOID; }
|
|
145
145
|
"bool" { return BOOL; }
|
|
146
146
|
|
|
147
|
-
"
|
|
147
|
+
"signed" { return SIGNED_KIND; }
|
|
148
148
|
"int8" { return INT8; }
|
|
149
149
|
"int16" { return INT16; }
|
|
150
150
|
"int32" { return INT32; }
|
|
151
151
|
"int64" { return INT64; }
|
|
152
152
|
|
|
153
|
-
"
|
|
153
|
+
"unsigned" { return UNSIGNED_KIND; }
|
|
154
154
|
"uint8" { return UINT8; }
|
|
155
155
|
"uint16" { return UINT16; }
|
|
156
156
|
"uint32" { return UINT32; }
|
|
157
157
|
"uint64" { return UINT64; }
|
|
158
158
|
|
|
159
|
-
"
|
|
159
|
+
"float" { return FLOAT_KIND; }
|
|
160
|
+
"bfloat16" { return BFLOAT16; }
|
|
160
161
|
"float16" { return FLOAT16; }
|
|
161
162
|
"float32" { return FLOAT32; }
|
|
162
163
|
"float64" { return FLOAT64; }
|
|
163
164
|
|
|
164
|
-
"
|
|
165
|
+
"complex" { return COMPLEX_KIND; }
|
|
166
|
+
"bcomplex32" { return BCOMPLEX32; }
|
|
165
167
|
"complex32" { return COMPLEX32; }
|
|
166
168
|
"complex64" { return COMPLEX64; }
|
|
167
169
|
"complex128" { return COMPLEX128; }
|
|
@@ -186,6 +188,9 @@ yycolumn = 1;
|
|
|
186
188
|
|
|
187
189
|
"fixed" { return FIXED; }
|
|
188
190
|
"var" { return VAR; }
|
|
191
|
+
"array" { return ARRAY; }
|
|
192
|
+
|
|
193
|
+
"of" { return OF; }
|
|
189
194
|
|
|
190
195
|
"..." { return ELLIPSIS; }
|
|
191
196
|
"->" { return RARROW; }
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3
|
|
Binary file
|
|
@@ -103,7 +103,7 @@ resolve_broadcast(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
|
|
103
103
|
}
|
|
104
104
|
|
|
105
105
|
static int
|
|
106
|
-
check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
|
|
106
|
+
check_contig(const ndt_t *ptypes[], const ndt_t *ctypes[], int64_t nargs)
|
|
107
107
|
{
|
|
108
108
|
for (int i = 0; i < nargs; i++) {
|
|
109
109
|
const ndt_t *p = ptypes[i];
|
|
@@ -130,14 +130,14 @@ check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
|
|
|
130
130
|
return 1;
|
|
131
131
|
}
|
|
132
132
|
|
|
133
|
-
static ndt_t *
|
|
134
|
-
to_fortran(const ndt_t *p, ndt_t *c, ndt_context_t *ctx)
|
|
133
|
+
static const ndt_t *
|
|
134
|
+
to_fortran(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
|
|
135
135
|
{
|
|
136
136
|
if (p->tag == EllipsisDim && p->EllipsisDim.tag == RequireF) {
|
|
137
|
-
|
|
138
|
-
return t;
|
|
137
|
+
return ndt_to_fortran(c, ctx);
|
|
139
138
|
}
|
|
140
139
|
else {
|
|
140
|
+
ndt_incref(c);
|
|
141
141
|
return c;
|
|
142
142
|
}
|
|
143
143
|
}
|
|
@@ -217,41 +217,124 @@ resolve_typevar(const char *key, symtable_entry_t w, symtable_t *tbl, ndt_contex
|
|
|
217
217
|
}
|
|
218
218
|
}
|
|
219
219
|
|
|
220
|
+
typedef struct {
|
|
221
|
+
int64_t index;
|
|
222
|
+
const ndt_t *type;
|
|
223
|
+
} indexed_type_t;
|
|
224
|
+
|
|
225
|
+
static indexed_type_t indexed_type_error = { 0, NULL };
|
|
226
|
+
|
|
227
|
+
static inline int64_t
|
|
228
|
+
adjust_index(const int64_t i, const int64_t shape, ndt_context_t *ctx)
|
|
229
|
+
{
|
|
230
|
+
const int64_t k = i < 0 ? i + shape : i;
|
|
231
|
+
|
|
232
|
+
if (k < 0 || k >= shape) {
|
|
233
|
+
ndt_err_format(ctx, NDT_IndexError,
|
|
234
|
+
"index with value %" PRIi64 " out of bounds", i);
|
|
235
|
+
return -1;
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
return k;
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
static inline indexed_type_t
|
|
242
|
+
var_dim_next(const indexed_type_t *x, const int64_t start, const int64_t step,
|
|
243
|
+
const int64_t i)
|
|
244
|
+
{
|
|
245
|
+
indexed_type_t next;
|
|
246
|
+
const ndt_t *t = x->type;
|
|
247
|
+
|
|
248
|
+
next.index = start + i * step;
|
|
249
|
+
next.type = t->VarDim.type;
|
|
250
|
+
|
|
251
|
+
return next;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
/* skip stored indices */
|
|
255
|
+
static indexed_type_t
|
|
256
|
+
apply_stored_index(const indexed_type_t *x, ndt_context_t *ctx)
|
|
257
|
+
{
|
|
258
|
+
const ndt_t * const t = x->type;
|
|
259
|
+
int64_t start, step, shape;
|
|
260
|
+
|
|
261
|
+
if (t->tag != VarDimElem) {
|
|
262
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
263
|
+
"apply_stored_index: need VarDimElem");
|
|
264
|
+
return indexed_type_error;
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
shape = ndt_var_indices(&start, &step, t, x->index, ctx);
|
|
268
|
+
if (shape < 0) {
|
|
269
|
+
return indexed_type_error;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
const int64_t i = adjust_index(t->VarDimElem.index, shape, ctx);
|
|
273
|
+
if (i < 0) {
|
|
274
|
+
return indexed_type_error;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
return var_dim_next(x, start, step, i);
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
static indexed_type_t
|
|
281
|
+
apply_stored_indices(const indexed_type_t *x, ndt_context_t *ctx)
|
|
282
|
+
{
|
|
283
|
+
indexed_type_t tl = *x;
|
|
284
|
+
|
|
285
|
+
while (tl.type->tag == VarDimElem) {
|
|
286
|
+
tl = apply_stored_index(&tl, ctx);
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
return tl;
|
|
290
|
+
}
|
|
291
|
+
|
|
220
292
|
static int
|
|
221
|
-
match_concrete_var_dim(const
|
|
222
|
-
const ndt_t *u, int64_t uindex,
|
|
293
|
+
match_concrete_var_dim(const indexed_type_t *a, const indexed_type_t *b,
|
|
223
294
|
const int outer_dims, ndt_context_t *ctx)
|
|
224
295
|
{
|
|
225
|
-
int64_t
|
|
226
|
-
int64_t
|
|
296
|
+
int64_t xshape, xstart, xstep;
|
|
297
|
+
int64_t yshape, ystart, ystep;
|
|
227
298
|
|
|
228
299
|
if (outer_dims == 0) {
|
|
229
300
|
return 1;
|
|
230
301
|
}
|
|
302
|
+
|
|
303
|
+
const indexed_type_t x = apply_stored_indices(a, ctx);
|
|
304
|
+
if (x.type == NULL) {
|
|
305
|
+
return -1;
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
const indexed_type_t y = apply_stored_indices(b, ctx);
|
|
309
|
+
if (x.type == NULL) {
|
|
310
|
+
return -1;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
const ndt_t * const t = x.type;
|
|
314
|
+
const ndt_t * const u = y.type;
|
|
315
|
+
|
|
231
316
|
if (t->Concrete.VarDim.itemsize != u->Concrete.VarDim.itemsize) {
|
|
232
317
|
return 0;
|
|
233
318
|
}
|
|
234
319
|
|
|
235
|
-
|
|
236
|
-
if (
|
|
320
|
+
xshape = ndt_var_indices(&xstart, &xstep, t, x.index, ctx);
|
|
321
|
+
if (xshape < 0) {
|
|
237
322
|
return -1;
|
|
238
323
|
}
|
|
239
324
|
|
|
240
|
-
|
|
241
|
-
if (
|
|
325
|
+
yshape = ndt_var_indices(&ystart, &ystep, u, y.index, ctx);
|
|
326
|
+
if (yshape < 0) {
|
|
242
327
|
return -1;
|
|
243
328
|
}
|
|
244
329
|
|
|
245
|
-
if (
|
|
330
|
+
if (xshape != yshape) {
|
|
246
331
|
return 0;
|
|
247
332
|
}
|
|
248
333
|
|
|
249
|
-
for (int64_t i = 0; i <
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
int ret = match_concrete_var_dim(
|
|
253
|
-
u->VarDim.type, unext,
|
|
254
|
-
outer_dims-1, ctx);
|
|
334
|
+
for (int64_t i = 0; i < xshape; i++) {
|
|
335
|
+
const indexed_type_t xnext = var_dim_next(&x, xstart, xstep, i);
|
|
336
|
+
const indexed_type_t ynext = var_dim_next(&y, ystart, ystep, i);
|
|
337
|
+
int ret = match_concrete_var_dim(&xnext, &ynext, outer_dims-1, ctx);
|
|
255
338
|
if (ret <= 0) {
|
|
256
339
|
return ret;
|
|
257
340
|
}
|
|
@@ -265,6 +348,7 @@ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
|
|
265
348
|
{
|
|
266
349
|
const char *key = "var";
|
|
267
350
|
symtable_entry_t v;
|
|
351
|
+
int vdims, wdims;
|
|
268
352
|
|
|
269
353
|
v = symtable_find(tbl, key);
|
|
270
354
|
if (v.tag == Unbound) {
|
|
@@ -274,16 +358,36 @@ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
|
|
274
358
|
return 1;
|
|
275
359
|
}
|
|
276
360
|
|
|
277
|
-
|
|
361
|
+
vdims = ndt_logical_ndim(v.VarSeq.dims[0]);
|
|
362
|
+
wdims = ndt_logical_ndim(w.VarSeq.dims[0]);
|
|
363
|
+
|
|
364
|
+
if (wdims != vdims) {
|
|
278
365
|
return 0;
|
|
279
366
|
}
|
|
280
|
-
if (
|
|
367
|
+
if (vdims == 0) {
|
|
281
368
|
return 1;
|
|
282
369
|
}
|
|
283
370
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
371
|
+
const indexed_type_t x = { w.VarSeq.linear_index, w.VarSeq.dims[0] };
|
|
372
|
+
const indexed_type_t y = { v.VarSeq.linear_index, v.VarSeq.dims[0] };
|
|
373
|
+
return match_concrete_var_dim(&x, &y, vdims, ctx);
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
static int
|
|
377
|
+
resolve_array(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
|
378
|
+
{
|
|
379
|
+
const char *key = "array";
|
|
380
|
+
symtable_entry_t v;
|
|
381
|
+
|
|
382
|
+
v = symtable_find(tbl, key);
|
|
383
|
+
if (v.tag == Unbound) {
|
|
384
|
+
if (symtable_add(tbl, key, w, ctx) < 0) {
|
|
385
|
+
return -1;
|
|
386
|
+
}
|
|
387
|
+
return 1;
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
return v.ArraySeq.size == w.ArraySeq.size;
|
|
287
391
|
}
|
|
288
392
|
|
|
289
393
|
static int
|
|
@@ -332,8 +436,32 @@ match_record_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
|
332
436
|
}
|
|
333
437
|
|
|
334
438
|
static int
|
|
335
|
-
|
|
336
|
-
|
|
439
|
+
match_union_tags(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
440
|
+
ndt_context_t *ctx)
|
|
441
|
+
{
|
|
442
|
+
int64_t i;
|
|
443
|
+
int n;
|
|
444
|
+
|
|
445
|
+
assert(p->tag == Union && c->tag == Union);
|
|
446
|
+
|
|
447
|
+
if (p->Union.ntags != c->Union.ntags) {
|
|
448
|
+
return 0;
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
for (i = 0; i < p->Union.ntags; i++) {
|
|
452
|
+
n = strcmp(p->Union.tags[i], c->Union.tags[i]);
|
|
453
|
+
if (n != 0) return 0;
|
|
454
|
+
|
|
455
|
+
n = match_datashape(p->Union.types[i], c->Union.types[i], tbl, ctx);
|
|
456
|
+
if (n <= 0) return n;
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
return 1;
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
static int
|
|
463
|
+
match_categorical(const ndt_value_t *p, int64_t plen,
|
|
464
|
+
const ndt_value_t *c, int64_t clen)
|
|
337
465
|
{
|
|
338
466
|
int64_t i;
|
|
339
467
|
|
|
@@ -378,7 +506,7 @@ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
|
|
|
378
506
|
}
|
|
379
507
|
return outer_inner(v, i+1, t->FixedDim.type, ndim);
|
|
380
508
|
}
|
|
381
|
-
case VarDim: {
|
|
509
|
+
case VarDim: case VarDimElem: {
|
|
382
510
|
switch (v->tag) {
|
|
383
511
|
case VarSeq:
|
|
384
512
|
v->VarSeq.size = i+1;
|
|
@@ -389,6 +517,18 @@ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
|
|
|
389
517
|
}
|
|
390
518
|
return outer_inner(v, i+1, t->VarDim.type, ndim);
|
|
391
519
|
}
|
|
520
|
+
case Array: {
|
|
521
|
+
switch (v->tag) {
|
|
522
|
+
case ArraySeq:
|
|
523
|
+
v->ArraySeq.size = i+1;
|
|
524
|
+
v->ArraySeq.dims[i] = t;
|
|
525
|
+
break;
|
|
526
|
+
default:
|
|
527
|
+
return NULL;
|
|
528
|
+
}
|
|
529
|
+
return outer_inner(v, i+1, t->Array.type, ndim);
|
|
530
|
+
}
|
|
531
|
+
|
|
392
532
|
default:
|
|
393
533
|
return NULL;
|
|
394
534
|
}
|
|
@@ -400,6 +540,10 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
|
400
540
|
{
|
|
401
541
|
int n;
|
|
402
542
|
|
|
543
|
+
if (c->tag == VarDimElem) {
|
|
544
|
+
return match_datashape(p, c->VarDimElem.type, tbl, ctx);
|
|
545
|
+
}
|
|
546
|
+
|
|
403
547
|
if (ndt_is_optional(c) != ndt_is_optional(p)) return 0;
|
|
404
548
|
|
|
405
549
|
switch (p->tag) {
|
|
@@ -445,57 +589,11 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
|
445
589
|
return match_datashape(p->SymbolicDim.type, c->FixedDim.type, tbl, ctx);
|
|
446
590
|
}
|
|
447
591
|
|
|
448
|
-
case EllipsisDim: {
|
|
449
|
-
symtable_entry_t outer;
|
|
450
|
-
const ndt_t *inner;
|
|
451
|
-
|
|
452
|
-
if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
|
|
453
|
-
return 0;
|
|
454
|
-
}
|
|
455
|
-
if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
|
|
456
|
-
return 0;
|
|
457
|
-
}
|
|
458
|
-
|
|
459
|
-
if (p->EllipsisDim.name == NULL) {
|
|
460
|
-
outer.tag = BroadcastSeq;
|
|
461
|
-
outer.BroadcastSeq.size = 0;
|
|
462
|
-
}
|
|
463
|
-
else if (strcmp(p->EllipsisDim.name, "var") == 0) {
|
|
464
|
-
outer.tag = VarSeq;
|
|
465
|
-
outer.VarSeq.size = 0;
|
|
466
|
-
}
|
|
467
|
-
else {
|
|
468
|
-
outer.tag = FixedSeq;
|
|
469
|
-
outer.FixedSeq.size = 0;
|
|
470
|
-
}
|
|
471
|
-
|
|
472
|
-
inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
|
|
473
|
-
if (inner == NULL) {
|
|
474
|
-
return 0;
|
|
475
|
-
}
|
|
476
|
-
|
|
477
|
-
n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
|
|
478
|
-
if (n <= 0) {
|
|
479
|
-
return n;
|
|
480
|
-
}
|
|
481
|
-
|
|
482
|
-
switch (outer.tag) {
|
|
483
|
-
case BroadcastSeq:
|
|
484
|
-
return resolve_broadcast(outer, tbl, ctx);
|
|
485
|
-
case FixedSeq:
|
|
486
|
-
return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
|
|
487
|
-
case VarSeq:
|
|
488
|
-
return resolve_var(outer, tbl, ctx);
|
|
489
|
-
default: /* NOT REACHED */
|
|
490
|
-
ndt_internal_error("invalid tag");
|
|
491
|
-
}
|
|
492
|
-
}
|
|
493
|
-
|
|
494
592
|
case Bool:
|
|
495
593
|
case Int8: case Int16: case Int32: case Int64:
|
|
496
594
|
case Uint8: case Uint16: case Uint32: case Uint64:
|
|
497
|
-
case Float16: case Float32: case Float64:
|
|
498
|
-
case Complex32: case Complex64: case Complex128:
|
|
595
|
+
case BFloat16: case Float16: case Float32: case Float64:
|
|
596
|
+
case BComplex32: case Complex32: case Complex64: case Complex128:
|
|
499
597
|
case String:
|
|
500
598
|
return p->tag == c->tag;
|
|
501
599
|
case FixedString:
|
|
@@ -528,6 +626,10 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
|
528
626
|
return c->tag == Categorical &&
|
|
529
627
|
match_categorical(p->Categorical.types, p->Categorical.ntypes,
|
|
530
628
|
c->Categorical.types, c->Categorical.ntypes);
|
|
629
|
+
case Array:
|
|
630
|
+
if (c->tag != Array) return 0;
|
|
631
|
+
if (c->Array.itemsize != p->Array.itemsize) return 0;
|
|
632
|
+
return match_datashape(p->Array.type, c->Array.type, tbl, ctx);
|
|
531
633
|
case Ref:
|
|
532
634
|
if (c->tag != Ref) return 0;
|
|
533
635
|
return match_datashape(p->Ref.type, c->Ref.type, tbl, ctx);
|
|
@@ -536,9 +638,12 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
|
536
638
|
if (c->tag != Tuple) return 0;
|
|
537
639
|
return match_tuple_fields(p, c, tbl, ctx);
|
|
538
640
|
case Record:
|
|
539
|
-
if (p->
|
|
641
|
+
if (p->Record.flag == Variadic) return 0;
|
|
540
642
|
if (c->tag != Record) return 0;
|
|
541
643
|
return match_record_fields(p, c, tbl, ctx);
|
|
644
|
+
case Union:
|
|
645
|
+
if (c->tag != Union) return 0;
|
|
646
|
+
return match_union_tags(p, c, tbl, ctx);
|
|
542
647
|
case Function: {
|
|
543
648
|
int64_t i;
|
|
544
649
|
if (c->tag != Function ||
|
|
@@ -576,12 +681,83 @@ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
|
|
576
681
|
case Constr:
|
|
577
682
|
return c->tag == Constr && strcmp(p->Constr.name, c->Constr.name) == 0 &&
|
|
578
683
|
ndt_equal(p->Constr.type, c->Constr.type);
|
|
684
|
+
case VarDimElem:
|
|
685
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
686
|
+
"VarDimElem cannot occur in pattern");
|
|
687
|
+
return -1;
|
|
688
|
+
case EllipsisDim:
|
|
689
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
690
|
+
"match_datashape: internal_error: unexpected EllipsisDim");
|
|
691
|
+
return -1;
|
|
579
692
|
}
|
|
580
693
|
|
|
581
694
|
/* NOT REACHED: tags should be exhaustive. */
|
|
582
695
|
ndt_internal_error("invalid type");
|
|
583
696
|
}
|
|
584
697
|
|
|
698
|
+
static int
|
|
699
|
+
match_datashape_top(const ndt_t *p, const ndt_t *c, int64_t linear_index,
|
|
700
|
+
symtable_t *tbl, ndt_context_t *ctx)
|
|
701
|
+
{
|
|
702
|
+
switch (p->tag) {
|
|
703
|
+
case EllipsisDim: {
|
|
704
|
+
symtable_entry_t outer;
|
|
705
|
+
const ndt_t *inner;
|
|
706
|
+
int n;
|
|
707
|
+
|
|
708
|
+
if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
|
|
709
|
+
return 0;
|
|
710
|
+
}
|
|
711
|
+
if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
|
|
712
|
+
return 0;
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
if (p->EllipsisDim.name == NULL) {
|
|
716
|
+
outer.tag = BroadcastSeq;
|
|
717
|
+
outer.BroadcastSeq.size = 0;
|
|
718
|
+
}
|
|
719
|
+
else if (strcmp(p->EllipsisDim.name, "var") == 0) {
|
|
720
|
+
outer.tag = VarSeq;
|
|
721
|
+
outer.VarSeq.size = 0;
|
|
722
|
+
outer.VarSeq.linear_index = linear_index;
|
|
723
|
+
}
|
|
724
|
+
else if (strcmp(p->EllipsisDim.name, "array") == 0) {
|
|
725
|
+
outer.tag = ArraySeq;
|
|
726
|
+
outer.ArraySeq.size = 0;
|
|
727
|
+
}
|
|
728
|
+
else {
|
|
729
|
+
outer.tag = FixedSeq;
|
|
730
|
+
outer.FixedSeq.size = 0;
|
|
731
|
+
}
|
|
732
|
+
|
|
733
|
+
inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
|
|
734
|
+
if (inner == NULL) {
|
|
735
|
+
return 0;
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
|
|
739
|
+
if (n <= 0) {
|
|
740
|
+
return n;
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
switch (outer.tag) {
|
|
744
|
+
case BroadcastSeq:
|
|
745
|
+
return resolve_broadcast(outer, tbl, ctx);
|
|
746
|
+
case FixedSeq:
|
|
747
|
+
return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
|
|
748
|
+
case VarSeq:
|
|
749
|
+
return resolve_var(outer, tbl, ctx);
|
|
750
|
+
case ArraySeq:
|
|
751
|
+
return resolve_array(outer, tbl, ctx);
|
|
752
|
+
default: /* NOT REACHED */
|
|
753
|
+
ndt_internal_error("invalid tag");
|
|
754
|
+
}
|
|
755
|
+
}
|
|
756
|
+
default:
|
|
757
|
+
return match_datashape(p, c, tbl, ctx);
|
|
758
|
+
}
|
|
759
|
+
}
|
|
760
|
+
|
|
585
761
|
int
|
|
586
762
|
ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
|
|
587
763
|
{
|
|
@@ -597,19 +773,18 @@ ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
|
|
|
597
773
|
return -1;
|
|
598
774
|
}
|
|
599
775
|
|
|
600
|
-
ret =
|
|
776
|
+
ret = match_datashape_top(p, c, 0, tbl, ctx);
|
|
601
777
|
symtable_del(tbl);
|
|
602
778
|
return ret;
|
|
603
779
|
}
|
|
604
780
|
|
|
605
|
-
static ndt_t *
|
|
781
|
+
static const ndt_t *
|
|
606
782
|
broadcast(const ndt_t *t, const int64_t *shape,
|
|
607
783
|
int outer_dims, int inner_dims,
|
|
608
784
|
bool use_max, ndt_context_t *ctx)
|
|
609
785
|
{
|
|
610
786
|
ndt_ndarray_t u;
|
|
611
|
-
const ndt_t *
|
|
612
|
-
ndt_t *v;
|
|
787
|
+
const ndt_t *v, *w;
|
|
613
788
|
int64_t step;
|
|
614
789
|
int ndim;
|
|
615
790
|
int i, k;
|
|
@@ -619,14 +794,12 @@ broadcast(const ndt_t *t, const int64_t *shape,
|
|
|
619
794
|
return NULL;
|
|
620
795
|
}
|
|
621
796
|
|
|
622
|
-
|
|
623
|
-
v
|
|
624
|
-
if (v == NULL) {
|
|
625
|
-
return NULL;
|
|
626
|
-
}
|
|
797
|
+
v = ndt_dtype(t);
|
|
798
|
+
ndt_incref(v);
|
|
627
799
|
|
|
628
800
|
for (i=ndim-1; i>=ndim-inner_dims; i--) {
|
|
629
|
-
|
|
801
|
+
w = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
|
|
802
|
+
ndt_move(&v, w);
|
|
630
803
|
if (v == NULL) {
|
|
631
804
|
return NULL;
|
|
632
805
|
}
|
|
@@ -634,7 +807,8 @@ broadcast(const ndt_t *t, const int64_t *shape,
|
|
|
634
807
|
|
|
635
808
|
for (k=outer_dims-1; i>=0 && k>=0; i--, k--) {
|
|
636
809
|
step = u.shape[i]<=1 ? 0 : u.steps[i];
|
|
637
|
-
|
|
810
|
+
w = ndt_fixed_dim(v, shape[k], step, ctx);
|
|
811
|
+
ndt_move(&v, w);
|
|
638
812
|
if (v == NULL) {
|
|
639
813
|
return NULL;
|
|
640
814
|
}
|
|
@@ -642,11 +816,12 @@ broadcast(const ndt_t *t, const int64_t *shape,
|
|
|
642
816
|
|
|
643
817
|
for (; k>=0; k--) {
|
|
644
818
|
if (use_max) {
|
|
645
|
-
|
|
819
|
+
w = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
|
|
646
820
|
}
|
|
647
821
|
else {
|
|
648
|
-
|
|
822
|
+
w = ndt_fixed_dim(v, shape[k], 0, ctx);
|
|
649
823
|
}
|
|
824
|
+
ndt_move(&v, w);
|
|
650
825
|
if (v == NULL) {
|
|
651
826
|
return NULL;
|
|
652
827
|
}
|
|
@@ -656,34 +831,44 @@ broadcast(const ndt_t *t, const int64_t *shape,
|
|
|
656
831
|
}
|
|
657
832
|
|
|
658
833
|
int
|
|
659
|
-
ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
660
|
-
const
|
|
661
|
-
const int64_t *shape, int outer_dims,
|
|
662
|
-
ndt_context_t *ctx)
|
|
834
|
+
ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig, bool check_broadcast,
|
|
835
|
+
const int64_t *shape, const int outer_dims, ndt_context_t *ctx)
|
|
663
836
|
{
|
|
664
|
-
|
|
837
|
+
const int nin = spec->nin;
|
|
838
|
+
const int nout = spec->nout;
|
|
665
839
|
int inner_dims;
|
|
666
840
|
int i;
|
|
667
841
|
|
|
668
842
|
for (i = 0; i < nin; i++) {
|
|
669
843
|
inner_dims = sig->Function.types[i]->ndim-1;
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
if (
|
|
844
|
+
const ndt_t *t = spec->types[i];
|
|
845
|
+
const ndt_t *u = broadcast(t, shape, outer_dims, inner_dims, false, ctx);
|
|
846
|
+
if (u == NULL) {
|
|
673
847
|
return -1;
|
|
674
848
|
}
|
|
675
|
-
|
|
849
|
+
ndt_decref(t);
|
|
850
|
+
spec->types[i] = u;
|
|
676
851
|
}
|
|
677
852
|
|
|
678
|
-
for (i = 0; i <
|
|
853
|
+
for (i = 0; i < nout; i++) {
|
|
679
854
|
inner_dims = sig->Function.types[nin+i]->ndim-1;
|
|
680
|
-
|
|
681
|
-
|
|
855
|
+
const ndt_t *t = spec->types[nin+i];
|
|
856
|
+
const ndt_t *u = broadcast(t, shape, outer_dims, inner_dims, true, ctx);
|
|
682
857
|
if (u == NULL) {
|
|
683
858
|
return -1;
|
|
684
859
|
}
|
|
685
|
-
|
|
686
|
-
|
|
860
|
+
|
|
861
|
+
if (check_broadcast) {
|
|
862
|
+
if (!ndt_equal(t, u)) {
|
|
863
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
864
|
+
"explicit 'out' type not compatible with input types");
|
|
865
|
+
ndt_decref(u);
|
|
866
|
+
return -1;
|
|
867
|
+
}
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
ndt_decref(t);
|
|
871
|
+
spec->types[nin+i] = u;
|
|
687
872
|
}
|
|
688
873
|
|
|
689
874
|
spec->outer_dims = outer_dims;
|
|
@@ -692,8 +877,7 @@ ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
692
877
|
}
|
|
693
878
|
|
|
694
879
|
static int
|
|
695
|
-
broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
696
|
-
const ndt_t *in[], const int nin,
|
|
880
|
+
broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig, bool check_broadcast,
|
|
697
881
|
const symtable_t *tbl, ndt_context_t *ctx)
|
|
698
882
|
{
|
|
699
883
|
symtable_entry_t v;
|
|
@@ -705,7 +889,7 @@ broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
705
889
|
return -1;
|
|
706
890
|
}
|
|
707
891
|
|
|
708
|
-
return ndt_broadcast_all(spec, sig,
|
|
892
|
+
return ndt_broadcast_all(spec, sig, check_broadcast,
|
|
709
893
|
v.BroadcastSeq.dims, v.BroadcastSeq.size,
|
|
710
894
|
ctx);
|
|
711
895
|
}
|
|
@@ -746,20 +930,23 @@ resolve_constraint(const ndt_constraint_t *c, const void *args, symtable_t *tbl,
|
|
|
746
930
|
*/
|
|
747
931
|
int
|
|
748
932
|
ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
749
|
-
const ndt_t *
|
|
933
|
+
const ndt_t *types[], const int64_t li[],
|
|
934
|
+
const int nin, const int nout, bool check_broadcast,
|
|
750
935
|
const ndt_constraint_t *c, const void *args,
|
|
751
936
|
ndt_context_t *ctx)
|
|
752
937
|
{
|
|
753
938
|
symtable_t *tbl;
|
|
754
|
-
ndt_t *t;
|
|
939
|
+
const ndt_t *t;
|
|
755
940
|
const char *name;
|
|
756
|
-
int
|
|
941
|
+
const int nargs = nin + nout;
|
|
757
942
|
int64_t i;
|
|
943
|
+
int ret;
|
|
758
944
|
|
|
759
945
|
assert(spec->flags == 0);
|
|
760
|
-
assert(spec->nout == 0);
|
|
761
|
-
assert(spec->nbroadcast == 0);
|
|
762
946
|
assert(spec->outer_dims == 0);
|
|
947
|
+
assert(spec->nin == 0);
|
|
948
|
+
assert(spec->nout == 0);
|
|
949
|
+
assert(spec->nargs == 0);
|
|
763
950
|
|
|
764
951
|
if (sig->tag != Function) {
|
|
765
952
|
ndt_err_format(ctx, NDT_ValueError,
|
|
@@ -773,8 +960,30 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
773
960
|
return -1;
|
|
774
961
|
}
|
|
775
962
|
|
|
776
|
-
|
|
777
|
-
|
|
963
|
+
/*
|
|
964
|
+
* Two configurations are allowed:
|
|
965
|
+
* 1) nout==0 and all 'out' types are inferred.
|
|
966
|
+
* 2) nout==sig->Function.nout and all 'out' types are given.
|
|
967
|
+
*/
|
|
968
|
+
if (nout && nout != sig->Function.nout) {
|
|
969
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
970
|
+
"expected %" PRIi64 " 'out' arguments, got %d", sig->Function.nout, nout);
|
|
971
|
+
return -1;
|
|
972
|
+
}
|
|
973
|
+
|
|
974
|
+
/*
|
|
975
|
+
* Broadcast configurations:
|
|
976
|
+
* nout==0 -> check_broadcast==false.
|
|
977
|
+
*/
|
|
978
|
+
if (!nout && check_broadcast == true) {
|
|
979
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
980
|
+
"internal error: without explicit 'out' arguments "
|
|
981
|
+
"'check_broadcast' must be false");
|
|
982
|
+
return -1;
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
for (i = 0; i < nargs; i++) {
|
|
986
|
+
if (ndt_is_abstract(types[i])) {
|
|
778
987
|
ndt_err_format(ctx, NDT_ValueError,
|
|
779
988
|
"type checking requires concrete argument types");
|
|
780
989
|
return -1;
|
|
@@ -786,8 +995,8 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
786
995
|
return -1;
|
|
787
996
|
}
|
|
788
997
|
|
|
789
|
-
for (i = 0; i <
|
|
790
|
-
ret =
|
|
998
|
+
for (i = 0; i < nargs; i++) {
|
|
999
|
+
ret = match_datashape_top(sig->Function.types[i], types[i], li[i], tbl, ctx);
|
|
791
1000
|
if (ret <= 0) {
|
|
792
1001
|
symtable_del(tbl);
|
|
793
1002
|
|
|
@@ -805,15 +1014,34 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
805
1014
|
return -1;
|
|
806
1015
|
}
|
|
807
1016
|
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
1017
|
+
spec->nin = 0;
|
|
1018
|
+
for (i = 0; i < nin; i++) {
|
|
1019
|
+
ndt_incref(types[i]);
|
|
1020
|
+
spec->types[i] = types[i];
|
|
1021
|
+
spec->nin++;
|
|
1022
|
+
}
|
|
1023
|
+
|
|
1024
|
+
spec->nout = 0;
|
|
1025
|
+
if (nout == 0) {
|
|
1026
|
+
/* Infer the return types. */
|
|
1027
|
+
for (i = 0; i < sig->Function.nout; i++) {
|
|
1028
|
+
spec->types[nin+i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
|
|
1029
|
+
if (spec->types[nin+i] == NULL) {
|
|
1030
|
+
ndt_apply_spec_clear(spec);
|
|
1031
|
+
symtable_del(tbl);
|
|
1032
|
+
return -1;
|
|
1033
|
+
}
|
|
1034
|
+
spec->nout++;
|
|
1035
|
+
}
|
|
1036
|
+
}
|
|
1037
|
+
else {
|
|
1038
|
+
for (i = 0; i < nout; i++) {
|
|
1039
|
+
ndt_incref(types[nin+i]);
|
|
1040
|
+
spec->types[nin+i] = types[nin+i];
|
|
1041
|
+
spec->nout++;
|
|
814
1042
|
}
|
|
815
|
-
spec->nout++;
|
|
816
1043
|
}
|
|
1044
|
+
spec->nargs = spec->nin + spec->nout;
|
|
817
1045
|
|
|
818
1046
|
if (sig->flags & NDT_ELLIPSIS) {
|
|
819
1047
|
if (sig->Function.nargs == 0 || sig->Function.types[0]->tag != EllipsisDim) {
|
|
@@ -833,8 +1061,17 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
833
1061
|
case FixedSeq:
|
|
834
1062
|
spec->outer_dims = v.FixedSeq.size;
|
|
835
1063
|
break;
|
|
836
|
-
case VarSeq:
|
|
837
|
-
|
|
1064
|
+
case VarSeq: {
|
|
1065
|
+
if (v.VarSeq.size > 0) {
|
|
1066
|
+
spec->outer_dims = ndt_logical_ndim(v.VarSeq.dims[0]);
|
|
1067
|
+
}
|
|
1068
|
+
else {
|
|
1069
|
+
spec->outer_dims = 0;
|
|
1070
|
+
}
|
|
1071
|
+
break;
|
|
1072
|
+
}
|
|
1073
|
+
case ArraySeq:
|
|
1074
|
+
spec->outer_dims = v.ArraySeq.size;
|
|
838
1075
|
break;
|
|
839
1076
|
default:
|
|
840
1077
|
ndt_err_format(ctx, NDT_RuntimeError,
|
|
@@ -845,7 +1082,7 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
845
1082
|
}
|
|
846
1083
|
}
|
|
847
1084
|
else {
|
|
848
|
-
if (broadcast_all(spec, sig,
|
|
1085
|
+
if (broadcast_all(spec, sig, check_broadcast, tbl, ctx) < 0) {
|
|
849
1086
|
ndt_apply_spec_clear(spec);
|
|
850
1087
|
symtable_del(tbl);
|
|
851
1088
|
return -1;
|
|
@@ -855,31 +1092,33 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
855
1092
|
|
|
856
1093
|
symtable_del(tbl);
|
|
857
1094
|
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
1095
|
+
if (nout == 0) {
|
|
1096
|
+
for (i = 0; i < sig->Function.nout; i++) {
|
|
1097
|
+
const ndt_t *_p = sig->Function.types[nin+i];
|
|
1098
|
+
const ndt_t *_c = spec->types[nin+i];
|
|
1099
|
+
const ndt_t *_t = to_fortran(_p, _c, ctx);
|
|
1100
|
+
|
|
1101
|
+
if (_t == NULL) {
|
|
1102
|
+
ndt_apply_spec_clear(spec);
|
|
1103
|
+
return -1;
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
ndt_decref(_c);
|
|
1107
|
+
spec->types[nin+i] = _t;
|
|
868
1108
|
}
|
|
869
|
-
spec->out[i] = _t;
|
|
870
1109
|
}
|
|
871
1110
|
|
|
872
|
-
if (!check_contig(sig->Function.types,
|
|
1111
|
+
if (!check_contig(sig->Function.types, types, nargs)) {
|
|
873
1112
|
ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
|
|
1113
|
+
ndt_apply_spec_clear(spec);
|
|
874
1114
|
return -1;
|
|
875
1115
|
}
|
|
876
|
-
|
|
877
|
-
|
|
1116
|
+
|
|
1117
|
+
if (ndt_select_kernel_strategy(spec, ctx) < 0) {
|
|
1118
|
+
ndt_apply_spec_clear(spec);
|
|
878
1119
|
return -1;
|
|
879
1120
|
}
|
|
880
1121
|
|
|
881
|
-
ndt_select_kernel_strategy(spec, sig, in, nin);
|
|
882
|
-
|
|
883
1122
|
return 0;
|
|
884
1123
|
}
|
|
885
1124
|
|
|
@@ -888,11 +1127,11 @@ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
888
1127
|
/* Optimized binary typecheck for fixed input */
|
|
889
1128
|
/*****************************************************************************/
|
|
890
1129
|
|
|
891
|
-
static ndt_t *
|
|
892
|
-
|
|
893
|
-
|
|
1130
|
+
static const ndt_t *
|
|
1131
|
+
fast_broadcast(const ndt_ndarray_t *t, const ndt_t *dtype,
|
|
1132
|
+
const int64_t *shape, int size, ndt_context_t *ctx)
|
|
894
1133
|
{
|
|
895
|
-
ndt_t *v;
|
|
1134
|
+
const ndt_t *v;
|
|
896
1135
|
int64_t step;
|
|
897
1136
|
int i, k;
|
|
898
1137
|
|
|
@@ -903,14 +1142,18 @@ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
|
|
|
903
1142
|
|
|
904
1143
|
for (i=t->ndim-1, k=size-1; i>=0 && k>=0; i--, k--) {
|
|
905
1144
|
step = t->shape[i]<=1 ? 0 : t->steps[i];
|
|
906
|
-
|
|
1145
|
+
const ndt_t *w = ndt_fixed_dim(v, shape[k], step, ctx);
|
|
1146
|
+
ndt_decref(v);
|
|
1147
|
+
v = w;
|
|
907
1148
|
if (v == NULL) {
|
|
908
1149
|
return NULL;
|
|
909
1150
|
}
|
|
910
1151
|
}
|
|
911
1152
|
|
|
912
1153
|
for (; k>=0; k--) {
|
|
913
|
-
|
|
1154
|
+
const ndt_t *w = ndt_fixed_dim(v, shape[k], 0, ctx);
|
|
1155
|
+
ndt_decref(v);
|
|
1156
|
+
v = w;
|
|
914
1157
|
if (v == NULL) {
|
|
915
1158
|
return NULL;
|
|
916
1159
|
}
|
|
@@ -919,15 +1162,19 @@ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
|
|
|
919
1162
|
return v;
|
|
920
1163
|
}
|
|
921
1164
|
|
|
922
|
-
static ndt_t *
|
|
923
|
-
fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
|
|
1165
|
+
static const ndt_t *
|
|
1166
|
+
fixed_dim_from_shape(const int64_t shape[], int len, const ndt_t *dtype,
|
|
924
1167
|
ndt_context_t *ctx)
|
|
925
1168
|
{
|
|
926
|
-
ndt_t *t;
|
|
1169
|
+
const ndt_t *t;
|
|
927
1170
|
int i;
|
|
928
1171
|
|
|
1172
|
+
ndt_incref(dtype);
|
|
1173
|
+
|
|
929
1174
|
for (i=len-1, t=dtype; i >= 0; i--) {
|
|
930
|
-
|
|
1175
|
+
const ndt_t *u = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
|
|
1176
|
+
ndt_decref(t);
|
|
1177
|
+
t = u;
|
|
931
1178
|
if (t == NULL) {
|
|
932
1179
|
return NULL;
|
|
933
1180
|
}
|
|
@@ -936,83 +1183,313 @@ fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
|
|
|
936
1183
|
return t;
|
|
937
1184
|
}
|
|
938
1185
|
|
|
939
|
-
static
|
|
940
|
-
|
|
1186
|
+
static int
|
|
1187
|
+
broadcast_error(const char *msg, ndt_context_t *ctx)
|
|
941
1188
|
{
|
|
942
|
-
|
|
943
|
-
|
|
1189
|
+
ndt_err_format(ctx, NDT_TypeError, "%s", msg);
|
|
1190
|
+
return -1;
|
|
1191
|
+
}
|
|
1192
|
+
|
|
1193
|
+
static int
|
|
1194
|
+
_ndt_unary_broadcast(ndt_apply_spec_t *spec,
|
|
1195
|
+
const ndt_ndarray_t *x, const ndt_t *dtype_x,
|
|
1196
|
+
const ndt_t *out, const ndt_t *dtype_out,
|
|
1197
|
+
const bool check_broadcast, const int inner,
|
|
1198
|
+
ndt_context_t *ctx)
|
|
1199
|
+
{
|
|
1200
|
+
int64_t shape[NDT_MAX_DIM];
|
|
1201
|
+
int size = x->ndim;
|
|
1202
|
+
ndt_ndarray_t y;
|
|
1203
|
+
const ndt_t *t;
|
|
1204
|
+
|
|
1205
|
+
for (int i = 0; i < size; i++) {
|
|
1206
|
+
shape[i] = x->shape[i];
|
|
944
1207
|
}
|
|
945
1208
|
|
|
946
|
-
|
|
947
|
-
if (
|
|
948
|
-
return
|
|
1209
|
+
if (out != NULL) {
|
|
1210
|
+
if (ndt_as_ndarray(&y, out, ctx) < 0) {
|
|
1211
|
+
return -1;
|
|
949
1212
|
}
|
|
1213
|
+
|
|
1214
|
+
size = _resolve_broadcast(shape, size, y.shape, y.ndim);
|
|
1215
|
+
if (size < 0) {
|
|
1216
|
+
return broadcast_error("could not broadcast output argument", ctx);
|
|
1217
|
+
}
|
|
1218
|
+
}
|
|
1219
|
+
|
|
1220
|
+
spec->types[0] = fast_broadcast(x, dtype_x, shape, size, ctx);
|
|
1221
|
+
if (spec->types[0] == NULL) {
|
|
1222
|
+
return -1;
|
|
1223
|
+
}
|
|
1224
|
+
|
|
1225
|
+
if (out != NULL) {
|
|
1226
|
+
if (check_broadcast) {
|
|
1227
|
+
t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
|
|
1228
|
+
if (t == NULL) {
|
|
1229
|
+
ndt_decref(spec->types[0]);
|
|
1230
|
+
return -1;
|
|
1231
|
+
}
|
|
1232
|
+
|
|
1233
|
+
if (!ndt_equal(t, out)) {
|
|
1234
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
1235
|
+
"explicit 'out' type not compatible with input types");
|
|
1236
|
+
ndt_decref(spec->types[0]);
|
|
1237
|
+
ndt_decref(t);
|
|
1238
|
+
return -1;
|
|
1239
|
+
}
|
|
1240
|
+
}
|
|
1241
|
+
else {
|
|
1242
|
+
t = fast_broadcast(&y, dtype_out, shape, size, ctx);
|
|
1243
|
+
if (t == NULL) {
|
|
1244
|
+
ndt_decref(spec->types[0]);
|
|
1245
|
+
return -1;
|
|
1246
|
+
}
|
|
1247
|
+
}
|
|
1248
|
+
}
|
|
1249
|
+
else {
|
|
1250
|
+
t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
|
|
1251
|
+
if (t == NULL) {
|
|
1252
|
+
ndt_decref(spec->types[0]);
|
|
1253
|
+
return -1;
|
|
1254
|
+
}
|
|
1255
|
+
}
|
|
1256
|
+
spec->types[1] = t;
|
|
1257
|
+
|
|
1258
|
+
spec->outer_dims = size-inner;
|
|
1259
|
+
spec->nin = 1;
|
|
1260
|
+
spec->nout = 1;
|
|
1261
|
+
spec->nargs = 2;
|
|
1262
|
+
|
|
1263
|
+
if (ndt_select_kernel_strategy(spec, ctx) < 0) {
|
|
1264
|
+
ndt_apply_spec_clear(spec);
|
|
1265
|
+
return -1;
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
return 0;
|
|
1269
|
+
}
|
|
1270
|
+
|
|
1271
|
+
static bool
|
|
1272
|
+
unary_all_ellipses(const ndt_t *t0, const ndt_t *t1, ndt_context_t *ctx)
|
|
1273
|
+
{
|
|
1274
|
+
if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
|
|
1275
|
+
(t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL)) {
|
|
1276
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1277
|
+
"fast unary typecheck expects leading ellipsis dimensions");
|
|
1278
|
+
return false;
|
|
950
1279
|
}
|
|
951
1280
|
|
|
952
1281
|
return true;
|
|
953
1282
|
}
|
|
954
1283
|
|
|
955
|
-
static
|
|
956
|
-
|
|
957
|
-
const ndt_ndarray_t *x, const ndt_ndarray_t *y,
|
|
958
|
-
const ndt_t *in[], const int nin, ndt_t *dtype,
|
|
959
|
-
int inner, ndt_context_t *ctx)
|
|
1284
|
+
static bool
|
|
1285
|
+
unary_all_same_symbol(const ndt_t *t0, const ndt_t *t1)
|
|
960
1286
|
{
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
if (shape_equal(x, y)) {
|
|
965
|
-
spec->nout = 1;
|
|
966
|
-
spec->nbroadcast = 0;
|
|
967
|
-
spec->outer_dims = x->ndim-inner;
|
|
968
|
-
spec->out[0] = fixed_dim_from_shape(x->shape, x->ndim, dtype, ctx);
|
|
969
|
-
if (spec->out[0] == NULL) {
|
|
970
|
-
return -1;
|
|
971
|
-
}
|
|
1287
|
+
if (t0->tag != SymbolicDim || t1->tag != SymbolicDim) {
|
|
1288
|
+
return false;
|
|
972
1289
|
}
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
1290
|
+
|
|
1291
|
+
return strcmp(t0->SymbolicDim.name, t1->SymbolicDim.name) == 0;
|
|
1292
|
+
}
|
|
1293
|
+
|
|
1294
|
+
static bool
|
|
1295
|
+
unary_all_ndim0(const ndt_t *t0, const ndt_t *t1)
|
|
1296
|
+
{
|
|
1297
|
+
return t0->ndim == 0 && t1->ndim == 0;
|
|
1298
|
+
}
|
|
1299
|
+
|
|
1300
|
+
/*
|
|
1301
|
+
* Optimized type checking for very specific signatures. The caller must
|
|
1302
|
+
* have identified the kernel location and the signature. For performance
|
|
1303
|
+
* reasons, no substitution is performed on the dtype, so the dtype must be
|
|
1304
|
+
* concrete.
|
|
1305
|
+
*
|
|
1306
|
+
* Supported signature: 1) ... * T0 -> ... * T1
|
|
1307
|
+
*/
|
|
1308
|
+
int
|
|
1309
|
+
ndt_fast_unary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
1310
|
+
const ndt_t *types[], const int nin, const int nout,
|
|
1311
|
+
const bool check_broadcast, ndt_context_t *ctx)
|
|
1312
|
+
{
|
|
1313
|
+
const ndt_t *p0, *p1;
|
|
1314
|
+
ndt_ndarray_t x;
|
|
1315
|
+
const ndt_t *out = NULL;
|
|
1316
|
+
const ndt_t *dtype = NULL;
|
|
1317
|
+
|
|
1318
|
+
assert(spec->flags == 0);
|
|
1319
|
+
assert(spec->outer_dims == 0);
|
|
1320
|
+
assert(spec->nin == 0);
|
|
1321
|
+
assert(spec->nout == 0);
|
|
1322
|
+
assert(spec->nargs == 0);
|
|
1323
|
+
|
|
1324
|
+
if (sig->tag != Function ||
|
|
1325
|
+
sig->Function.nin != 1) {
|
|
1326
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1327
|
+
"fast unary typecheck expects a signature with one input and "
|
|
1328
|
+
"one output");
|
|
1329
|
+
return -1;
|
|
1330
|
+
}
|
|
1331
|
+
|
|
1332
|
+
if (nin != 1) {
|
|
1333
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1334
|
+
"fast unary typecheck expects one input argument");
|
|
1335
|
+
return -1;
|
|
1336
|
+
}
|
|
1337
|
+
|
|
1338
|
+
p0 = sig->Function.types[0];
|
|
1339
|
+
p1 = sig->Function.types[1];
|
|
1340
|
+
|
|
1341
|
+
if (nout) {
|
|
1342
|
+
if (nout != 1) {
|
|
1343
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1344
|
+
"fast unary typecheck expects at most one explicit out argument");
|
|
1345
|
+
return -1;
|
|
976
1346
|
}
|
|
1347
|
+
out = types[1];
|
|
1348
|
+
dtype = ndt_dtype(types[1]);
|
|
977
1349
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
ndt_del(dtype);
|
|
1350
|
+
if (!ndt_equal(dtype, ndt_dtype(p1))) {
|
|
1351
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
1352
|
+
"dtype of the out argument does not match");
|
|
982
1353
|
return -1;
|
|
983
1354
|
}
|
|
1355
|
+
}
|
|
1356
|
+
else {
|
|
1357
|
+
dtype = ndt_dtype(sig->Function.types[1]);
|
|
1358
|
+
}
|
|
1359
|
+
|
|
1360
|
+
if (ndt_is_abstract(dtype)) {
|
|
1361
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1362
|
+
"fast unary typecheck expects a concrete dtype");
|
|
1363
|
+
return -1;
|
|
1364
|
+
}
|
|
1365
|
+
|
|
1366
|
+
if (!unary_all_ellipses(p0, p1, ctx)) {
|
|
1367
|
+
return -1;
|
|
1368
|
+
}
|
|
1369
|
+
|
|
1370
|
+
if (ndt_as_ndarray(&x, types[0], ctx) < 0) {
|
|
1371
|
+
return -1;
|
|
1372
|
+
}
|
|
1373
|
+
|
|
1374
|
+
p0 = p0->EllipsisDim.type;
|
|
1375
|
+
p1 = p1->EllipsisDim.type;
|
|
1376
|
+
|
|
1377
|
+
if (unary_all_same_symbol(p0, p1)) {
|
|
1378
|
+
return _ndt_unary_broadcast(spec,
|
|
1379
|
+
&x, ndt_dtype(types[0]),
|
|
1380
|
+
out, dtype, check_broadcast,
|
|
1381
|
+
1, ctx);
|
|
1382
|
+
}
|
|
1383
|
+
else if (unary_all_ndim0(p0, p1)) {
|
|
1384
|
+
return _ndt_unary_broadcast(spec,
|
|
1385
|
+
&x, ndt_dtype(types[0]),
|
|
1386
|
+
out, dtype, check_broadcast,
|
|
1387
|
+
0, ctx);
|
|
1388
|
+
}
|
|
1389
|
+
else {
|
|
1390
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1391
|
+
"unsupported signature in fast unary typecheck");
|
|
1392
|
+
return -1;
|
|
1393
|
+
}
|
|
1394
|
+
}
|
|
1395
|
+
|
|
1396
|
+
static int
|
|
1397
|
+
_ndt_binary_broadcast(ndt_apply_spec_t *spec,
|
|
1398
|
+
const ndt_ndarray_t *x, const ndt_t *dtype_x,
|
|
1399
|
+
const ndt_ndarray_t *y, const ndt_t *dtype_y,
|
|
1400
|
+
const ndt_t *out, const ndt_t *dtype_out,
|
|
1401
|
+
const bool check_broadcast, const int inner,
|
|
1402
|
+
ndt_context_t *ctx)
|
|
1403
|
+
{
|
|
1404
|
+
int64_t shape[NDT_MAX_DIM];
|
|
1405
|
+
int size = x->ndim;
|
|
1406
|
+
ndt_ndarray_t z;
|
|
1407
|
+
const ndt_t *t;
|
|
984
1408
|
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
1409
|
+
for (int i = 0; i < size; i++) {
|
|
1410
|
+
shape[i] = x->shape[i];
|
|
1411
|
+
}
|
|
1412
|
+
|
|
1413
|
+
size = _resolve_broadcast(shape, size, y->shape, y->ndim);
|
|
1414
|
+
if (size < 0) {
|
|
1415
|
+
return broadcast_error("could not broadcast input arguments", ctx);
|
|
1416
|
+
}
|
|
988
1417
|
|
|
989
|
-
|
|
990
|
-
if (
|
|
1418
|
+
if (out != NULL) {
|
|
1419
|
+
if (ndt_as_ndarray(&z, out, ctx) < 0) {
|
|
991
1420
|
return -1;
|
|
992
1421
|
}
|
|
993
1422
|
|
|
994
|
-
|
|
995
|
-
if (
|
|
996
|
-
|
|
997
|
-
return -1;
|
|
1423
|
+
size = _resolve_broadcast(shape, size, z.shape, z.ndim);
|
|
1424
|
+
if (size < 0) {
|
|
1425
|
+
return broadcast_error("could not broadcast output argument", ctx);
|
|
998
1426
|
}
|
|
1427
|
+
}
|
|
1428
|
+
|
|
1429
|
+
spec->types[0] = fast_broadcast(x, dtype_x, shape, size, ctx);
|
|
1430
|
+
if (spec->types[0] == NULL) {
|
|
1431
|
+
return -1;
|
|
1432
|
+
}
|
|
1433
|
+
|
|
1434
|
+
spec->types[1] = fast_broadcast(y, dtype_y, shape, size, ctx);
|
|
1435
|
+
if (spec->types[1] == NULL) {
|
|
1436
|
+
ndt_decref(spec->types[0]);
|
|
1437
|
+
return -1;
|
|
1438
|
+
}
|
|
1439
|
+
|
|
1440
|
+
if (out != NULL) {
|
|
1441
|
+
if (check_broadcast) {
|
|
1442
|
+
t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
|
|
1443
|
+
if (t == NULL) {
|
|
1444
|
+
ndt_decref(spec->types[0]);
|
|
1445
|
+
ndt_decref(spec->types[1]);
|
|
1446
|
+
return -1;
|
|
1447
|
+
}
|
|
999
1448
|
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1449
|
+
if (!ndt_equal(t, out)) {
|
|
1450
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
1451
|
+
"explicit 'out' type not compatible with input types");
|
|
1452
|
+
ndt_decref(spec->types[0]);
|
|
1453
|
+
ndt_decref(spec->types[1]);
|
|
1454
|
+
ndt_decref(t);
|
|
1455
|
+
return -1;
|
|
1456
|
+
}
|
|
1457
|
+
}
|
|
1458
|
+
else {
|
|
1459
|
+
t = fast_broadcast(&z, dtype_out, shape, size, ctx);
|
|
1460
|
+
if (t == NULL) {
|
|
1461
|
+
ndt_decref(spec->types[0]);
|
|
1462
|
+
ndt_decref(spec->types[1]);
|
|
1463
|
+
return -1;
|
|
1464
|
+
}
|
|
1465
|
+
}
|
|
1466
|
+
}
|
|
1467
|
+
else {
|
|
1468
|
+
t = fixed_dim_from_shape(shape, size, dtype_out, ctx);
|
|
1469
|
+
if (t == NULL) {
|
|
1470
|
+
ndt_decref(spec->types[0]);
|
|
1471
|
+
ndt_decref(spec->types[1]);
|
|
1004
1472
|
return -1;
|
|
1005
1473
|
}
|
|
1006
1474
|
}
|
|
1475
|
+
spec->types[2] = t;
|
|
1007
1476
|
|
|
1008
|
-
|
|
1477
|
+
spec->outer_dims = size-inner;
|
|
1478
|
+
spec->nin = 2;
|
|
1479
|
+
spec->nout = 1;
|
|
1480
|
+
spec->nargs = 3;
|
|
1481
|
+
|
|
1482
|
+
if (ndt_select_kernel_strategy(spec, ctx) < 0) {
|
|
1483
|
+
ndt_apply_spec_clear(spec);
|
|
1484
|
+
return -1;
|
|
1485
|
+
}
|
|
1009
1486
|
|
|
1010
1487
|
return 0;
|
|
1011
1488
|
}
|
|
1012
1489
|
|
|
1013
1490
|
static bool
|
|
1014
|
-
|
|
1015
|
-
|
|
1491
|
+
binary_all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
|
|
1492
|
+
ndt_context_t *ctx)
|
|
1016
1493
|
{
|
|
1017
1494
|
if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
|
|
1018
1495
|
(t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL) ||
|
|
@@ -1026,7 +1503,7 @@ all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
|
|
|
1026
1503
|
}
|
|
1027
1504
|
|
|
1028
1505
|
static bool
|
|
1029
|
-
|
|
1506
|
+
binary_all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
|
|
1030
1507
|
{
|
|
1031
1508
|
if (t0->tag != SymbolicDim || t1->tag != SymbolicDim ||
|
|
1032
1509
|
t2->tag != SymbolicDim) {
|
|
@@ -1038,37 +1515,37 @@ all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
|
|
|
1038
1515
|
}
|
|
1039
1516
|
|
|
1040
1517
|
static bool
|
|
1041
|
-
|
|
1518
|
+
binary_all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
|
|
1042
1519
|
{
|
|
1043
1520
|
return t0->ndim == 0 && t1->ndim == 0 && t2->ndim == 0;
|
|
1044
1521
|
}
|
|
1045
1522
|
|
|
1046
1523
|
/*
|
|
1047
1524
|
* Optimized type checking for very specific signatures. The caller must
|
|
1048
|
-
* have identified the kernel location
|
|
1049
|
-
*
|
|
1050
|
-
*
|
|
1525
|
+
* have identified the kernel location and the signature. For performance
|
|
1526
|
+
* reasons, no substitution is performed on the dtype, so the dtype must be
|
|
1527
|
+
* concrete.
|
|
1051
1528
|
*
|
|
1052
|
-
* Supported
|
|
1053
|
-
* 1) ... * N * T0, ... * N * T1 -> N * T2
|
|
1054
|
-
* 2) ... * T0, ... * T1 -> ... * T2
|
|
1529
|
+
* Supported signature: 1) ... * T0, ... * T1 -> ... * T2
|
|
1055
1530
|
*/
|
|
1056
1531
|
int
|
|
1057
1532
|
ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
1058
|
-
const ndt_t *
|
|
1059
|
-
ndt_context_t *ctx)
|
|
1533
|
+
const ndt_t *types[], const int nin, const int nout,
|
|
1534
|
+
const bool check_broadcast, ndt_context_t *ctx)
|
|
1060
1535
|
{
|
|
1061
|
-
ndt_t *p0, *p1, *p2;
|
|
1536
|
+
const ndt_t *p0, *p1, *p2;
|
|
1062
1537
|
ndt_ndarray_t x, y;
|
|
1538
|
+
const ndt_t *out = NULL;
|
|
1539
|
+
const ndt_t *dtype = NULL;
|
|
1063
1540
|
|
|
1064
1541
|
assert(spec->flags == 0);
|
|
1065
|
-
assert(spec->nout == 0);
|
|
1066
|
-
assert(spec->nbroadcast == 0);
|
|
1067
1542
|
assert(spec->outer_dims == 0);
|
|
1543
|
+
assert(spec->nin == 0);
|
|
1544
|
+
assert(spec->nout == 0);
|
|
1545
|
+
assert(spec->nargs == 0);
|
|
1068
1546
|
|
|
1069
1547
|
if (sig->tag != Function ||
|
|
1070
|
-
sig->Function.nin != 2
|
|
1071
|
-
sig->Function.nout != 1) {
|
|
1548
|
+
sig->Function.nin != 2) {
|
|
1072
1549
|
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1073
1550
|
"fast binary typecheck expects a signature with two inputs and "
|
|
1074
1551
|
"one output");
|
|
@@ -1081,27 +1558,44 @@ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
1081
1558
|
return -1;
|
|
1082
1559
|
}
|
|
1083
1560
|
|
|
1561
|
+
p0 = sig->Function.types[0];
|
|
1562
|
+
p1 = sig->Function.types[1];
|
|
1563
|
+
p2 = sig->Function.types[2];
|
|
1564
|
+
|
|
1565
|
+
if (nout) {
|
|
1566
|
+
if (nout != 1) {
|
|
1567
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1568
|
+
"fast binary typecheck expects at most one explicit out argument");
|
|
1569
|
+
return -1;
|
|
1570
|
+
}
|
|
1571
|
+
out = types[2];
|
|
1572
|
+
dtype = ndt_dtype(types[2]);
|
|
1573
|
+
|
|
1574
|
+
if (!ndt_equal(dtype, ndt_dtype(p2))) {
|
|
1575
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
1576
|
+
"dtype of the out argument does not match");
|
|
1577
|
+
return -1;
|
|
1578
|
+
}
|
|
1579
|
+
}
|
|
1580
|
+
else {
|
|
1581
|
+
dtype = ndt_dtype(sig->Function.types[2]);
|
|
1582
|
+
}
|
|
1583
|
+
|
|
1084
1584
|
if (ndt_is_abstract(dtype)) {
|
|
1085
1585
|
ndt_err_format(ctx, NDT_RuntimeError,
|
|
1086
1586
|
"fast binary typecheck expects a concrete dtype");
|
|
1087
1587
|
return -1;
|
|
1088
1588
|
}
|
|
1089
1589
|
|
|
1090
|
-
p0
|
|
1091
|
-
p1 = sig->Function.types[1];
|
|
1092
|
-
p2 = sig->Function.types[2];
|
|
1093
|
-
|
|
1094
|
-
if (!all_ellipses(p0, p1, p2, ctx)) {
|
|
1590
|
+
if (!binary_all_ellipses(p0, p1, p2, ctx)) {
|
|
1095
1591
|
return -1;
|
|
1096
1592
|
}
|
|
1097
1593
|
|
|
1098
|
-
if (ndt_as_ndarray(&x,
|
|
1099
|
-
ndt_del(dtype);
|
|
1594
|
+
if (ndt_as_ndarray(&x, types[0], ctx) < 0) {
|
|
1100
1595
|
return -1;
|
|
1101
1596
|
}
|
|
1102
1597
|
|
|
1103
|
-
if (ndt_as_ndarray(&y,
|
|
1104
|
-
ndt_del(dtype);
|
|
1598
|
+
if (ndt_as_ndarray(&y, types[1], ctx) < 0) {
|
|
1105
1599
|
return -1;
|
|
1106
1600
|
}
|
|
1107
1601
|
|
|
@@ -1109,20 +1603,27 @@ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
|
|
1109
1603
|
p1 = p1->EllipsisDim.type;
|
|
1110
1604
|
p2 = p2->EllipsisDim.type;
|
|
1111
1605
|
|
|
1112
|
-
if (
|
|
1606
|
+
if (binary_all_same_symbol(p0, p1, p2)) {
|
|
1113
1607
|
if (x.ndim > 0 && y.ndim > 0) {
|
|
1114
1608
|
const int64_t xshape = x.shape[x.ndim-1];
|
|
1115
1609
|
const int64_t yshape = y.shape[y.ndim-1];
|
|
1116
1610
|
if (xshape != 1 && yshape != 1 && xshape != yshape) {
|
|
1117
1611
|
ndt_err_format(ctx, NDT_TypeError, "mismatch in inner dimensions");
|
|
1118
|
-
ndt_del(dtype);
|
|
1119
1612
|
return -1;
|
|
1120
1613
|
}
|
|
1121
1614
|
}
|
|
1122
|
-
return _ndt_binary_broadcast(spec,
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1615
|
+
return _ndt_binary_broadcast(spec,
|
|
1616
|
+
&x, ndt_dtype(types[0]),
|
|
1617
|
+
&y, ndt_dtype(types[1]),
|
|
1618
|
+
out, dtype, check_broadcast,
|
|
1619
|
+
1, ctx);
|
|
1620
|
+
}
|
|
1621
|
+
else if (binary_all_ndim0(p0, p1, p2)) {
|
|
1622
|
+
return _ndt_binary_broadcast(spec,
|
|
1623
|
+
&x, ndt_dtype(types[0]),
|
|
1624
|
+
&y, ndt_dtype(types[1]),
|
|
1625
|
+
out, dtype, check_broadcast,
|
|
1626
|
+
0, ctx);
|
|
1126
1627
|
}
|
|
1127
1628
|
else {
|
|
1128
1629
|
ndt_err_format(ctx, NDT_RuntimeError,
|