relationalai 1.0.0a1__py3-none-any.whl → 1.0.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/semantics/frontend/base.py +3 -0
- relationalai/semantics/frontend/front_compiler.py +5 -2
- relationalai/semantics/metamodel/builtins.py +2 -1
- relationalai/semantics/metamodel/metamodel.py +32 -4
- relationalai/semantics/metamodel/pprint.py +5 -3
- relationalai/semantics/metamodel/typer.py +324 -297
- relationalai/semantics/std/aggregates.py +0 -1
- relationalai/semantics/std/datetime.py +4 -1
- relationalai/shims/executor.py +22 -4
- relationalai/shims/mm2v0.py +108 -38
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/METADATA +1 -1
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/RECORD +27 -27
- v0/relationalai/errors.py +23 -0
- v0/relationalai/semantics/internal/internal.py +4 -4
- v0/relationalai/semantics/internal/snowflake.py +2 -1
- v0/relationalai/semantics/lqp/executor.py +16 -11
- v0/relationalai/semantics/lqp/model2lqp.py +42 -4
- v0/relationalai/semantics/lqp/passes.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/cdc.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +53 -12
- v0/relationalai/semantics/metamodel/builtins.py +8 -6
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
- v0/relationalai/semantics/reasoners/graph/core.py +8 -9
- v0/relationalai/semantics/sql/compiler.py +2 -2
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/WHEEL +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/top_level.txt +0 -0
|
@@ -81,17 +81,21 @@ class Typer:
|
|
|
81
81
|
# Propagation Network
|
|
82
82
|
#--------------------------------------------------
|
|
83
83
|
|
|
84
|
-
# The core idea of the typer is to build a propagation network
|
|
85
|
-
#
|
|
86
|
-
#
|
|
87
|
-
#
|
|
88
|
-
#
|
|
89
|
-
|
|
90
|
-
#
|
|
91
|
-
#
|
|
92
|
-
#
|
|
93
|
-
|
|
94
|
-
|
|
84
|
+
# The core idea of the typer is to build a propagation network that represents data
|
|
85
|
+
# dependencies between nodes. Nodes can be variables, literals, fields and tasks that refer
|
|
86
|
+
# to relations (lookups, aggregates, updates). An edge from node A to node B means that
|
|
87
|
+
# in order to resolve references and types for node B we need to know the type of node A.
|
|
88
|
+
#
|
|
89
|
+
# After building the network, we start with roots that have known types (literals, fields
|
|
90
|
+
# loaded from previous analysis and vars with concrete types) and propagate their types via
|
|
91
|
+
# the edges to other nodes. It is possible that the network contains cycles, so we iterate
|
|
92
|
+
# on the work list until fixpoint is reached, i.e. when we cannot propagate any new type.
|
|
93
|
+
#
|
|
94
|
+
# During the propagation, we are only gathering information about the types and references
|
|
95
|
+
# of nodes. Once we reach a fixpoint, we do a final pass to rewrite the model with the new
|
|
96
|
+
# type information, which may include adding casts to convert between types where needed.
|
|
97
|
+
#
|
|
98
|
+
Node = Union[mm.Var, mm.Field, mm.Literal, mm.Lookup, mm.Aggregate, mm.Update]
|
|
95
99
|
|
|
96
100
|
class PropagationNetwork():
|
|
97
101
|
def __init__(self, model: mm.Model):
|
|
@@ -104,22 +108,12 @@ class PropagationNetwork():
|
|
|
104
108
|
# map from unresolved placeholder relations to their potential target replacements
|
|
105
109
|
self.potential_targets: dict[mm.Relation, list[mm.Relation]] = {}
|
|
106
110
|
|
|
107
|
-
#
|
|
108
|
-
self.roots = OrderedSet()
|
|
109
|
-
# we separately want to track nodes that were loaded from a previous run
|
|
110
|
-
# so that even if we have edges to them, we _still_ consider them roots
|
|
111
|
-
# and properly propagate types from them at the beginning
|
|
111
|
+
# nodes loaded from previous analysis, which we use as roots to start propagation
|
|
112
112
|
self.loaded_roots = set()
|
|
113
113
|
|
|
114
114
|
# edges in the propagation network, from one node to potentially many
|
|
115
115
|
self.edges:dict[Node, OrderedSet[Node]] = defaultdict(lambda: OrderedSet())
|
|
116
116
|
self.back_edges:dict[Node, OrderedSet[Node]] = defaultdict(lambda: OrderedSet())
|
|
117
|
-
# all nodes that are the target of an edge (to find roots)
|
|
118
|
-
self.has_incoming = set()
|
|
119
|
-
|
|
120
|
-
# type requirements: for a var with abstract declared type, the set of fields that
|
|
121
|
-
# it must match the type of because it flows into them
|
|
122
|
-
self.type_requirements:dict[mm.Var, OrderedSet[mm.Field]] = defaultdict(lambda: OrderedSet())
|
|
123
117
|
|
|
124
118
|
# all errors collected during inference
|
|
125
119
|
self.errors:list[TyperError] = []
|
|
@@ -127,22 +121,22 @@ class PropagationNetwork():
|
|
|
127
121
|
# overloads resolved for a lookup/update/aggregate, by node id. This is only for
|
|
128
122
|
# relations that declare overloads
|
|
129
123
|
self.resolved_overload:dict[int, mm.Overload] = {}
|
|
124
|
+
|
|
130
125
|
# placeholders resolved for a lookup, by node id. This is only for relations that
|
|
131
126
|
# are placeholders (i.e. only Any fields) and will be replaced by references to
|
|
132
127
|
# these concrete relations. E.g. a query for "name(Any, Any)" may be replaced by
|
|
133
128
|
# the union of "name(Dog, String)" and name(Cat, String)".
|
|
134
129
|
self.resolved_placeholder:dict[int, list[mm.Relation]] = {}
|
|
130
|
+
|
|
135
131
|
# for a given lookup/update/aggregate that involves numbers, the specific number
|
|
136
132
|
# type resolved for it.
|
|
137
133
|
self.resolved_number:dict[int, mm.NumberType] = {}
|
|
138
|
-
# keep track of nodes already resolved to avoid re-resolving
|
|
139
|
-
self.resolved_nodes:set[int] = set()
|
|
140
134
|
|
|
141
135
|
#--------------------------------------------------
|
|
142
136
|
# Error reporting
|
|
143
137
|
#--------------------------------------------------
|
|
144
138
|
|
|
145
|
-
def type_mismatch(self, node: Node, expected: mm.Type, actual: mm.Type):
|
|
139
|
+
def type_mismatch(self, node: Node|mm.Update, expected: mm.Type, actual: mm.Type):
|
|
146
140
|
self.errors.append(TypeMismatch(node, expected, actual))
|
|
147
141
|
|
|
148
142
|
def invalid_type(self, node: Node, type: mm.Type):
|
|
@@ -153,6 +147,8 @@ class PropagationNetwork():
|
|
|
153
147
|
self.errors.append(UnresolvedOverload(node, [self.resolve(a) for a in node.args]))
|
|
154
148
|
|
|
155
149
|
def unresolved_type(self, node: Node):
|
|
150
|
+
# TODO - this is not being used yet, we need a pass at the end to check for any node
|
|
151
|
+
# that we could not resolve
|
|
156
152
|
self.errors.append(UnresolvedType(node))
|
|
157
153
|
|
|
158
154
|
def has_errors(self, node: Node) -> bool:
|
|
@@ -166,15 +162,8 @@ class PropagationNetwork():
|
|
|
166
162
|
#--------------------------------------------------
|
|
167
163
|
|
|
168
164
|
def add_edge(self, source: Node, target: Node):
|
|
169
|
-
# manage roots
|
|
170
|
-
if target in self.roots and target not in self.loaded_roots:
|
|
171
|
-
self.roots.remove(target)
|
|
172
|
-
if source not in self.has_incoming:
|
|
173
|
-
self.roots.add(source)
|
|
174
|
-
# register edge
|
|
175
165
|
self.edges[source].add(target)
|
|
176
166
|
self.back_edges[target].add(source)
|
|
177
|
-
self.has_incoming.add(target)
|
|
178
167
|
|
|
179
168
|
def add_resolved_type(self, node: Node, type: mm.Type):
|
|
180
169
|
""" Register that this node was resolved to have this type. """
|
|
@@ -183,13 +172,8 @@ class PropagationNetwork():
|
|
|
183
172
|
else:
|
|
184
173
|
self.resolved_types[node] = type
|
|
185
174
|
|
|
186
|
-
def add_type_requirement(self, source: mm.Var, field: mm.Field):
|
|
187
|
-
""" Register that this var, which has an abstract declared type, must match the type
|
|
188
|
-
of this field as it flows into it.. """
|
|
189
|
-
self.type_requirements[source].add(field)
|
|
190
|
-
|
|
191
175
|
#--------------------------------------------------
|
|
192
|
-
# Load previous
|
|
176
|
+
# Load types from a previous analysis
|
|
193
177
|
#--------------------------------------------------
|
|
194
178
|
|
|
195
179
|
def load_types(self, type_dict: dict[Node, mm.Type]):
|
|
@@ -197,22 +181,39 @@ class PropagationNetwork():
|
|
|
197
181
|
if isinstance(node, (mm.Field)):
|
|
198
182
|
self.add_resolved_type(node, type)
|
|
199
183
|
self.loaded_roots.add(node)
|
|
200
|
-
self.roots.add(node)
|
|
201
184
|
|
|
202
185
|
#--------------------------------------------------
|
|
203
186
|
# Resolve Values
|
|
204
187
|
#--------------------------------------------------
|
|
205
188
|
|
|
206
|
-
def resolve(self, value:
|
|
207
|
-
if isinstance(value, (mm.Var, mm.
|
|
189
|
+
def resolve(self, value: mm.Value) -> mm.Type:
|
|
190
|
+
if isinstance(value, (mm.Var, mm.Literal)):
|
|
208
191
|
return self.resolved_types.get(value) or to_type(value)
|
|
209
|
-
|
|
192
|
+
if isinstance(value, mm.Field):
|
|
193
|
+
return self.resolved_types.get(value) or value.type
|
|
194
|
+
assert not isinstance(value, (mm.Task)), "Should never try to resolve a task"
|
|
210
195
|
return to_type(value)
|
|
211
196
|
|
|
212
197
|
#--------------------------------------------------
|
|
213
198
|
# Resolve References
|
|
214
199
|
#--------------------------------------------------
|
|
215
200
|
|
|
201
|
+
def all_dependencies_resolved(self, op:mm.Lookup|mm.Aggregate):
|
|
202
|
+
""" True iff all dependencies required to resolve this reference are met. """
|
|
203
|
+
rel = get_relation(op)
|
|
204
|
+
# if this is a placeholder, eq or cast, we assume all possible args were resolved
|
|
205
|
+
if bt.is_placeholder(rel) or rel == b.core.eq or rel == b.core.cast:
|
|
206
|
+
return True
|
|
207
|
+
|
|
208
|
+
# else, find whether all back-edges were resolved
|
|
209
|
+
for node in self.back_edges[op]:
|
|
210
|
+
# cannot resolve if a required var, literal or input field is still abstract
|
|
211
|
+
if isinstance(node, (mm.Var, mm.Literal)) or (isinstance(node, mm.Field) and node.input):
|
|
212
|
+
node_type = self.resolve(node)
|
|
213
|
+
if bt.is_abstract(node_type):
|
|
214
|
+
return False
|
|
215
|
+
return True
|
|
216
|
+
|
|
216
217
|
def resolve_reference(self, op: mm.Lookup|mm.Aggregate) -> Optional[mm.Overload|list[mm.Relation]]:
|
|
217
218
|
# check if all dependencies required to resolve this reference are met
|
|
218
219
|
if not self.all_dependencies_resolved(op):
|
|
@@ -228,7 +229,6 @@ class PropagationNetwork():
|
|
|
228
229
|
if all(type_matches(arg, self.resolve(field))
|
|
229
230
|
for arg, field in zip(resolved_args, fields)):
|
|
230
231
|
matches.append(target)
|
|
231
|
-
|
|
232
232
|
return matches
|
|
233
233
|
|
|
234
234
|
elif relation.overloads:
|
|
@@ -243,24 +243,11 @@ class PropagationNetwork():
|
|
|
243
243
|
self.resolved_overload[op.id] = overload
|
|
244
244
|
return overload
|
|
245
245
|
return [] # no matches found
|
|
246
|
+
|
|
246
247
|
else:
|
|
247
248
|
# this is a relation with type vars or numbers that needs to be specialized
|
|
248
249
|
return [relation]
|
|
249
250
|
|
|
250
|
-
|
|
251
|
-
def all_dependencies_resolved(self, op:mm.Lookup|mm.Aggregate):
|
|
252
|
-
# if this is a placeholder, we need assume all possible args were resolved
|
|
253
|
-
if bt.is_placeholder(get_relation(op)):
|
|
254
|
-
return True
|
|
255
|
-
# else, find whether all back-edges were resolved
|
|
256
|
-
for node in self.back_edges[op]:
|
|
257
|
-
if isinstance(node, (mm.Var, mm.Field, mm.Literal)):
|
|
258
|
-
node_type = self.resolve(node)
|
|
259
|
-
if bt.is_abstract(node_type):
|
|
260
|
-
return False
|
|
261
|
-
return True
|
|
262
|
-
|
|
263
|
-
|
|
264
251
|
#--------------------------------------------------
|
|
265
252
|
# Propagation
|
|
266
253
|
#--------------------------------------------------
|
|
@@ -269,73 +256,112 @@ class PropagationNetwork():
|
|
|
269
256
|
edges = self.edges
|
|
270
257
|
work_list = []
|
|
271
258
|
|
|
272
|
-
#
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
continue
|
|
278
|
-
node_type = self.resolve(node)
|
|
279
|
-
if not bt.is_abstract(node_type):
|
|
280
|
-
work_list.append(node)
|
|
281
|
-
else:
|
|
282
|
-
unhandled_roots.add(node)
|
|
259
|
+
# start with the loaded roots + all literals + sources of edges without back edges
|
|
260
|
+
work_list.extend(self.loaded_roots)
|
|
261
|
+
for source in self.edges.keys():
|
|
262
|
+
if isinstance(source, (mm.Literal)) or not source in self.back_edges:
|
|
263
|
+
work_list.append(source)
|
|
283
264
|
|
|
284
|
-
#
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
self.resolved_nodes.add(source.id)
|
|
288
|
-
if source in unhandled_roots:
|
|
289
|
-
unhandled_roots.remove(source)
|
|
290
|
-
source_type = self.resolve(source)
|
|
291
|
-
# check to see if the source has ended up with a set of types that
|
|
292
|
-
# aren't valid, e.g. a union of primitives
|
|
293
|
-
if invalid_type(source_type):
|
|
294
|
-
self.invalid_type(source, source_type)
|
|
295
|
-
|
|
296
|
-
# propagate our type to each outgoing edge
|
|
297
|
-
for out in edges.get(source, []):
|
|
298
|
-
# if this is an overload then we need to try and resolve it
|
|
299
|
-
if isinstance(out, (mm.Lookup, mm.Aggregate)):
|
|
300
|
-
if not out.id in self.resolved_nodes:
|
|
301
|
-
found = self.resolve_reference(out)
|
|
302
|
-
if found is not None:
|
|
303
|
-
self.resolved_nodes.add(out.id)
|
|
304
|
-
self.propagate_reference(out, found)
|
|
305
|
-
for arg in out.args:
|
|
306
|
-
if arg not in work_list:
|
|
307
|
-
work_list.append(arg)
|
|
308
|
-
# otherwise, we just add to the outgoing node's type and if it
|
|
309
|
-
# changes we add it to the work list
|
|
310
|
-
elif start := self.resolve(out):
|
|
311
|
-
self.add_resolved_type(out, source_type)
|
|
312
|
-
if out not in work_list and (start != self.resolve(out) or not out.id in self.resolved_nodes):
|
|
313
|
-
work_list.append(out)
|
|
314
|
-
|
|
315
|
-
for source in unhandled_roots:
|
|
316
|
-
self.unresolved_type(source)
|
|
317
|
-
|
|
318
|
-
# now that we've pushed all the types through the network, we need to validate
|
|
319
|
-
# that all type requirements of those nodes are met
|
|
320
|
-
for node, fields in self.type_requirements.items():
|
|
321
|
-
node_type = self.resolve(node)
|
|
322
|
-
for field in fields:
|
|
323
|
-
field_type = self.resolve(field)
|
|
324
|
-
if not type_matches(node_type, field_type) and not conversion_allowed(node_type, field_type):
|
|
325
|
-
self.type_mismatch(node, field_type, node_type)
|
|
265
|
+
# limit the number of iterations to avoid infinite loops
|
|
266
|
+
i = 0
|
|
267
|
+
max_iterations = 100 * len(self.edges)
|
|
326
268
|
|
|
269
|
+
# propagate types until we reach a fixed point
|
|
270
|
+
while work_list:
|
|
271
|
+
i += 1
|
|
272
|
+
if i > max_iterations:
|
|
273
|
+
err("Infinite Loop", "Infinite loop detected in the typer. Please, report this as a bug.")
|
|
274
|
+
break
|
|
327
275
|
|
|
328
|
-
|
|
276
|
+
node = work_list.pop(0)
|
|
277
|
+
next = None
|
|
278
|
+
if isinstance(node, mm.Field):
|
|
279
|
+
# this is a field loaded from a previous analysis, so all we need to do is
|
|
280
|
+
# propagate its type to its output edges
|
|
281
|
+
next = edges.get(node, [])
|
|
282
|
+
|
|
283
|
+
elif isinstance(node, (mm.Lookup, mm.Aggregate)):
|
|
284
|
+
# this is a lookup/aggregate that may be overloaded or a placeholder; try
|
|
285
|
+
# to resolve its reference, i.e. determine which specific relation or
|
|
286
|
+
# overload it refers to
|
|
287
|
+
found = self.resolve_reference(node)
|
|
288
|
+
if found is not None:
|
|
289
|
+
# if found is None it means that we need more info to resolve the
|
|
290
|
+
# reference; otherwise it was possible to resolve it (even if to no matches)
|
|
291
|
+
|
|
292
|
+
# keep the resolved args before propagation to see if they will change
|
|
293
|
+
resolved_args = [self.resolve(arg) for arg in node.args]
|
|
294
|
+
# propagate the reference resolution
|
|
295
|
+
self.propagate_reference(node, resolved_args, found)
|
|
296
|
+
if found:
|
|
297
|
+
# the next nodes to process are all the outgoing edges plus any arg that
|
|
298
|
+
# changed during propagation
|
|
299
|
+
next = OrderedSet()
|
|
300
|
+
next.update(edges.get(node, []))
|
|
301
|
+
next.update([arg for idx, arg in enumerate(node.args) if isinstance(arg, mm.Var) and not (self.resolve(arg) == resolved_args[idx])])
|
|
302
|
+
else:
|
|
303
|
+
# the reference is unresolved, so remove from the worklist the args
|
|
304
|
+
# because they depend on this being resolved
|
|
305
|
+
for arg in node.args:
|
|
306
|
+
if arg in work_list:
|
|
307
|
+
work_list.remove(arg)
|
|
308
|
+
else:
|
|
309
|
+
assert isinstance(node, (mm.Var, mm.Literal))
|
|
310
|
+
resolved = self.resolve(node)
|
|
311
|
+
# if we ended up with an invalid type, report it on a next edge (which is
|
|
312
|
+
# where the var or literal is going to be used)
|
|
313
|
+
if invalid_type(resolved):
|
|
314
|
+
self.invalid_type(edges.get(node, [node])[0], resolved)
|
|
315
|
+
|
|
316
|
+
# if the var is still abstract, add it to the work list to try again later
|
|
317
|
+
if isinstance(node, mm.Var) and bt.is_abstract(resolved):
|
|
318
|
+
# but if everything that is left are vars, we cannot make progress
|
|
319
|
+
if not all(isinstance(x, mm.Var) for x in work_list):
|
|
320
|
+
next = [node]
|
|
321
|
+
else:
|
|
322
|
+
# otherwise, check all outgoing edges
|
|
323
|
+
next = []
|
|
324
|
+
for out in edges.get(node, []):
|
|
325
|
+
if isinstance(out, mm.Update):
|
|
326
|
+
# this is an update to a field. We have to check that the type is
|
|
327
|
+
# valid for the field and if the field is abstract we can refine the
|
|
328
|
+
# type. But if the field is concrete we have to check that the type
|
|
329
|
+
# matche or can be converted.
|
|
330
|
+
is_population_lookup = isinstance(out.relation, mm.TypeNode)
|
|
331
|
+
for field in get_update_fields(node, out):
|
|
332
|
+
if not type_matches(resolved, field.type, accept_expected_super_types=is_population_lookup) and not conversion_allowed(resolved, field.type):
|
|
333
|
+
self.type_mismatch(out, field.type, resolved)
|
|
334
|
+
elif bt.is_abstract(field.type):
|
|
335
|
+
# if the field type changed, propagate further
|
|
336
|
+
if resolved != self.resolve(field):
|
|
337
|
+
next.append(field)
|
|
338
|
+
self.add_resolved_type(field, resolved)
|
|
339
|
+
else:
|
|
340
|
+
# this is an arg flowing into a task, so resolve the task next
|
|
341
|
+
next.append(out)
|
|
342
|
+
|
|
343
|
+
# add the new nodes to the work list
|
|
344
|
+
if next is not None:
|
|
345
|
+
for n in next:
|
|
346
|
+
if n not in work_list:
|
|
347
|
+
work_list.append(n)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def propagate_reference(self, task:mm.Lookup|mm.Aggregate, resolved_args:list[mm.Type], references:mm.Overload|list[mm.Relation]):
|
|
329
351
|
# TODO: distinguish between overloads and placeholders better when raising errors
|
|
330
352
|
if not references:
|
|
331
353
|
return self.unresolved_overload(task)
|
|
332
354
|
|
|
333
|
-
resolved_args = [self.resolve(arg) for arg in task.args]
|
|
334
|
-
|
|
335
|
-
# we need to determine the final types of our args by taking all the references
|
|
336
|
-
# and adding the type of their fields back to the args.
|
|
337
355
|
relation = get_relation(task)
|
|
338
356
|
|
|
357
|
+
if relation == b.core.cast:
|
|
358
|
+
# cast is special cased: we just propagate the type to the target
|
|
359
|
+
cast_type, target = task.args[0], task.args[2]
|
|
360
|
+
assert isinstance(cast_type, mm.Type)
|
|
361
|
+
assert isinstance(target, mm.Var)
|
|
362
|
+
self.add_resolved_type(target, cast_type)
|
|
363
|
+
return
|
|
364
|
+
|
|
339
365
|
if bt.is_placeholder(relation):
|
|
340
366
|
assert(references and isinstance(references, list))
|
|
341
367
|
# we've resolved the placeholder, so store that
|
|
@@ -357,22 +383,25 @@ class PropagationNetwork():
|
|
|
357
383
|
# number specialization, so use that relation's field types
|
|
358
384
|
types = list([self.resolve(f) for f in relation.fields])
|
|
359
385
|
|
|
360
|
-
# if our overload preserves types, we check to see if there's a preserved
|
|
361
|
-
#
|
|
362
|
-
#
|
|
386
|
+
# if our overload preserves types, we check to see if there's a preserved output
|
|
387
|
+
# type given the inputs and if so, shadow the field type with the preserved type.
|
|
388
|
+
# this is only attempted if all input types match the field types, i.e. no
|
|
389
|
+
# conversions are needed
|
|
363
390
|
resolved_fields = types
|
|
364
|
-
if bt.is_function(relation) and len(set(resolved_fields)) == 1
|
|
391
|
+
if bt.is_function(relation) and len(set(resolved_fields)) == 1 and not relation in self.NON_TYPE_PRESERVERS and\
|
|
392
|
+
all(type_matches(arg_type, field_type) for arg_type, field_type, field in zip(resolved_args, types, relation.fields) if field.input):
|
|
393
|
+
|
|
365
394
|
input_types = set([arg_type for field, arg_type
|
|
366
395
|
in zip(relation.fields, resolved_args) if field.input])
|
|
367
396
|
if out_type := self.try_preserve_type(input_types):
|
|
368
397
|
resolved_fields = [field_type if field.input else out_type
|
|
369
398
|
for field, field_type in zip(relation.fields, types)]
|
|
370
399
|
|
|
371
|
-
|
|
372
|
-
#
|
|
373
|
-
#
|
|
374
|
-
if b.core.Number in types or (b.core.TypeVar in types and any(bt.is_number(t) for t in resolved_args)):
|
|
375
|
-
# this
|
|
400
|
+
|
|
401
|
+
# eq is special cased because we don't want to specialize a number for it, as it
|
|
402
|
+
# can be just comparing numbers of different types.
|
|
403
|
+
if relation != b.core.eq and (b.core.Number in types or (b.core.TypeVar in types and any(bt.is_number(t) for t in resolved_args))):
|
|
404
|
+
# this relation contains generic numbers or typevars bound to numbers, so
|
|
376
405
|
# find which specific type of number to use given the arguments being passed
|
|
377
406
|
number, resolved_fields = self.specialize_number(relation, resolved_fields, resolved_args)
|
|
378
407
|
self.resolved_number[task.id] = number
|
|
@@ -380,16 +409,62 @@ class PropagationNetwork():
|
|
|
380
409
|
for field, field_type, arg in zip(relation.fields, resolved_fields, task.args):
|
|
381
410
|
if not field.input and isinstance(arg, mm.Var):
|
|
382
411
|
self.add_resolved_type(arg, field_type)
|
|
412
|
+
|
|
413
|
+
elif b.core.TypeVar in types:
|
|
414
|
+
# this relation contains type vars, so we have to make sure that all args
|
|
415
|
+
# bound to the same type var are consistent
|
|
416
|
+
|
|
417
|
+
# find which arg is bound to the type var and check that they are all consistent
|
|
418
|
+
typevar_type = None
|
|
419
|
+
for arg_type, field_type in zip(resolved_args, types):
|
|
420
|
+
if field_type == b.core.TypeVar:
|
|
421
|
+
if typevar_type is None:
|
|
422
|
+
typevar_type = arg_type
|
|
423
|
+
else:
|
|
424
|
+
typevar_type = merge_types(typevar_type, arg_type)
|
|
425
|
+
assert typevar_type is not None
|
|
426
|
+
|
|
427
|
+
# compute the final arg types by replacing type vars with the typevar_type
|
|
428
|
+
computed_arg_types = []
|
|
429
|
+
for arg_type, field_type in zip(resolved_args, types):
|
|
430
|
+
# check that the arg type matches the type var type
|
|
431
|
+
if not type_matches(typevar_type, arg_type) and not conversion_allowed(arg_type, typevar_type):
|
|
432
|
+
self.type_mismatch(task, typevar_type, arg_type)
|
|
433
|
+
if field_type == b.core.TypeVar:
|
|
434
|
+
computed_arg_types.append(typevar_type)
|
|
435
|
+
else:
|
|
436
|
+
computed_arg_types.append(arg_type)
|
|
437
|
+
|
|
438
|
+
# if no mismatches were found, propagate the computed arg types back to the args
|
|
439
|
+
if len(computed_arg_types) == len(task.args):
|
|
440
|
+
for computed_type, arg, arg_type in zip(computed_arg_types, task.args, resolved_args):
|
|
441
|
+
# TODO: we could allow for non-numeric/string literals because this means
|
|
442
|
+
# the literal is being used as a value type, but the backend emitters
|
|
443
|
+
# would have to deal with that.
|
|
444
|
+
if isinstance(arg, mm.Var) or (isinstance(arg, mm.Literal) and (bt.is_numeric(computed_type) or computed_type == b.core.String)):
|
|
445
|
+
self.add_resolved_type(arg, computed_type)
|
|
446
|
+
|
|
383
447
|
else:
|
|
384
|
-
|
|
385
|
-
|
|
448
|
+
# no typevar or number specialization shenanigans, just propagate field types to args
|
|
449
|
+
for field, field_type, arg, arg_type in zip(relation.fields, resolved_fields, task.args, resolved_args):
|
|
450
|
+
# add a resolved type if we learned more about the arg's type
|
|
451
|
+
if isinstance(arg, mm.Var) and (bt.is_abstract(arg_type) or bt.extends(field_type, arg_type)):
|
|
386
452
|
self.add_resolved_type(arg, field_type)
|
|
453
|
+
elif not type_matches(arg_type, field_type) and not conversion_allowed(arg_type, field_type):
|
|
454
|
+
self.type_mismatch(task, field_type, arg_type)
|
|
455
|
+
|
|
387
456
|
|
|
457
|
+
# we try to preserve types for relations that are functions (i.e. potentially multiple
|
|
458
|
+
# input but a single output) and where all types are the same. However, there are some
|
|
459
|
+
# exceptions to this rule, e.g. range() always returns int regardless of whether the
|
|
460
|
+
# input type is an int or a value type that extends int.
|
|
461
|
+
NON_TYPE_PRESERVERS = [
|
|
462
|
+
b.common.range
|
|
463
|
+
]
|
|
388
464
|
|
|
389
465
|
def try_preserve_type(self, types:set[mm.Type]) -> Optional[mm.Type]:
|
|
390
|
-
# we keep the input type as the output type if either all inputs
|
|
391
|
-
#
|
|
392
|
-
# type, e.g. USD + Decimal
|
|
466
|
+
# we keep the input type as the output type if either all inputs are the exact same
|
|
467
|
+
# type or there's one nominal and its base primitive type, e.g. USD + Decimal
|
|
393
468
|
if len(types) == 1:
|
|
394
469
|
return next(iter(types))
|
|
395
470
|
if len(types) == 2:
|
|
@@ -411,13 +486,15 @@ class PropagationNetwork():
|
|
|
411
486
|
|
|
412
487
|
def specialize_number(self, op, field_types:list[mm.Type], arg_types:list[mm.Type]) -> Tuple[mm.NumberType, list[mm.Type]]:
|
|
413
488
|
"""
|
|
414
|
-
Find the number type to use for an overload that has Number in its field_types,
|
|
415
|
-
|
|
489
|
+
Find the number type to use for an overload that has Number in its field_types, and
|
|
490
|
+
which is being referred to with these arg_types.
|
|
416
491
|
|
|
417
492
|
Return a tuple where the first element is the specialized number type, and the second
|
|
418
493
|
element is a new list that contains the same types as field_types but with
|
|
419
494
|
Number replaced by this specialized number.
|
|
420
495
|
"""
|
|
496
|
+
# special case a few operators according to Snowflake's rules in
|
|
497
|
+
# https://docs.snowflake.com/en/sql-reference/operators-arithmetic#scale-and-precision-in-arithmetic-operations
|
|
421
498
|
if op == b.core.div:
|
|
422
499
|
# see https://docs.snowflake.com/en/sql-reference/operators-arithmetic#division
|
|
423
500
|
numerator, denominator = get_number_type(arg_types[0]), get_number_type(arg_types[1])
|
|
@@ -436,29 +513,21 @@ class PropagationNetwork():
|
|
|
436
513
|
# TODO!! - implement proper avg specialization
|
|
437
514
|
pass
|
|
438
515
|
|
|
516
|
+
# fall back to the the current specialization policy, which is to select the number
|
|
517
|
+
# with largest scale and, if there multiple with the largest scale, the one with the
|
|
518
|
+
# largest precision. This is safe because when converting a number to the
|
|
519
|
+
# specialized number, we never truncate fractional digits (because we selected the
|
|
520
|
+
# largest scale) and, if the non-fractional digits are too large to fit the
|
|
521
|
+
# specialized number, we will have a runtime overflow, which should alert the user
|
|
522
|
+
# of the problem.
|
|
439
523
|
number = None
|
|
440
524
|
for arg_type in arg_types:
|
|
441
525
|
x = bt.get_number_supertype(arg_type)
|
|
442
526
|
if isinstance(x, mm.NumberType):
|
|
443
|
-
# the current specialization policy is to select the number with largest
|
|
444
|
-
# scale and, if there multiple with the largest scale, the one with the
|
|
445
|
-
# largest precision. This is safe because when converting a number to the
|
|
446
|
-
# specialized number, we never truncate fractional digits (because we
|
|
447
|
-
# selected the largest scale) and, if the non-fractional digits are too
|
|
448
|
-
# large to fit the specialized number, we will have a runtime overflow,
|
|
449
|
-
# which should alert the user of the problem.
|
|
450
|
-
#
|
|
451
|
-
# In the future we can implement more complex policies. For example,
|
|
452
|
-
# snowflake has well documented behavior for how the output of operations
|
|
453
|
-
# behave in face of different number types, and we may use that:
|
|
454
|
-
# https://docs.snowflake.com/en/sql-reference/operators-arithmetic#scale-and-precision-in-arithmetic-operations
|
|
455
527
|
if number is None or x.scale > number.scale or (x.scale == number.scale and x.precision > number.precision):
|
|
456
528
|
number = x
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
# assert(isinstance(number, mm.NumberType))
|
|
460
|
-
return number, [number if bt.is_number(field_type) else field_type
|
|
461
|
-
for field_type in field_types]
|
|
529
|
+
assert(number is not None)
|
|
530
|
+
return number, [number if bt.is_number(t) else t for t in field_types]
|
|
462
531
|
|
|
463
532
|
|
|
464
533
|
#--------------------------------------------------
|
|
@@ -467,7 +536,6 @@ class PropagationNetwork():
|
|
|
467
536
|
|
|
468
537
|
# draw the network as a mermaid graph for the debugger
|
|
469
538
|
def to_mermaid(self, max_edges=500) -> str:
|
|
470
|
-
|
|
471
539
|
# add links for edges while collecting nodes
|
|
472
540
|
nodes = OrderedSet()
|
|
473
541
|
link_strs = []
|
|
@@ -481,22 +549,12 @@ class PropagationNetwork():
|
|
|
481
549
|
link_strs.append(f"n{src.id} --> n{dst.id}")
|
|
482
550
|
if len(link_strs) > max_edges:
|
|
483
551
|
break
|
|
484
|
-
# type requirements
|
|
485
|
-
for src, dsts in self.type_requirements.items():
|
|
486
|
-
nodes.add(src)
|
|
487
|
-
for dst in dsts:
|
|
488
|
-
if len(link_strs) > max_edges:
|
|
489
|
-
break
|
|
490
|
-
nodes.add(dst)
|
|
491
|
-
link_strs.append(f"n{src.id} -.-> n{dst.id}")
|
|
492
|
-
if len(link_strs) > max_edges:
|
|
493
|
-
break
|
|
494
552
|
|
|
495
553
|
def type_span(t:mm.Type) -> str:
|
|
496
554
|
type_str = t.name if isinstance(t, mm.ScalarType) else str(t)
|
|
497
555
|
return f"<span style='color:cyan;'>{type_str.strip()}</span>"
|
|
498
556
|
|
|
499
|
-
def reference_span(rel:mm.Relation, arg_types:list[mm.Type]
|
|
557
|
+
def reference_span(rel:mm.Relation, arg_types:list[mm.Type]) -> str:
|
|
500
558
|
args = []
|
|
501
559
|
for field, arg_type in zip(rel.fields, arg_types):
|
|
502
560
|
field_type = self.resolve(field)
|
|
@@ -506,22 +564,21 @@ class PropagationNetwork():
|
|
|
506
564
|
args.append(type_span(field_type))
|
|
507
565
|
else:
|
|
508
566
|
args.append(type_span(arg_type))
|
|
509
|
-
return f'{rel.name}
|
|
567
|
+
return f'{rel.name}({", ".join(args)})'
|
|
510
568
|
|
|
511
569
|
resolved = self.resolved_types
|
|
512
570
|
node_strs = []
|
|
513
571
|
for node in nodes:
|
|
514
572
|
klass = ""
|
|
515
|
-
root = "(*)" if node in self.roots else ""
|
|
516
573
|
if isinstance(node, mm.Var):
|
|
517
574
|
ir_type = resolved.get(node) or self.resolve(node)
|
|
518
575
|
type_str = type_span(ir_type)
|
|
519
|
-
label = f'(["{node.name}
|
|
576
|
+
label = f'(["{node.name}:{type_str}"])'
|
|
520
577
|
elif isinstance(node, mm.Literal):
|
|
521
578
|
ir_type = resolved.get(node) or self.resolve(node)
|
|
522
579
|
type_str = type_span(ir_type)
|
|
523
580
|
klass = ":::literal"
|
|
524
|
-
label = f'[
|
|
581
|
+
label = f'(["{node.value}: {type_str}"])'
|
|
525
582
|
elif isinstance(node, mm.Field):
|
|
526
583
|
ir_type = resolved.get(node) or self.resolve(node)
|
|
527
584
|
type_str = type_span(ir_type)
|
|
@@ -529,19 +586,22 @@ class PropagationNetwork():
|
|
|
529
586
|
rel = node._relation
|
|
530
587
|
if rel is not None:
|
|
531
588
|
rel = str(node._relation)
|
|
532
|
-
label = f'
|
|
589
|
+
label = f'[/"{node.name}:{type_str}\nfrom {rel}"\\]'
|
|
533
590
|
else:
|
|
534
|
-
label = f'
|
|
535
|
-
elif isinstance(node, (mm.Lookup, mm.
|
|
591
|
+
label = f'[/"{node.name}:\n{type_str}"\\]'
|
|
592
|
+
elif isinstance(node, (mm.Lookup, mm.Aggregate, mm.Update)):
|
|
536
593
|
arg_types = [self.resolve(arg) for arg in node.args]
|
|
537
594
|
if node.id in self.resolved_placeholder:
|
|
538
595
|
overloads = self.resolved_placeholder[node.id]
|
|
539
|
-
content = "<br/>".join([reference_span(o, arg_types
|
|
596
|
+
content = "<br/>".join([reference_span(o, arg_types) for o in overloads])
|
|
597
|
+
else:
|
|
598
|
+
content = reference_span(get_relation(node), arg_types)
|
|
599
|
+
if isinstance(node, mm.Update):
|
|
600
|
+
klass = ":::update"
|
|
601
|
+
label = f'{{{{"{content}"}}}}'
|
|
540
602
|
else:
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
# elif isinstance(node, mm.Relation):
|
|
544
|
-
# label = f'[("{node}")]'
|
|
603
|
+
klass = ":::reference"
|
|
604
|
+
label = f'[/"{content}"/]'
|
|
545
605
|
else:
|
|
546
606
|
raise NotImplementedError(f"Unknown node type: {type(node)}")
|
|
547
607
|
if self.has_errors(node):
|
|
@@ -555,6 +615,7 @@ class PropagationNetwork():
|
|
|
555
615
|
flowchart TD
|
|
556
616
|
linkStyle default stroke:#666
|
|
557
617
|
classDef field fill:#245,stroke:#478
|
|
618
|
+
classDef update fill:#245,stroke:#478
|
|
558
619
|
classDef literal fill:#452,stroke:#784
|
|
559
620
|
classDef error fill:#624,stroke:#945,color:#f9a
|
|
560
621
|
classDef default stroke:#444,stroke-width:2px, font-size:12px
|
|
@@ -598,121 +659,42 @@ class Analyzer(Walker):
|
|
|
598
659
|
rel = node.relation
|
|
599
660
|
self.compute_potential_targets(rel)
|
|
600
661
|
|
|
601
|
-
#
|
|
602
|
-
# type; so, it's fine to pass a super-type in to the population e.g. Employee(Person)
|
|
603
|
-
# should be a valid way to populate that a particular Person is also an Employee.
|
|
604
|
-
is_type_relation = isinstance(rel, mm.TypeNode)
|
|
662
|
+
# arg is flowing into a field
|
|
605
663
|
for arg, field in zip(node.args, rel.fields):
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
# if the arg is abstract, but the field isn't, then we need to make sure that
|
|
610
|
-
# once the arg is resolved we check that it matches the field type
|
|
611
|
-
if isinstance(arg, mm.Var) and bt.is_abstract(arg_type) and bt.is_concrete(field_type):
|
|
612
|
-
self.net.add_type_requirement(arg, field)
|
|
613
|
-
|
|
614
|
-
if bt.is_abstract(field_type) and isinstance(arg, (mm.Var, mm.Literal)):
|
|
615
|
-
# if the field is abstract, then eventually this arg will help determine
|
|
616
|
-
# the field's type, so add an edge from the arg to the field
|
|
617
|
-
self.net.add_edge(arg, field)
|
|
618
|
-
elif not type_matches(arg_type, field_type, accept_expected_super_types=is_type_relation):
|
|
619
|
-
if not conversion_allowed(arg_type, field_type):
|
|
620
|
-
self.net.type_mismatch(node, field_type, arg_type)
|
|
664
|
+
if isinstance(arg, (mm.Var, mm.Literal)):
|
|
665
|
+
self.net.add_edge(arg, node)
|
|
621
666
|
|
|
622
667
|
#--------------------------------------------------
|
|
623
668
|
# Walk Lookups + Aggregates
|
|
624
669
|
#--------------------------------------------------
|
|
625
670
|
|
|
626
|
-
def lookup(self,
|
|
627
|
-
self.compute_potential_targets(
|
|
628
|
-
self.visit_rel_op(
|
|
629
|
-
|
|
630
|
-
def aggregate(self, node: mm.Aggregate):
|
|
631
|
-
self.visit_rel_op(node)
|
|
632
|
-
|
|
633
|
-
def visit_rel_op(self, node: mm.Lookup|mm.Aggregate):
|
|
634
|
-
rel = get_relation(node)
|
|
635
|
-
|
|
636
|
-
# special case eq lookups
|
|
637
|
-
if isinstance(node, mm.Lookup) and rel == b.core.eq:
|
|
638
|
-
# if both args for an eq are abstract, link them, otherwise do normal processing
|
|
639
|
-
(left, right) = node.args
|
|
640
|
-
left_type = self.net.resolve(left)
|
|
641
|
-
right_type = self.net.resolve(right)
|
|
642
|
-
if bt.is_abstract(left_type) and bt.is_abstract(right_type):
|
|
643
|
-
assert isinstance(left, mm.Var) and isinstance(right, mm.Var)
|
|
644
|
-
# if both sides are abstract, then whatever we find out about
|
|
645
|
-
# either should propagate to the other
|
|
646
|
-
self.net.add_edge(left, right)
|
|
647
|
-
self.net.add_edge(right, left)
|
|
648
|
-
return
|
|
649
|
-
|
|
650
|
-
# special case when the relation needs to be resolved as there are overloads, placeholders,
|
|
651
|
-
# type vars or it needs number specialization
|
|
652
|
-
if self.requires_resolution(rel):
|
|
653
|
-
return self.visit_unresolved_reference(node)
|
|
654
|
-
|
|
655
|
-
# if this is a population check, then it's fine to pass a subtype in to do the check
|
|
656
|
-
# e.g. Employee(Person) is a valid way to check if a person is an employee
|
|
657
|
-
is_population_lookup = isinstance(rel, mm.TypeNode)
|
|
658
|
-
for arg, field in zip(node.args, rel.fields):
|
|
659
|
-
field_type = self.net.resolve(field)
|
|
660
|
-
arg_type = self.net.resolve(arg)
|
|
661
|
-
if not type_matches(arg_type, field_type, is_population_lookup):
|
|
662
|
-
# Do not complain if we can convert the arg to the field type.
|
|
663
|
-
if not conversion_allowed(arg_type, field_type):
|
|
664
|
-
# if the arg is a var and it matches when allowing for super types of
|
|
665
|
-
# the expected we can expect to refine it later; but we add a type
|
|
666
|
-
# requirement to check at the end
|
|
667
|
-
if isinstance(arg, mm.Var) and type_matches(arg_type, field_type, True):
|
|
668
|
-
self.net.add_type_requirement(arg, field)
|
|
669
|
-
else:
|
|
670
|
-
self.net.type_mismatch(node, field_type, arg_type)
|
|
671
|
-
# if we have an abstract var then this field will ultimately propagate to that
|
|
672
|
-
# var's type; also, if this is a population lookup, the type of the population
|
|
673
|
-
# being looked up will flow back to the var
|
|
674
|
-
if isinstance(arg, mm.Var):
|
|
675
|
-
if not field.input:
|
|
676
|
-
self.net.add_edge(field, arg)
|
|
677
|
-
else:
|
|
678
|
-
self.net.add_type_requirement(arg, field)
|
|
671
|
+
def lookup(self, task: mm.Lookup):
|
|
672
|
+
self.compute_potential_targets(task.relation)
|
|
673
|
+
self.visit_rel_op(task)
|
|
679
674
|
|
|
675
|
+
def aggregate(self, task: mm.Aggregate):
|
|
676
|
+
self.visit_rel_op(task)
|
|
680
677
|
|
|
681
|
-
def
|
|
682
|
-
|
|
683
|
-
if rel.overloads or bt.is_placeholder(rel):
|
|
684
|
-
return True
|
|
685
|
-
# there are type vars or numbers in the fields that need specialization
|
|
686
|
-
for field in rel.fields:
|
|
687
|
-
t = self.net.resolve(field)
|
|
688
|
-
if bt.is_type_var(t) or t == b.core.Number:
|
|
689
|
-
return True
|
|
690
|
-
return False
|
|
691
|
-
|
|
678
|
+
def visit_rel_op(self, task: mm.Lookup|mm.Aggregate):
|
|
679
|
+
relation = get_relation(task)
|
|
692
680
|
|
|
693
|
-
def visit_unresolved_reference(self, node: mm.Lookup|mm.Aggregate):
|
|
694
|
-
relation = get_relation(node)
|
|
695
|
-
# functions have their outputs determined by their inputs
|
|
696
|
-
is_function = bt.is_function(relation)
|
|
697
681
|
is_placeholder = bt.is_placeholder(relation)
|
|
698
|
-
|
|
699
|
-
for field, arg in zip(relation.fields, node.args):
|
|
682
|
+
for field, arg in zip(relation.fields, task.args):
|
|
700
683
|
if isinstance(arg, (mm.Var, mm.Literal)):
|
|
701
684
|
if field.input:
|
|
702
|
-
#
|
|
703
|
-
self.net.add_edge(arg,
|
|
685
|
+
# we need to resolve all inputs before resolving the relation
|
|
686
|
+
self.net.add_edge(arg, task)
|
|
704
687
|
else:
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
self.net.add_edge(
|
|
688
|
+
# placeholders also need the output to be resolved
|
|
689
|
+
if is_placeholder:
|
|
690
|
+
self.net.add_edge(arg, task)
|
|
708
691
|
else:
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
if
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
self.net.add_edge(node, arg)
|
|
692
|
+
# args bound to outputs can be resolved after
|
|
693
|
+
self.net.add_edge(task, arg)
|
|
694
|
+
# if the field is abstract, it needs to be resolved before we can
|
|
695
|
+
# resolve the task.
|
|
696
|
+
if bt.is_abstract(self.net.resolve(field)):
|
|
697
|
+
self.net.add_edge(field, task)
|
|
716
698
|
|
|
717
699
|
|
|
718
700
|
#--------------------------------------------------
|
|
@@ -750,7 +732,7 @@ class Replacer(Rewriter):
|
|
|
750
732
|
# TODO - this is only modifying the relation in the model, but then we have a new
|
|
751
733
|
# relation there, which is different than the object referenced by tasks.
|
|
752
734
|
if node in self.net.resolved_types:
|
|
753
|
-
return mm.Field(node.name, self.net.resolved_types[node], node.input)
|
|
735
|
+
return mm.Field(node.name, self.net.resolved_types[node], node.input, _relation = node._relation)
|
|
754
736
|
return node
|
|
755
737
|
|
|
756
738
|
def var(self, node: mm.Var):
|
|
@@ -773,31 +755,40 @@ class Replacer(Rewriter):
|
|
|
773
755
|
|
|
774
756
|
args = types = None
|
|
775
757
|
if node.id in self.net.resolved_placeholder:
|
|
758
|
+
# placeholder resolved to multiple relations
|
|
776
759
|
resolved_relations = self.net.resolved_placeholder[node.id]
|
|
777
760
|
args = get_lookup_args(node, resolved_relations[0])
|
|
778
761
|
types = [f.type for f in resolved_relations[0].fields]
|
|
779
762
|
elif node.id in self.net.resolved_overload:
|
|
763
|
+
# overload resolved to a specific relation
|
|
780
764
|
resolved_relations = [node.relation]
|
|
781
765
|
types = self.net.resolved_overload[node.id].types
|
|
782
766
|
else:
|
|
767
|
+
# single relation
|
|
783
768
|
resolved_relations = [node.relation]
|
|
784
769
|
|
|
785
770
|
if len(resolved_relations) == 1:
|
|
771
|
+
# single relation, just convert arguments
|
|
786
772
|
x = self.convert_arguments(node, resolved_relations[0], args, types)
|
|
787
773
|
if isinstance(x, mm.Logical) and len(x.body) == 1:
|
|
788
774
|
return x.body[0]
|
|
789
775
|
else:
|
|
790
776
|
return x
|
|
791
777
|
|
|
778
|
+
# multiple relations, create a union
|
|
792
779
|
branches:list = []
|
|
793
780
|
for target in resolved_relations:
|
|
794
781
|
args = get_lookup_args(node, target)
|
|
795
782
|
types = [f.type for f in get_relation_fields(resolved_relations[0], node.relation.name)]
|
|
796
783
|
# adding this logical to avoid issues in the old backend
|
|
797
|
-
branches.append(mm.Logical((self.convert_arguments(node, target, args, types=types),)))
|
|
784
|
+
branches.append(mm.Logical((self.convert_arguments(node, target, args, types=types, force_copy=True),)))
|
|
798
785
|
return mm.Union(tuple(branches))
|
|
799
786
|
|
|
800
|
-
def convert_arguments(self, node: mm.Lookup|mm.Update, relation: mm.Relation, args: Iterable[mm.Value]|None=None, types: Iterable[mm.Type]|None=None) -> mm.Logical|mm.Lookup|mm.Update:
|
|
787
|
+
def convert_arguments(self, node: mm.Lookup|mm.Update, relation: mm.Relation, args: Iterable[mm.Value]|None=None, types: Iterable[mm.Type]|None=None, force_copy=False) -> mm.Logical|mm.Lookup|mm.Update:
|
|
788
|
+
""" This node was resolved to target this relation using these args, which should
|
|
789
|
+
have these types. Convert any arguments as needed and return a new node with the
|
|
790
|
+
proper relation and converted args. If multiple conversions are needed, return a
|
|
791
|
+
logical that contains all the conversion tasks plus the final node. """
|
|
801
792
|
args = args or node.args
|
|
802
793
|
types = types or [self.net.resolve(f) for f in relation.fields]
|
|
803
794
|
number_type = self.net.resolved_number.get(node.id)
|
|
@@ -817,15 +808,20 @@ class Replacer(Rewriter):
|
|
|
817
808
|
final_args.append(arg)
|
|
818
809
|
else:
|
|
819
810
|
final_args.append(arg)
|
|
820
|
-
|
|
821
|
-
|
|
811
|
+
# add the original node with the proper target relation and converted args;
|
|
812
|
+
# here we want to use mut because we keep information about nodes based on their ids
|
|
813
|
+
# (e.g. we store resolved types based on ids), so we want to keep the same id. But
|
|
814
|
+
# when a lookup is being converted as part of a union (i.e. multiple targets for a
|
|
815
|
+
# placeholder), we need to create nodes with new ids to avoid conflicts.
|
|
816
|
+
if force_copy:
|
|
817
|
+
tasks.append(node.replace(relation = relation, args = tuple(final_args)))
|
|
822
818
|
else:
|
|
823
819
|
tasks.append(node.mut(relation = relation, args = tuple(final_args)))
|
|
820
|
+
# if we need conversion tasks, wrap in a logical
|
|
824
821
|
if len(tasks) == 1:
|
|
825
822
|
return tasks[0]
|
|
826
823
|
return mm.Logical(tuple(tasks))
|
|
827
824
|
|
|
828
|
-
|
|
829
825
|
def visit_eq_lookup(self, node: mm.Lookup):
|
|
830
826
|
(left, right) = node.args
|
|
831
827
|
left_type = to_type(left)
|
|
@@ -842,7 +838,7 @@ class Replacer(Rewriter):
|
|
|
842
838
|
elif conversion_allowed(right_type, left_type):
|
|
843
839
|
final_args = [left, convert(right, left_type, tasks)]
|
|
844
840
|
else:
|
|
845
|
-
|
|
841
|
+
# this type mismatch was reported during propagation, so just return the node
|
|
846
842
|
return node
|
|
847
843
|
|
|
848
844
|
tasks.append(mm.Lookup(b.core.eq, tuple(final_args)))
|
|
@@ -873,6 +869,15 @@ def get_name(type: mm.Type) -> str:
|
|
|
873
869
|
# Type and Relation helpers
|
|
874
870
|
#--------------------------------------------------
|
|
875
871
|
|
|
872
|
+
def get_update_fields(arg, update: mm.Update) -> Iterable[mm.Field]:
|
|
873
|
+
""" Get the fields of the relation being updated by this arg. Note that an arg can be
|
|
874
|
+
bound to multiple fields at the same time. """
|
|
875
|
+
if arg in update.args:
|
|
876
|
+
for x, field in zip(update.args, update.relation.fields):
|
|
877
|
+
if arg == x:
|
|
878
|
+
yield field
|
|
879
|
+
return []
|
|
880
|
+
|
|
876
881
|
def get_relation_fields(relation: mm.Relation, name: str) -> Iterable[mm.Field]:
|
|
877
882
|
""" Get the fields of this relation, potentially reordered to match the reading with the given name."""
|
|
878
883
|
if name == relation.name:
|
|
@@ -915,18 +920,20 @@ def get_potential_targets(model: mm.Model, placeholder: mm.Relation) -> list[mm.
|
|
|
915
920
|
return list(filter(lambda r: is_potential_target(placeholder, r), model.relations))
|
|
916
921
|
|
|
917
922
|
def to_type(value: mm.Value|mm.Field|mm.Literal) -> mm.Type:
|
|
918
|
-
if isinstance(value, (mm.Var, mm.
|
|
923
|
+
if isinstance(value, (mm.Var, mm.Literal)):
|
|
919
924
|
return value.type
|
|
920
925
|
|
|
921
926
|
if isinstance(value, mm.Type):
|
|
922
927
|
return b.core.Type
|
|
923
928
|
|
|
929
|
+
if isinstance(value, mm.Field):
|
|
930
|
+
return b.core.Field
|
|
931
|
+
|
|
924
932
|
if isinstance(value, tuple):
|
|
925
933
|
return mm.TupleType(element_types=tuple(to_type(v) for v in value))
|
|
926
934
|
|
|
927
935
|
raise TypeError(f"Cannot determine IR type for value: {value} of type {type(value).__name__}")
|
|
928
936
|
|
|
929
|
-
|
|
930
937
|
def convert(value: mm.Var|mm.Literal, to_type: mm.Type, tasks: list[mm.Task]) -> mm.Value:
|
|
931
938
|
# if the arg is a literal, we can just change its type
|
|
932
939
|
# TODO - we may want to check that the value is actually convertible
|
|
@@ -940,8 +947,10 @@ def convert(value: mm.Var|mm.Literal, to_type: mm.Type, tasks: list[mm.Task]) ->
|
|
|
940
947
|
tasks.append(mm.Lookup(b.core.cast, (to_type_base, value, new_value)))
|
|
941
948
|
return new_value
|
|
942
949
|
|
|
943
|
-
|
|
944
950
|
def conversion_allowed(from_type: mm.Type, to_type: mm.Type) -> bool:
|
|
951
|
+
# value type conversion is allowed only if the value types are related by inheritance
|
|
952
|
+
if bt.is_value_type(from_type) and bt.is_value_type(to_type) and not bt.extends(from_type, to_type):
|
|
953
|
+
return False
|
|
945
954
|
# value type conversion is allowed
|
|
946
955
|
x = bt.get_primitive_supertype(from_type)
|
|
947
956
|
y = bt.get_primitive_supertype(to_type)
|
|
@@ -954,7 +963,7 @@ def conversion_allowed(from_type: mm.Type, to_type: mm.Type) -> bool:
|
|
|
954
963
|
|
|
955
964
|
# a number can be converted to another number of larger scale
|
|
956
965
|
if isinstance(from_type, mm.NumberType) and isinstance(to_type, mm.NumberType):
|
|
957
|
-
if to_type.scale
|
|
966
|
+
if to_type.scale >= from_type.scale:
|
|
958
967
|
return True
|
|
959
968
|
|
|
960
969
|
if from_type == b.core.Number and isinstance(to_type, mm.NumberType):
|
|
@@ -990,11 +999,6 @@ def type_matches(actual: mm.Type, expected: mm.Type, accept_expected_super_types
|
|
|
990
999
|
if expected == b.core.TypeVar:
|
|
991
1000
|
return True
|
|
992
1001
|
|
|
993
|
-
# TODO - remove this once we make them singletons per precision/scale
|
|
994
|
-
if isinstance(actual, mm.NumberType) and isinstance(expected, mm.NumberType):
|
|
995
|
-
if actual.precision == expected.precision and actual.scale == expected.scale:
|
|
996
|
-
return True
|
|
997
|
-
|
|
998
1002
|
# if an entity type var or any entity is expected, it matches any actual entity type
|
|
999
1003
|
if (expected == b.core.EntityTypeVar or bt.extends(expected, b.core.AnyEntity)) and not bt.is_primitive(actual):
|
|
1000
1004
|
return True
|
|
@@ -1006,6 +1010,10 @@ def type_matches(actual: mm.Type, expected: mm.Type, accept_expected_super_types
|
|
|
1006
1010
|
if (expected == b.core.Numeric) and bt.is_numeric(actual):
|
|
1007
1011
|
return True
|
|
1008
1012
|
|
|
1013
|
+
# different value types never match
|
|
1014
|
+
if bt.is_value_type(actual) and bt.is_value_type(expected) and not bt.extends(actual, expected):
|
|
1015
|
+
return False
|
|
1016
|
+
|
|
1009
1017
|
# if actual is scalar, any of its parents may match the expected type
|
|
1010
1018
|
if isinstance(actual, mm.ScalarType) and any([type_matches(parent, expected) for parent in actual.super_types]):
|
|
1011
1019
|
return True
|
|
@@ -1049,18 +1057,12 @@ def type_matches(actual: mm.Type, expected: mm.Type, accept_expected_super_types
|
|
|
1049
1057
|
|
|
1050
1058
|
# accept tuples with a single element type to match a list with that type
|
|
1051
1059
|
if isinstance(actual, mm.TupleType) and isinstance(expected, mm.ListType):
|
|
1052
|
-
|
|
1053
|
-
return type_matches(actual.element_types[0], expected.element_type)
|
|
1060
|
+
return type_matches(actual.element_types[0], expected.element_type)
|
|
1054
1061
|
|
|
1055
1062
|
# otherwise no match
|
|
1056
1063
|
return False
|
|
1057
1064
|
|
|
1058
|
-
|
|
1059
|
-
def merge_types(type1: mm.Type, type2: mm.Type) -> mm.Type:
|
|
1060
|
-
if type1 == type2:
|
|
1061
|
-
return type1
|
|
1062
|
-
types_to_process = [type1, type2]
|
|
1063
|
-
|
|
1065
|
+
def merge_numeric_types(type1: mm.Type, type2: mm.Type) -> Optional[mm.Type]:
|
|
1064
1066
|
# if one of them is the abstract Number type, pick the other
|
|
1065
1067
|
if type1 == b.core.Number and isinstance(type2, mm.NumberType):
|
|
1066
1068
|
return type2
|
|
@@ -1077,6 +1079,24 @@ def merge_types(type1: mm.Type, type2: mm.Type) -> mm.Type:
|
|
|
1077
1079
|
# if we are overriding a number with a float, pick float
|
|
1078
1080
|
if isinstance(type1, mm.NumberType) and type2 == b.core.Float:
|
|
1079
1081
|
return type2
|
|
1082
|
+
if isinstance(type2, mm.NumberType) and type1 == b.core.Float:
|
|
1083
|
+
return type1
|
|
1084
|
+
|
|
1085
|
+
return None
|
|
1086
|
+
|
|
1087
|
+
def merge_types(type1: mm.Type, type2: mm.Type) -> mm.Type:
|
|
1088
|
+
if type1 == type2:
|
|
1089
|
+
return type1
|
|
1090
|
+
if bt.is_type_var(type1):
|
|
1091
|
+
return type2
|
|
1092
|
+
if bt.is_type_var(type2):
|
|
1093
|
+
return type1
|
|
1094
|
+
|
|
1095
|
+
types_to_process = [type1, type2]
|
|
1096
|
+
|
|
1097
|
+
numeric_merge = merge_numeric_types(type1, type2)
|
|
1098
|
+
if numeric_merge is not None:
|
|
1099
|
+
return numeric_merge
|
|
1080
1100
|
|
|
1081
1101
|
# if one extends the other, pick the most specific one
|
|
1082
1102
|
if bt.extends(type1, type2):
|
|
@@ -1093,6 +1113,13 @@ def merge_types(type1: mm.Type, type2: mm.Type) -> mm.Type:
|
|
|
1093
1113
|
elif bt.is_primitive(type2):
|
|
1094
1114
|
return type1
|
|
1095
1115
|
|
|
1116
|
+
if base_primitive_type1 and base_primitive_type2:
|
|
1117
|
+
numeric_merge = merge_numeric_types(base_primitive_type1, base_primitive_type2)
|
|
1118
|
+
if numeric_merge == base_primitive_type1:
|
|
1119
|
+
return type1
|
|
1120
|
+
elif numeric_merge == base_primitive_type2:
|
|
1121
|
+
return type2
|
|
1122
|
+
|
|
1096
1123
|
combined = OrderedSet()
|
|
1097
1124
|
# Iterative flattening of union types
|
|
1098
1125
|
while types_to_process:
|