ndtypes 0.2.0dev4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CONTRIBUTING.md +50 -0
- data/Gemfile +2 -0
- data/History.md +0 -0
- data/README.md +19 -0
- data/Rakefile +125 -0
- data/ext/ruby_ndtypes/extconf.rb +55 -0
- data/ext/ruby_ndtypes/gc_guard.c +36 -0
- data/ext/ruby_ndtypes/gc_guard.h +12 -0
- data/ext/ruby_ndtypes/ndtypes/AUTHORS.txt +5 -0
- data/ext/ruby_ndtypes/ndtypes/INSTALL.txt +101 -0
- data/ext/ruby_ndtypes/ndtypes/LICENSE.txt +29 -0
- data/ext/ruby_ndtypes/ndtypes/MANIFEST.in +3 -0
- data/ext/ruby_ndtypes/ndtypes/Makefile.in +87 -0
- data/ext/ruby_ndtypes/ndtypes/README.rst +47 -0
- data/ext/ruby_ndtypes/ndtypes/config.guess +1530 -0
- data/ext/ruby_ndtypes/ndtypes/config.h.in +67 -0
- data/ext/ruby_ndtypes/ndtypes/config.sub +1782 -0
- data/ext/ruby_ndtypes/ndtypes/configure +5260 -0
- data/ext/ruby_ndtypes/ndtypes/configure.ac +161 -0
- data/ext/ruby_ndtypes/ndtypes/doc/Makefile +14 -0
- data/ext/ruby_ndtypes/ndtypes/doc/_static/copybutton.js +66 -0
- data/ext/ruby_ndtypes/ndtypes/doc/conf.py +26 -0
- data/ext/ruby_ndtypes/ndtypes/doc/grammar/grammar.rst +27 -0
- data/ext/ruby_ndtypes/ndtypes/doc/index.rst +56 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/context.rst +131 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/encodings.rst +68 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/fields-values.rst +175 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/functions.rst +72 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/index.rst +43 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/init.rst +48 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/io.rst +100 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/memory.rst +124 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/predicates.rst +110 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/typedef.rst +31 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/types.rst +594 -0
- data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/util.rst +166 -0
- data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/buffer-protocol.rst +27 -0
- data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/index.rst +21 -0
- data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/pattern-matching.rst +330 -0
- data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/quickstart.rst +144 -0
- data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +544 -0
- data/ext/ruby_ndtypes/ndtypes/doc/releases/index.rst +35 -0
- data/ext/ruby_ndtypes/ndtypes/install-sh +527 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +271 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +269 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +230 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.c +268 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.h +109 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile.in +73 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile.vc +70 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/README.txt +16 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +2179 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +134 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +428 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +2543 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +735 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +176 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +543 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +110 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/context.c +228 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +634 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.c +116 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +288 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +3067 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +180 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +417 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +1658 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +2773 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +734 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +222 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +1132 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +2323 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +893 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/overflow.h +161 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +473 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +92 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +246 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +269 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +197 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile.in +48 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile.vc +46 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +1007 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +442 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/slice.h +42 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +238 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +50 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +371 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +100 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +55 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +45 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/alloc_fail.c +82 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/alloc_fail.h +49 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +1657 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +85 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +115 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +137 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_indent.c +201 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +2397 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_numba.c +57 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +349 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +27839 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +350 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +231 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +375 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typedef.c +65 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/valgrind.supp +30 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/bench.c +79 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/indent.c +94 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/print_ast.c +96 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +474 -0
- data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +228 -0
- data/ext/ruby_ndtypes/ndtypes/python/bench.py +49 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +409 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +14 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +70 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +1332 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/docstrings.h +319 -0
- data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +154 -0
- data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +1977 -0
- data/ext/ruby_ndtypes/ndtypes/setup.py +288 -0
- data/ext/ruby_ndtypes/ndtypes/vcbuild/INSTALL.txt +41 -0
- data/ext/ruby_ndtypes/ndtypes/vcbuild/runtest32.bat +15 -0
- data/ext/ruby_ndtypes/ndtypes/vcbuild/runtest64.bat +13 -0
- data/ext/ruby_ndtypes/ndtypes/vcbuild/vcbuild32.bat +38 -0
- data/ext/ruby_ndtypes/ndtypes/vcbuild/vcbuild64.bat +38 -0
- data/ext/ruby_ndtypes/ndtypes/vcbuild/vcclean.bat +13 -0
- data/ext/ruby_ndtypes/ndtypes/vcbuild/vcdistclean.bat +14 -0
- data/ext/ruby_ndtypes/ruby_ndtypes.c +1003 -0
- data/ext/ruby_ndtypes/ruby_ndtypes.h +37 -0
- data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +28 -0
- data/lib/ndtypes.rb +45 -0
- data/lib/ndtypes/errors.rb +2 -0
- data/lib/ndtypes/version.rb +6 -0
- data/ndtypes.gemspec +47 -0
- data/spec/gc_table_spec.rb +10 -0
- data/spec/ndtypes_spec.rb +289 -0
- data/spec/spec_helper.rb +241 -0
- metadata +242 -0
@@ -0,0 +1,222 @@
|
|
1
|
+
%{
|
2
|
+
/*
|
3
|
+
* BSD 3-Clause License
|
4
|
+
*
|
5
|
+
* Copyright (c) 2017-2018, plures
|
6
|
+
* All rights reserved.
|
7
|
+
*
|
8
|
+
* Redistribution and use in source and binary forms, with or without
|
9
|
+
* modification, are permitted provided that the following conditions are met:
|
10
|
+
*
|
11
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
12
|
+
* this list of conditions and the following disclaimer.
|
13
|
+
*
|
14
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
15
|
+
* this list of conditions and the following disclaimer in the documentation
|
16
|
+
* and/or other materials provided with the distribution.
|
17
|
+
*
|
18
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
19
|
+
* contributors may be used to endorse or promote products derived from
|
20
|
+
* this software without specific prior written permission.
|
21
|
+
*
|
22
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
23
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
24
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
25
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
26
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
27
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
28
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
29
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
30
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
31
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
32
|
+
*/
|
33
|
+
|
34
|
+
|
35
|
+
#include <stdio.h>
|
36
|
+
#include <string.h>
|
37
|
+
#include <stdint.h>
|
38
|
+
#include <setjmp.h>
|
39
|
+
#include "ndtypes.h"
|
40
|
+
#include "parsefuncs.h"
|
41
|
+
#include "grammar.h"
|
42
|
+
|
43
|
+
/* From PostgreSQL: avoid exit() on fatal scanner errors. */
|
44
|
+
#undef fprintf
|
45
|
+
#define fprintf(file, fmt, msg) fprintf_to_longjmp(fmt, msg, yyscanner)
|
46
|
+
|
47
|
+
extern jmp_buf ndt_lexerror;
|
48
|
+
static void
|
49
|
+
fprintf_to_longjmp(const char *fmt, const char *msg, yyscan_t yyscanner)
|
50
|
+
{
|
51
|
+
(void)fmt; (void)msg; (void)yyscanner;
|
52
|
+
|
53
|
+
/* We don't have access to the parse context here: discard the error
|
54
|
+
message, which is always either an allocation failure or an internal
|
55
|
+
flex error. */
|
56
|
+
longjmp(ndt_lexerror, 1);
|
57
|
+
}
|
58
|
+
|
59
|
+
void *
|
60
|
+
yyalloc(size_t size, yyscan_t yyscanner)
|
61
|
+
{
|
62
|
+
(void)yyscanner;
|
63
|
+
|
64
|
+
return ndt_alloc(1, size);
|
65
|
+
}
|
66
|
+
|
67
|
+
void *
|
68
|
+
yyrealloc(void *ptr, size_t size, yyscan_t yyscanner)
|
69
|
+
{
|
70
|
+
(void)yyscanner;
|
71
|
+
|
72
|
+
return ndt_realloc(ptr, 1, size);
|
73
|
+
}
|
74
|
+
|
75
|
+
void
|
76
|
+
yyfree(void *ptr, yyscan_t yyscanner)
|
77
|
+
{
|
78
|
+
(void)yyscanner;
|
79
|
+
|
80
|
+
ndt_free(ptr);
|
81
|
+
}
|
82
|
+
|
83
|
+
%}
|
84
|
+
|
85
|
+
%option bison-bridge bison-locations reentrant noyywrap
|
86
|
+
%option nounput noinput noyyalloc noyyrealloc noyyfree
|
87
|
+
%option never-interactive
|
88
|
+
%option yylineno
|
89
|
+
%option 8bit
|
90
|
+
%option extra-type="ndt_context_t *"
|
91
|
+
%option warn nodefault
|
92
|
+
|
93
|
+
|
94
|
+
newline [\n\r]
|
95
|
+
space [ \t\f]
|
96
|
+
non_newline [^\n\r]
|
97
|
+
comment #{non_newline}*
|
98
|
+
|
99
|
+
escapeseq \\.
|
100
|
+
single_strchar [^\\\n']
|
101
|
+
double_strchar [^\\\n"]
|
102
|
+
single_str '({single_strchar}|{escapeseq})*'
|
103
|
+
double_str \"({double_strchar}|{escapeseq})*\"
|
104
|
+
stringlit {single_str}|{double_str}
|
105
|
+
|
106
|
+
octdigit [0-7]
|
107
|
+
octinteger 0[oO]{octdigit}+
|
108
|
+
nonzerodigit [1-9]
|
109
|
+
digit [0-9]
|
110
|
+
decimalinteger {nonzerodigit}{digit}*|0+
|
111
|
+
hexdigit {digit}|[a-f]|[A-F]
|
112
|
+
hexinteger 0[xX]{hexdigit}+
|
113
|
+
integer -?({decimalinteger}|{octinteger}|{hexinteger})
|
114
|
+
|
115
|
+
intpart {digit}+
|
116
|
+
fraction \.{digit}+
|
117
|
+
exponent [eE][+-]?{digit}+
|
118
|
+
pointfloat {intpart}?{fraction}|{intpart}\.
|
119
|
+
exponentfloat ({intpart}|{pointfloat}){exponent}
|
120
|
+
floatnumber -?({pointfloat}|{exponentfloat})
|
121
|
+
|
122
|
+
name_lower [a-z][a-zA-Z0-9_]*
|
123
|
+
name_upper [A-Z][a-zA-Z0-9_]*
|
124
|
+
name_other _[a-zA-Z0-9_]*
|
125
|
+
|
126
|
+
|
127
|
+
%%
|
128
|
+
|
129
|
+
%code {
|
130
|
+
yycolumn = 1;
|
131
|
+
|
132
|
+
#undef YY_USER_ACTION
|
133
|
+
#define YY_USER_ACTION \
|
134
|
+
yylloc->first_line = yylloc->last_line = yylineno; \
|
135
|
+
yylloc->first_column = yycolumn; \
|
136
|
+
yylloc->last_column = yycolumn+yyleng-1; \
|
137
|
+
yycolumn += yyleng;
|
138
|
+
|
139
|
+
}
|
140
|
+
|
141
|
+
"Any" { return ANY_KIND; }
|
142
|
+
"Scalar" { return SCALAR_KIND; }
|
143
|
+
|
144
|
+
"void" { return VOID; }
|
145
|
+
"bool" { return BOOL; }
|
146
|
+
|
147
|
+
"Signed" { return SIGNED_KIND; }
|
148
|
+
"int8" { return INT8; }
|
149
|
+
"int16" { return INT16; }
|
150
|
+
"int32" { return INT32; }
|
151
|
+
"int64" { return INT64; }
|
152
|
+
|
153
|
+
"Unsigned" { return UNSIGNED_KIND; }
|
154
|
+
"uint8" { return UINT8; }
|
155
|
+
"uint16" { return UINT16; }
|
156
|
+
"uint32" { return UINT32; }
|
157
|
+
"uint64" { return UINT64; }
|
158
|
+
|
159
|
+
"Float" { return FLOAT_KIND; }
|
160
|
+
"float16" { return FLOAT16; }
|
161
|
+
"float32" { return FLOAT32; }
|
162
|
+
"float64" { return FLOAT64; }
|
163
|
+
|
164
|
+
"Complex" { return COMPLEX_KIND; }
|
165
|
+
"complex32" { return COMPLEX32; }
|
166
|
+
"complex64" { return COMPLEX64; }
|
167
|
+
"complex128" { return COMPLEX128; }
|
168
|
+
|
169
|
+
"intptr" { return INTPTR; }
|
170
|
+
"uintptr" { return UINTPTR; }
|
171
|
+
"size_t" { return SIZE; }
|
172
|
+
"char" { return CHAR; }
|
173
|
+
"string" { return STRING; }
|
174
|
+
"bytes" { return BYTES; }
|
175
|
+
|
176
|
+
"FixedString" { return FIXED_STRING_KIND; }
|
177
|
+
"fixed_string" { return FIXED_STRING; }
|
178
|
+
|
179
|
+
"FixedBytes" { return FIXED_BYTES_KIND; }
|
180
|
+
"fixed_bytes" { return FIXED_BYTES; }
|
181
|
+
|
182
|
+
"categorical" { return CATEGORICAL; }
|
183
|
+
"NA" { return NA; }
|
184
|
+
|
185
|
+
"ref" { return REF; }
|
186
|
+
|
187
|
+
"fixed" { return FIXED; }
|
188
|
+
"var" { return VAR; }
|
189
|
+
|
190
|
+
"..." { return ELLIPSIS; }
|
191
|
+
"->" { return RARROW; }
|
192
|
+
"," { return COMMA; }
|
193
|
+
":" { return COLON; }
|
194
|
+
"(" { return LPAREN; }
|
195
|
+
")" { return RPAREN; }
|
196
|
+
"{" { return LBRACE; }
|
197
|
+
"}" { return RBRACE; }
|
198
|
+
"[" { return LBRACK; }
|
199
|
+
"]" { return RBRACK; }
|
200
|
+
"*" { return STAR; }
|
201
|
+
"=" { return EQUAL; }
|
202
|
+
"?" { return QUESTIONMARK; }
|
203
|
+
"!" { return BANG; }
|
204
|
+
"&" { return AMPERSAND; }
|
205
|
+
"|" { return BAR; }
|
206
|
+
"<" { return LESS; }
|
207
|
+
">" { return GREATER; }
|
208
|
+
|
209
|
+
{name_lower} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_LOWER; }
|
210
|
+
{name_upper} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_UPPER; }
|
211
|
+
{name_other} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_OTHER; }
|
212
|
+
|
213
|
+
{stringlit} { yylval->string = mk_stringlit(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return STRINGLIT; }
|
214
|
+
{integer} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return INTEGER; }
|
215
|
+
{floatnumber} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return FLOATNUMBER; }
|
216
|
+
|
217
|
+
{newline} { yycolumn = 1; }
|
218
|
+
{space} {} /* ignore */
|
219
|
+
{comment} {} /* ignore */
|
220
|
+
. { return ERRTOKEN; }
|
221
|
+
|
222
|
+
%%
|
@@ -0,0 +1,1132 @@
|
|
1
|
+
/*
|
2
|
+
* BSD 3-Clause License
|
3
|
+
*
|
4
|
+
* Copyright (c) 2017-2018, plures
|
5
|
+
* All rights reserved.
|
6
|
+
*
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
9
|
+
*
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
* this list of conditions and the following disclaimer.
|
12
|
+
*
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
15
|
+
* and/or other materials provided with the distribution.
|
16
|
+
*
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
19
|
+
* this software without specific prior written permission.
|
20
|
+
*
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
*/
|
32
|
+
|
33
|
+
|
34
|
+
#include <stdio.h>
|
35
|
+
#include <stdlib.h>
|
36
|
+
#include <stdint.h>
|
37
|
+
#include <inttypes.h>
|
38
|
+
#include <stdbool.h>
|
39
|
+
#include <string.h>
|
40
|
+
#include <stdarg.h>
|
41
|
+
#include <assert.h>
|
42
|
+
#include "ndtypes.h"
|
43
|
+
#include "symtable.h"
|
44
|
+
#include "substitute.h"
|
45
|
+
|
46
|
+
|
47
|
+
static int match_datashape(const ndt_t *, const ndt_t *, symtable_t *, ndt_context_t *);
|
48
|
+
|
49
|
+
static int
|
50
|
+
_resolve_broadcast(int64_t vshape[NDT_MAX_DIM], int vsize,
|
51
|
+
const int64_t wshape[NDT_MAX_DIM], int wsize)
|
52
|
+
{
|
53
|
+
int64_t n, m;
|
54
|
+
int i, k;
|
55
|
+
|
56
|
+
for (i=vsize-1, k=wsize-1; i>=0 && k>=0; i--, k--) {
|
57
|
+
n = vshape[i];
|
58
|
+
m = wshape[k];
|
59
|
+
if (m != n) {
|
60
|
+
if (n == 1) {
|
61
|
+
n = m;
|
62
|
+
}
|
63
|
+
else if (m == 0) {
|
64
|
+
n = 0;
|
65
|
+
}
|
66
|
+
else if (m != 1) {
|
67
|
+
return -1;
|
68
|
+
}
|
69
|
+
}
|
70
|
+
vshape[i<k ? k : i] = n;
|
71
|
+
}
|
72
|
+
for (; k >= 0; k--) {
|
73
|
+
vshape[k] = wshape[k];
|
74
|
+
}
|
75
|
+
|
76
|
+
return vsize >= wsize ? vsize : wsize;
|
77
|
+
}
|
78
|
+
|
79
|
+
static int
|
80
|
+
resolve_broadcast(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
81
|
+
{
|
82
|
+
const char *key = "00_ELLIPSIS";
|
83
|
+
symtable_entry_t *v;
|
84
|
+
int vsize;
|
85
|
+
|
86
|
+
v = symtable_find_ptr(tbl, key);
|
87
|
+
if (v == NULL) {
|
88
|
+
if (symtable_add(tbl, key, w, ctx) < 0) {
|
89
|
+
return -1;
|
90
|
+
}
|
91
|
+
return 1;
|
92
|
+
}
|
93
|
+
|
94
|
+
vsize = _resolve_broadcast(v->BroadcastSeq.dims, v->BroadcastSeq.size,
|
95
|
+
w.BroadcastSeq.dims, w.BroadcastSeq.size);
|
96
|
+
if (vsize < 0) {
|
97
|
+
ndt_err_format(ctx, NDT_TypeError, "broadcast error");
|
98
|
+
return -1;
|
99
|
+
}
|
100
|
+
v->BroadcastSeq.size = vsize;
|
101
|
+
|
102
|
+
return 1;
|
103
|
+
}
|
104
|
+
|
105
|
+
static int
|
106
|
+
check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
|
107
|
+
{
|
108
|
+
for (int i = 0; i < nargs; i++) {
|
109
|
+
const ndt_t *p = ptypes[i];
|
110
|
+
const ndt_t *c = ctypes[i];
|
111
|
+
|
112
|
+
if (p->tag == EllipsisDim) {
|
113
|
+
switch (p->EllipsisDim.tag) {
|
114
|
+
case RequireNA:
|
115
|
+
break;
|
116
|
+
case RequireC:
|
117
|
+
if (!ndt_is_c_contiguous(c)) {
|
118
|
+
return 0;
|
119
|
+
}
|
120
|
+
break;
|
121
|
+
case RequireF:
|
122
|
+
if (!ndt_is_f_contiguous(c)) {
|
123
|
+
return 0;
|
124
|
+
}
|
125
|
+
break;
|
126
|
+
}
|
127
|
+
}
|
128
|
+
}
|
129
|
+
|
130
|
+
return 1;
|
131
|
+
}
|
132
|
+
|
133
|
+
static ndt_t *
|
134
|
+
to_fortran(const ndt_t *p, ndt_t *c, ndt_context_t *ctx)
|
135
|
+
{
|
136
|
+
if (p->tag == EllipsisDim && p->EllipsisDim.tag == RequireF) {
|
137
|
+
ndt_t *t = ndt_to_fortran(c, ctx);
|
138
|
+
return t;
|
139
|
+
}
|
140
|
+
else {
|
141
|
+
return c;
|
142
|
+
}
|
143
|
+
}
|
144
|
+
|
145
|
+
static int
|
146
|
+
resolve_fixed(const char *key, symtable_entry_t w,
|
147
|
+
symtable_t *tbl, ndt_context_t *ctx)
|
148
|
+
{
|
149
|
+
symtable_entry_t v;
|
150
|
+
|
151
|
+
v = symtable_find(tbl, key);
|
152
|
+
if (v.tag == Unbound) {
|
153
|
+
if (symtable_add(tbl, key, w, ctx) < 0) {
|
154
|
+
return -1;
|
155
|
+
}
|
156
|
+
return 1;
|
157
|
+
}
|
158
|
+
|
159
|
+
if (w.FixedSeq.size != v.FixedSeq.size) {
|
160
|
+
return 0;
|
161
|
+
}
|
162
|
+
|
163
|
+
for (int i = 0; i < v.FixedSeq.size; i++) {
|
164
|
+
const ndt_t *t = v.FixedSeq.dims[i];
|
165
|
+
const ndt_t *u = w.FixedSeq.dims[i];
|
166
|
+
if (u->FixedDim.shape != t->FixedDim.shape) {
|
167
|
+
return 0;
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
return 1;
|
172
|
+
}
|
173
|
+
|
174
|
+
static int
|
175
|
+
resolve_shape(const char *key, int64_t shape, symtable_t *tbl, ndt_context_t *ctx)
|
176
|
+
{
|
177
|
+
symtable_entry_t v;
|
178
|
+
|
179
|
+
v = symtable_find(tbl, key);
|
180
|
+
if (v.tag == Unbound) {
|
181
|
+
v.tag = Shape;
|
182
|
+
v.Shape = shape;
|
183
|
+
if (symtable_add(tbl, key, v, ctx) < 0) {
|
184
|
+
return -1;
|
185
|
+
}
|
186
|
+
return 1;
|
187
|
+
}
|
188
|
+
|
189
|
+
if (v.tag != Shape) {
|
190
|
+
return 0;
|
191
|
+
}
|
192
|
+
|
193
|
+
return shape == v.Shape;
|
194
|
+
}
|
195
|
+
|
196
|
+
static int
|
197
|
+
resolve_typevar(const char *key, symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
198
|
+
{
|
199
|
+
symtable_entry_t v;
|
200
|
+
|
201
|
+
v = symtable_find(tbl, key);
|
202
|
+
if (v.tag == Unbound) {
|
203
|
+
if (symtable_add(tbl, key, w, ctx) < 0) {
|
204
|
+
return -1;
|
205
|
+
}
|
206
|
+
return 1;
|
207
|
+
}
|
208
|
+
|
209
|
+
if (v.tag == Symbol && w.tag == Symbol) {
|
210
|
+
return strcmp(v.Symbol, w.Symbol) == 0;
|
211
|
+
}
|
212
|
+
else if (v.tag == Type && w.tag == Type) {
|
213
|
+
return ndt_equal(v.Type, w.Type);
|
214
|
+
}
|
215
|
+
else {
|
216
|
+
return 0;
|
217
|
+
}
|
218
|
+
}
|
219
|
+
|
220
|
+
static int
|
221
|
+
match_concrete_var_dim(const ndt_t *t, int64_t tindex,
|
222
|
+
const ndt_t *u, int64_t uindex,
|
223
|
+
const int outer_dims, ndt_context_t *ctx)
|
224
|
+
{
|
225
|
+
int64_t tshape, tstart, tstep;
|
226
|
+
int64_t ushape, ustart, ustep;
|
227
|
+
|
228
|
+
if (outer_dims == 0) {
|
229
|
+
return 1;
|
230
|
+
}
|
231
|
+
if (t->Concrete.VarDim.itemsize != u->Concrete.VarDim.itemsize) {
|
232
|
+
return 0;
|
233
|
+
}
|
234
|
+
|
235
|
+
tshape = ndt_var_indices(&tstart, &tstep, t, tindex, ctx);
|
236
|
+
if (tshape < 0) {
|
237
|
+
return -1;
|
238
|
+
}
|
239
|
+
|
240
|
+
ushape = ndt_var_indices(&ustart, &ustep, u, uindex, ctx);
|
241
|
+
if (ushape < 0) {
|
242
|
+
return -1;
|
243
|
+
}
|
244
|
+
|
245
|
+
if (ushape != tshape) {
|
246
|
+
return 0;
|
247
|
+
}
|
248
|
+
|
249
|
+
for (int64_t i = 0; i < tshape; i++) {
|
250
|
+
int64_t tnext = tstart + i * tstep;
|
251
|
+
int64_t unext = ustart + i * ustep;
|
252
|
+
int ret = match_concrete_var_dim(t->VarDim.type, tnext,
|
253
|
+
u->VarDim.type, unext,
|
254
|
+
outer_dims-1, ctx);
|
255
|
+
if (ret <= 0) {
|
256
|
+
return ret;
|
257
|
+
}
|
258
|
+
}
|
259
|
+
|
260
|
+
return 1;
|
261
|
+
}
|
262
|
+
|
263
|
+
static int
|
264
|
+
resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
|
265
|
+
{
|
266
|
+
const char *key = "var";
|
267
|
+
symtable_entry_t v;
|
268
|
+
|
269
|
+
v = symtable_find(tbl, key);
|
270
|
+
if (v.tag == Unbound) {
|
271
|
+
if (symtable_add(tbl, key, w, ctx) < 0) {
|
272
|
+
return -1;
|
273
|
+
}
|
274
|
+
return 1;
|
275
|
+
}
|
276
|
+
|
277
|
+
if (w.VarSeq.size != v.VarSeq.size) {
|
278
|
+
return 0;
|
279
|
+
}
|
280
|
+
if (v.VarSeq.size == 0) {
|
281
|
+
return 1;
|
282
|
+
}
|
283
|
+
|
284
|
+
return match_concrete_var_dim(w.VarSeq.dims[0], 0,
|
285
|
+
v.VarSeq.dims[0], 0,
|
286
|
+
v.VarSeq.size, ctx);
|
287
|
+
}
|
288
|
+
|
289
|
+
static int
|
290
|
+
match_tuple_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
291
|
+
ndt_context_t *ctx)
|
292
|
+
{
|
293
|
+
int64_t i;
|
294
|
+
int n;
|
295
|
+
|
296
|
+
assert(p->tag == Tuple && c->tag == Tuple);
|
297
|
+
|
298
|
+
if (p->Tuple.shape != c->Tuple.shape) {
|
299
|
+
return 0;
|
300
|
+
}
|
301
|
+
|
302
|
+
for (i = 0; i < p->Tuple.shape; i++) {
|
303
|
+
n = match_datashape(p->Tuple.types[i], c->Tuple.types[i], tbl, ctx);
|
304
|
+
if (n <= 0) return n;
|
305
|
+
}
|
306
|
+
|
307
|
+
return 1;
|
308
|
+
}
|
309
|
+
|
310
|
+
static int
|
311
|
+
match_record_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
312
|
+
ndt_context_t *ctx)
|
313
|
+
{
|
314
|
+
int64_t i;
|
315
|
+
int n;
|
316
|
+
|
317
|
+
assert(p->tag == Record && c->tag == Record);
|
318
|
+
|
319
|
+
if (p->Record.shape != c->Record.shape) {
|
320
|
+
return 0;
|
321
|
+
}
|
322
|
+
|
323
|
+
for (i = 0; i < p->Record.shape; i++) {
|
324
|
+
n = strcmp(p->Record.names[i], c->Record.names[i]);
|
325
|
+
if (n != 0) return 0;
|
326
|
+
|
327
|
+
n = match_datashape(p->Record.types[i], c->Record.types[i], tbl, ctx);
|
328
|
+
if (n <= 0) return n;
|
329
|
+
}
|
330
|
+
|
331
|
+
return 1;
|
332
|
+
}
|
333
|
+
|
334
|
+
static int
|
335
|
+
match_categorical(ndt_value_t *p, int64_t plen,
|
336
|
+
ndt_value_t *c, int64_t clen)
|
337
|
+
{
|
338
|
+
int64_t i;
|
339
|
+
|
340
|
+
if (plen != clen) {
|
341
|
+
return 0;
|
342
|
+
}
|
343
|
+
|
344
|
+
for (i = 0; i < plen; i++) {
|
345
|
+
if (!ndt_value_equal(&p[i], &c[i])) {
|
346
|
+
return 0;
|
347
|
+
}
|
348
|
+
}
|
349
|
+
|
350
|
+
return 1;
|
351
|
+
}
|
352
|
+
|
353
|
+
static const ndt_t *
|
354
|
+
outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
|
355
|
+
{
|
356
|
+
assert(ndt_is_concrete(t));
|
357
|
+
|
358
|
+
if (t->ndim < ndim) {
|
359
|
+
return NULL;
|
360
|
+
}
|
361
|
+
if (t->ndim == ndim) {
|
362
|
+
return t;
|
363
|
+
}
|
364
|
+
|
365
|
+
switch (t->tag) {
|
366
|
+
case FixedDim: {
|
367
|
+
switch (v->tag) {
|
368
|
+
case FixedSeq:
|
369
|
+
v->FixedSeq.size = i+1;
|
370
|
+
v->FixedSeq.dims[i] = t;
|
371
|
+
break;
|
372
|
+
case BroadcastSeq:
|
373
|
+
v->BroadcastSeq.size = i+1;
|
374
|
+
v->BroadcastSeq.dims[i] = t->FixedDim.shape;
|
375
|
+
break;
|
376
|
+
default:
|
377
|
+
return NULL;
|
378
|
+
}
|
379
|
+
return outer_inner(v, i+1, t->FixedDim.type, ndim);
|
380
|
+
}
|
381
|
+
case VarDim: {
|
382
|
+
switch (v->tag) {
|
383
|
+
case VarSeq:
|
384
|
+
v->VarSeq.size = i+1;
|
385
|
+
v->VarSeq.dims[i] = t;
|
386
|
+
break;
|
387
|
+
default:
|
388
|
+
return NULL;
|
389
|
+
}
|
390
|
+
return outer_inner(v, i+1, t->VarDim.type, ndim);
|
391
|
+
}
|
392
|
+
default:
|
393
|
+
return NULL;
|
394
|
+
}
|
395
|
+
}
|
396
|
+
|
397
|
+
static int
|
398
|
+
match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
|
399
|
+
ndt_context_t *ctx)
|
400
|
+
{
|
401
|
+
int n;
|
402
|
+
|
403
|
+
if (ndt_is_optional(c) != ndt_is_optional(p)) return 0;
|
404
|
+
|
405
|
+
switch (p->tag) {
|
406
|
+
case AnyKind: {
|
407
|
+
return 1;
|
408
|
+
}
|
409
|
+
|
410
|
+
case FixedDim: {
|
411
|
+
if (c->tag != FixedDim || p->FixedDim.shape != c->FixedDim.shape) {
|
412
|
+
return 0;
|
413
|
+
}
|
414
|
+
if (p->FixedDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
|
415
|
+
return 0;
|
416
|
+
}
|
417
|
+
if (p->FixedDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
|
418
|
+
return 0;
|
419
|
+
}
|
420
|
+
|
421
|
+
return match_datashape(p->FixedDim.type, c->FixedDim.type, tbl, ctx);
|
422
|
+
}
|
423
|
+
|
424
|
+
case VarDim: {
|
425
|
+
if (c->tag != VarDim) {
|
426
|
+
return 0;
|
427
|
+
}
|
428
|
+
return match_datashape(p->VarDim.type, c->VarDim.type, tbl, ctx);
|
429
|
+
}
|
430
|
+
|
431
|
+
case SymbolicDim: {
|
432
|
+
if (c->tag != FixedDim) return 0;
|
433
|
+
|
434
|
+
if (p->SymbolicDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
|
435
|
+
return 0;
|
436
|
+
}
|
437
|
+
if (p->SymbolicDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
|
438
|
+
return 0;
|
439
|
+
}
|
440
|
+
|
441
|
+
n = resolve_shape(p->SymbolicDim.name, c->FixedDim.shape, tbl, ctx);
|
442
|
+
if (n <= 0) {
|
443
|
+
return n;
|
444
|
+
}
|
445
|
+
return match_datashape(p->SymbolicDim.type, c->FixedDim.type, tbl, ctx);
|
446
|
+
}
|
447
|
+
|
448
|
+
case EllipsisDim: {
|
449
|
+
symtable_entry_t outer;
|
450
|
+
const ndt_t *inner;
|
451
|
+
|
452
|
+
if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
|
453
|
+
return 0;
|
454
|
+
}
|
455
|
+
if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
|
456
|
+
return 0;
|
457
|
+
}
|
458
|
+
|
459
|
+
if (p->EllipsisDim.name == NULL) {
|
460
|
+
outer.tag = BroadcastSeq;
|
461
|
+
outer.BroadcastSeq.size = 0;
|
462
|
+
}
|
463
|
+
else if (strcmp(p->EllipsisDim.name, "var") == 0) {
|
464
|
+
outer.tag = VarSeq;
|
465
|
+
outer.VarSeq.size = 0;
|
466
|
+
}
|
467
|
+
else {
|
468
|
+
outer.tag = FixedSeq;
|
469
|
+
outer.FixedSeq.size = 0;
|
470
|
+
}
|
471
|
+
|
472
|
+
inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
|
473
|
+
if (inner == NULL) {
|
474
|
+
return 0;
|
475
|
+
}
|
476
|
+
|
477
|
+
n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
|
478
|
+
if (n <= 0) {
|
479
|
+
return n;
|
480
|
+
}
|
481
|
+
|
482
|
+
switch (outer.tag) {
|
483
|
+
case BroadcastSeq:
|
484
|
+
return resolve_broadcast(outer, tbl, ctx);
|
485
|
+
case FixedSeq:
|
486
|
+
return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
|
487
|
+
case VarSeq:
|
488
|
+
return resolve_var(outer, tbl, ctx);
|
489
|
+
default: /* NOT REACHED */
|
490
|
+
ndt_internal_error("invalid tag");
|
491
|
+
}
|
492
|
+
}
|
493
|
+
|
494
|
+
case Bool:
|
495
|
+
case Int8: case Int16: case Int32: case Int64:
|
496
|
+
case Uint8: case Uint16: case Uint32: case Uint64:
|
497
|
+
case Float16: case Float32: case Float64:
|
498
|
+
case Complex32: case Complex64: case Complex128:
|
499
|
+
case String:
|
500
|
+
return p->tag == c->tag;
|
501
|
+
case FixedString:
|
502
|
+
return c->tag == FixedString &&
|
503
|
+
p->FixedString.size == c->FixedString.size &&
|
504
|
+
p->FixedString.encoding == c->FixedString.encoding;
|
505
|
+
case FixedBytes:
|
506
|
+
return c->tag == FixedBytes &&
|
507
|
+
p->FixedBytes.size == c->FixedBytes.size &&
|
508
|
+
p->FixedBytes.align == c->FixedBytes.align;
|
509
|
+
case SignedKind:
|
510
|
+
return c->tag == SignedKind || ndt_is_signed(c);
|
511
|
+
case UnsignedKind:
|
512
|
+
return c->tag == UnsignedKind || ndt_is_unsigned(c);
|
513
|
+
case FloatKind:
|
514
|
+
return c->tag == FloatKind || ndt_is_float(c);
|
515
|
+
case ComplexKind:
|
516
|
+
return c->tag == ComplexKind || ndt_is_complex(c);
|
517
|
+
case FixedStringKind:
|
518
|
+
return c->tag == FixedStringKind || c->tag == FixedString;
|
519
|
+
case FixedBytesKind:
|
520
|
+
return c->tag == FixedBytesKind || c->tag == FixedBytes;
|
521
|
+
case ScalarKind:
|
522
|
+
return c->tag == ScalarKind || ndt_is_scalar(c);
|
523
|
+
case Char:
|
524
|
+
return c->tag == Char && c->Char.encoding == p->Char.encoding;
|
525
|
+
case Bytes:
|
526
|
+
return c->tag == Bytes && p->Bytes.target_align == c->Bytes.target_align;
|
527
|
+
case Categorical:
|
528
|
+
return c->tag == Categorical &&
|
529
|
+
match_categorical(p->Categorical.types, p->Categorical.ntypes,
|
530
|
+
c->Categorical.types, c->Categorical.ntypes);
|
531
|
+
case Ref:
|
532
|
+
if (c->tag != Ref) return 0;
|
533
|
+
return match_datashape(p->Ref.type, c->Ref.type, tbl, ctx);
|
534
|
+
case Tuple:
|
535
|
+
if (p->Tuple.flag == Variadic) return 0;
|
536
|
+
if (c->tag != Tuple) return 0;
|
537
|
+
return match_tuple_fields(p, c, tbl, ctx);
|
538
|
+
case Record:
|
539
|
+
if (p->Tuple.flag == Variadic) return 0;
|
540
|
+
if (c->tag != Record) return 0;
|
541
|
+
return match_record_fields(p, c, tbl, ctx);
|
542
|
+
case Function: {
|
543
|
+
int64_t i;
|
544
|
+
if (c->tag != Function ||
|
545
|
+
c->Function.nin != p->Function.nin ||
|
546
|
+
c->Function.nout != p->Function.nout ||
|
547
|
+
c->Function.nargs != p->Function.nargs) {
|
548
|
+
return 0;
|
549
|
+
}
|
550
|
+
|
551
|
+
for (i = 0; i < p->Function.nargs; i++) {
|
552
|
+
n = match_datashape(p->Function.types[i], c->Function.types[i], tbl, ctx);
|
553
|
+
if (n <= 0) return n;
|
554
|
+
}
|
555
|
+
|
556
|
+
return check_contig(p->Function.types, c->Function.types, p->Function.nargs);
|
557
|
+
}
|
558
|
+
case Typevar: {
|
559
|
+
if (c->tag == Typevar) {
|
560
|
+
symtable_entry_t entry = { .tag=Symbol, .Symbol=c->Typevar.name };
|
561
|
+
return resolve_typevar(p->Typevar.name, entry, tbl, ctx);
|
562
|
+
}
|
563
|
+
else {
|
564
|
+
symtable_entry_t entry = { .tag=Type, .Type=c };
|
565
|
+
return resolve_typevar(p->Typevar.name, entry, tbl, ctx);
|
566
|
+
}
|
567
|
+
}
|
568
|
+
case Nominal:
|
569
|
+
/* Assume that the type has been created through ndt_nominal(), in
|
570
|
+
which case the name is guaranteed to be unique and present in the
|
571
|
+
typedef table. */
|
572
|
+
return c->tag == Nominal && strcmp(p->Nominal.name, c->Nominal.name) == 0;
|
573
|
+
case Module:
|
574
|
+
return c->tag == Module && strcmp(p->Module.name, c->Module.name) == 0 &&
|
575
|
+
ndt_equal(p->Module.type, c->Module.type);
|
576
|
+
case Constr:
|
577
|
+
return c->tag == Constr && strcmp(p->Constr.name, c->Constr.name) == 0 &&
|
578
|
+
ndt_equal(p->Constr.type, c->Constr.type);
|
579
|
+
}
|
580
|
+
|
581
|
+
/* NOT REACHED: tags should be exhaustive. */
|
582
|
+
ndt_internal_error("invalid type");
|
583
|
+
}
|
584
|
+
|
585
|
+
int
|
586
|
+
ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
|
587
|
+
{
|
588
|
+
symtable_t *tbl;
|
589
|
+
int ret;
|
590
|
+
|
591
|
+
if (ndt_is_abstract(c)) {
|
592
|
+
return 0;
|
593
|
+
}
|
594
|
+
|
595
|
+
tbl = symtable_new(ctx);
|
596
|
+
if (tbl == NULL) {
|
597
|
+
return -1;
|
598
|
+
}
|
599
|
+
|
600
|
+
ret = match_datashape(p, c, tbl, ctx);
|
601
|
+
symtable_del(tbl);
|
602
|
+
return ret;
|
603
|
+
}
|
604
|
+
|
605
|
+
static ndt_t *
|
606
|
+
broadcast(const ndt_t *t, const int64_t *shape,
|
607
|
+
int outer_dims, int inner_dims,
|
608
|
+
bool use_max, ndt_context_t *ctx)
|
609
|
+
{
|
610
|
+
ndt_ndarray_t u;
|
611
|
+
const ndt_t *dtype;
|
612
|
+
ndt_t *v;
|
613
|
+
int64_t step;
|
614
|
+
int ndim;
|
615
|
+
int i, k;
|
616
|
+
|
617
|
+
ndim = ndt_as_ndarray(&u, t, ctx);
|
618
|
+
if (ndim < 0) {
|
619
|
+
return NULL;
|
620
|
+
}
|
621
|
+
|
622
|
+
dtype = ndt_dtype(t);
|
623
|
+
v = ndt_copy(dtype, ctx);
|
624
|
+
if (v == NULL) {
|
625
|
+
return NULL;
|
626
|
+
}
|
627
|
+
|
628
|
+
for (i=ndim-1; i>=ndim-inner_dims; i--) {
|
629
|
+
v = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
|
630
|
+
if (v == NULL) {
|
631
|
+
return NULL;
|
632
|
+
}
|
633
|
+
}
|
634
|
+
|
635
|
+
for (k=outer_dims-1; i>=0 && k>=0; i--, k--) {
|
636
|
+
step = u.shape[i]<=1 ? 0 : u.steps[i];
|
637
|
+
v = ndt_fixed_dim(v, shape[k], step, ctx);
|
638
|
+
if (v == NULL) {
|
639
|
+
return NULL;
|
640
|
+
}
|
641
|
+
}
|
642
|
+
|
643
|
+
for (; k>=0; k--) {
|
644
|
+
if (use_max) {
|
645
|
+
v = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
|
646
|
+
}
|
647
|
+
else {
|
648
|
+
v = ndt_fixed_dim(v, shape[k], 0, ctx);
|
649
|
+
}
|
650
|
+
if (v == NULL) {
|
651
|
+
return NULL;
|
652
|
+
}
|
653
|
+
}
|
654
|
+
|
655
|
+
return v;
|
656
|
+
}
|
657
|
+
|
658
|
+
int
|
659
|
+
ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
660
|
+
const ndt_t *in[], const int nin,
|
661
|
+
const int64_t *shape, int outer_dims,
|
662
|
+
ndt_context_t *ctx)
|
663
|
+
{
|
664
|
+
ndt_t *u;
|
665
|
+
int inner_dims;
|
666
|
+
int i;
|
667
|
+
|
668
|
+
for (i = 0; i < nin; i++) {
|
669
|
+
inner_dims = sig->Function.types[i]->ndim-1;
|
670
|
+
spec->broadcast[i] = broadcast(in[i], shape,
|
671
|
+
outer_dims, inner_dims, false, ctx);
|
672
|
+
if (spec->broadcast[i] == NULL) {
|
673
|
+
return -1;
|
674
|
+
}
|
675
|
+
spec->nbroadcast++;
|
676
|
+
}
|
677
|
+
|
678
|
+
for (i = 0; i < spec->nout; i++) {
|
679
|
+
inner_dims = sig->Function.types[nin+i]->ndim-1;
|
680
|
+
u = broadcast(spec->out[i], shape,
|
681
|
+
outer_dims, inner_dims, true, ctx);
|
682
|
+
if (u == NULL) {
|
683
|
+
return -1;
|
684
|
+
}
|
685
|
+
ndt_del(spec->out[i]);
|
686
|
+
spec->out[i] = u;
|
687
|
+
}
|
688
|
+
|
689
|
+
spec->outer_dims = outer_dims;
|
690
|
+
|
691
|
+
return 0;
|
692
|
+
}
|
693
|
+
|
694
|
+
static int
|
695
|
+
broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
|
696
|
+
const ndt_t *in[], const int nin,
|
697
|
+
const symtable_t *tbl, ndt_context_t *ctx)
|
698
|
+
{
|
699
|
+
symtable_entry_t v;
|
700
|
+
|
701
|
+
v = symtable_find(tbl, "00_ELLIPSIS");
|
702
|
+
if (v.tag != BroadcastSeq) {
|
703
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
704
|
+
"unexpected missing unnamed ellipsis entry");
|
705
|
+
return -1;
|
706
|
+
}
|
707
|
+
|
708
|
+
return ndt_broadcast_all(spec, sig, in, nin,
|
709
|
+
v.BroadcastSeq.dims, v.BroadcastSeq.size,
|
710
|
+
ctx);
|
711
|
+
}
|
712
|
+
|
713
|
+
static int
|
714
|
+
resolve_constraint(const ndt_constraint_t *c, const void *args, symtable_t *tbl,
|
715
|
+
ndt_context_t *ctx)
|
716
|
+
{
|
717
|
+
int64_t shapes[NDT_MAX_SYMBOLS];
|
718
|
+
symtable_entry_t v;
|
719
|
+
|
720
|
+
for (int i = 0; i < c->nin; i++) {
|
721
|
+
v = symtable_find(tbl, c->symbols[i]);
|
722
|
+
if (v.tag != Shape) {
|
723
|
+
ndt_err_format(ctx, NDT_ValueError, "expected dimension variable");
|
724
|
+
return -1;
|
725
|
+
}
|
726
|
+
shapes[i] = v.Shape;
|
727
|
+
}
|
728
|
+
|
729
|
+
if (c->f(shapes, args, ctx) < 0) {
|
730
|
+
return -1;
|
731
|
+
}
|
732
|
+
|
733
|
+
for (int i = 0; i < c->nout; i++) {
|
734
|
+
if (resolve_shape(c->symbols[c->nin+i], shapes[c->nin+i], tbl, ctx) < 0) {
|
735
|
+
return -1;
|
736
|
+
}
|
737
|
+
}
|
738
|
+
|
739
|
+
return 0;
|
740
|
+
}
|
741
|
+
|
742
|
+
/*
|
743
|
+
* Check the concrete function arguments 'in' against the function
|
744
|
+
* signature 'sig'. On success, infer and return the concrete return
|
745
|
+
* types and the (possibly broadcasted) 'in' types.
|
746
|
+
*/
|
747
|
+
int
|
748
|
+
ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
749
|
+
const ndt_t *in[], const int nin,
|
750
|
+
const ndt_constraint_t *c, const void *args,
|
751
|
+
ndt_context_t *ctx)
|
752
|
+
{
|
753
|
+
symtable_t *tbl;
|
754
|
+
ndt_t *t;
|
755
|
+
const char *name;
|
756
|
+
int ret;
|
757
|
+
int64_t i;
|
758
|
+
|
759
|
+
assert(spec->flags == 0);
|
760
|
+
assert(spec->nout == 0);
|
761
|
+
assert(spec->nbroadcast == 0);
|
762
|
+
assert(spec->outer_dims == 0);
|
763
|
+
|
764
|
+
if (sig->tag != Function) {
|
765
|
+
ndt_err_format(ctx, NDT_ValueError,
|
766
|
+
"signature must be a function type");
|
767
|
+
return -1;
|
768
|
+
}
|
769
|
+
|
770
|
+
if (nin != sig->Function.nin) {
|
771
|
+
ndt_err_format(ctx, NDT_ValueError,
|
772
|
+
"expected %" PRIi64 " arguments, got %d", sig->Function.nin, nin);
|
773
|
+
return -1;
|
774
|
+
}
|
775
|
+
|
776
|
+
for (i = 0; i < nin; i++) {
|
777
|
+
if (ndt_is_abstract(in[i])) {
|
778
|
+
ndt_err_format(ctx, NDT_ValueError,
|
779
|
+
"type checking requires concrete argument types");
|
780
|
+
return -1;
|
781
|
+
}
|
782
|
+
}
|
783
|
+
|
784
|
+
tbl = symtable_new(ctx);
|
785
|
+
if (tbl == NULL) {
|
786
|
+
return -1;
|
787
|
+
}
|
788
|
+
|
789
|
+
for (i = 0; i < nin; i++) {
|
790
|
+
ret = match_datashape(sig->Function.types[i], in[i], tbl, ctx);
|
791
|
+
if (ret <= 0) {
|
792
|
+
symtable_del(tbl);
|
793
|
+
|
794
|
+
if (ret == 0) {
|
795
|
+
ndt_err_format(ctx, NDT_TypeError,
|
796
|
+
"argument types do not match");
|
797
|
+
}
|
798
|
+
|
799
|
+
return -1;
|
800
|
+
}
|
801
|
+
}
|
802
|
+
|
803
|
+
if (c != NULL && resolve_constraint(c, args, tbl, ctx) < 0) {
|
804
|
+
symtable_del(tbl);
|
805
|
+
return -1;
|
806
|
+
}
|
807
|
+
|
808
|
+
for (i = 0; i < sig->Function.nout; i++) {
|
809
|
+
spec->out[i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
|
810
|
+
if (spec->out[i] == NULL) {
|
811
|
+
ndt_apply_spec_clear(spec);
|
812
|
+
symtable_del(tbl);
|
813
|
+
return -1;
|
814
|
+
}
|
815
|
+
spec->nout++;
|
816
|
+
}
|
817
|
+
|
818
|
+
if (sig->flags & NDT_ELLIPSIS) {
|
819
|
+
if (sig->Function.nargs == 0 || sig->Function.types[0]->tag != EllipsisDim) {
|
820
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
821
|
+
"unexpected configuration of ellipsis flag and function types");
|
822
|
+
ndt_apply_spec_clear(spec);
|
823
|
+
symtable_del(tbl);
|
824
|
+
return -1;
|
825
|
+
}
|
826
|
+
|
827
|
+
t = sig->Function.types[0];
|
828
|
+
name = t->EllipsisDim.name;
|
829
|
+
|
830
|
+
if (name != NULL) {
|
831
|
+
symtable_entry_t v = symtable_find(tbl, name);
|
832
|
+
switch (v.tag) {
|
833
|
+
case FixedSeq:
|
834
|
+
spec->outer_dims = v.FixedSeq.size;
|
835
|
+
break;
|
836
|
+
case VarSeq:
|
837
|
+
spec->outer_dims = v.VarSeq.size;
|
838
|
+
break;
|
839
|
+
default:
|
840
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
841
|
+
"unexpected missing dimension list entry");
|
842
|
+
ndt_apply_spec_clear(spec);
|
843
|
+
symtable_del(tbl);
|
844
|
+
return -1;
|
845
|
+
}
|
846
|
+
}
|
847
|
+
else {
|
848
|
+
if (broadcast_all(spec, sig, in, nin, tbl, ctx) < 0) {
|
849
|
+
ndt_apply_spec_clear(spec);
|
850
|
+
symtable_del(tbl);
|
851
|
+
return -1;
|
852
|
+
}
|
853
|
+
}
|
854
|
+
}
|
855
|
+
|
856
|
+
symtable_del(tbl);
|
857
|
+
|
858
|
+
for (i = 0; i < sig->Function.nout; i++) {
|
859
|
+
ndt_t *_p = sig->Function.types[nin+i];
|
860
|
+
ndt_t *_c = spec->out[i];
|
861
|
+
ndt_t *_t = to_fortran(_p, _c, ctx);
|
862
|
+
if (_t == NULL) {
|
863
|
+
ndt_apply_spec_clear(spec);
|
864
|
+
return -1;
|
865
|
+
}
|
866
|
+
if (_t != _c) {
|
867
|
+
ndt_del(_c);
|
868
|
+
}
|
869
|
+
spec->out[i] = _t;
|
870
|
+
}
|
871
|
+
|
872
|
+
if (!check_contig(sig->Function.types, (ndt_t **)in, nin)) {
|
873
|
+
ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
|
874
|
+
return -1;
|
875
|
+
}
|
876
|
+
if (!check_contig(sig->Function.types+nin, spec->out, spec->nout)) {
|
877
|
+
ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
|
878
|
+
return -1;
|
879
|
+
}
|
880
|
+
|
881
|
+
ndt_select_kernel_strategy(spec, sig, in, nin);
|
882
|
+
|
883
|
+
return 0;
|
884
|
+
}
|
885
|
+
|
886
|
+
|
887
|
+
/*****************************************************************************/
|
888
|
+
/* Optimized binary typecheck for fixed input */
|
889
|
+
/*****************************************************************************/
|
890
|
+
|
891
|
+
static ndt_t *
|
892
|
+
binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
|
893
|
+
const int64_t *shape, int size, ndt_context_t *ctx)
|
894
|
+
{
|
895
|
+
ndt_t *v;
|
896
|
+
int64_t step;
|
897
|
+
int i, k;
|
898
|
+
|
899
|
+
v = ndt_copy(dtype, ctx);
|
900
|
+
if (v == NULL) {
|
901
|
+
return NULL;
|
902
|
+
}
|
903
|
+
|
904
|
+
for (i=t->ndim-1, k=size-1; i>=0 && k>=0; i--, k--) {
|
905
|
+
step = t->shape[i]<=1 ? 0 : t->steps[i];
|
906
|
+
v = ndt_fixed_dim(v, shape[k], step, ctx);
|
907
|
+
if (v == NULL) {
|
908
|
+
return NULL;
|
909
|
+
}
|
910
|
+
}
|
911
|
+
|
912
|
+
for (; k>=0; k--) {
|
913
|
+
v = ndt_fixed_dim(v, shape[k], 0, ctx);
|
914
|
+
if (v == NULL) {
|
915
|
+
return NULL;
|
916
|
+
}
|
917
|
+
}
|
918
|
+
|
919
|
+
return v;
|
920
|
+
}
|
921
|
+
|
922
|
+
static ndt_t *
|
923
|
+
fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
|
924
|
+
ndt_context_t *ctx)
|
925
|
+
{
|
926
|
+
ndt_t *t;
|
927
|
+
int i;
|
928
|
+
|
929
|
+
for (i=len-1, t=dtype; i >= 0; i--) {
|
930
|
+
t = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
|
931
|
+
if (t == NULL) {
|
932
|
+
return NULL;
|
933
|
+
}
|
934
|
+
}
|
935
|
+
|
936
|
+
return t;
|
937
|
+
}
|
938
|
+
|
939
|
+
static bool
|
940
|
+
shape_equal(const ndt_ndarray_t *a, const ndt_ndarray_t *b)
|
941
|
+
{
|
942
|
+
if (b->ndim != a->ndim) {
|
943
|
+
return false;
|
944
|
+
}
|
945
|
+
|
946
|
+
for (int i = 0; i < a->ndim; i++) {
|
947
|
+
if (b->shape[i] != a->shape[i]) {
|
948
|
+
return false;
|
949
|
+
}
|
950
|
+
}
|
951
|
+
|
952
|
+
return true;
|
953
|
+
}
|
954
|
+
|
955
|
+
static int
|
956
|
+
_ndt_binary_broadcast(ndt_apply_spec_t *spec, const ndt_t *sig,
|
957
|
+
const ndt_ndarray_t *x, const ndt_ndarray_t *y,
|
958
|
+
const ndt_t *in[], const int nin, ndt_t *dtype,
|
959
|
+
int inner, ndt_context_t *ctx)
|
960
|
+
{
|
961
|
+
int64_t shape[NDT_MAX_DIM];
|
962
|
+
int size;
|
963
|
+
|
964
|
+
if (shape_equal(x, y)) {
|
965
|
+
spec->nout = 1;
|
966
|
+
spec->nbroadcast = 0;
|
967
|
+
spec->outer_dims = x->ndim-inner;
|
968
|
+
spec->out[0] = fixed_dim_from_shape(x->shape, x->ndim, dtype, ctx);
|
969
|
+
if (spec->out[0] == NULL) {
|
970
|
+
return -1;
|
971
|
+
}
|
972
|
+
}
|
973
|
+
else {
|
974
|
+
for (int i = 0; i < x->ndim; i++) {
|
975
|
+
shape[i] = x->shape[i];
|
976
|
+
}
|
977
|
+
|
978
|
+
size = _resolve_broadcast(shape, x->ndim, y->shape, y->ndim);
|
979
|
+
if (size < 0) {
|
980
|
+
ndt_err_format(ctx, NDT_TypeError, "broadcast error");
|
981
|
+
ndt_del(dtype);
|
982
|
+
return -1;
|
983
|
+
}
|
984
|
+
|
985
|
+
spec->nout = 1;
|
986
|
+
spec->nbroadcast = 2;
|
987
|
+
spec->outer_dims = size-inner;
|
988
|
+
|
989
|
+
spec->out[0] = fixed_dim_from_shape(shape, size, dtype, ctx);
|
990
|
+
if (spec->out[0] == NULL) {
|
991
|
+
return -1;
|
992
|
+
}
|
993
|
+
|
994
|
+
spec->broadcast[0] = binary_broadcast_1D(x, ndt_dtype(in[0]), shape, size, ctx);
|
995
|
+
if (spec->broadcast[0] == NULL) {
|
996
|
+
ndt_del(spec->out[0]);
|
997
|
+
return -1;
|
998
|
+
}
|
999
|
+
|
1000
|
+
spec->broadcast[1] = binary_broadcast_1D(y, ndt_dtype(in[1]), shape, size, ctx);
|
1001
|
+
if (spec->broadcast[1] == NULL) {
|
1002
|
+
ndt_del(spec->out[0]);
|
1003
|
+
ndt_del(spec->broadcast[0]);
|
1004
|
+
return -1;
|
1005
|
+
}
|
1006
|
+
}
|
1007
|
+
|
1008
|
+
ndt_select_kernel_strategy(spec, sig, in, nin);
|
1009
|
+
|
1010
|
+
return 0;
|
1011
|
+
}
|
1012
|
+
|
1013
|
+
static bool
|
1014
|
+
all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
|
1015
|
+
ndt_context_t *ctx)
|
1016
|
+
{
|
1017
|
+
if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
|
1018
|
+
(t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL) ||
|
1019
|
+
(t2->tag != EllipsisDim || t2->EllipsisDim.name != NULL)) {
|
1020
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1021
|
+
"fast binary typecheck expects leading ellipsis dimensions");
|
1022
|
+
return false;
|
1023
|
+
}
|
1024
|
+
|
1025
|
+
return true;
|
1026
|
+
}
|
1027
|
+
|
1028
|
+
static bool
|
1029
|
+
all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
|
1030
|
+
{
|
1031
|
+
if (t0->tag != SymbolicDim || t1->tag != SymbolicDim ||
|
1032
|
+
t2->tag != SymbolicDim) {
|
1033
|
+
return false;
|
1034
|
+
}
|
1035
|
+
|
1036
|
+
return strcmp(t0->SymbolicDim.name, t1->SymbolicDim.name) == 0 &&
|
1037
|
+
strcmp(t0->SymbolicDim.name, t2->SymbolicDim.name) == 0;
|
1038
|
+
}
|
1039
|
+
|
1040
|
+
static bool
|
1041
|
+
all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
|
1042
|
+
{
|
1043
|
+
return t0->ndim == 0 && t1->ndim == 0 && t2->ndim == 0;
|
1044
|
+
}
|
1045
|
+
|
1046
|
+
/*
|
1047
|
+
* Optimized type checking for very specific signatures. The caller must
|
1048
|
+
* have identified the kernel location, signature and the dtype. For
|
1049
|
+
* performance reasons, no substitution is performed on the dtype, so
|
1050
|
+
* the dtype must be concrete.
|
1051
|
+
*
|
1052
|
+
* Supported signatures:
|
1053
|
+
* 1) ... * N * T0, ... * N * T1 -> N * T2
|
1054
|
+
* 2) ... * T0, ... * T1 -> ... * T2
|
1055
|
+
*/
|
1056
|
+
int
|
1057
|
+
ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
|
1058
|
+
const ndt_t *in[], const int nin, ndt_t *dtype,
|
1059
|
+
ndt_context_t *ctx)
|
1060
|
+
{
|
1061
|
+
ndt_t *p0, *p1, *p2;
|
1062
|
+
ndt_ndarray_t x, y;
|
1063
|
+
|
1064
|
+
assert(spec->flags == 0);
|
1065
|
+
assert(spec->nout == 0);
|
1066
|
+
assert(spec->nbroadcast == 0);
|
1067
|
+
assert(spec->outer_dims == 0);
|
1068
|
+
|
1069
|
+
if (sig->tag != Function ||
|
1070
|
+
sig->Function.nin != 2 ||
|
1071
|
+
sig->Function.nout != 1) {
|
1072
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1073
|
+
"fast binary typecheck expects a signature with two inputs and "
|
1074
|
+
"one output");
|
1075
|
+
return -1;
|
1076
|
+
}
|
1077
|
+
|
1078
|
+
if (nin != 2) {
|
1079
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1080
|
+
"fast binary typecheck expects two input arguments");
|
1081
|
+
return -1;
|
1082
|
+
}
|
1083
|
+
|
1084
|
+
if (ndt_is_abstract(dtype)) {
|
1085
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1086
|
+
"fast binary typecheck expects a concrete dtype");
|
1087
|
+
return -1;
|
1088
|
+
}
|
1089
|
+
|
1090
|
+
p0 = sig->Function.types[0];
|
1091
|
+
p1 = sig->Function.types[1];
|
1092
|
+
p2 = sig->Function.types[2];
|
1093
|
+
|
1094
|
+
if (!all_ellipses(p0, p1, p2, ctx)) {
|
1095
|
+
return -1;
|
1096
|
+
}
|
1097
|
+
|
1098
|
+
if (ndt_as_ndarray(&x, in[0], ctx) < 0) {
|
1099
|
+
ndt_del(dtype);
|
1100
|
+
return -1;
|
1101
|
+
}
|
1102
|
+
|
1103
|
+
if (ndt_as_ndarray(&y, in[1], ctx) < 0) {
|
1104
|
+
ndt_del(dtype);
|
1105
|
+
return -1;
|
1106
|
+
}
|
1107
|
+
|
1108
|
+
p0 = p0->EllipsisDim.type;
|
1109
|
+
p1 = p1->EllipsisDim.type;
|
1110
|
+
p2 = p2->EllipsisDim.type;
|
1111
|
+
|
1112
|
+
if (all_same_symbol(p0, p1, p2)) {
|
1113
|
+
if (x.ndim > 0 && y.ndim > 0) {
|
1114
|
+
const int64_t xshape = x.shape[x.ndim-1];
|
1115
|
+
const int64_t yshape = y.shape[y.ndim-1];
|
1116
|
+
if (xshape != 1 && yshape != 1 && xshape != yshape) {
|
1117
|
+
ndt_err_format(ctx, NDT_TypeError, "mismatch in inner dimensions");
|
1118
|
+
ndt_del(dtype);
|
1119
|
+
return -1;
|
1120
|
+
}
|
1121
|
+
}
|
1122
|
+
return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 1, ctx);
|
1123
|
+
}
|
1124
|
+
else if (all_ndim0(p0, p1, p2)) {
|
1125
|
+
return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 0, ctx);
|
1126
|
+
}
|
1127
|
+
else {
|
1128
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
1129
|
+
"unsupported signature in fast binary typecheck");
|
1130
|
+
return -1;
|
1131
|
+
}
|
1132
|
+
}
|