ndtypes 0.2.0dev4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (139) hide show
  1. checksums.yaml +7 -0
  2. data/CONTRIBUTING.md +50 -0
  3. data/Gemfile +2 -0
  4. data/History.md +0 -0
  5. data/README.md +19 -0
  6. data/Rakefile +125 -0
  7. data/ext/ruby_ndtypes/extconf.rb +55 -0
  8. data/ext/ruby_ndtypes/gc_guard.c +36 -0
  9. data/ext/ruby_ndtypes/gc_guard.h +12 -0
  10. data/ext/ruby_ndtypes/ndtypes/AUTHORS.txt +5 -0
  11. data/ext/ruby_ndtypes/ndtypes/INSTALL.txt +101 -0
  12. data/ext/ruby_ndtypes/ndtypes/LICENSE.txt +29 -0
  13. data/ext/ruby_ndtypes/ndtypes/MANIFEST.in +3 -0
  14. data/ext/ruby_ndtypes/ndtypes/Makefile.in +87 -0
  15. data/ext/ruby_ndtypes/ndtypes/README.rst +47 -0
  16. data/ext/ruby_ndtypes/ndtypes/config.guess +1530 -0
  17. data/ext/ruby_ndtypes/ndtypes/config.h.in +67 -0
  18. data/ext/ruby_ndtypes/ndtypes/config.sub +1782 -0
  19. data/ext/ruby_ndtypes/ndtypes/configure +5260 -0
  20. data/ext/ruby_ndtypes/ndtypes/configure.ac +161 -0
  21. data/ext/ruby_ndtypes/ndtypes/doc/Makefile +14 -0
  22. data/ext/ruby_ndtypes/ndtypes/doc/_static/copybutton.js +66 -0
  23. data/ext/ruby_ndtypes/ndtypes/doc/conf.py +26 -0
  24. data/ext/ruby_ndtypes/ndtypes/doc/grammar/grammar.rst +27 -0
  25. data/ext/ruby_ndtypes/ndtypes/doc/index.rst +56 -0
  26. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/context.rst +131 -0
  27. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/encodings.rst +68 -0
  28. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/fields-values.rst +175 -0
  29. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/functions.rst +72 -0
  30. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/index.rst +43 -0
  31. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/init.rst +48 -0
  32. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/io.rst +100 -0
  33. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/memory.rst +124 -0
  34. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/predicates.rst +110 -0
  35. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/typedef.rst +31 -0
  36. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/types.rst +594 -0
  37. data/ext/ruby_ndtypes/ndtypes/doc/libndtypes/util.rst +166 -0
  38. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/buffer-protocol.rst +27 -0
  39. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/index.rst +21 -0
  40. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/pattern-matching.rst +330 -0
  41. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/quickstart.rst +144 -0
  42. data/ext/ruby_ndtypes/ndtypes/doc/ndtypes/types.rst +544 -0
  43. data/ext/ruby_ndtypes/ndtypes/doc/releases/index.rst +35 -0
  44. data/ext/ruby_ndtypes/ndtypes/install-sh +527 -0
  45. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.in +271 -0
  46. data/ext/ruby_ndtypes/ndtypes/libndtypes/Makefile.vc +269 -0
  47. data/ext/ruby_ndtypes/ndtypes/libndtypes/alloc.c +230 -0
  48. data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.c +268 -0
  49. data/ext/ruby_ndtypes/ndtypes/libndtypes/attr.h +109 -0
  50. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile.in +73 -0
  51. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/Makefile.vc +70 -0
  52. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/README.txt +16 -0
  53. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.c +2179 -0
  54. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.h +134 -0
  55. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bpgrammar.y +428 -0
  56. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.c +2543 -0
  57. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.h +735 -0
  58. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/bplexer.l +176 -0
  59. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/export.c +543 -0
  60. data/ext/ruby_ndtypes/ndtypes/libndtypes/compat/import.c +110 -0
  61. data/ext/ruby_ndtypes/ndtypes/libndtypes/context.c +228 -0
  62. data/ext/ruby_ndtypes/ndtypes/libndtypes/copy.c +634 -0
  63. data/ext/ruby_ndtypes/ndtypes/libndtypes/encodings.c +116 -0
  64. data/ext/ruby_ndtypes/ndtypes/libndtypes/equal.c +288 -0
  65. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.c +3067 -0
  66. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.h +180 -0
  67. data/ext/ruby_ndtypes/ndtypes/libndtypes/grammar.y +417 -0
  68. data/ext/ruby_ndtypes/ndtypes/libndtypes/io.c +1658 -0
  69. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.c +2773 -0
  70. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.h +734 -0
  71. data/ext/ruby_ndtypes/ndtypes/libndtypes/lexer.l +222 -0
  72. data/ext/ruby_ndtypes/ndtypes/libndtypes/match.c +1132 -0
  73. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.c +2323 -0
  74. data/ext/ruby_ndtypes/ndtypes/libndtypes/ndtypes.h.in +893 -0
  75. data/ext/ruby_ndtypes/ndtypes/libndtypes/overflow.h +161 -0
  76. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.c +473 -0
  77. data/ext/ruby_ndtypes/ndtypes/libndtypes/parsefuncs.h +92 -0
  78. data/ext/ruby_ndtypes/ndtypes/libndtypes/parser.c +246 -0
  79. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.c +269 -0
  80. data/ext/ruby_ndtypes/ndtypes/libndtypes/seq.h +197 -0
  81. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile.in +48 -0
  82. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/Makefile.vc +46 -0
  83. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/deserialize.c +1007 -0
  84. data/ext/ruby_ndtypes/ndtypes/libndtypes/serialize/serialize.c +442 -0
  85. data/ext/ruby_ndtypes/ndtypes/libndtypes/slice.h +42 -0
  86. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.c +238 -0
  87. data/ext/ruby_ndtypes/ndtypes/libndtypes/substitute.h +50 -0
  88. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.c +371 -0
  89. data/ext/ruby_ndtypes/ndtypes/libndtypes/symtable.h +100 -0
  90. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.in +55 -0
  91. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/Makefile.vc +45 -0
  92. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/alloc_fail.c +82 -0
  93. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/alloc_fail.h +49 -0
  94. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/runtest.c +1657 -0
  95. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test.h +85 -0
  96. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_array.c +115 -0
  97. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_buffer.c +137 -0
  98. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_indent.c +201 -0
  99. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_match.c +2397 -0
  100. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_numba.c +57 -0
  101. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse.c +349 -0
  102. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_error.c +27839 -0
  103. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_parse_roundtrip.c +350 -0
  104. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_record.c +231 -0
  105. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typecheck.c +375 -0
  106. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/test_typedef.c +65 -0
  107. data/ext/ruby_ndtypes/ndtypes/libndtypes/tests/valgrind.supp +30 -0
  108. data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/bench.c +79 -0
  109. data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/indent.c +94 -0
  110. data/ext/ruby_ndtypes/ndtypes/libndtypes/tools/print_ast.c +96 -0
  111. data/ext/ruby_ndtypes/ndtypes/libndtypes/util.c +474 -0
  112. data/ext/ruby_ndtypes/ndtypes/libndtypes/values.c +228 -0
  113. data/ext/ruby_ndtypes/ndtypes/python/bench.py +49 -0
  114. data/ext/ruby_ndtypes/ndtypes/python/ndt_randtype.py +409 -0
  115. data/ext/ruby_ndtypes/ndtypes/python/ndt_support.py +14 -0
  116. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/__init__.py +70 -0
  117. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/_ndtypes.c +1332 -0
  118. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/docstrings.h +319 -0
  119. data/ext/ruby_ndtypes/ndtypes/python/ndtypes/pyndtypes.h +154 -0
  120. data/ext/ruby_ndtypes/ndtypes/python/test_ndtypes.py +1977 -0
  121. data/ext/ruby_ndtypes/ndtypes/setup.py +288 -0
  122. data/ext/ruby_ndtypes/ndtypes/vcbuild/INSTALL.txt +41 -0
  123. data/ext/ruby_ndtypes/ndtypes/vcbuild/runtest32.bat +15 -0
  124. data/ext/ruby_ndtypes/ndtypes/vcbuild/runtest64.bat +13 -0
  125. data/ext/ruby_ndtypes/ndtypes/vcbuild/vcbuild32.bat +38 -0
  126. data/ext/ruby_ndtypes/ndtypes/vcbuild/vcbuild64.bat +38 -0
  127. data/ext/ruby_ndtypes/ndtypes/vcbuild/vcclean.bat +13 -0
  128. data/ext/ruby_ndtypes/ndtypes/vcbuild/vcdistclean.bat +14 -0
  129. data/ext/ruby_ndtypes/ruby_ndtypes.c +1003 -0
  130. data/ext/ruby_ndtypes/ruby_ndtypes.h +37 -0
  131. data/ext/ruby_ndtypes/ruby_ndtypes_internal.h +28 -0
  132. data/lib/ndtypes.rb +45 -0
  133. data/lib/ndtypes/errors.rb +2 -0
  134. data/lib/ndtypes/version.rb +6 -0
  135. data/ndtypes.gemspec +47 -0
  136. data/spec/gc_table_spec.rb +10 -0
  137. data/spec/ndtypes_spec.rb +289 -0
  138. data/spec/spec_helper.rb +241 -0
  139. metadata +242 -0
@@ -0,0 +1,222 @@
1
+ %{
2
+ /*
3
+ * BSD 3-Clause License
4
+ *
5
+ * Copyright (c) 2017-2018, plures
6
+ * All rights reserved.
7
+ *
8
+ * Redistribution and use in source and binary forms, with or without
9
+ * modification, are permitted provided that the following conditions are met:
10
+ *
11
+ * 1. Redistributions of source code must retain the above copyright notice,
12
+ * this list of conditions and the following disclaimer.
13
+ *
14
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
15
+ * this list of conditions and the following disclaimer in the documentation
16
+ * and/or other materials provided with the distribution.
17
+ *
18
+ * 3. Neither the name of the copyright holder nor the names of its
19
+ * contributors may be used to endorse or promote products derived from
20
+ * this software without specific prior written permission.
21
+ *
22
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
26
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
28
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+ */
33
+
34
+
35
+ #include <stdio.h>
36
+ #include <string.h>
37
+ #include <stdint.h>
38
+ #include <setjmp.h>
39
+ #include "ndtypes.h"
40
+ #include "parsefuncs.h"
41
+ #include "grammar.h"
42
+
43
+ /* From PostgreSQL: avoid exit() on fatal scanner errors. */
44
+ #undef fprintf
45
+ #define fprintf(file, fmt, msg) fprintf_to_longjmp(fmt, msg, yyscanner)
46
+
47
+ extern jmp_buf ndt_lexerror;
48
+ static void
49
+ fprintf_to_longjmp(const char *fmt, const char *msg, yyscan_t yyscanner)
50
+ {
51
+ (void)fmt; (void)msg; (void)yyscanner;
52
+
53
+ /* We don't have access to the parse context here: discard the error
54
+ message, which is always either an allocation failure or an internal
55
+ flex error. */
56
+ longjmp(ndt_lexerror, 1);
57
+ }
58
+
59
+ void *
60
+ yyalloc(size_t size, yyscan_t yyscanner)
61
+ {
62
+ (void)yyscanner;
63
+
64
+ return ndt_alloc(1, size);
65
+ }
66
+
67
+ void *
68
+ yyrealloc(void *ptr, size_t size, yyscan_t yyscanner)
69
+ {
70
+ (void)yyscanner;
71
+
72
+ return ndt_realloc(ptr, 1, size);
73
+ }
74
+
75
+ void
76
+ yyfree(void *ptr, yyscan_t yyscanner)
77
+ {
78
+ (void)yyscanner;
79
+
80
+ ndt_free(ptr);
81
+ }
82
+
83
+ %}
84
+
85
+ %option bison-bridge bison-locations reentrant noyywrap
86
+ %option nounput noinput noyyalloc noyyrealloc noyyfree
87
+ %option never-interactive
88
+ %option yylineno
89
+ %option 8bit
90
+ %option extra-type="ndt_context_t *"
91
+ %option warn nodefault
92
+
93
+
94
+ newline [\n\r]
95
+ space [ \t\f]
96
+ non_newline [^\n\r]
97
+ comment #{non_newline}*
98
+
99
+ escapeseq \\.
100
+ single_strchar [^\\\n']
101
+ double_strchar [^\\\n"]
102
+ single_str '({single_strchar}|{escapeseq})*'
103
+ double_str \"({double_strchar}|{escapeseq})*\"
104
+ stringlit {single_str}|{double_str}
105
+
106
+ octdigit [0-7]
107
+ octinteger 0[oO]{octdigit}+
108
+ nonzerodigit [1-9]
109
+ digit [0-9]
110
+ decimalinteger {nonzerodigit}{digit}*|0+
111
+ hexdigit {digit}|[a-f]|[A-F]
112
+ hexinteger 0[xX]{hexdigit}+
113
+ integer -?({decimalinteger}|{octinteger}|{hexinteger})
114
+
115
+ intpart {digit}+
116
+ fraction \.{digit}+
117
+ exponent [eE][+-]?{digit}+
118
+ pointfloat {intpart}?{fraction}|{intpart}\.
119
+ exponentfloat ({intpart}|{pointfloat}){exponent}
120
+ floatnumber -?({pointfloat}|{exponentfloat})
121
+
122
+ name_lower [a-z][a-zA-Z0-9_]*
123
+ name_upper [A-Z][a-zA-Z0-9_]*
124
+ name_other _[a-zA-Z0-9_]*
125
+
126
+
127
+ %%
128
+
129
+ %code {
130
+ yycolumn = 1;
131
+
132
+ #undef YY_USER_ACTION
133
+ #define YY_USER_ACTION \
134
+ yylloc->first_line = yylloc->last_line = yylineno; \
135
+ yylloc->first_column = yycolumn; \
136
+ yylloc->last_column = yycolumn+yyleng-1; \
137
+ yycolumn += yyleng;
138
+
139
+ }
140
+
141
+ "Any" { return ANY_KIND; }
142
+ "Scalar" { return SCALAR_KIND; }
143
+
144
+ "void" { return VOID; }
145
+ "bool" { return BOOL; }
146
+
147
+ "Signed" { return SIGNED_KIND; }
148
+ "int8" { return INT8; }
149
+ "int16" { return INT16; }
150
+ "int32" { return INT32; }
151
+ "int64" { return INT64; }
152
+
153
+ "Unsigned" { return UNSIGNED_KIND; }
154
+ "uint8" { return UINT8; }
155
+ "uint16" { return UINT16; }
156
+ "uint32" { return UINT32; }
157
+ "uint64" { return UINT64; }
158
+
159
+ "Float" { return FLOAT_KIND; }
160
+ "float16" { return FLOAT16; }
161
+ "float32" { return FLOAT32; }
162
+ "float64" { return FLOAT64; }
163
+
164
+ "Complex" { return COMPLEX_KIND; }
165
+ "complex32" { return COMPLEX32; }
166
+ "complex64" { return COMPLEX64; }
167
+ "complex128" { return COMPLEX128; }
168
+
169
+ "intptr" { return INTPTR; }
170
+ "uintptr" { return UINTPTR; }
171
+ "size_t" { return SIZE; }
172
+ "char" { return CHAR; }
173
+ "string" { return STRING; }
174
+ "bytes" { return BYTES; }
175
+
176
+ "FixedString" { return FIXED_STRING_KIND; }
177
+ "fixed_string" { return FIXED_STRING; }
178
+
179
+ "FixedBytes" { return FIXED_BYTES_KIND; }
180
+ "fixed_bytes" { return FIXED_BYTES; }
181
+
182
+ "categorical" { return CATEGORICAL; }
183
+ "NA" { return NA; }
184
+
185
+ "ref" { return REF; }
186
+
187
+ "fixed" { return FIXED; }
188
+ "var" { return VAR; }
189
+
190
+ "..." { return ELLIPSIS; }
191
+ "->" { return RARROW; }
192
+ "," { return COMMA; }
193
+ ":" { return COLON; }
194
+ "(" { return LPAREN; }
195
+ ")" { return RPAREN; }
196
+ "{" { return LBRACE; }
197
+ "}" { return RBRACE; }
198
+ "[" { return LBRACK; }
199
+ "]" { return RBRACK; }
200
+ "*" { return STAR; }
201
+ "=" { return EQUAL; }
202
+ "?" { return QUESTIONMARK; }
203
+ "!" { return BANG; }
204
+ "&" { return AMPERSAND; }
205
+ "|" { return BAR; }
206
+ "<" { return LESS; }
207
+ ">" { return GREATER; }
208
+
209
+ {name_lower} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_LOWER; }
210
+ {name_upper} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_UPPER; }
211
+ {name_other} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return NAME_OTHER; }
212
+
213
+ {stringlit} { yylval->string = mk_stringlit(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return STRINGLIT; }
214
+ {integer} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return INTEGER; }
215
+ {floatnumber} { yylval->string = ndt_strdup(yytext, ctx); if (yylval->string == NULL) return ERRTOKEN; return FLOATNUMBER; }
216
+
217
+ {newline} { yycolumn = 1; }
218
+ {space} {} /* ignore */
219
+ {comment} {} /* ignore */
220
+ . { return ERRTOKEN; }
221
+
222
+ %%
@@ -0,0 +1,1132 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdio.h>
35
+ #include <stdlib.h>
36
+ #include <stdint.h>
37
+ #include <inttypes.h>
38
+ #include <stdbool.h>
39
+ #include <string.h>
40
+ #include <stdarg.h>
41
+ #include <assert.h>
42
+ #include "ndtypes.h"
43
+ #include "symtable.h"
44
+ #include "substitute.h"
45
+
46
+
47
+ static int match_datashape(const ndt_t *, const ndt_t *, symtable_t *, ndt_context_t *);
48
+
49
+ static int
50
+ _resolve_broadcast(int64_t vshape[NDT_MAX_DIM], int vsize,
51
+ const int64_t wshape[NDT_MAX_DIM], int wsize)
52
+ {
53
+ int64_t n, m;
54
+ int i, k;
55
+
56
+ for (i=vsize-1, k=wsize-1; i>=0 && k>=0; i--, k--) {
57
+ n = vshape[i];
58
+ m = wshape[k];
59
+ if (m != n) {
60
+ if (n == 1) {
61
+ n = m;
62
+ }
63
+ else if (m == 0) {
64
+ n = 0;
65
+ }
66
+ else if (m != 1) {
67
+ return -1;
68
+ }
69
+ }
70
+ vshape[i<k ? k : i] = n;
71
+ }
72
+ for (; k >= 0; k--) {
73
+ vshape[k] = wshape[k];
74
+ }
75
+
76
+ return vsize >= wsize ? vsize : wsize;
77
+ }
78
+
79
+ static int
80
+ resolve_broadcast(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
81
+ {
82
+ const char *key = "00_ELLIPSIS";
83
+ symtable_entry_t *v;
84
+ int vsize;
85
+
86
+ v = symtable_find_ptr(tbl, key);
87
+ if (v == NULL) {
88
+ if (symtable_add(tbl, key, w, ctx) < 0) {
89
+ return -1;
90
+ }
91
+ return 1;
92
+ }
93
+
94
+ vsize = _resolve_broadcast(v->BroadcastSeq.dims, v->BroadcastSeq.size,
95
+ w.BroadcastSeq.dims, w.BroadcastSeq.size);
96
+ if (vsize < 0) {
97
+ ndt_err_format(ctx, NDT_TypeError, "broadcast error");
98
+ return -1;
99
+ }
100
+ v->BroadcastSeq.size = vsize;
101
+
102
+ return 1;
103
+ }
104
+
105
+ static int
106
+ check_contig(ndt_t *ptypes[], ndt_t *ctypes[], int64_t nargs)
107
+ {
108
+ for (int i = 0; i < nargs; i++) {
109
+ const ndt_t *p = ptypes[i];
110
+ const ndt_t *c = ctypes[i];
111
+
112
+ if (p->tag == EllipsisDim) {
113
+ switch (p->EllipsisDim.tag) {
114
+ case RequireNA:
115
+ break;
116
+ case RequireC:
117
+ if (!ndt_is_c_contiguous(c)) {
118
+ return 0;
119
+ }
120
+ break;
121
+ case RequireF:
122
+ if (!ndt_is_f_contiguous(c)) {
123
+ return 0;
124
+ }
125
+ break;
126
+ }
127
+ }
128
+ }
129
+
130
+ return 1;
131
+ }
132
+
133
+ static ndt_t *
134
+ to_fortran(const ndt_t *p, ndt_t *c, ndt_context_t *ctx)
135
+ {
136
+ if (p->tag == EllipsisDim && p->EllipsisDim.tag == RequireF) {
137
+ ndt_t *t = ndt_to_fortran(c, ctx);
138
+ return t;
139
+ }
140
+ else {
141
+ return c;
142
+ }
143
+ }
144
+
145
+ static int
146
+ resolve_fixed(const char *key, symtable_entry_t w,
147
+ symtable_t *tbl, ndt_context_t *ctx)
148
+ {
149
+ symtable_entry_t v;
150
+
151
+ v = symtable_find(tbl, key);
152
+ if (v.tag == Unbound) {
153
+ if (symtable_add(tbl, key, w, ctx) < 0) {
154
+ return -1;
155
+ }
156
+ return 1;
157
+ }
158
+
159
+ if (w.FixedSeq.size != v.FixedSeq.size) {
160
+ return 0;
161
+ }
162
+
163
+ for (int i = 0; i < v.FixedSeq.size; i++) {
164
+ const ndt_t *t = v.FixedSeq.dims[i];
165
+ const ndt_t *u = w.FixedSeq.dims[i];
166
+ if (u->FixedDim.shape != t->FixedDim.shape) {
167
+ return 0;
168
+ }
169
+ }
170
+
171
+ return 1;
172
+ }
173
+
174
+ static int
175
+ resolve_shape(const char *key, int64_t shape, symtable_t *tbl, ndt_context_t *ctx)
176
+ {
177
+ symtable_entry_t v;
178
+
179
+ v = symtable_find(tbl, key);
180
+ if (v.tag == Unbound) {
181
+ v.tag = Shape;
182
+ v.Shape = shape;
183
+ if (symtable_add(tbl, key, v, ctx) < 0) {
184
+ return -1;
185
+ }
186
+ return 1;
187
+ }
188
+
189
+ if (v.tag != Shape) {
190
+ return 0;
191
+ }
192
+
193
+ return shape == v.Shape;
194
+ }
195
+
196
+ static int
197
+ resolve_typevar(const char *key, symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
198
+ {
199
+ symtable_entry_t v;
200
+
201
+ v = symtable_find(tbl, key);
202
+ if (v.tag == Unbound) {
203
+ if (symtable_add(tbl, key, w, ctx) < 0) {
204
+ return -1;
205
+ }
206
+ return 1;
207
+ }
208
+
209
+ if (v.tag == Symbol && w.tag == Symbol) {
210
+ return strcmp(v.Symbol, w.Symbol) == 0;
211
+ }
212
+ else if (v.tag == Type && w.tag == Type) {
213
+ return ndt_equal(v.Type, w.Type);
214
+ }
215
+ else {
216
+ return 0;
217
+ }
218
+ }
219
+
220
+ static int
221
+ match_concrete_var_dim(const ndt_t *t, int64_t tindex,
222
+ const ndt_t *u, int64_t uindex,
223
+ const int outer_dims, ndt_context_t *ctx)
224
+ {
225
+ int64_t tshape, tstart, tstep;
226
+ int64_t ushape, ustart, ustep;
227
+
228
+ if (outer_dims == 0) {
229
+ return 1;
230
+ }
231
+ if (t->Concrete.VarDim.itemsize != u->Concrete.VarDim.itemsize) {
232
+ return 0;
233
+ }
234
+
235
+ tshape = ndt_var_indices(&tstart, &tstep, t, tindex, ctx);
236
+ if (tshape < 0) {
237
+ return -1;
238
+ }
239
+
240
+ ushape = ndt_var_indices(&ustart, &ustep, u, uindex, ctx);
241
+ if (ushape < 0) {
242
+ return -1;
243
+ }
244
+
245
+ if (ushape != tshape) {
246
+ return 0;
247
+ }
248
+
249
+ for (int64_t i = 0; i < tshape; i++) {
250
+ int64_t tnext = tstart + i * tstep;
251
+ int64_t unext = ustart + i * ustep;
252
+ int ret = match_concrete_var_dim(t->VarDim.type, tnext,
253
+ u->VarDim.type, unext,
254
+ outer_dims-1, ctx);
255
+ if (ret <= 0) {
256
+ return ret;
257
+ }
258
+ }
259
+
260
+ return 1;
261
+ }
262
+
263
+ static int
264
+ resolve_var(symtable_entry_t w, symtable_t *tbl, ndt_context_t *ctx)
265
+ {
266
+ const char *key = "var";
267
+ symtable_entry_t v;
268
+
269
+ v = symtable_find(tbl, key);
270
+ if (v.tag == Unbound) {
271
+ if (symtable_add(tbl, key, w, ctx) < 0) {
272
+ return -1;
273
+ }
274
+ return 1;
275
+ }
276
+
277
+ if (w.VarSeq.size != v.VarSeq.size) {
278
+ return 0;
279
+ }
280
+ if (v.VarSeq.size == 0) {
281
+ return 1;
282
+ }
283
+
284
+ return match_concrete_var_dim(w.VarSeq.dims[0], 0,
285
+ v.VarSeq.dims[0], 0,
286
+ v.VarSeq.size, ctx);
287
+ }
288
+
289
+ static int
290
+ match_tuple_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
291
+ ndt_context_t *ctx)
292
+ {
293
+ int64_t i;
294
+ int n;
295
+
296
+ assert(p->tag == Tuple && c->tag == Tuple);
297
+
298
+ if (p->Tuple.shape != c->Tuple.shape) {
299
+ return 0;
300
+ }
301
+
302
+ for (i = 0; i < p->Tuple.shape; i++) {
303
+ n = match_datashape(p->Tuple.types[i], c->Tuple.types[i], tbl, ctx);
304
+ if (n <= 0) return n;
305
+ }
306
+
307
+ return 1;
308
+ }
309
+
310
+ static int
311
+ match_record_fields(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
312
+ ndt_context_t *ctx)
313
+ {
314
+ int64_t i;
315
+ int n;
316
+
317
+ assert(p->tag == Record && c->tag == Record);
318
+
319
+ if (p->Record.shape != c->Record.shape) {
320
+ return 0;
321
+ }
322
+
323
+ for (i = 0; i < p->Record.shape; i++) {
324
+ n = strcmp(p->Record.names[i], c->Record.names[i]);
325
+ if (n != 0) return 0;
326
+
327
+ n = match_datashape(p->Record.types[i], c->Record.types[i], tbl, ctx);
328
+ if (n <= 0) return n;
329
+ }
330
+
331
+ return 1;
332
+ }
333
+
334
+ static int
335
+ match_categorical(ndt_value_t *p, int64_t plen,
336
+ ndt_value_t *c, int64_t clen)
337
+ {
338
+ int64_t i;
339
+
340
+ if (plen != clen) {
341
+ return 0;
342
+ }
343
+
344
+ for (i = 0; i < plen; i++) {
345
+ if (!ndt_value_equal(&p[i], &c[i])) {
346
+ return 0;
347
+ }
348
+ }
349
+
350
+ return 1;
351
+ }
352
+
353
+ static const ndt_t *
354
+ outer_inner(symtable_entry_t *v, int i, const ndt_t *t, int ndim)
355
+ {
356
+ assert(ndt_is_concrete(t));
357
+
358
+ if (t->ndim < ndim) {
359
+ return NULL;
360
+ }
361
+ if (t->ndim == ndim) {
362
+ return t;
363
+ }
364
+
365
+ switch (t->tag) {
366
+ case FixedDim: {
367
+ switch (v->tag) {
368
+ case FixedSeq:
369
+ v->FixedSeq.size = i+1;
370
+ v->FixedSeq.dims[i] = t;
371
+ break;
372
+ case BroadcastSeq:
373
+ v->BroadcastSeq.size = i+1;
374
+ v->BroadcastSeq.dims[i] = t->FixedDim.shape;
375
+ break;
376
+ default:
377
+ return NULL;
378
+ }
379
+ return outer_inner(v, i+1, t->FixedDim.type, ndim);
380
+ }
381
+ case VarDim: {
382
+ switch (v->tag) {
383
+ case VarSeq:
384
+ v->VarSeq.size = i+1;
385
+ v->VarSeq.dims[i] = t;
386
+ break;
387
+ default:
388
+ return NULL;
389
+ }
390
+ return outer_inner(v, i+1, t->VarDim.type, ndim);
391
+ }
392
+ default:
393
+ return NULL;
394
+ }
395
+ }
396
+
397
+ static int
398
+ match_datashape(const ndt_t *p, const ndt_t *c, symtable_t *tbl,
399
+ ndt_context_t *ctx)
400
+ {
401
+ int n;
402
+
403
+ if (ndt_is_optional(c) != ndt_is_optional(p)) return 0;
404
+
405
+ switch (p->tag) {
406
+ case AnyKind: {
407
+ return 1;
408
+ }
409
+
410
+ case FixedDim: {
411
+ if (c->tag != FixedDim || p->FixedDim.shape != c->FixedDim.shape) {
412
+ return 0;
413
+ }
414
+ if (p->FixedDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
415
+ return 0;
416
+ }
417
+ if (p->FixedDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
418
+ return 0;
419
+ }
420
+
421
+ return match_datashape(p->FixedDim.type, c->FixedDim.type, tbl, ctx);
422
+ }
423
+
424
+ case VarDim: {
425
+ if (c->tag != VarDim) {
426
+ return 0;
427
+ }
428
+ return match_datashape(p->VarDim.type, c->VarDim.type, tbl, ctx);
429
+ }
430
+
431
+ case SymbolicDim: {
432
+ if (c->tag != FixedDim) return 0;
433
+
434
+ if (p->SymbolicDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
435
+ return 0;
436
+ }
437
+ if (p->SymbolicDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
438
+ return 0;
439
+ }
440
+
441
+ n = resolve_shape(p->SymbolicDim.name, c->FixedDim.shape, tbl, ctx);
442
+ if (n <= 0) {
443
+ return n;
444
+ }
445
+ return match_datashape(p->SymbolicDim.type, c->FixedDim.type, tbl, ctx);
446
+ }
447
+
448
+ case EllipsisDim: {
449
+ symtable_entry_t outer;
450
+ const ndt_t *inner;
451
+
452
+ if (p->EllipsisDim.tag == RequireC && !ndt_is_c_contiguous(c)) {
453
+ return 0;
454
+ }
455
+ if (p->EllipsisDim.tag == RequireF && !ndt_is_f_contiguous(c)) {
456
+ return 0;
457
+ }
458
+
459
+ if (p->EllipsisDim.name == NULL) {
460
+ outer.tag = BroadcastSeq;
461
+ outer.BroadcastSeq.size = 0;
462
+ }
463
+ else if (strcmp(p->EllipsisDim.name, "var") == 0) {
464
+ outer.tag = VarSeq;
465
+ outer.VarSeq.size = 0;
466
+ }
467
+ else {
468
+ outer.tag = FixedSeq;
469
+ outer.FixedSeq.size = 0;
470
+ }
471
+
472
+ inner = outer_inner(&outer, 0, c, p->EllipsisDim.type->ndim);
473
+ if (inner == NULL) {
474
+ return 0;
475
+ }
476
+
477
+ n = match_datashape(p->EllipsisDim.type, inner, tbl, ctx);
478
+ if (n <= 0) {
479
+ return n;
480
+ }
481
+
482
+ switch (outer.tag) {
483
+ case BroadcastSeq:
484
+ return resolve_broadcast(outer, tbl, ctx);
485
+ case FixedSeq:
486
+ return resolve_fixed(p->EllipsisDim.name, outer, tbl, ctx);
487
+ case VarSeq:
488
+ return resolve_var(outer, tbl, ctx);
489
+ default: /* NOT REACHED */
490
+ ndt_internal_error("invalid tag");
491
+ }
492
+ }
493
+
494
+ case Bool:
495
+ case Int8: case Int16: case Int32: case Int64:
496
+ case Uint8: case Uint16: case Uint32: case Uint64:
497
+ case Float16: case Float32: case Float64:
498
+ case Complex32: case Complex64: case Complex128:
499
+ case String:
500
+ return p->tag == c->tag;
501
+ case FixedString:
502
+ return c->tag == FixedString &&
503
+ p->FixedString.size == c->FixedString.size &&
504
+ p->FixedString.encoding == c->FixedString.encoding;
505
+ case FixedBytes:
506
+ return c->tag == FixedBytes &&
507
+ p->FixedBytes.size == c->FixedBytes.size &&
508
+ p->FixedBytes.align == c->FixedBytes.align;
509
+ case SignedKind:
510
+ return c->tag == SignedKind || ndt_is_signed(c);
511
+ case UnsignedKind:
512
+ return c->tag == UnsignedKind || ndt_is_unsigned(c);
513
+ case FloatKind:
514
+ return c->tag == FloatKind || ndt_is_float(c);
515
+ case ComplexKind:
516
+ return c->tag == ComplexKind || ndt_is_complex(c);
517
+ case FixedStringKind:
518
+ return c->tag == FixedStringKind || c->tag == FixedString;
519
+ case FixedBytesKind:
520
+ return c->tag == FixedBytesKind || c->tag == FixedBytes;
521
+ case ScalarKind:
522
+ return c->tag == ScalarKind || ndt_is_scalar(c);
523
+ case Char:
524
+ return c->tag == Char && c->Char.encoding == p->Char.encoding;
525
+ case Bytes:
526
+ return c->tag == Bytes && p->Bytes.target_align == c->Bytes.target_align;
527
+ case Categorical:
528
+ return c->tag == Categorical &&
529
+ match_categorical(p->Categorical.types, p->Categorical.ntypes,
530
+ c->Categorical.types, c->Categorical.ntypes);
531
+ case Ref:
532
+ if (c->tag != Ref) return 0;
533
+ return match_datashape(p->Ref.type, c->Ref.type, tbl, ctx);
534
+ case Tuple:
535
+ if (p->Tuple.flag == Variadic) return 0;
536
+ if (c->tag != Tuple) return 0;
537
+ return match_tuple_fields(p, c, tbl, ctx);
538
+ case Record:
539
+ if (p->Tuple.flag == Variadic) return 0;
540
+ if (c->tag != Record) return 0;
541
+ return match_record_fields(p, c, tbl, ctx);
542
+ case Function: {
543
+ int64_t i;
544
+ if (c->tag != Function ||
545
+ c->Function.nin != p->Function.nin ||
546
+ c->Function.nout != p->Function.nout ||
547
+ c->Function.nargs != p->Function.nargs) {
548
+ return 0;
549
+ }
550
+
551
+ for (i = 0; i < p->Function.nargs; i++) {
552
+ n = match_datashape(p->Function.types[i], c->Function.types[i], tbl, ctx);
553
+ if (n <= 0) return n;
554
+ }
555
+
556
+ return check_contig(p->Function.types, c->Function.types, p->Function.nargs);
557
+ }
558
+ case Typevar: {
559
+ if (c->tag == Typevar) {
560
+ symtable_entry_t entry = { .tag=Symbol, .Symbol=c->Typevar.name };
561
+ return resolve_typevar(p->Typevar.name, entry, tbl, ctx);
562
+ }
563
+ else {
564
+ symtable_entry_t entry = { .tag=Type, .Type=c };
565
+ return resolve_typevar(p->Typevar.name, entry, tbl, ctx);
566
+ }
567
+ }
568
+ case Nominal:
569
+ /* Assume that the type has been created through ndt_nominal(), in
570
+ which case the name is guaranteed to be unique and present in the
571
+ typedef table. */
572
+ return c->tag == Nominal && strcmp(p->Nominal.name, c->Nominal.name) == 0;
573
+ case Module:
574
+ return c->tag == Module && strcmp(p->Module.name, c->Module.name) == 0 &&
575
+ ndt_equal(p->Module.type, c->Module.type);
576
+ case Constr:
577
+ return c->tag == Constr && strcmp(p->Constr.name, c->Constr.name) == 0 &&
578
+ ndt_equal(p->Constr.type, c->Constr.type);
579
+ }
580
+
581
+ /* NOT REACHED: tags should be exhaustive. */
582
+ ndt_internal_error("invalid type");
583
+ }
584
+
585
+ int
586
+ ndt_match(const ndt_t *p, const ndt_t *c, ndt_context_t *ctx)
587
+ {
588
+ symtable_t *tbl;
589
+ int ret;
590
+
591
+ if (ndt_is_abstract(c)) {
592
+ return 0;
593
+ }
594
+
595
+ tbl = symtable_new(ctx);
596
+ if (tbl == NULL) {
597
+ return -1;
598
+ }
599
+
600
+ ret = match_datashape(p, c, tbl, ctx);
601
+ symtable_del(tbl);
602
+ return ret;
603
+ }
604
+
605
+ static ndt_t *
606
+ broadcast(const ndt_t *t, const int64_t *shape,
607
+ int outer_dims, int inner_dims,
608
+ bool use_max, ndt_context_t *ctx)
609
+ {
610
+ ndt_ndarray_t u;
611
+ const ndt_t *dtype;
612
+ ndt_t *v;
613
+ int64_t step;
614
+ int ndim;
615
+ int i, k;
616
+
617
+ ndim = ndt_as_ndarray(&u, t, ctx);
618
+ if (ndim < 0) {
619
+ return NULL;
620
+ }
621
+
622
+ dtype = ndt_dtype(t);
623
+ v = ndt_copy(dtype, ctx);
624
+ if (v == NULL) {
625
+ return NULL;
626
+ }
627
+
628
+ for (i=ndim-1; i>=ndim-inner_dims; i--) {
629
+ v = ndt_fixed_dim(v, u.shape[i], u.steps[i], ctx);
630
+ if (v == NULL) {
631
+ return NULL;
632
+ }
633
+ }
634
+
635
+ for (k=outer_dims-1; i>=0 && k>=0; i--, k--) {
636
+ step = u.shape[i]<=1 ? 0 : u.steps[i];
637
+ v = ndt_fixed_dim(v, shape[k], step, ctx);
638
+ if (v == NULL) {
639
+ return NULL;
640
+ }
641
+ }
642
+
643
+ for (; k>=0; k--) {
644
+ if (use_max) {
645
+ v = ndt_fixed_dim(v, shape[k], INT64_MAX, ctx);
646
+ }
647
+ else {
648
+ v = ndt_fixed_dim(v, shape[k], 0, ctx);
649
+ }
650
+ if (v == NULL) {
651
+ return NULL;
652
+ }
653
+ }
654
+
655
+ return v;
656
+ }
657
+
658
+ int
659
+ ndt_broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
660
+ const ndt_t *in[], const int nin,
661
+ const int64_t *shape, int outer_dims,
662
+ ndt_context_t *ctx)
663
+ {
664
+ ndt_t *u;
665
+ int inner_dims;
666
+ int i;
667
+
668
+ for (i = 0; i < nin; i++) {
669
+ inner_dims = sig->Function.types[i]->ndim-1;
670
+ spec->broadcast[i] = broadcast(in[i], shape,
671
+ outer_dims, inner_dims, false, ctx);
672
+ if (spec->broadcast[i] == NULL) {
673
+ return -1;
674
+ }
675
+ spec->nbroadcast++;
676
+ }
677
+
678
+ for (i = 0; i < spec->nout; i++) {
679
+ inner_dims = sig->Function.types[nin+i]->ndim-1;
680
+ u = broadcast(spec->out[i], shape,
681
+ outer_dims, inner_dims, true, ctx);
682
+ if (u == NULL) {
683
+ return -1;
684
+ }
685
+ ndt_del(spec->out[i]);
686
+ spec->out[i] = u;
687
+ }
688
+
689
+ spec->outer_dims = outer_dims;
690
+
691
+ return 0;
692
+ }
693
+
694
+ static int
695
+ broadcast_all(ndt_apply_spec_t *spec, const ndt_t *sig,
696
+ const ndt_t *in[], const int nin,
697
+ const symtable_t *tbl, ndt_context_t *ctx)
698
+ {
699
+ symtable_entry_t v;
700
+
701
+ v = symtable_find(tbl, "00_ELLIPSIS");
702
+ if (v.tag != BroadcastSeq) {
703
+ ndt_err_format(ctx, NDT_RuntimeError,
704
+ "unexpected missing unnamed ellipsis entry");
705
+ return -1;
706
+ }
707
+
708
+ return ndt_broadcast_all(spec, sig, in, nin,
709
+ v.BroadcastSeq.dims, v.BroadcastSeq.size,
710
+ ctx);
711
+ }
712
+
713
+ static int
714
+ resolve_constraint(const ndt_constraint_t *c, const void *args, symtable_t *tbl,
715
+ ndt_context_t *ctx)
716
+ {
717
+ int64_t shapes[NDT_MAX_SYMBOLS];
718
+ symtable_entry_t v;
719
+
720
+ for (int i = 0; i < c->nin; i++) {
721
+ v = symtable_find(tbl, c->symbols[i]);
722
+ if (v.tag != Shape) {
723
+ ndt_err_format(ctx, NDT_ValueError, "expected dimension variable");
724
+ return -1;
725
+ }
726
+ shapes[i] = v.Shape;
727
+ }
728
+
729
+ if (c->f(shapes, args, ctx) < 0) {
730
+ return -1;
731
+ }
732
+
733
+ for (int i = 0; i < c->nout; i++) {
734
+ if (resolve_shape(c->symbols[c->nin+i], shapes[c->nin+i], tbl, ctx) < 0) {
735
+ return -1;
736
+ }
737
+ }
738
+
739
+ return 0;
740
+ }
741
+
742
+ /*
743
+ * Check the concrete function arguments 'in' against the function
744
+ * signature 'sig'. On success, infer and return the concrete return
745
+ * types and the (possibly broadcasted) 'in' types.
746
+ */
747
+ int
748
+ ndt_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
749
+ const ndt_t *in[], const int nin,
750
+ const ndt_constraint_t *c, const void *args,
751
+ ndt_context_t *ctx)
752
+ {
753
+ symtable_t *tbl;
754
+ ndt_t *t;
755
+ const char *name;
756
+ int ret;
757
+ int64_t i;
758
+
759
+ assert(spec->flags == 0);
760
+ assert(spec->nout == 0);
761
+ assert(spec->nbroadcast == 0);
762
+ assert(spec->outer_dims == 0);
763
+
764
+ if (sig->tag != Function) {
765
+ ndt_err_format(ctx, NDT_ValueError,
766
+ "signature must be a function type");
767
+ return -1;
768
+ }
769
+
770
+ if (nin != sig->Function.nin) {
771
+ ndt_err_format(ctx, NDT_ValueError,
772
+ "expected %" PRIi64 " arguments, got %d", sig->Function.nin, nin);
773
+ return -1;
774
+ }
775
+
776
+ for (i = 0; i < nin; i++) {
777
+ if (ndt_is_abstract(in[i])) {
778
+ ndt_err_format(ctx, NDT_ValueError,
779
+ "type checking requires concrete argument types");
780
+ return -1;
781
+ }
782
+ }
783
+
784
+ tbl = symtable_new(ctx);
785
+ if (tbl == NULL) {
786
+ return -1;
787
+ }
788
+
789
+ for (i = 0; i < nin; i++) {
790
+ ret = match_datashape(sig->Function.types[i], in[i], tbl, ctx);
791
+ if (ret <= 0) {
792
+ symtable_del(tbl);
793
+
794
+ if (ret == 0) {
795
+ ndt_err_format(ctx, NDT_TypeError,
796
+ "argument types do not match");
797
+ }
798
+
799
+ return -1;
800
+ }
801
+ }
802
+
803
+ if (c != NULL && resolve_constraint(c, args, tbl, ctx) < 0) {
804
+ symtable_del(tbl);
805
+ return -1;
806
+ }
807
+
808
+ for (i = 0; i < sig->Function.nout; i++) {
809
+ spec->out[i] = ndt_substitute(sig->Function.types[nin+i], tbl, false, ctx);
810
+ if (spec->out[i] == NULL) {
811
+ ndt_apply_spec_clear(spec);
812
+ symtable_del(tbl);
813
+ return -1;
814
+ }
815
+ spec->nout++;
816
+ }
817
+
818
+ if (sig->flags & NDT_ELLIPSIS) {
819
+ if (sig->Function.nargs == 0 || sig->Function.types[0]->tag != EllipsisDim) {
820
+ ndt_err_format(ctx, NDT_RuntimeError,
821
+ "unexpected configuration of ellipsis flag and function types");
822
+ ndt_apply_spec_clear(spec);
823
+ symtable_del(tbl);
824
+ return -1;
825
+ }
826
+
827
+ t = sig->Function.types[0];
828
+ name = t->EllipsisDim.name;
829
+
830
+ if (name != NULL) {
831
+ symtable_entry_t v = symtable_find(tbl, name);
832
+ switch (v.tag) {
833
+ case FixedSeq:
834
+ spec->outer_dims = v.FixedSeq.size;
835
+ break;
836
+ case VarSeq:
837
+ spec->outer_dims = v.VarSeq.size;
838
+ break;
839
+ default:
840
+ ndt_err_format(ctx, NDT_RuntimeError,
841
+ "unexpected missing dimension list entry");
842
+ ndt_apply_spec_clear(spec);
843
+ symtable_del(tbl);
844
+ return -1;
845
+ }
846
+ }
847
+ else {
848
+ if (broadcast_all(spec, sig, in, nin, tbl, ctx) < 0) {
849
+ ndt_apply_spec_clear(spec);
850
+ symtable_del(tbl);
851
+ return -1;
852
+ }
853
+ }
854
+ }
855
+
856
+ symtable_del(tbl);
857
+
858
+ for (i = 0; i < sig->Function.nout; i++) {
859
+ ndt_t *_p = sig->Function.types[nin+i];
860
+ ndt_t *_c = spec->out[i];
861
+ ndt_t *_t = to_fortran(_p, _c, ctx);
862
+ if (_t == NULL) {
863
+ ndt_apply_spec_clear(spec);
864
+ return -1;
865
+ }
866
+ if (_t != _c) {
867
+ ndt_del(_c);
868
+ }
869
+ spec->out[i] = _t;
870
+ }
871
+
872
+ if (!check_contig(sig->Function.types, (ndt_t **)in, nin)) {
873
+ ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
874
+ return -1;
875
+ }
876
+ if (!check_contig(sig->Function.types+nin, spec->out, spec->nout)) {
877
+ ndt_err_format(ctx, NDT_TypeError, "argument types do not match");
878
+ return -1;
879
+ }
880
+
881
+ ndt_select_kernel_strategy(spec, sig, in, nin);
882
+
883
+ return 0;
884
+ }
885
+
886
+
887
+ /*****************************************************************************/
888
+ /* Optimized binary typecheck for fixed input */
889
+ /*****************************************************************************/
890
+
891
+ static ndt_t *
892
+ binary_broadcast_1D(const ndt_ndarray_t *t, const ndt_t *dtype,
893
+ const int64_t *shape, int size, ndt_context_t *ctx)
894
+ {
895
+ ndt_t *v;
896
+ int64_t step;
897
+ int i, k;
898
+
899
+ v = ndt_copy(dtype, ctx);
900
+ if (v == NULL) {
901
+ return NULL;
902
+ }
903
+
904
+ for (i=t->ndim-1, k=size-1; i>=0 && k>=0; i--, k--) {
905
+ step = t->shape[i]<=1 ? 0 : t->steps[i];
906
+ v = ndt_fixed_dim(v, shape[k], step, ctx);
907
+ if (v == NULL) {
908
+ return NULL;
909
+ }
910
+ }
911
+
912
+ for (; k>=0; k--) {
913
+ v = ndt_fixed_dim(v, shape[k], 0, ctx);
914
+ if (v == NULL) {
915
+ return NULL;
916
+ }
917
+ }
918
+
919
+ return v;
920
+ }
921
+
922
+ static ndt_t *
923
+ fixed_dim_from_shape(const int64_t shape[], int len, ndt_t *dtype,
924
+ ndt_context_t *ctx)
925
+ {
926
+ ndt_t *t;
927
+ int i;
928
+
929
+ for (i=len-1, t=dtype; i >= 0; i--) {
930
+ t = ndt_fixed_dim(t, shape[i], INT64_MAX, ctx);
931
+ if (t == NULL) {
932
+ return NULL;
933
+ }
934
+ }
935
+
936
+ return t;
937
+ }
938
+
939
+ static bool
940
+ shape_equal(const ndt_ndarray_t *a, const ndt_ndarray_t *b)
941
+ {
942
+ if (b->ndim != a->ndim) {
943
+ return false;
944
+ }
945
+
946
+ for (int i = 0; i < a->ndim; i++) {
947
+ if (b->shape[i] != a->shape[i]) {
948
+ return false;
949
+ }
950
+ }
951
+
952
+ return true;
953
+ }
954
+
955
+ static int
956
+ _ndt_binary_broadcast(ndt_apply_spec_t *spec, const ndt_t *sig,
957
+ const ndt_ndarray_t *x, const ndt_ndarray_t *y,
958
+ const ndt_t *in[], const int nin, ndt_t *dtype,
959
+ int inner, ndt_context_t *ctx)
960
+ {
961
+ int64_t shape[NDT_MAX_DIM];
962
+ int size;
963
+
964
+ if (shape_equal(x, y)) {
965
+ spec->nout = 1;
966
+ spec->nbroadcast = 0;
967
+ spec->outer_dims = x->ndim-inner;
968
+ spec->out[0] = fixed_dim_from_shape(x->shape, x->ndim, dtype, ctx);
969
+ if (spec->out[0] == NULL) {
970
+ return -1;
971
+ }
972
+ }
973
+ else {
974
+ for (int i = 0; i < x->ndim; i++) {
975
+ shape[i] = x->shape[i];
976
+ }
977
+
978
+ size = _resolve_broadcast(shape, x->ndim, y->shape, y->ndim);
979
+ if (size < 0) {
980
+ ndt_err_format(ctx, NDT_TypeError, "broadcast error");
981
+ ndt_del(dtype);
982
+ return -1;
983
+ }
984
+
985
+ spec->nout = 1;
986
+ spec->nbroadcast = 2;
987
+ spec->outer_dims = size-inner;
988
+
989
+ spec->out[0] = fixed_dim_from_shape(shape, size, dtype, ctx);
990
+ if (spec->out[0] == NULL) {
991
+ return -1;
992
+ }
993
+
994
+ spec->broadcast[0] = binary_broadcast_1D(x, ndt_dtype(in[0]), shape, size, ctx);
995
+ if (spec->broadcast[0] == NULL) {
996
+ ndt_del(spec->out[0]);
997
+ return -1;
998
+ }
999
+
1000
+ spec->broadcast[1] = binary_broadcast_1D(y, ndt_dtype(in[1]), shape, size, ctx);
1001
+ if (spec->broadcast[1] == NULL) {
1002
+ ndt_del(spec->out[0]);
1003
+ ndt_del(spec->broadcast[0]);
1004
+ return -1;
1005
+ }
1006
+ }
1007
+
1008
+ ndt_select_kernel_strategy(spec, sig, in, nin);
1009
+
1010
+ return 0;
1011
+ }
1012
+
1013
+ static bool
1014
+ all_ellipses(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2,
1015
+ ndt_context_t *ctx)
1016
+ {
1017
+ if ((t0->tag != EllipsisDim || t0->EllipsisDim.name != NULL) ||
1018
+ (t1->tag != EllipsisDim || t1->EllipsisDim.name != NULL) ||
1019
+ (t2->tag != EllipsisDim || t2->EllipsisDim.name != NULL)) {
1020
+ ndt_err_format(ctx, NDT_RuntimeError,
1021
+ "fast binary typecheck expects leading ellipsis dimensions");
1022
+ return false;
1023
+ }
1024
+
1025
+ return true;
1026
+ }
1027
+
1028
+ static bool
1029
+ all_same_symbol(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1030
+ {
1031
+ if (t0->tag != SymbolicDim || t1->tag != SymbolicDim ||
1032
+ t2->tag != SymbolicDim) {
1033
+ return false;
1034
+ }
1035
+
1036
+ return strcmp(t0->SymbolicDim.name, t1->SymbolicDim.name) == 0 &&
1037
+ strcmp(t0->SymbolicDim.name, t2->SymbolicDim.name) == 0;
1038
+ }
1039
+
1040
+ static bool
1041
+ all_ndim0(const ndt_t *t0, const ndt_t *t1, const ndt_t *t2)
1042
+ {
1043
+ return t0->ndim == 0 && t1->ndim == 0 && t2->ndim == 0;
1044
+ }
1045
+
1046
+ /*
1047
+ * Optimized type checking for very specific signatures. The caller must
1048
+ * have identified the kernel location, signature and the dtype. For
1049
+ * performance reasons, no substitution is performed on the dtype, so
1050
+ * the dtype must be concrete.
1051
+ *
1052
+ * Supported signatures:
1053
+ * 1) ... * N * T0, ... * N * T1 -> N * T2
1054
+ * 2) ... * T0, ... * T1 -> ... * T2
1055
+ */
1056
+ int
1057
+ ndt_fast_binary_fixed_typecheck(ndt_apply_spec_t *spec, const ndt_t *sig,
1058
+ const ndt_t *in[], const int nin, ndt_t *dtype,
1059
+ ndt_context_t *ctx)
1060
+ {
1061
+ ndt_t *p0, *p1, *p2;
1062
+ ndt_ndarray_t x, y;
1063
+
1064
+ assert(spec->flags == 0);
1065
+ assert(spec->nout == 0);
1066
+ assert(spec->nbroadcast == 0);
1067
+ assert(spec->outer_dims == 0);
1068
+
1069
+ if (sig->tag != Function ||
1070
+ sig->Function.nin != 2 ||
1071
+ sig->Function.nout != 1) {
1072
+ ndt_err_format(ctx, NDT_RuntimeError,
1073
+ "fast binary typecheck expects a signature with two inputs and "
1074
+ "one output");
1075
+ return -1;
1076
+ }
1077
+
1078
+ if (nin != 2) {
1079
+ ndt_err_format(ctx, NDT_RuntimeError,
1080
+ "fast binary typecheck expects two input arguments");
1081
+ return -1;
1082
+ }
1083
+
1084
+ if (ndt_is_abstract(dtype)) {
1085
+ ndt_err_format(ctx, NDT_RuntimeError,
1086
+ "fast binary typecheck expects a concrete dtype");
1087
+ return -1;
1088
+ }
1089
+
1090
+ p0 = sig->Function.types[0];
1091
+ p1 = sig->Function.types[1];
1092
+ p2 = sig->Function.types[2];
1093
+
1094
+ if (!all_ellipses(p0, p1, p2, ctx)) {
1095
+ return -1;
1096
+ }
1097
+
1098
+ if (ndt_as_ndarray(&x, in[0], ctx) < 0) {
1099
+ ndt_del(dtype);
1100
+ return -1;
1101
+ }
1102
+
1103
+ if (ndt_as_ndarray(&y, in[1], ctx) < 0) {
1104
+ ndt_del(dtype);
1105
+ return -1;
1106
+ }
1107
+
1108
+ p0 = p0->EllipsisDim.type;
1109
+ p1 = p1->EllipsisDim.type;
1110
+ p2 = p2->EllipsisDim.type;
1111
+
1112
+ if (all_same_symbol(p0, p1, p2)) {
1113
+ if (x.ndim > 0 && y.ndim > 0) {
1114
+ const int64_t xshape = x.shape[x.ndim-1];
1115
+ const int64_t yshape = y.shape[y.ndim-1];
1116
+ if (xshape != 1 && yshape != 1 && xshape != yshape) {
1117
+ ndt_err_format(ctx, NDT_TypeError, "mismatch in inner dimensions");
1118
+ ndt_del(dtype);
1119
+ return -1;
1120
+ }
1121
+ }
1122
+ return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 1, ctx);
1123
+ }
1124
+ else if (all_ndim0(p0, p1, p2)) {
1125
+ return _ndt_binary_broadcast(spec, sig, &x, &y, in, nin, dtype, 0, ctx);
1126
+ }
1127
+ else {
1128
+ ndt_err_format(ctx, NDT_RuntimeError,
1129
+ "unsupported signature in fast binary typecheck");
1130
+ return -1;
1131
+ }
1132
+ }