ndtypes 0.2.0dev5 → 0.2.0dev6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +12 -0
- data/Rakefile +8 -0
- data/ext/ruby_ndtypes/GPATH +0 -0
- data/ext/ruby_ndtypes/GRTAGS +0 -0
- data/ext/ruby_ndtypes/GTAGS +0 -0
- data/ext/ruby_ndtypes/extconf.rb +1 -1
- data/ext/ruby_ndtypes/include/ndtypes.h +231 -122
- data/ext/ruby_ndtypes/include/ruby_ndtypes.h +1 -1
- data/ext/ruby_ndtypes/lib/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/lib/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/Makefile +87 -0
- data/ext/ruby_ndtypes/ndtypes/config.h +68 -0
- data/ext/ruby_ndtypes/ndtypes/config.log +477 -0
- data/ext/ruby_ndtypes/ndtypes/config.status +1027 -0
- data/ext/ruby_ndtypes/ndtypes/doc/_static/style.css +7 -0
- data/ext/ruby_ndtypes/ndtypes/doc/_templates/layout.html +2 -0
- data/ext/ruby_ndtypes/ndtypes/doc/conf.py +40 -4
- data/ext/ruby_ndtypes/ndtypes/doc/images/xndlogo.png +0 -0
- data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +1 -1
- data/ext/ruby_ndtypes/ndtypes/doc/requirements.txt +2 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile +287 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +20 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +22 -3
- data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile +73 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +246 -229
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +15 -11
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +38 -28
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +91 -91
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +4 -3
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +8 -7
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/context.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +263 -182
- data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +67 -7
- data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +1112 -1000
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +69 -58
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +150 -99
- data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +185 -15
- data/ext/ruby_ndtypes/ndtypes/libndtypes/io.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +301 -276
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +9 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so +1 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0 +1 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +729 -228
- data/ext/ruby_ndtypes/ndtypes/libndtypes/match.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +768 -403
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h +1002 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +231 -122
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +176 -84
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +26 -14
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +57 -35
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.c +420 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/primitive.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +8 -8
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile +48 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +200 -116
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +46 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +58 -27
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +1 -1
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +3 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +12 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile +55 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +8 -8
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +5 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +274 -172
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +24 -4
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +14 -14
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +32 -30
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +37 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +36 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +16 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +5 -5
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +706 -253
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_unify.c +132 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.c +703 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/unify.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +335 -127
- data/ext/ruby_ndtypes/ndtypes/libndtypes/util.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +2 -2
- data/ext/ruby_ndtypes/ndtypes/libndtypes/values.o +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +88 -71
- data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +0 -1
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +10 -13
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +395 -314
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.a +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so +1 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0 +1 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/libndtypes.so.0.2.0dev3 +0 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/ndtypes.h +1002 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +15 -33
- data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +340 -132
- data/ext/ruby_ndtypes/ndtypes/setup.py +11 -2
- data/ext/ruby_ndtypes/ruby_ndtypes.c +364 -241
- data/ext/ruby_ndtypes/ruby_ndtypes.h +1 -1
- data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +0 -1
- data/lib/ndtypes.rb +11 -0
- data/lib/ndtypes/version.rb +2 -2
- data/lib/ruby_ndtypes.so +0 -0
- data/ndtypes.gemspec +3 -0
- data/spec/ndtypes_spec.rb +6 -0
- metadata +98 -4
- data/ext/ruby_ndtypes/gc_guard.c +0 -36
- data/ext/ruby_ndtypes/gc_guard.h +0 -12
|
Binary file
|
|
@@ -95,21 +95,46 @@ ndt_asprintf(ndt_context_t *ctx, const char *fmt, ...)
|
|
|
95
95
|
/* Type functions (unstable API) */
|
|
96
96
|
/*****************************************************************************/
|
|
97
97
|
|
|
98
|
+
/* Number of elements, currently only for ndarrays. Return -1 on error. */
|
|
99
|
+
int64_t
|
|
100
|
+
ndt_nelem(const ndt_t *t)
|
|
101
|
+
{
|
|
102
|
+
NDT_STATIC_CONTEXT(ctx);
|
|
103
|
+
ndt_ndarray_t x;
|
|
104
|
+
int64_t n;
|
|
105
|
+
|
|
106
|
+
if (ndt_as_ndarray(&x, t, &ctx) < 0) {
|
|
107
|
+
ndt_err_clear(&ctx);
|
|
108
|
+
return -1;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
n = 1;
|
|
112
|
+
for (int i = 0; i < x.ndim; i++) {
|
|
113
|
+
n *= x.shape[i];
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return n;
|
|
117
|
+
}
|
|
118
|
+
|
|
98
119
|
/* Return the next type in a dimension chain. Undefined for non-dimensions. */
|
|
99
120
|
static const ndt_t *
|
|
100
|
-
next_dim(const ndt_t *
|
|
121
|
+
next_dim(const ndt_t *t)
|
|
101
122
|
{
|
|
102
|
-
assert(
|
|
123
|
+
assert(t->ndim > 0);
|
|
103
124
|
|
|
104
|
-
switch (
|
|
125
|
+
switch (t->tag) {
|
|
105
126
|
case FixedDim:
|
|
106
|
-
return
|
|
127
|
+
return t->FixedDim.type;
|
|
107
128
|
case VarDim:
|
|
108
|
-
return
|
|
129
|
+
return t->VarDim.type;
|
|
130
|
+
case VarDimElem:
|
|
131
|
+
return t->VarDimElem.type;
|
|
109
132
|
case SymbolicDim:
|
|
110
|
-
return
|
|
133
|
+
return t->SymbolicDim.type;
|
|
111
134
|
case EllipsisDim:
|
|
112
|
-
return
|
|
135
|
+
return t->EllipsisDim.type;
|
|
136
|
+
case Array:
|
|
137
|
+
return t->Array.type;
|
|
113
138
|
default:
|
|
114
139
|
/* NOT REACHED: tags should be exhaustive. */
|
|
115
140
|
ndt_internal_error("invalid value");
|
|
@@ -126,16 +151,19 @@ ndt_dtype(const ndt_t *t)
|
|
|
126
151
|
return t;
|
|
127
152
|
}
|
|
128
153
|
|
|
129
|
-
|
|
130
|
-
|
|
154
|
+
int
|
|
155
|
+
ndt_logical_ndim(const ndt_t *t)
|
|
131
156
|
{
|
|
132
|
-
|
|
157
|
+
int ndim = t->ndim;
|
|
133
158
|
|
|
134
|
-
|
|
135
|
-
t
|
|
159
|
+
while (t->tag == VarDim || t->tag == VarDimElem) {
|
|
160
|
+
if (t->tag == VarDimElem) {
|
|
161
|
+
--ndim;
|
|
162
|
+
}
|
|
163
|
+
t = t->VarDim.type;
|
|
136
164
|
}
|
|
137
165
|
|
|
138
|
-
return
|
|
166
|
+
return ndim;
|
|
139
167
|
}
|
|
140
168
|
|
|
141
169
|
const ndt_t *
|
|
@@ -170,6 +198,75 @@ ndt_dims_dtype(const ndt_t *dims[NDT_MAX_DIM], const ndt_t **dtype, const ndt_t
|
|
|
170
198
|
return n;
|
|
171
199
|
}
|
|
172
200
|
|
|
201
|
+
static const ndt_t *
|
|
202
|
+
compress(const ndt_t *t)
|
|
203
|
+
{
|
|
204
|
+
while (t->tag == VarDimElem) {
|
|
205
|
+
t = t->VarDimElem.type;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
return t;
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
/*
|
|
212
|
+
* Return the next logical type in a dimension chain. Undefined for
|
|
213
|
+
* non-dimensions.
|
|
214
|
+
*/
|
|
215
|
+
static const ndt_t *
|
|
216
|
+
next_logical_dim(const ndt_t *t)
|
|
217
|
+
{
|
|
218
|
+
switch (t->tag) {
|
|
219
|
+
case FixedDim:
|
|
220
|
+
return t->FixedDim.type;
|
|
221
|
+
case VarDim:
|
|
222
|
+
return compress(t->VarDim.type);
|
|
223
|
+
case VarDimElem:
|
|
224
|
+
return next_logical_dim(compress(t));
|
|
225
|
+
case SymbolicDim:
|
|
226
|
+
return t->SymbolicDim.type;
|
|
227
|
+
case EllipsisDim:
|
|
228
|
+
return t->EllipsisDim.type;
|
|
229
|
+
default:
|
|
230
|
+
return t;
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
const ndt_t *
|
|
235
|
+
ndt_logical_dim_at(const ndt_t *t, int n)
|
|
236
|
+
{
|
|
237
|
+
for (int i = 0; i < n; i++) {
|
|
238
|
+
t = next_logical_dim(t);
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
return t;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
const ndt_t *
|
|
245
|
+
ndt_copy_contiguous_at(const ndt_t *t, int n, const ndt_t *dtype, ndt_context_t *ctx)
|
|
246
|
+
{
|
|
247
|
+
const ndt_t *u;
|
|
248
|
+
|
|
249
|
+
if (!ndt_is_ndarray(t)) {
|
|
250
|
+
ndt_err_format(ctx, NDT_NotImplementedError,
|
|
251
|
+
"partial copies are currently restricted to fixed dimensions");
|
|
252
|
+
return NULL;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
if (n < 0 || n > t->ndim) {
|
|
256
|
+
ndt_err_format(ctx, NDT_ValueError, "n out of bounds");
|
|
257
|
+
return NULL;
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
u = ndt_logical_dim_at(t, n);
|
|
261
|
+
|
|
262
|
+
if (dtype != NULL) {
|
|
263
|
+
return ndt_copy_contiguous_dtype(u, dtype, 0, ctx);
|
|
264
|
+
}
|
|
265
|
+
else {
|
|
266
|
+
return ndt_copy_contiguous(u, 0, ctx);
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
173
270
|
int
|
|
174
271
|
ndt_as_ndarray(ndt_ndarray_t *a, const ndt_t *t, ndt_context_t *ctx)
|
|
175
272
|
{
|
|
@@ -205,9 +302,89 @@ ndt_as_ndarray(ndt_ndarray_t *a, const ndt_t *t, ndt_context_t *ctx)
|
|
|
205
302
|
return n;
|
|
206
303
|
}
|
|
207
304
|
|
|
305
|
+
static const ndt_t *
|
|
306
|
+
_ndt_transpose(const ndt_ndarray_t *a, const int p[], const ndt_t *dtype,
|
|
307
|
+
ndt_context_t *ctx)
|
|
308
|
+
{
|
|
309
|
+
const ndt_t *t;
|
|
310
|
+
|
|
311
|
+
t = dtype;
|
|
312
|
+
ndt_incref(t);
|
|
313
|
+
|
|
314
|
+
for (int i = a->ndim-1; i >= 0; i--) {
|
|
315
|
+
const ndt_t *u = ndt_fixed_dim(t, a->shape[p[i]], a->steps[p[i]], ctx);
|
|
316
|
+
ndt_decref(t);
|
|
317
|
+
if (u == NULL) {
|
|
318
|
+
return NULL;
|
|
319
|
+
}
|
|
320
|
+
t = u;
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
return t;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
const ndt_t *
|
|
327
|
+
ndt_transpose(const ndt_t *t, const int *p, int ndim, ndt_context_t *ctx)
|
|
328
|
+
{
|
|
329
|
+
ndt_ndarray_t a;
|
|
330
|
+
int permute[NDT_MAX_DIM];
|
|
331
|
+
|
|
332
|
+
if (ndt_as_ndarray(&a, t, ctx) < 0) {
|
|
333
|
+
return NULL;
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
if (p == NULL) {
|
|
337
|
+
if (ndim != 0) {
|
|
338
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
339
|
+
"internal error: ndim != 0 but no permutations given");
|
|
340
|
+
return NULL;
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
ndim = a.ndim;
|
|
344
|
+
if (ndim <= 1) {
|
|
345
|
+
ndt_incref(t);
|
|
346
|
+
return t;
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
for (int i = 0; i < ndim; i++) {
|
|
350
|
+
permute[i] = ndim-i-1;
|
|
351
|
+
}
|
|
352
|
+
p = permute;
|
|
353
|
+
}
|
|
354
|
+
else {
|
|
355
|
+
if (ndim != a.ndim) {
|
|
356
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
357
|
+
"number of permutations != ndim");
|
|
358
|
+
return NULL;
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
for (int i = 0; i < ndim; i++) {
|
|
362
|
+
permute[i] = 0;
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
for (int i = 0; i < ndim; i++) {
|
|
366
|
+
if (p[i] < 0 || p[i] >= ndim) {
|
|
367
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
368
|
+
"permutation with value %d out of bounds", p[i]);
|
|
369
|
+
return NULL;
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
if (permute[p[i]] != 0) {
|
|
373
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
374
|
+
"duplicate permutation index=%d", p[i]);
|
|
375
|
+
return NULL;
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
permute[p[i]] = 1;
|
|
379
|
+
}
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
return _ndt_transpose(&a, p, ndt_dtype(t), ctx);
|
|
383
|
+
}
|
|
384
|
+
|
|
208
385
|
/* Unoptimized hash function for experimenting. */
|
|
209
386
|
ndt_ssize_t
|
|
210
|
-
ndt_hash(ndt_t *t, ndt_context_t *ctx)
|
|
387
|
+
ndt_hash(const ndt_t *t, ndt_context_t *ctx)
|
|
211
388
|
{
|
|
212
389
|
unsigned char *s, *cp;
|
|
213
390
|
size_t len;
|
|
@@ -242,11 +419,11 @@ ndt_hash(ndt_t *t, ndt_context_t *ctx)
|
|
|
242
419
|
|
|
243
420
|
const ndt_apply_spec_t ndt_apply_spec_empty = {
|
|
244
421
|
.flags = 0U,
|
|
245
|
-
.nout = 0,
|
|
246
|
-
.nbroadcast = 0,
|
|
247
422
|
.outer_dims = 0,
|
|
248
|
-
.
|
|
249
|
-
.
|
|
423
|
+
.nin = 0,
|
|
424
|
+
.nout = 0,
|
|
425
|
+
.nargs = 0,
|
|
426
|
+
.types = {NULL}
|
|
250
427
|
};
|
|
251
428
|
|
|
252
429
|
ndt_apply_spec_t *
|
|
@@ -259,9 +436,10 @@ ndt_apply_spec_new(ndt_context_t *ctx)
|
|
|
259
436
|
return ndt_memory_error(ctx);
|
|
260
437
|
}
|
|
261
438
|
spec->flags = 0U;
|
|
262
|
-
spec->nout = 0;
|
|
263
|
-
spec->nbroadcast = 0;
|
|
264
439
|
spec->outer_dims = 0;
|
|
440
|
+
spec->nin = 0;
|
|
441
|
+
spec->nout = 0;
|
|
442
|
+
spec->nargs = 0;
|
|
265
443
|
|
|
266
444
|
return spec;
|
|
267
445
|
}
|
|
@@ -269,24 +447,20 @@ ndt_apply_spec_new(ndt_context_t *ctx)
|
|
|
269
447
|
void
|
|
270
448
|
ndt_apply_spec_clear(ndt_apply_spec_t *spec)
|
|
271
449
|
{
|
|
272
|
-
int i;
|
|
273
|
-
|
|
274
450
|
if (spec == NULL) {
|
|
275
451
|
return;
|
|
276
452
|
}
|
|
277
453
|
|
|
278
|
-
for (i = 0; i < spec->
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
for (i = 0; i < spec->nout; i++) {
|
|
283
|
-
ndt_del(spec->out[i]);
|
|
454
|
+
for (int i = 0; i < spec->nin+spec->nout; i++) {
|
|
455
|
+
ndt_decref(spec->types[i]);
|
|
456
|
+
spec->types[i] = NULL;
|
|
284
457
|
}
|
|
285
458
|
|
|
286
459
|
spec->flags = 0U;
|
|
287
|
-
spec->nout = 0;
|
|
288
|
-
spec->nbroadcast = 0;
|
|
289
460
|
spec->outer_dims = 0;
|
|
461
|
+
spec->nin = 0;
|
|
462
|
+
spec->nout = 0;
|
|
463
|
+
spec->nargs = 0;
|
|
290
464
|
}
|
|
291
465
|
|
|
292
466
|
void
|
|
@@ -300,6 +474,15 @@ ndt_apply_spec_del(ndt_apply_spec_t *spec)
|
|
|
300
474
|
ndt_free(spec);
|
|
301
475
|
}
|
|
302
476
|
|
|
477
|
+
#define X (NDT_INNER_XND)
|
|
478
|
+
#define S (NDT_INNER_STRIDED)
|
|
479
|
+
#define C (NDT_INNER_C)
|
|
480
|
+
#define F (NDT_INNER_F)
|
|
481
|
+
|
|
482
|
+
#define ES (NDT_EXT_STRIDED)
|
|
483
|
+
#define EC (NDT_EXT_C)
|
|
484
|
+
#define EZ (NDT_EXT_ZERO)
|
|
485
|
+
|
|
303
486
|
/* This function is used in places where it is _really_ convenient not to
|
|
304
487
|
have to deal with deallocating an error message. */
|
|
305
488
|
const char *
|
|
@@ -307,38 +490,35 @@ ndt_apply_flags_as_string(const ndt_apply_spec_t *spec)
|
|
|
307
490
|
{
|
|
308
491
|
switch (spec->flags) {
|
|
309
492
|
case 0: return "None";
|
|
310
|
-
case
|
|
311
|
-
|
|
312
|
-
case
|
|
313
|
-
case
|
|
314
|
-
case
|
|
315
|
-
case
|
|
316
|
-
|
|
317
|
-
case
|
|
318
|
-
case
|
|
319
|
-
case
|
|
320
|
-
case
|
|
321
|
-
case
|
|
322
|
-
case
|
|
323
|
-
case
|
|
324
|
-
case
|
|
325
|
-
case
|
|
326
|
-
case
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED|NDT_FORTRAN: return "Elemwise1D|Fortran|Strided|Xnd";
|
|
340
|
-
case NDT_ELEMWISE_1D|NDT_XND|NDT_STRIDED|NDT_FORTRAN|NDT_C: return "Elemwise1D|C|Fortran|Strided|Xnd";
|
|
341
|
-
default: return "error: invalid combination of spec->flags";
|
|
493
|
+
case X: return "Xnd";
|
|
494
|
+
|
|
495
|
+
case S|X: return "Strided|Xnd";
|
|
496
|
+
case C|S|X: return "C|Strided|Xnd";
|
|
497
|
+
case F|S|X: return "Fortran|Strided|Xnd";
|
|
498
|
+
case C|F|S|X: return "C|Fortran|Strided|Xnd";
|
|
499
|
+
|
|
500
|
+
case ES|S|X: return "OptS|Strided|Xnd";
|
|
501
|
+
case ES|C|S|X: return "OptS|C|Strided|Xnd";
|
|
502
|
+
case ES|F|S|X: return "OptS|Fortran|Strided|Xnd";
|
|
503
|
+
case ES|C|F|S|X: return "OptS|C|Fortran|Strided|Xnd";
|
|
504
|
+
case EC|ES|C|S|X: return "OptC|OptS|C|Strided|Xnd";
|
|
505
|
+
case EC|ES|C|F|S|X: return "OptC|OptS|C|Fortran|Strided|Xnd";
|
|
506
|
+
case EZ|EC|ES|C|F|S|X: return "OptZ|OptC|OptS|C|Fortran|Strided|Xnd";
|
|
507
|
+
case EZ|EC|ES|C|S|X: return "OptZ|OptC|OptS|C|Strided|Xnd";
|
|
508
|
+
case EZ|ES|C|S|X: return "OptZ|OptS|C|Strided|Xnd";
|
|
509
|
+
case EZ|ES|C|F|S|X: return "OptZ|OptS|C|Fortran|Strided|Xnd";
|
|
510
|
+
|
|
511
|
+
default:
|
|
512
|
+
if (spec->flags & NDT_EXT_ZERO) fprintf(stderr, "EZ ");
|
|
513
|
+
if (spec->flags & NDT_EXT_C) fprintf(stderr, "EC ");
|
|
514
|
+
if (spec->flags & NDT_EXT_STRIDED) fprintf(stderr, "ES ");
|
|
515
|
+
if (spec->flags & NDT_INNER_C) fprintf(stderr, "C ");
|
|
516
|
+
if (spec->flags & NDT_INNER_F) fprintf(stderr, "F ");
|
|
517
|
+
if (spec->flags & NDT_INNER_STRIDED) fprintf(stderr, "S ");
|
|
518
|
+
if (spec->flags & NDT_INNER_XND) fprintf(stderr, "X ");
|
|
519
|
+
fprintf(stderr, "\n");
|
|
520
|
+
|
|
521
|
+
return "unknown flags";
|
|
342
522
|
}
|
|
343
523
|
}
|
|
344
524
|
|
|
@@ -362,16 +542,21 @@ ndt_meta_new(ndt_context_t *ctx)
|
|
|
362
542
|
}
|
|
363
543
|
|
|
364
544
|
void
|
|
365
|
-
|
|
545
|
+
ndt_meta_clear(ndt_meta_t *m)
|
|
366
546
|
{
|
|
367
547
|
if (m == NULL) {
|
|
368
548
|
return;
|
|
369
549
|
}
|
|
370
550
|
|
|
371
551
|
for (int i = 0; i < m->ndims; i++) {
|
|
372
|
-
|
|
552
|
+
ndt_decref_offsets(m->offsets[i]);
|
|
373
553
|
}
|
|
554
|
+
}
|
|
374
555
|
|
|
556
|
+
void
|
|
557
|
+
ndt_meta_del(ndt_meta_t *m)
|
|
558
|
+
{
|
|
559
|
+
ndt_meta_clear(m);
|
|
375
560
|
ndt_free(m);
|
|
376
561
|
}
|
|
377
562
|
|
|
@@ -380,95 +565,118 @@ ndt_meta_del(ndt_meta_t *m)
|
|
|
380
565
|
/* Optimized kernel strategy (unstable API) */
|
|
381
566
|
/*****************************************************************************/
|
|
382
567
|
|
|
383
|
-
static
|
|
384
|
-
|
|
568
|
+
static uint32_t
|
|
569
|
+
check_c(uint32_t flags, ndt_ndarray_t *x, int outer)
|
|
385
570
|
{
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
571
|
+
int inner = x->ndim-outer;
|
|
572
|
+
int64_t *shape = x->shape+outer;
|
|
573
|
+
int64_t *steps = x->steps+outer;
|
|
574
|
+
int64_t step = 1;
|
|
575
|
+
|
|
576
|
+
for (int i = inner-1; i >= 0; i--) {
|
|
577
|
+
if (shape[i] > 1 && steps[i] != step) {
|
|
578
|
+
return flags & ~(NDT_INNER_C|NDT_EXT_C|NDT_EXT_ZERO);
|
|
389
579
|
}
|
|
390
|
-
|
|
391
|
-
|
|
580
|
+
step *= shape[i];
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
if (x->ndim == inner) {
|
|
584
|
+
return flags & ~(NDT_EXT_C|NDT_EXT_ZERO);
|
|
585
|
+
}
|
|
586
|
+
else if (x->ndim > inner) {
|
|
587
|
+
if (shape[-1] > 1 && steps[-1] != step) {
|
|
588
|
+
flags &= ~NDT_EXT_C;
|
|
589
|
+
if (steps[-1] != 0) {
|
|
590
|
+
flags &= ~NDT_EXT_ZERO;
|
|
591
|
+
}
|
|
392
592
|
}
|
|
393
593
|
}
|
|
394
594
|
|
|
395
|
-
return
|
|
595
|
+
return flags;
|
|
396
596
|
}
|
|
397
597
|
|
|
398
|
-
static
|
|
399
|
-
|
|
598
|
+
static uint32_t
|
|
599
|
+
check_f(uint32_t flags, ndt_ndarray_t *x, int outer)
|
|
400
600
|
{
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
601
|
+
int inner = x->ndim-outer;
|
|
602
|
+
int64_t *shape = x->shape+outer;
|
|
603
|
+
int64_t *steps = x->steps+outer;
|
|
604
|
+
int64_t step = 1;
|
|
605
|
+
|
|
606
|
+
for (int i = 0; i < inner; i++) {
|
|
607
|
+
if (shape[i] > 1 && steps[i] != step) {
|
|
608
|
+
return flags & ~NDT_INNER_F;
|
|
404
609
|
}
|
|
610
|
+
step *= shape[i];
|
|
405
611
|
}
|
|
406
612
|
|
|
407
|
-
return
|
|
613
|
+
return flags;
|
|
408
614
|
}
|
|
409
615
|
|
|
410
|
-
static
|
|
411
|
-
|
|
616
|
+
static uint32_t
|
|
617
|
+
check_strided(uint32_t flags, int outer)
|
|
412
618
|
{
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
return false;
|
|
416
|
-
}
|
|
619
|
+
if (outer == 0) {
|
|
620
|
+
return flags & ~NDT_EXT_STRIDED;
|
|
417
621
|
}
|
|
418
622
|
|
|
419
|
-
return
|
|
623
|
+
return flags;
|
|
420
624
|
}
|
|
421
625
|
|
|
422
|
-
static
|
|
423
|
-
|
|
626
|
+
static uint32_t
|
|
627
|
+
check_var(uint32_t flags, const ndt_t *t, int outer)
|
|
424
628
|
{
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
}
|
|
433
|
-
|
|
434
|
-
return
|
|
629
|
+
#if 0
|
|
630
|
+
/*
|
|
631
|
+
* This currently does not handle a corner case with zeros in a shape.
|
|
632
|
+
* The loop in gumath is safe, however, since it stops at shape==0.
|
|
633
|
+
*/
|
|
634
|
+
if (ndt_logical_ndim(t) == outer) {
|
|
635
|
+
return flags & NDT_INNER_XND;
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
return 0U;
|
|
639
|
+
#endif
|
|
640
|
+
(void)t;
|
|
641
|
+
(void)outer;
|
|
642
|
+
return flags & NDT_INNER_XND;
|
|
435
643
|
}
|
|
436
644
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
const ndt_t *in[], int nin)
|
|
645
|
+
static uint32_t
|
|
646
|
+
select_flags(const ndt_t *types[], int n, int outer, ndt_context_t *ctx)
|
|
440
647
|
{
|
|
441
|
-
|
|
442
|
-
|
|
648
|
+
uint32_t flags = NDT_SPEC_FLAGS_ALL;
|
|
649
|
+
ndt_ndarray_t x;
|
|
443
650
|
|
|
444
|
-
|
|
445
|
-
|
|
651
|
+
for (int i = 0; i < n; i++) {
|
|
652
|
+
const ndt_t *t = types[i];
|
|
446
653
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
654
|
+
if (ndt_as_ndarray(&x, t, ctx) < 0) { /* var dimension */
|
|
655
|
+
ndt_err_clear(ctx);
|
|
656
|
+
if (t->tag == VarDim || t->tag == VarDimElem) {
|
|
657
|
+
flags = check_var(flags, t, outer);
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
else {
|
|
661
|
+
if (outer > t->ndim) {
|
|
662
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
663
|
+
"number of outer dimensions greater than ndim");
|
|
664
|
+
return UINT32_MAX;
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
flags = check_strided(flags, outer);
|
|
668
|
+
flags = check_c(flags, &x, outer);
|
|
669
|
+
flags = check_f(flags, &x, outer);
|
|
670
|
+
}
|
|
450
671
|
}
|
|
451
672
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
out_inner_c = all_inner_c_contiguous(out, spec->nout, spec->outer_dims);
|
|
455
|
-
|
|
456
|
-
spec->flags = NDT_XND;
|
|
673
|
+
return flags;
|
|
674
|
+
}
|
|
457
675
|
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
}
|
|
463
|
-
}
|
|
464
|
-
if (in_inner_c && out_inner_c) {
|
|
465
|
-
spec->flags |= NDT_C;
|
|
466
|
-
}
|
|
467
|
-
if (in_inner_f && out_inner_c) {
|
|
468
|
-
spec->flags |= NDT_FORTRAN;
|
|
469
|
-
}
|
|
676
|
+
int
|
|
677
|
+
ndt_select_kernel_strategy(ndt_apply_spec_t *spec, ndt_context_t *ctx)
|
|
678
|
+
{
|
|
679
|
+
spec->flags = select_flags(spec->types, spec->nargs, spec->outer_dims, ctx);
|
|
470
680
|
|
|
471
|
-
|
|
472
|
-
spec->flags |= NDT_STRIDED;
|
|
473
|
-
}
|
|
681
|
+
return spec->flags == UINT32_MAX ? -1 : 0;
|
|
474
682
|
}
|