ndtypes 0.2.0dev4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. checksums.yaml +7 -0
  2. data/CONTRIBUTING.md +50 -0
  3. data/Gemfile +2 -0
  4. data/History.md +0 -0
  5. data/README.md +19 -0
  6. data/Rakefile +125 -0
  7. data/ext/ruby_ndtypes/extconf.rb +55 -0
  8. data/ext/ruby_ndtypes/gc_guard.c +36 -0
  9. data/ext/ruby_ndtypes/gc_guard.h +12 -0
  10. data/ext/ruby_ndtypes/ndtypes/AUTHORS.txt +5 -0
  11. data/ext/ruby_ndtypes/ndtypes/INSTALL.txt +101 -0
  12. data/ext/ruby_ndtypes/ndtypes/LICENSE.txt +29 -0
  13. data/ext/ruby_ndtypes/ndtypes/MANIFEST.in +3 -0
  14. data/ext/ruby_ndtypes/ndtypes/Makefile.in +87 -0
  15. data/ext/ruby_ndtypes/ndtypes/README.rst +47 -0
  16. data/ext/ruby_ndtypes/ndtypes/config.guess +1530 -0
  17. data/ext/ruby_ndtypes/ndtypes/config.h.in +67 -0
  18. data/ext/ruby_ndtypes/ndtypes/config.sub +1782 -0
  19. data/ext/ruby_ndtypes/ndtypes/configure +5260 -0
  20. data/ext/ruby_ndtypes/ndtypes/configure.ac +161 -0
  21. data/ext/ruby_ndtypes/ndtypes/doc/Makefile +14 -0
  22. data/ext/ruby_ndtypes/ndtypes/doc/_static/copybutton.js +66 -0
  23. data/ext/ruby_ndtypes/ndtypes/doc/conf.py +26 -0
  24. data/ext/ruby_ndtypes/ndtypes/doc/grammar/grammar.rst +27 -0
  25. data/ext/ruby_ndtypes/ndtypes/doc/index.rst +56 -0
  26. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/context.rst +131 -0
  27. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/encodings.rst +68 -0
  28. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/fields-values.rst +175 -0
  29. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/functions.rst +72 -0
  30. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/index.rst +43 -0
  31. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/init.rst +48 -0
  32. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/io.rst +100 -0
  33. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/memory.rst +124 -0
  34. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/predicates.rst +110 -0
  35. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/typedef.rst +31 -0
  36. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/types.rst +594 -0
  37. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/util.rst +166 -0
  38. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/buffer-protocol.rst +27 -0
  39. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/index.rst +21 -0
  40. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/pattern-matching.rst +330 -0
  41. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/quickstart.rst +144 -0
  42. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +544 -0
  43. data/ext/ruby_ndtypes/ndtypes/doc/releases/index.rst +35 -0
  44. data/ext/ruby_ndtypes/ndtypes/install-sh +527 -0
  45. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +271 -0
  46. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +269 -0
  47. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +230 -0
  48. data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.c +268 -0
  49. data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.h +109 -0
  50. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile.in +73 -0
  51. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile.vc +70 -0
  52. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/README.txt +16 -0
  53. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +2179 -0
  54. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +134 -0
  55. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +428 -0
  56. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +2543 -0
  57. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +735 -0
  58. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +176 -0
  59. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +543 -0
  60. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +110 -0
  61. data/ext/ruby_ndtypes/ndtypes/libndtypes/context.c +228 -0
  62. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +634 -0
  63. data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.c +116 -0
  64. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +288 -0
  65. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +3067 -0
  66. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +180 -0
  67. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +417 -0
  68. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +1658 -0
  69. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +2773 -0
  70. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +734 -0
  71. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +222 -0
  72. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +1132 -0
  73. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +2323 -0
  74. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +893 -0
  75. data/ext/ruby_ndtypes/ndtypes/libndtypes/overflow.h +161 -0
  76. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +473 -0
  77. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +92 -0
  78. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +246 -0
  79. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +269 -0
  80. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +197 -0
  81. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile.in +48 -0
  82. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile.vc +46 -0
  83. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +1007 -0
  84. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +442 -0
  85. data/ext/ruby_ndtypes/ndtypes/libndtypes/slice.h +42 -0
  86. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +238 -0
  87. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +50 -0
  88. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +371 -0
  89. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +100 -0
  90. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +55 -0
  91. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +45 -0
  92. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/alloc_fail.c +82 -0
  93. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/alloc_fail.h +49 -0
  94. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +1657 -0
  95. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +85 -0
  96. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +115 -0
  97. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +137 -0
  98. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_indent.c +201 -0
  99. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +2397 -0
  100. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_numba.c +57 -0
  101. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +349 -0
  102. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +27839 -0
  103. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +350 -0
  104. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +231 -0
  105. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +375 -0
  106. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typedef.c +65 -0
  107. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/valgrind.supp +30 -0
  108. data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/bench.c +79 -0
  109. data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/indent.c +94 -0
  110. data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/print_ast.c +96 -0
  111. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +474 -0
  112. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +228 -0
  113. data/ext/ruby_ndtypes/ndtypes/python/bench.py +49 -0
  114. data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +409 -0
  115. data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +14 -0
  116. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +70 -0
  117. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +1332 -0
  118. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/docstrings.h +319 -0
  119. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +154 -0
  120. data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +1977 -0
  121. data/ext/ruby_ndtypes/ndtypes/setup.py +288 -0
  122. data/ext/ruby_ndtypes/ndtypes/vcbuild/INSTALL.txt +41 -0
  123. data/ext/ruby_ndtypes/ndtypes/vcbuild/runtest32.bat +15 -0
  124. data/ext/ruby_ndtypes/ndtypes/vcbuild/runtest64.bat +13 -0
  125. data/ext/ruby_ndtypes/ndtypes/vcbuild/vcbuild32.bat +38 -0
  126. data/ext/ruby_ndtypes/ndtypes/vcbuild/vcbuild64.bat +38 -0
  127. data/ext/ruby_ndtypes/ndtypes/vcbuild/vcclean.bat +13 -0
  128. data/ext/ruby_ndtypes/ndtypes/vcbuild/vcdistclean.bat +14 -0
  129. data/ext/ruby_ndtypes/ruby_ndtypes.c +1003 -0
  130. data/ext/ruby_ndtypes/ruby_ndtypes.h +37 -0
  131. data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +28 -0
  132. data/lib/ndtypes.rb +45 -0
  133. data/lib/ndtypes/errors.rb +2 -0
  134. data/lib/ndtypes/version.rb +6 -0
  135. data/ndtypes.gemspec +47 -0
  136. data/spec/gc_table_spec.rb +10 -0
  137. data/spec/ndtypes_spec.rb +289 -0
  138. data/spec/spec_helper.rb +241 -0
  139. metadata +242 -0
@@ -0,0 +1,222 @@
1
+ %{
2
+ /*
3
+ * BSD 3-Clause License
4
+ *
5
+ * Copyright (c) 2017-2018, plures
6
+ * All rights reserved.
7
+ *
8
+ * Redistribution and use in source and binary forms, with or without
9
+ * modification, are permitted provided that the following conditions are met:
10
+ *
11
+ * 1. Redistributions of source code must retain the above copyright notice,
12
+ * this list of conditions and the following disclaimer.
13
+ *
14
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
15
+ * this list of conditions and the following disclaimer in the documentation
16
+ * and/or other materials provided with the distribution.
17
+ *
18
+ * 3. Neither the name of the copyright holder nor the names of its
19
+ * contributors may be used to endorse or promote products derived from
20
+ * this software without specific prior written permission.
21
+ *
22
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
26
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
28
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+ */
33
+
34
+
35
+ #include <stdio.h>
36
+ #include <string.h>
37
+ #include <stdint.h>
38
+ #include <setjmp.h>
39
+ #include "ndtypes.h"
40
+ #include "parsefuncs.h"
41
+ #include "grammar.h"
42
+
43
+ /* From PostgreSQL: avoid exit() on fatal scanner errors. */
44
+ #undef fprintf
45
+ #define fprintf(file, fmt, msg) fprintf_to_longjmp(fmt, msg, yyscanner)
46
+
47
+ extern jmp_buf ndt_lexerror;
48
+ static void
49
+ fprintf_to_longjmp(const char *fmt, const char *msg, yyscan_t yyscanner)
50
+ {
51
+ (void)fmt; (void)msg; (void)yyscanner;
52
+
53
+ /* We don't have access to the parse context here: discard the error
54
+ message, which is always either an allocation failure or an internal
55
+ flex error. */
56
+ longjmp(ndt_lexerror, 1);
57
+ }
58
+
59
+ void *
60
+ yyalloc(size_t size, yyscan_t yyscanner)
61
+ {
62
+ (void)yyscanner;
63
+
64
+ return ndt_alloc(1, size);
65
+ }
66
+
67
+ void *
68
+ yyrealloc(void *ptr, size_t size, yyscan_t yyscanner)
69
+ {
70
+ (void)yyscanner;
71
+
72
+ return ndt_realloc(ptr, 1, size);
73
+ }
74
+
75
+ void
76
+ yyfree(void *ptr, yyscan_t yyscanner)
77
+ {
78
+ (void)yyscanner;
79
+
80
+ ndt_free(ptr);
81
+ }
82
+
83
+ %}
84
+
85
+ %option bison-bridge bison-locations reentrant noyywrap
86
+ %option nounput noinput noyyalloc noyyrealloc noyyfree
87
+ %option never-interactive
88
+ %option yylineno
89
+ %option 8bit
90
+ %option extra-type="ndt_context_t *"
91
+ %option warn nodefault
92
+
93
+
94
+ newline [\n\r]
95
+ space [ \t\f]
96
+ non_newline [^\n\r]
97
+ comment #{non_newline}*
98
+
99
+ escapeseq \\.
100
+ single_strchar [^\\\n']
101
+ double_strchar [^\\\n"]
102
+ single_str '({single_strchar}|{escapeseq})*'
103
+ double_str \"({double_strchar}|{escapeseq})*\"
104
+ stringlit {single_str}|{double_str}
105
+
106
+ octdigit [0-7]
107
+ octinteger 0[oO]{octdigit}+
108
+ nonzerodigit [1-9]
109
+ digit [0-9]
110
+ decimalinteger {nonzerodigit}{digit}*|0+
111
+ hexdigit {digit}|[a-f]|[A-F]
112
+ hexinteger 0[xX]{hexdigit}+
113
+ integer -?({decimalinteger}|{octinteger}|{hexinteger})
114
+
115
+ intpart {digit}+
116
+ fraction \.{digit}+
117
+ exponent [eE][+-]?{digit}+
118
+ pointfloat {intpart}?{fraction}|{intpart}\.
119
+ exponentfloat ({intpart}|{pointfloat}){exponent}
120
+ floatnumber -?({pointfloat}|{exponentfloat})
121
+
122
+ name_lower [a-z][a-zA-Z0-9_]*
123
+ name_upper [A-Z][a-zA-Z0-9_]*
124
+ name_other _[a-zA-Z0-9_]*
125
+
126
+
127
+ %%
128
+
129
+ %code {
130
+ yycolumn = 1;
131
+
132
+ #undef YY_USER_ACTION
133
+ #define YY_USER_ACTION \
134
+ yylloc->first_line = yylloc->last_line = yylineno; \
135
+ yylloc->first_column = yycolumn; \
136
+ yylloc->last_column = yycolumn+yyleng-1; \
137
+ yycolumn += yyleng;
138
+
139
+ }
140
+
141
+ "Any" { return ANY_KIND; }
142
+ "Scalar" { return SCALAR_KIND; }
143
+
144
+ "void" { return VOID; }
145
+ "bool" { return BOOL; }
146
+
147
+ "Signed" { return SIGNED_KIND; }
148
+ "int8" { return INT8; }
149
+ "int16" { return INT16; }
150
+ "int32" { return INT32; }
151
+ "int64" { return INT64; }
152
+
153
+ "Unsigned" { return UNSIGNED_KIND; }
154
+ "uint8" { return UINT8; }
155
+ "uint16" { return UINT16; }
156
+ "uint32" { return UINT32; }
157
+ "uint64" { return UINT64; }
158
+
159
+ "Float" { return FLOAT_KIND; }
160
+ "float16" { return FLOAT16; }
161
+ "float32" { return FLOAT32; }
162
+ "float64" { return FLOAT64; }
163
+
164
+ "Complex" { return COMPLEX_KIND; }
165
+ "complex32" { return COMPLEX32; }
166
+ "complex64" { return COMPLEX64; }
167
+ "complex128" { return COMPLEX128; }
168
+
169
+ "intptr" { return INTPTR; }
170
+ "uintptr" { return UINTPTR; }
171
+ "size_t" { return SIZE; }
172
+ "char" { return CHAR; }
173
+ "string" { return STRING; }
174
+ "bytes" { return BYTES; }
175
+
176
+ "FixedString" { return FIXED_STRING_KIND; }
177
+ "fixed_string" { return FIXED_STRING; }
178
+
179
+ "FixedBytes" { return FIXED_BYTES_KIND; }
180
+ "fixed_bytes" { return FIXED_BYTES; }
181
+
182
+ "categorical" { return CATEGORICAL; }
183
+ "NA" { return NA; }
184
+
185
+ "ref" { return REF; }
186
+
187
+ "fixed" { return FIXED; }
188
+ "var" { return VAR; }
189
+
190
+ "..." { return ELLIPSIS; }
191
+ "->" { return RARROW; }
192
+ "," { return COMMA; }
193
+ ":" { return COLON; }
194
+ "(" { return LPAREN; }
195
+ ")" { return RPAREN; }
196
+ "{" { return LBRACE; }
197
+ "}" { return RBRACE; }
198
+ "[" { return LBRACK; }
199
+ "]" { return RBRACK; }
200
+ "*" { return STAR; }
201
+ "=" { return EQUAL; }
202
+ "?" { return QUESTIONMARK; }
203
+ "!" { return BANG; }
204
+ "&" { return AMPERSAND; }
205
+ "|" { return BAR; }
206
+ "<" { return LESS; }
207
+ ">" { return GREATER; }
208
+
209
+ {name_lower} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_LOWER; }
210
+ {name_upper} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_UPPER; }
211
+ {name_other} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_OTHER; }
212
+
213
+ {stringlit} { yylval->string = mk_stringlit(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return STRINGLIT; }
214
+ {integer} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return INTEGER; }
215
+ {floatnumber} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return FLOATNUMBER; }
216
+
217
+ {newline} { yycolumn = 1; }
218
+ {space} {} /* ignore */
219
+ {comment} {} /* ignore */
220
+ . { return ERRTOKEN; }
221
+
222
+ %%
@@ -0,0 +1,1132 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdio.h>
35
+ #include <stdlib.h>
36
+ #include <stdint.h>
37
+ #include <inttypes.h>
38
+ #include <stdbool.h>
39
+ #include <string.h>
40
+ #include <stdarg.h>
41
+ #include <assert.h>
42
+ #include "ndtypes.h"
43
+ #include "symtable.h"
44
+ #include "substitute.h"
45
+
46
+
47
+ static int match_datashape(const ndt_t *, const ndt_t *, symtable_t *, ndt_context_t *);
48
+
49
+ static int
50
+ _resolve_broadcast(int64_t vshape[NDT_MAX_DIM], int vsize,
51
+ const int64_t wshape[NDT_MAX_DIM], int wsize)
52
+ {
53
+ int64_t n, m;
54
+ int i, k;
55
+
56
+ for (i=vsize-1, k=wsize-1; i>=0 && k>=0; i--, k--) {
57
+ n = vshape[i];
58
+ m = wshape[k];
59
+ if (m != n) {
60
+ if (n == 1) {
61
+ n = m;
62
+ }
63
+ else if (m == 0) {
64
+ n = 0;
65
+ }
66
+ else if (m != 1) {
67
+ return -1;
68
+ }
69
+ }
70
+ vshape[i<k ? k : i] = n;
71
+ }
72
+ for (; k >= 0; k--) {
73
+ vshape[k] = wshape[k];
74
+ }
75
+
76
+ return vsize >= wsize ? vsize : wsize;
77
+ }
78
+
79
+ static int
80
+ resolve_broadcast(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
81
+ {
82
+ const char *key = "00_ELLIPSIS";
83
+ symtable_entry_t *v;
84
+ int vsize;
85
+
86
+ v = symtable_find_ptr(tbl, key);
87
+ if (v == NULL) {
88
+ if (symtable_add(tbl, key, w, ctx) < 0) {
89
+ return -1;
90
+ }
91
+ return 1;
92
+ }
93
+
94
+ vsize = _resolve_broadcast(v->BroadcastSeq.dims, v->BroadcastSeq.size,
95
+ w.BroadcastSeq.dims, w.BroadcastSeq.size);
96
+ if (vsize < 0) {
97
+ ndt_err_format(ctx, NDT_TypeError, "broadcast error");
98
+ return -1;
99
+ }
100
+ v->BroadcastSeq.size = vsize;
101
+
102
+ return 1;
103
+ }
104
+
105
+ static int
106
+ check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
107
+ {
108
+ for (int i = 0; i < nargs; i++) {
109
+ const ndt_t *p = ptypes[i];
110
+ const ndt_t *c = ctypes[i];
111
+
112
+ if (p->tag == EllipsisDim) {
113
+ switch (p->EllipsisDim.tag) {
114
+ case RequireNA:
115
+ break;
116
+ case RequireC:
117
+ if (!ndt_is_c_contiguous(c)) {
118
+ return 0;
119
+ }
120
+ break;
121
+ case RequireF:
122
+ if (!ndt_is_f_contiguous(c)) {
123
+ return 0;
124
+ }
125
+ break;
126
+ }
127
+ }
128
+ }
129
+
130
+ return 1;
131
+ }
132
+
133
+ static ndt_t *
134
+ to_fortran(const ndt_t *p, ndt_t *c, ndt_context_t *ctx)
135
+ {
136
+ if (p->tag == EllipsisDim && p->EllipsisDim.tag == RequireF) {
137
+ ndt_t *t = ndt_to_fortran(c, ctx);
138
+ return t;
139
+ }
140
+ else {
141
+ return c;
142
+ }
143
+ }
144
+
145
+ static int
146
+ resolve_fixed(const char *key, symtable_entry_t w,
147
+ symtable_t *tbl, ndt_context_t *ctx)
148
+ {
149
+ symtable_entry_t v;
150
+
151
+ v = symtable_find(tbl, key);
152
+ if (v.tag == Unbound) {
153
+ if (symtable_add(tbl, key, w, ctx) < 0) {
154
+ return -1;
155
+ }
156
+ return 1;
157
+ }
158
+
159
+ if (w.FixedSeq.size != v.FixedSeq.size) {
160
+ return 0;
161
+ }
162
+
163
+ for (int i = 0; i < v.FixedSeq.size; i++) {
164
+ const ndt_t *t = v.FixedSeq.dims[i];
165
+ const ndt_t *u = w.FixedSeq.dims[i];
166
+ if (u->FixedDim.shape != t->FixedDim.shape) {
167
+ return 0;
168
+ }
169
+ }
170
+
171
+ return 1;
172
+ }
173
+
174
+ static int
175
+ resolve_shape(const char *key, int64_t shape, symtable_t *tbl, ndt_context_t *ctx)
176
+ {
177
+ symtable_entry_t v;
178
+
179
+ v = symtable_find(tbl, key);
180
+ if (v.tag == Unbound) {
181
+ v.tag = Shape;
182
+ v.Shape = shape;
183
+ if (symtable_add(tbl, key, v, ctx) < 0) {
184
+ return -1;
185
+ }
186
+ return 1;
187
+ }
188
+
189
+ if (v.tag != Shape) {
190
+ return 0;
191
+ }
192
+
193
+ return shape == v.Shape;
194
+ }
195
+
196
+ static int
197
+ resolve_typevar(const char *key, symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
198
+ {
199
+ symtable_entry_t v;
200
+
201
+ v = symtable_find(tbl, key);
202
+ if (v.tag == Unbound) {
203
+ if (symtable_add(tbl, key, w, ctx) < 0) {
204
+ return -1;
205
+ }
206
+ return 1;
207
+ }
208
+
209
+ if (v.tag == Symbol && w.tag == Symbol) {
210
+ return strcmp(v.Symbol, w.Symbol) == 0;
211
+ }
212
+ else if (v.tag == Type && w.tag == Type) {
213
+ return ndt_equal(v.Type, w.Type);
214
+ }
215
+ else {
216
+ return 0;
217
+ }
218
+ }
219
+
220
+ static int
221
+ match_concrete_var_dim(const ndt_t *t, int64_t tindex,
222
+ const ndt_t *u, int64_t uindex,
223
+ const int outer_dims, ndt_context_t *ctx)
224
+ {
225
+ int64_t tshape, tstart, tstep;
226
+ int64_t ushape, ustart, ustep;
227
+
228
+ if (outer_dims == 0) {
229
+ return 1;
230
+ }
231
+ if (t->Concrete.VarDim.itemsize != u->Concrete.VarDim.itemsize) {
232
+ return 0;
233
+ }
234
+
235
+ tshape = ndt_var_indices(&tstart, &tstep, t, tindex, ctx);
236
+ if (tshape < 0) {
237
+ return -1;
238
+ }
239
+
240
+ ushape = ndt_var_indices(&ustart, &ustep, u, uindex, ctx);
241
+ if (ushape < 0) {
242
+ return -1;
243
+ }
244
+
245
+ if (ushape != tshape) {
246
+ return 0;
247
+ }
248
+
249
+ for (int64_t i = 0; i < tshape; i++) {
250
+ int64_t tnext = tstart + i * tstep;
251
+ int64_t unext = ustart + i * ustep;
252
+ int ret = match_concrete_var_dim(t->VarDim.type, tnext,
253
+ u->VarDim.type, unext,
254
+ outer_dims-1, ctx);
255
+ if (ret <= 0) {
256
+ return ret;
257
+ }
258
+ }
259
+
260
+ return 1;
261
+ }
262
+
263
+ static int
264
+ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
265
+ {
266
+ const char *key = "var";
267
+ symtable_entry_t v;
268
+
269
+ v = symtable_find(tbl, key);
270
+ if (v.tag == Unbound) {
271
+ if (symtable_add(tbl, key, w, ctx) < 0) {
272
+ return -1;
273
+ }
274
+ return 1;
275
+ }
276
+
277
+ if (w.VarSeq.size != v.VarSeq.size) {
278
+ return 0;
279
+ }
280
+ if (v.VarSeq.size == 0) {
281
+ return 1;
282
+ }
283
+
284
+ return match_concrete_var_dim(w.VarSeq.dims[0], 0,
285
+ v.VarSeq.dims[0], 0,
286
+ v.VarSeq.size, ctx);
287
+ }
288
+
289
+ static int
290
+ match_tuple_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
291
+ ndt_context_t *ctx)
292
+ {
293
+ int64_t i;
294
+ int n;
295
+
296
+ assert(p->tag == Tuple && c->tag == Tuple);
297
+
298
+ if (p->Tuple.shape != c->Tuple.shape) {
299
+ return 0;
300
+ }
301
+
302
+ for (i = 0; i < p->Tuple.shape; i++) {
303
+ n = match_datashape(p->Tuple.types[i], c->Tuple.types[i], tbl, ctx);
304
+ if (n <= 0) return n;
305
+ }
306
+
307
+ return 1;
308
+ }
309
+
310
+ static int
311
+ match_record_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
312
+ ndt_context_t *ctx)
313
+ {
314
+ int64_t i;
315
+ int n;
316
+
317
+ assert(p->tag == Record && c->tag == Record);
318
+
319
+ if (p->Record.shape != c->Record.shape) {
320
+ return 0;
321
+ }
322
+
323
+ for (i = 0; i < p->Record.shape; i++) {
324
+ n = strcmp(p->Record.names[i], c->Record.names[i]);
325
+ if (n != 0) return 0;
326
+
327
+ n = match_datashape(p->Record.types[i], c->Record.types[i], tbl, ctx);
328
+ if (n <= 0) return n;
329
+ }
330
+
331
+ return 1;
332
+ }
333
+
334
+ static int
335
+ match_categorical(ndt_value_t *p, int64_t plen,
336
+ ndt_value_t *c, int64_t clen)
337
+ {
338
+ int64_t i;
339
+
340
+ if (plen != clen) {
341
+ return 0;
342
+ }
343
+
344
+ for (i = 0; i < plen; i++) {
345
+ if (!ndt_value_equal(&p[i], &c[i])) {
346
+ return 0;
347
+ }
348
+ }
349
+
350
+ return 1;
351
+ }
352
+
353
+ static const ndt_t *
354
+ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
355
+ {
356
+ assert(ndt_is_concrete(t));
357
+
358
+ if (t->ndim < ndim) {
359
+ return NULL;
360
+ }
361
+ if (t->ndim == ndim) {
362
+ return t;
363
+ }
364
+
365
+ switch (t->tag) {
366
+ case FixedDim: {
367
+ switch (v->tag) {
368
+ case FixedSeq:
369
+ v->FixedSeq.size = i+1;
370
+ v->FixedSeq.dims[i] = t;
371
+ break;
372
+ case BroadcastSeq:
373
+ v->BroadcastSeq.size = i+1;
374
+ v->BroadcastSeq.dims[i] = t->FixedDim.shape;
375
+ break;
376
+ default:
377
+ return NULL;
378
+ }
379
+ return outer_inner(v, i+1, t->FixedDim.type, ndim);
380
+ }
381
+ case VarDim: {
382
+ switch (v->tag) {
383
+ case VarSeq:
384
+ v->VarSeq.size = i+1;
385
+ v->VarSeq.dims[i] = t;
386
+ break;
387
+ default:
388
+ return NULL;
389
+ }
390
+ return outer_inner(v, i+1, t->VarDim.type, ndim);
391
+ }
392
+ default:
393
+ return NULL;
394
+ }
395
+ }
396
+
397
+ static int
398
+ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
399
+ ndt_context_t *ctx)
400
+ {
401
+ int n;
402
+
403
+ if (ndt_is_optional(c) != ndt_is_optional(p)) return 0;
404
+
405
+ switch (p->tag) {
406
+ case AnyKind: {
407
+ return 1;
408
+ }
409
+
410
+ case FixedDim: {
411
+ if (c->tag != FixedDim || p->FixedDim.shape != c->FixedDim.shape) {
412
+ return 0;
413
+ }
414
+ if (p->FixedDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
415
+ return 0;
416
+ }
417
+ if (p->FixedDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
418
+ return 0;
419
+ }
420
+
421
+ return match_datashape(p->FixedDim.type, c->FixedDim.type, tbl, ctx);
422
+ }
423
+
424
+ case VarDim: {
425
+ if (c->tag != VarDim) {
426
+ return 0;
427
+ }
428
+ return match_datashape(p->VarDim.type, c->VarDim.type, tbl, ctx);
429
+ }
430
+
431
+ case SymbolicDim: {
432
+ if (c->tag != FixedDim) return 0;
433
+
434
+ if (p->SymbolicDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
435
+ return 0;
436
+ }
437
+ if (p->SymbolicDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
438
+ return 0;
439
+ }
440
+
441
+ n = resolve_shape(p->SymbolicDim.name, c->FixedDim.shape, tbl, ctx);
442
+ if (n <= 0) {
443
+ return n;
444
+ }
445
+ return match_datashape(p->SymbolicDim.type, c->FixedDim.type, tbl, ctx);
446
+ }
447
+
448
+ case EllipsisDim: {
449
+ symtable_entry_t outer;
450
+ const ndt_t *inner;
451
+
452
+ if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
453
+ return 0;
454
+ }
455
+ if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
456
+ return 0;
457
+ }
458
+
459
+ if (p->EllipsisDim.name == NULL) {
460
+ outer.tag = BroadcastSeq;
461
+ outer.BroadcastSeq.size = 0;
462
+ }
463
+ else if (strcmp(p->EllipsisDim.name, "var") == 0) {
464
+ outer.tag = VarSeq;
465
+ outer.VarSeq.size = 0;
466
+ }
467
+ else {
468
+ outer.tag = FixedSeq;
469
+ outer.FixedSeq.size = 0;
470
+ }
471
+
472
+ inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
473
+ if (inner == NULL) {
474
+ return 0;
475
+ }
476
+
477
+ n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
478
+ if (n <= 0) {
479
+ return n;
480
+ }
481
+
482
+ switch (outer.tag) {
483
+ case BroadcastSeq:
484
+ return resolve_broadcast(outer, tbl, ctx);
485
+ case FixedSeq:
486
+ return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
487
+ case VarSeq:
488
+ return resolve_var(outer, tbl, ctx);
489
+ default: /* NOT REACHED */
490
+ ndt_internal_error("invalid tag");
491
+ }
492
+ }
493
+
494
+ case Bool:
495
+ case Int8: case Int16: case Int32: case Int64:
496
+ case Uint8: case Uint16: case Uint32: case Uint64:
497
+ case Float16: case Float32: case Float64:
498
+ case Complex32: case Complex64: case Complex128:
499
+ case String:
500
+ return p->tag == c->tag;
501
+ case FixedString:
502
+ return c->tag == FixedString &&
503
+ p->FixedString.size == c->FixedString.size &&
504
+ p->FixedString.encoding == c->FixedString.encoding;
505
+ case FixedBytes:
506
+ return c->tag == FixedBytes &&
507
+ p->FixedBytes.size == c->FixedBytes.size &&
508
+ p->FixedBytes.align == c->FixedBytes.align;
509
+ case SignedKind:
510
+ return c->tag == SignedKind || ndt_is_signed(c);
511
+ case UnsignedKind:
512
+ return c->tag == UnsignedKind || ndt_is_unsigned(c);
513
+ case FloatKind:
514
+ return c->tag == FloatKind || ndt_is_float(c);
515
+ case ComplexKind:
516
+ return c->tag == ComplexKind || ndt_is_complex(c);
517
+ case FixedStringKind:
518
+ return c->tag == FixedStringKind || c->tag == FixedString;
519
+ case FixedBytesKind:
520
+ return c->tag == FixedBytesKind || c->tag == FixedBytes;
521
+ case ScalarKind:
522
+ return c->tag == ScalarKind || ndt_is_scalar(c);
523
+ case Char:
524
+ return c->tag == Char && c->Char.encoding == p->Char.encoding;
525
+ case Bytes:
526
+ return c->tag == Bytes && p->Bytes.target_align == c->Bytes.target_align;
527
+ case Categorical:
528
+ return c->tag == Categorical &&
529
+ match_categorical(p->Categorical.types, p->Categorical.ntypes,
530
+ c->Categorical.types, c->Categorical.ntypes);
531
+ case Ref:
532
+ if (c->tag != Ref) return 0;
533
+ return match_datashape(p->Ref.type, c->Ref.type, tbl, ctx);
534
+ case Tuple:
535
+ if (p->Tuple.flag == Variadic) return 0;
536
+ if (c->tag != Tuple) return 0;
537
+ return match_tuple_fields(p, c, tbl, ctx);
538
+ case Record:
539
+ if (p->Tuple.flag == Variadic) return 0;
540
+ if (c->tag != Record) return 0;
541
+ return match_record_fields(p, c, tbl, ctx);
542
+ case Function: {
543
+ int64_t i;
544
+ if (c->tag != Function ||
545
+ c->Function.nin != p->Function.nin ||
546
+ c->Function.nout != p->Function.nout ||
547
+ c->Function.nargs != p->Function.nargs) {
548
+ return 0;
549
+ }
550
+
551
+ for (i = 0; i < p->Function.nargs; i++) {
552
+ n = match_datashape(p->Function.types[i], c->Function.types[i], tbl, ctx);
553
+ if (n <= 0) return n;
554
+ }
555
+
556
+ return check_contig(p->Function.types, c->Function.types, p->Function.nargs);
557
+ }
558
+ case Typevar: {
559
+ if (c->tag == Typevar) {
560
+ symtable_entry_t entry = { .tag=Symbol, .Symbol=c->Typevar.name };
561
+ return resolve_typevar(p->Typevar.name, entry, tbl, ctx);
562
+ }
563
+ else {
564
+ symtable_entry_t entry = { .tag=Type, .Type=c };
565
+ return resolve_typevar(p->Typevar.name, entry, tbl, ctx);
566
+ }
567
+ }
568
+ case Nominal:
569
+ /* Assume that the type has been created through ndt_nominal(), in
570
+ which case the name is guaranteed to be unique and present in the
571
+ typedef table. */
572
+ return c->tag == Nominal && strcmp(p->Nominal.name, c->Nominal.name) == 0;
573
+ case Module:
574
+ return c->tag == Module && strcmp(p->Module.name, c->Module.name) == 0 &&
575
+ ndt_equal(p->Module.type, c->Module.type);
576
+ case Constr:
577
+ return c->tag == Constr && strcmp(p->Constr.name, c->Constr.name) == 0 &&
578
+ ndt_equal(p->Constr.type, c->Constr.type);
579
+ }
580
+
581
+ /* NOT REACHED: tags should be exhaustive. */
582
+ ndt_internal_error("invalid type");
583
+ }
584
+
585
+ int
586
+ ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
587
+ {
588
+ symtable_t *tbl;
589
+ int ret;
590
+
591
+ if (ndt_is_abstract(c)) {
592
+ return 0;
593
+ }
594
+
595
+ tbl = symtable_new(ctx);
596
+ if (tbl == NULL) {
597
+ return -1;
598
+ }
599
+
600
+ ret = match_datashape(p, c, tbl, ctx);
601
+ symtable_del(tbl);
602
+ return ret;
603
+ }
604
+
605
+ static ndt_t *
606
+ broadcast(const ndt_t *t, const int64_t *shape,
607
+ int outer_dims, int inner_dims,
608
+ bool use_max, ndt_context_t *ctx)
609
+ {
610
+ ndt_ndarray_t u;
611
+ const ndt_t *dtype;
612
+ ndt_t *v;
613
+ int64_t step;
614
+ int ndim;
615
+ int i, k;
616
+
617
+ ndim = ndt_as_ndarray(&u, t, ctx);
618
+ if (ndim < 0) {
619
+ return NULL;
620
+ }
621
+
622
+ dtype = ndt_dtype(t);
623
+ v = ndt_copy(dtype, ctx);
624
+ if (v == NULL) {
625
+ return NULL;
626
+ }
627
+
628
+ for (i=ndim-1; i>=ndim-inner_dims; i--) {
629
+ v = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
630
+ if (v == NULL) {
631
+ return NULL;
632
+ }
633
+ }
634
+
635
+ for (k=outer_dims-1; i>=0 && k>=0; i--, k--) {
636
+ step = u.shape[i]<=1 ? 0 : u.steps[i];
637
+ v = ndt_fixed_dim(v, shape[k], step, ctx);
638
+ if (v == NULL) {
639
+ return NULL;
640
+ }
641
+ }
642
+
643
+ for (; k>=0; k--) {
644
+ if (use_max) {
645
+ v = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
646
+ }
647
+ else {
648
+ v = ndt_fixed_dim(v, shape[k], 0, ctx);
649
+ }
650
+ if (v == NULL) {
651
+ return NULL;
652
+ }
653
+ }
654
+
655
+ return v;
656
+ }
657
+
658
+ int
659
+ ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
660
+ const ndt_t *in[], const int nin,
661
+ const int64_t *shape, int outer_dims,
662
+ ndt_context_t *ctx)
663
+ {
664
+ ndt_t *u;
665
+ int inner_dims;
666
+ int i;
667
+
668
+ for (i = 0; i < nin; i++) {
669
+ inner_dims = sig->Function.types[i]->ndim-1;
670
+ spec->broadcast[i] = broadcast(in[i], shape,
671
+ outer_dims, inner_dims, false, ctx);
672
+ if (spec->broadcast[i] == NULL) {
673
+ return -1;
674
+ }
675
+ spec->nbroadcast++;
676
+ }
677
+
678
+ for (i = 0; i < spec->nout; i++) {
679
+ inner_dims = sig->Function.types[nin+i]->ndim-1;
680
+ u = broadcast(spec->out[i], shape,
681
+ outer_dims, inner_dims, true, ctx);
682
+ if (u == NULL) {
683
+ return -1;
684
+ }
685
+ ndt_del(spec->out[i]);
686
+ spec->out[i] = u;
687
+ }
688
+
689
+ spec->outer_dims = outer_dims;
690
+
691
+ return 0;
692
+ }
693
+
694
+ static int
695
+ broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
696
+ const ndt_t *in[], const int nin,
697
+ const symtable_t *tbl, ndt_context_t *ctx)
698
+ {
699
+ symtable_entry_t v;
700
+
701
+ v = symtable_find(tbl, "00_ELLIPSIS");
702
+ if (v.tag != BroadcastSeq) {
703
+ ndt_err_format(ctx, NDT_RuntimeError,
704
+ "unexpected missing unnamed ellipsis entry");
705
+ return -1;
706
+ }
707
+
708
+ return ndt_broadcast_all(spec, sig, in, nin,
709
+ v.BroadcastSeq.dims, v.BroadcastSeq.size,
710
+ ctx);
711
+ }
712
+
713
+ static int
714
+ resolve_constraint(const ndt_constraint_t *c, const void *args, symtable_t *tbl,
715
+ ndt_context_t *ctx)
716
+ {
717
+ int64_t shapes[NDT_MAX_SYMBOLS];
718
+ symtable_entry_t v;
719
+
720
+ for (int i = 0; i < c->nin; i++) {
721
+ v = symtable_find(tbl, c->symbols[i]);
722
+ if (v.tag != Shape) {
723
+ ndt_err_format(ctx, NDT_ValueError, "expected dimension variable");
724
+ return -1;
725
+ }
726
+ shapes[i] = v.Shape;
727
+ }
728
+
729
+ if (c->f(shapes, args, ctx) < 0) {
730
+ return -1;
731
+ }
732
+
733
+ for (int i = 0; i < c->nout; i++) {
734
+ if (resolve_shape(c->symbols[c->nin+i], shapes[c->nin+i], tbl, ctx) < 0) {
735
+ return -1;
736
+ }
737
+ }
738
+
739
+ return 0;
740
+ }
741
+
742
+ /*
743
+ * Check the concrete function arguments 'in' against the function
744
+ * signature 'sig'. On success, infer and return the concrete return
745
+ * types and the (possibly broadcasted) 'in' types.
746
+ */
747
+ int
748
+ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
749
+ const ndt_t *in[], const int nin,
750
+ const ndt_constraint_t *c, const void *args,
751
+ ndt_context_t *ctx)
752
+ {
753
+ symtable_t *tbl;
754
+ ndt_t *t;
755
+ const char *name;
756
+ int ret;
757
+ int64_t i;
758
+
759
+ assert(spec->flags == 0);
760
+ assert(spec->nout == 0);
761
+ assert(spec->nbroadcast == 0);
762
+ assert(spec->outer_dims == 0);
763
+
764
+ if (sig->tag != Function) {
765
+ ndt_err_format(ctx, NDT_ValueError,
766
+ "signature must be a function type");
767
+ return -1;
768
+ }
769
+
770
+ if (nin != sig->Function.nin) {
771
+ ndt_err_format(ctx, NDT_ValueError,
772
+ "expected %" PRIi64 " arguments, got %d", sig->Function.nin, nin);
773
+ return -1;
774
+ }
775
+
776
+ for (i = 0; i < nin; i++) {
777
+ if (ndt_is_abstract(in[i])) {
778
+ ndt_err_format(ctx, NDT_ValueError,
779
+ "type checking requires concrete argument types");
780
+ return -1;
781
+ }
782
+ }
783
+
784
+ tbl = symtable_new(ctx);
785
+ if (tbl == NULL) {
786
+ return -1;
787
+ }
788
+
789
+ for (i = 0; i < nin; i++) {
790
+ ret = match_datashape(sig->Function.types[i], in[i], tbl, ctx);
791
+ if (ret <= 0) {
792
+ symtable_del(tbl);
793
+
794
+ if (ret == 0) {
795
+ ndt_err_format(ctx, NDT_TypeError,
796
+ "argument types do not match");
797
+ }
798
+
799
+ return -1;
800
+ }
801
+ }
802
+
803
+ if (c != NULL && resolve_constraint(c, args, tbl, ctx) < 0) {
804
+ symtable_del(tbl);
805
+ return -1;
806
+ }
807
+
808
+ for (i = 0; i < sig->Function.nout; i++) {
809
+ spec->out[i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
810
+ if (spec->out[i] == NULL) {
811
+ ndt_apply_spec_clear(spec);
812
+ symtable_del(tbl);
813
+ return -1;
814
+ }
815
+ spec->nout++;
816
+ }
817
+
818
+ if (sig->flags & NDT_ELLIPSIS) {
819
+ if (sig->Function.nargs == 0 || sig->Function.types[0]->tag != EllipsisDim) {
820
+ ndt_err_format(ctx, NDT_RuntimeError,
821
+ "unexpected configuration of ellipsis flag and function types");
822
+ ndt_apply_spec_clear(spec);
823
+ symtable_del(tbl);
824
+ return -1;
825
+ }
826
+
827
+ t = sig->Function.types[0];
828
+ name = t->EllipsisDim.name;
829
+
830
+ if (name != NULL) {
831
+ symtable_entry_t v = symtable_find(tbl, name);
832
+ switch (v.tag) {
833
+ case FixedSeq:
834
+ spec->outer_dims = v.FixedSeq.size;
835
+ break;
836
+ case VarSeq:
837
+ spec->outer_dims = v.VarSeq.size;
838
+ break;
839
+ default:
840
+ ndt_err_format(ctx, NDT_RuntimeError,
841
+ "unexpected missing dimension list entry");
842
+ ndt_apply_spec_clear(spec);
843
+ symtable_del(tbl);
844
+ return -1;
845
+ }
846
+ }
847
+ else {
848
+ if (broadcast_all(spec, sig, in, nin, tbl, ctx) < 0) {
849
+ ndt_apply_spec_clear(spec);
850
+ symtable_del(tbl);
851
+ return -1;
852
+ }
853
+ }
854
+ }
855
+
856
+ symtable_del(tbl);
857
+
858
+ for (i = 0; i < sig->Function.nout; i++) {
859
+ ndt_t *_p = sig->Function.types[nin+i];
860
+ ndt_t *_c = spec->out[i];
861
+ ndt_t *_t = to_fortran(_p, _c, ctx);
862
+ if (_t == NULL) {
863
+ ndt_apply_spec_clear(spec);
864
+ return -1;
865
+ }
866
+ if (_t != _c) {
867
+ ndt_del(_c);
868
+ }
869
+ spec->out[i] = _t;
870
+ }
871
+
872
+ if (!check_contig(sig->Function.types, (ndt_t **)in, nin)) {
873
+ ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
874
+ return -1;
875
+ }
876
+ if (!check_contig(sig->Function.types+nin, spec->out, spec->nout)) {
877
+ ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
878
+ return -1;
879
+ }
880
+
881
+ ndt_select_kernel_strategy(spec, sig, in, nin);
882
+
883
+ return 0;
884
+ }
885
+
886
+
887
+ /*****************************************************************************/
888
+ /* Optimized binary typecheck for fixed input */
889
+ /*****************************************************************************/
890
+
891
+ static ndt_t *
892
+ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
893
+ const int64_t *shape, int size, ndt_context_t *ctx)
894
+ {
895
+ ndt_t *v;
896
+ int64_t step;
897
+ int i, k;
898
+
899
+ v = ndt_copy(dtype, ctx);
900
+ if (v == NULL) {
901
+ return NULL;
902
+ }
903
+
904
+ for (i=t->ndim-1, k=size-1; i>=0 && k>=0; i--, k--) {
905
+ step = t->shape[i]<=1 ? 0 : t->steps[i];
906
+ v = ndt_fixed_dim(v, shape[k], step, ctx);
907
+ if (v == NULL) {
908
+ return NULL;
909
+ }
910
+ }
911
+
912
+ for (; k>=0; k--) {
913
+ v = ndt_fixed_dim(v, shape[k], 0, ctx);
914
+ if (v == NULL) {
915
+ return NULL;
916
+ }
917
+ }
918
+
919
+ return v;
920
+ }
921
+
922
+ static ndt_t *
923
+ fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
924
+ ndt_context_t *ctx)
925
+ {
926
+ ndt_t *t;
927
+ int i;
928
+
929
+ for (i=len-1, t=dtype; i >= 0; i--) {
930
+ t = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
931
+ if (t == NULL) {
932
+ return NULL;
933
+ }
934
+ }
935
+
936
+ return t;
937
+ }
938
+
939
+ static bool
940
+ shape_equal(const ndt_ndarray_t *a, const ndt_ndarray_t *b)
941
+ {
942
+ if (b->ndim != a->ndim) {
943
+ return false;
944
+ }
945
+
946
+ for (int i = 0; i < a->ndim; i++) {
947
+ if (b->shape[i] != a->shape[i]) {
948
+ return false;
949
+ }
950
+ }
951
+
952
+ return true;
953
+ }
954
+
955
+ static int
956
+ _ndt_binary_broadcast(ndt_apply_spec_t *spec, const ndt_t *sig,
957
+ const ndt_ndarray_t *x, const ndt_ndarray_t *y,
958
+ const ndt_t *in[], const int nin, ndt_t *dtype,
959
+ int inner, ndt_context_t *ctx)
960
+ {
961
+ int64_t shape[NDT_MAX_DIM];
962
+ int size;
963
+
964
+ if (shape_equal(x, y)) {
965
+ spec->nout = 1;
966
+ spec->nbroadcast = 0;
967
+ spec->outer_dims = x->ndim-inner;
968
+ spec->out[0] = fixed_dim_from_shape(x->shape, x->ndim, dtype, ctx);
969
+ if (spec->out[0] == NULL) {
970
+ return -1;
971
+ }
972
+ }
973
+ else {
974
+ for (int i = 0; i < x->ndim; i++) {
975
+ shape[i] = x->shape[i];
976
+ }
977
+
978
+ size = _resolve_broadcast(shape, x->ndim, y->shape, y->ndim);
979
+ if (size < 0) {
980
+ ndt_err_format(ctx, NDT_TypeError, "broadcast error");
981
+ ndt_del(dtype);
982
+ return -1;
983
+ }
984
+
985
+ spec->nout = 1;
986
+ spec->nbroadcast = 2;
987
+ spec->outer_dims = size-inner;
988
+
989
+ spec->out[0] = fixed_dim_from_shape(shape, size, dtype, ctx);
990
+ if (spec->out[0] == NULL) {
991
+ return -1;
992
+ }
993
+
994
+ spec->broadcast[0] = binary_broadcast_1D(x, ndt_dtype(in[0]), shape, size, ctx);
995
+ if (spec->broadcast[0] == NULL) {
996
+ ndt_del(spec->out[0]);
997
+ return -1;
998
+ }
999
+
1000
+ spec->broadcast[1] = binary_broadcast_1D(y, ndt_dtype(in[1]), shape, size, ctx);
1001
+ if (spec->broadcast[1] == NULL) {
1002
+ ndt_del(spec->out[0]);
1003
+ ndt_del(spec->broadcast[0]);
1004
+ return -1;
1005
+ }
1006
+ }
1007
+
1008
+ ndt_select_kernel_strategy(spec, sig, in, nin);
1009
+
1010
+ return 0;
1011
+ }
1012
+
1013
+ static bool
1014
+ all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
1015
+ ndt_context_t *ctx)
1016
+ {
1017
+ if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
1018
+ (t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL) ||
1019
+ (t2->tag != EllipsisDim || t2->EllipsisDim.name != NULL)) {
1020
+ ndt_err_format(ctx, NDT_RuntimeError,
1021
+ "fast binary typecheck expects leading ellipsis dimensions");
1022
+ return false;
1023
+ }
1024
+
1025
+ return true;
1026
+ }
1027
+
1028
+ static bool
1029
+ all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1030
+ {
1031
+ if (t0->tag != SymbolicDim || t1->tag != SymbolicDim ||
1032
+ t2->tag != SymbolicDim) {
1033
+ return false;
1034
+ }
1035
+
1036
+ return strcmp(t0->SymbolicDim.name, t1->SymbolicDim.name) == 0 &&
1037
+ strcmp(t0->SymbolicDim.name, t2->SymbolicDim.name) == 0;
1038
+ }
1039
+
1040
+ static bool
1041
+ all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1042
+ {
1043
+ return t0->ndim == 0 && t1->ndim == 0 && t2->ndim == 0;
1044
+ }
1045
+
1046
+ /*
1047
+ * Optimized type checking for very specific signatures. The caller must
1048
+ * have identified the kernel location, signature and the dtype. For
1049
+ * performance reasons, no substitution is performed on the dtype, so
1050
+ * the dtype must be concrete.
1051
+ *
1052
+ * Supported signatures:
1053
+ * 1) ... * N * T0, ... * N * T1 -> N * T2
1054
+ * 2) ... * T0, ... * T1 -> ... * T2
1055
+ */
1056
+ int
1057
+ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1058
+ const ndt_t *in[], const int nin, ndt_t *dtype,
1059
+ ndt_context_t *ctx)
1060
+ {
1061
+ ndt_t *p0, *p1, *p2;
1062
+ ndt_ndarray_t x, y;
1063
+
1064
+ assert(spec->flags == 0);
1065
+ assert(spec->nout == 0);
1066
+ assert(spec->nbroadcast == 0);
1067
+ assert(spec->outer_dims == 0);
1068
+
1069
+ if (sig->tag != Function ||
1070
+ sig->Function.nin != 2 ||
1071
+ sig->Function.nout != 1) {
1072
+ ndt_err_format(ctx, NDT_RuntimeError,
1073
+ "fast binary typecheck expects a signature with two inputs and "
1074
+ "one output");
1075
+ return -1;
1076
+ }
1077
+
1078
+ if (nin != 2) {
1079
+ ndt_err_format(ctx, NDT_RuntimeError,
1080
+ "fast binary typecheck expects two input arguments");
1081
+ return -1;
1082
+ }
1083
+
1084
+ if (ndt_is_abstract(dtype)) {
1085
+ ndt_err_format(ctx, NDT_RuntimeError,
1086
+ "fast binary typecheck expects a concrete dtype");
1087
+ return -1;
1088
+ }
1089
+
1090
+ p0 = sig->Function.types[0];
1091
+ p1 = sig->Function.types[1];
1092
+ p2 = sig->Function.types[2];
1093
+
1094
+ if (!all_ellipses(p0, p1, p2, ctx)) {
1095
+ return -1;
1096
+ }
1097
+
1098
+ if (ndt_as_ndarray(&x, in[0], ctx) < 0) {
1099
+ ndt_del(dtype);
1100
+ return -1;
1101
+ }
1102
+
1103
+ if (ndt_as_ndarray(&y, in[1], ctx) < 0) {
1104
+ ndt_del(dtype);
1105
+ return -1;
1106
+ }
1107
+
1108
+ p0 = p0->EllipsisDim.type;
1109
+ p1 = p1->EllipsisDim.type;
1110
+ p2 = p2->EllipsisDim.type;
1111
+
1112
+ if (all_same_symbol(p0, p1, p2)) {
1113
+ if (x.ndim > 0 && y.ndim > 0) {
1114
+ const int64_t xshape = x.shape[x.ndim-1];
1115
+ const int64_t yshape = y.shape[y.ndim-1];
1116
+ if (xshape != 1 && yshape != 1 && xshape != yshape) {
1117
+ ndt_err_format(ctx, NDT_TypeError, "mismatch in inner dimensions");
1118
+ ndt_del(dtype);
1119
+ return -1;
1120
+ }
1121
+ }
1122
+ return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 1, ctx);
1123
+ }
1124
+ else if (all_ndim0(p0, p1, p2)) {
1125
+ return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 0, ctx);
1126
+ }
1127
+ else {
1128
+ ndt_err_format(ctx, NDT_RuntimeError,
1129
+ "unsupported signature in fast binary typecheck");
1130
+ return -1;
1131
+ }
1132
+ }