relationalai 1.0.0a1__py3-none-any.whl → 1.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. relationalai/semantics/frontend/base.py +3 -0
  2. relationalai/semantics/frontend/front_compiler.py +5 -2
  3. relationalai/semantics/metamodel/builtins.py +2 -1
  4. relationalai/semantics/metamodel/metamodel.py +32 -4
  5. relationalai/semantics/metamodel/pprint.py +5 -3
  6. relationalai/semantics/metamodel/typer.py +324 -297
  7. relationalai/semantics/std/aggregates.py +0 -1
  8. relationalai/semantics/std/datetime.py +4 -1
  9. relationalai/shims/executor.py +26 -5
  10. relationalai/shims/mm2v0.py +119 -44
  11. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/METADATA +1 -1
  12. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/RECORD +57 -48
  13. v0/relationalai/__init__.py +69 -22
  14. v0/relationalai/clients/__init__.py +15 -2
  15. v0/relationalai/clients/client.py +4 -4
  16. v0/relationalai/clients/local.py +5 -5
  17. v0/relationalai/clients/resources/__init__.py +8 -0
  18. v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
  19. v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
  20. v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
  21. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
  22. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
  23. v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
  24. v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
  25. v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +606 -1392
  26. v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
  27. v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
  28. v0/relationalai/clients/resources/snowflake/util.py +387 -0
  29. v0/relationalai/early_access/dsl/ir/executor.py +4 -4
  30. v0/relationalai/early_access/dsl/snow/api.py +2 -1
  31. v0/relationalai/errors.py +23 -0
  32. v0/relationalai/experimental/solvers.py +7 -7
  33. v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
  34. v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
  35. v0/relationalai/semantics/internal/internal.py +4 -4
  36. v0/relationalai/semantics/internal/snowflake.py +3 -2
  37. v0/relationalai/semantics/lqp/executor.py +20 -22
  38. v0/relationalai/semantics/lqp/model2lqp.py +42 -4
  39. v0/relationalai/semantics/lqp/passes.py +1 -1
  40. v0/relationalai/semantics/lqp/rewrite/cdc.py +1 -1
  41. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +53 -12
  42. v0/relationalai/semantics/metamodel/builtins.py +8 -6
  43. v0/relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
  44. v0/relationalai/semantics/metamodel/util.py +6 -5
  45. v0/relationalai/semantics/reasoners/graph/core.py +8 -9
  46. v0/relationalai/semantics/rel/executor.py +14 -11
  47. v0/relationalai/semantics/sql/compiler.py +2 -2
  48. v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
  49. v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
  50. v0/relationalai/tools/cli.py +26 -30
  51. v0/relationalai/tools/cli_helpers.py +10 -2
  52. v0/relationalai/util/otel_configuration.py +2 -1
  53. v0/relationalai/util/otel_handler.py +1 -1
  54. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/WHEEL +0 -0
  55. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/entry_points.txt +0 -0
  56. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/top_level.txt +0 -0
  57. /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
@@ -81,17 +81,21 @@ class Typer:
81
81
  # Propagation Network
82
82
  #--------------------------------------------------
83
83
 
84
- # The core idea of the typer is to build a propagation network where nodes
85
- # are vars, fields, or overloaded lookups/updates/aggregates. The intuition
86
- # is that _all_ types in the IR ultimately flow from relation fields, so if
87
- # we figure those out we just need to propagate their types to unknown vars, which
88
- # may then flow into other fields and so on.
89
-
90
- # This means the network only needs to contain nodes that either directly flow into
91
- # an abstract node or are themselves abstract. We need to track overloads because
92
- # their arguments effectively act like abstract vars until we've resolved the final types.
93
-
94
- Node = Union[mm.Var, mm.Field, mm.Literal, mm.Lookup, mm.Update, mm.Aggregate]
84
+ # The core idea of the typer is to build a propagation network that represents data
85
+ # dependencies between nodes. Nodes can be variables, literals, fields and tasks that refer
86
+ # to relations (lookups, aggregates, updates). An edge from node A to node B means that
87
+ # in order to resolve references and types for node B we need to know the type of node A.
88
+ #
89
+ # After building the network, we start with roots that have known types (literals, fields
90
+ # loaded from previous analysis and vars with concrete types) and propagate their types via
91
+ # the edges to other nodes. It is possible that the network contains cycles, so we iterate
92
+ # on the work list until fixpoint is reached, i.e. when we cannot propagate any new type.
93
+ #
94
+ # During the propagation, we are only gathering information about the types and references
95
+ # of nodes. Once we reach a fixpoint, we do a final pass to rewrite the model with the new
96
+ # type information, which may include adding casts to convert between types where needed.
97
+ #
98
+ Node = Union[mm.Var, mm.Field, mm.Literal, mm.Lookup, mm.Aggregate, mm.Update]
95
99
 
96
100
  class PropagationNetwork():
97
101
  def __init__(self, model: mm.Model):
@@ -104,22 +108,12 @@ class PropagationNetwork():
104
108
  # map from unresolved placeholder relations to their potential target replacements
105
109
  self.potential_targets: dict[mm.Relation, list[mm.Relation]] = {}
106
110
 
107
- # track the set of nodes that represent entry points into the network
108
- self.roots = OrderedSet()
109
- # we separately want to track nodes that were loaded from a previous run
110
- # so that even if we have edges to them, we _still_ consider them roots
111
- # and properly propagate types from them at the beginning
111
+ # nodes loaded from previous analysis, which we use as roots to start propagation
112
112
  self.loaded_roots = set()
113
113
 
114
114
  # edges in the propagation network, from one node to potentially many
115
115
  self.edges:dict[Node, OrderedSet[Node]] = defaultdict(lambda: OrderedSet())
116
116
  self.back_edges:dict[Node, OrderedSet[Node]] = defaultdict(lambda: OrderedSet())
117
- # all nodes that are the target of an edge (to find roots)
118
- self.has_incoming = set()
119
-
120
- # type requirements: for a var with abstract declared type, the set of fields that
121
- # it must match the type of because it flows into them
122
- self.type_requirements:dict[mm.Var, OrderedSet[mm.Field]] = defaultdict(lambda: OrderedSet())
123
117
 
124
118
  # all errors collected during inference
125
119
  self.errors:list[TyperError] = []
@@ -127,22 +121,22 @@ class PropagationNetwork():
127
121
  # overloads resolved for a lookup/update/aggregate, by node id. This is only for
128
122
  # relations that declare overloads
129
123
  self.resolved_overload:dict[int, mm.Overload] = {}
124
+
130
125
  # placeholders resolved for a lookup, by node id. This is only for relations that
131
126
  # are placeholders (i.e. only Any fields) and will be replaced by references to
132
127
  # these concrete relations. E.g. a query for "name(Any, Any)" may be replaced by
133
128
  # the union of "name(Dog, String)" and name(Cat, String)".
134
129
  self.resolved_placeholder:dict[int, list[mm.Relation]] = {}
130
+
135
131
  # for a given lookup/update/aggregate that involves numbers, the specific number
136
132
  # type resolved for it.
137
133
  self.resolved_number:dict[int, mm.NumberType] = {}
138
- # keep track of nodes already resolved to avoid re-resolving
139
- self.resolved_nodes:set[int] = set()
140
134
 
141
135
  #--------------------------------------------------
142
136
  # Error reporting
143
137
  #--------------------------------------------------
144
138
 
145
- def type_mismatch(self, node: Node, expected: mm.Type, actual: mm.Type):
139
+ def type_mismatch(self, node: Node|mm.Update, expected: mm.Type, actual: mm.Type):
146
140
  self.errors.append(TypeMismatch(node, expected, actual))
147
141
 
148
142
  def invalid_type(self, node: Node, type: mm.Type):
@@ -153,6 +147,8 @@ class PropagationNetwork():
153
147
  self.errors.append(UnresolvedOverload(node, [self.resolve(a) for a in node.args]))
154
148
 
155
149
  def unresolved_type(self, node: Node):
150
+ # TODO - this is not being used yet, we need a pass at the end to check for any node
151
+ # that we could not resolve
156
152
  self.errors.append(UnresolvedType(node))
157
153
 
158
154
  def has_errors(self, node: Node) -> bool:
@@ -166,15 +162,8 @@ class PropagationNetwork():
166
162
  #--------------------------------------------------
167
163
 
168
164
  def add_edge(self, source: Node, target: Node):
169
- # manage roots
170
- if target in self.roots and target not in self.loaded_roots:
171
- self.roots.remove(target)
172
- if source not in self.has_incoming:
173
- self.roots.add(source)
174
- # register edge
175
165
  self.edges[source].add(target)
176
166
  self.back_edges[target].add(source)
177
- self.has_incoming.add(target)
178
167
 
179
168
  def add_resolved_type(self, node: Node, type: mm.Type):
180
169
  """ Register that this node was resolved to have this type. """
@@ -183,13 +172,8 @@ class PropagationNetwork():
183
172
  else:
184
173
  self.resolved_types[node] = type
185
174
 
186
- def add_type_requirement(self, source: mm.Var, field: mm.Field):
187
- """ Register that this var, which has an abstract declared type, must match the type
188
- of this field as it flows into it.. """
189
- self.type_requirements[source].add(field)
190
-
191
175
  #--------------------------------------------------
192
- # Load previous types
176
+ # Load types from a previous analysis
193
177
  #--------------------------------------------------
194
178
 
195
179
  def load_types(self, type_dict: dict[Node, mm.Type]):
@@ -197,22 +181,39 @@ class PropagationNetwork():
197
181
  if isinstance(node, (mm.Field)):
198
182
  self.add_resolved_type(node, type)
199
183
  self.loaded_roots.add(node)
200
- self.roots.add(node)
201
184
 
202
185
  #--------------------------------------------------
203
186
  # Resolve Values
204
187
  #--------------------------------------------------
205
188
 
206
- def resolve(self, value: Node|mm.Value) -> mm.Type:
207
- if isinstance(value, (mm.Var, mm.Field, mm.Literal)):
189
+ def resolve(self, value: mm.Value) -> mm.Type:
190
+ if isinstance(value, (mm.Var, mm.Literal)):
208
191
  return self.resolved_types.get(value) or to_type(value)
209
- assert not isinstance(value, (mm.Lookup, mm.Update, mm.Aggregate)), "Should never try to resolve a task"
192
+ if isinstance(value, mm.Field):
193
+ return self.resolved_types.get(value) or value.type
194
+ assert not isinstance(value, (mm.Task)), "Should never try to resolve a task"
210
195
  return to_type(value)
211
196
 
212
197
  #--------------------------------------------------
213
198
  # Resolve References
214
199
  #--------------------------------------------------
215
200
 
201
+ def all_dependencies_resolved(self, op:mm.Lookup|mm.Aggregate):
202
+ """ True iff all dependencies required to resolve this reference are met. """
203
+ rel = get_relation(op)
204
+ # if this is a placeholder, eq or cast, we assume all possible args were resolved
205
+ if bt.is_placeholder(rel) or rel == b.core.eq or rel == b.core.cast:
206
+ return True
207
+
208
+ # else, find whether all back-edges were resolved
209
+ for node in self.back_edges[op]:
210
+ # cannot resolve if a required var, literal or input field is still abstract
211
+ if isinstance(node, (mm.Var, mm.Literal)) or (isinstance(node, mm.Field) and node.input):
212
+ node_type = self.resolve(node)
213
+ if bt.is_abstract(node_type):
214
+ return False
215
+ return True
216
+
216
217
  def resolve_reference(self, op: mm.Lookup|mm.Aggregate) -> Optional[mm.Overload|list[mm.Relation]]:
217
218
  # check if all dependencies required to resolve this reference are met
218
219
  if not self.all_dependencies_resolved(op):
@@ -228,7 +229,6 @@ class PropagationNetwork():
228
229
  if all(type_matches(arg, self.resolve(field))
229
230
  for arg, field in zip(resolved_args, fields)):
230
231
  matches.append(target)
231
-
232
232
  return matches
233
233
 
234
234
  elif relation.overloads:
@@ -243,24 +243,11 @@ class PropagationNetwork():
243
243
  self.resolved_overload[op.id] = overload
244
244
  return overload
245
245
  return [] # no matches found
246
+
246
247
  else:
247
248
  # this is a relation with type vars or numbers that needs to be specialized
248
249
  return [relation]
249
250
 
250
-
251
- def all_dependencies_resolved(self, op:mm.Lookup|mm.Aggregate):
252
- # if this is a placeholder, we need assume all possible args were resolved
253
- if bt.is_placeholder(get_relation(op)):
254
- return True
255
- # else, find whether all back-edges were resolved
256
- for node in self.back_edges[op]:
257
- if isinstance(node, (mm.Var, mm.Field, mm.Literal)):
258
- node_type = self.resolve(node)
259
- if bt.is_abstract(node_type):
260
- return False
261
- return True
262
-
263
-
264
251
  #--------------------------------------------------
265
252
  # Propagation
266
253
  #--------------------------------------------------
@@ -269,73 +256,112 @@ class PropagationNetwork():
269
256
  edges = self.edges
270
257
  work_list = []
271
258
 
272
- # go through all the roots and find any that are not abstract, they'll be the first
273
- # nodes to push types through the network
274
- unhandled_roots = OrderedSet()
275
- for node in self.roots:
276
- if not isinstance(node, (mm.Var, mm.Field, mm.Literal)):
277
- continue
278
- node_type = self.resolve(node)
279
- if not bt.is_abstract(node_type):
280
- work_list.append(node)
281
- else:
282
- unhandled_roots.add(node)
259
+ # start with the loaded roots + all literals + sources of edges without back edges
260
+ work_list.extend(self.loaded_roots)
261
+ for source in self.edges.keys():
262
+ if isinstance(source, (mm.Literal)) or not source in self.back_edges:
263
+ work_list.append(source)
283
264
 
284
- # push known type nodes through the edges
285
- while work_list:
286
- source = work_list.pop(0)
287
- self.resolved_nodes.add(source.id)
288
- if source in unhandled_roots:
289
- unhandled_roots.remove(source)
290
- source_type = self.resolve(source)
291
- # check to see if the source has ended up with a set of types that
292
- # aren't valid, e.g. a union of primitives
293
- if invalid_type(source_type):
294
- self.invalid_type(source, source_type)
295
-
296
- # propagate our type to each outgoing edge
297
- for out in edges.get(source, []):
298
- # if this is an overload then we need to try and resolve it
299
- if isinstance(out, (mm.Lookup, mm.Aggregate)):
300
- if not out.id in self.resolved_nodes:
301
- found = self.resolve_reference(out)
302
- if found is not None:
303
- self.resolved_nodes.add(out.id)
304
- self.propagate_reference(out, found)
305
- for arg in out.args:
306
- if arg not in work_list:
307
- work_list.append(arg)
308
- # otherwise, we just add to the outgoing node's type and if it
309
- # changes we add it to the work list
310
- elif start := self.resolve(out):
311
- self.add_resolved_type(out, source_type)
312
- if out not in work_list and (start != self.resolve(out) or not out.id in self.resolved_nodes):
313
- work_list.append(out)
314
-
315
- for source in unhandled_roots:
316
- self.unresolved_type(source)
317
-
318
- # now that we've pushed all the types through the network, we need to validate
319
- # that all type requirements of those nodes are met
320
- for node, fields in self.type_requirements.items():
321
- node_type = self.resolve(node)
322
- for field in fields:
323
- field_type = self.resolve(field)
324
- if not type_matches(node_type, field_type) and not conversion_allowed(node_type, field_type):
325
- self.type_mismatch(node, field_type, node_type)
265
+ # limit the number of iterations to avoid infinite loops
266
+ i = 0
267
+ max_iterations = 100 * len(self.edges)
326
268
 
269
+ # propagate types until we reach a fixed point
270
+ while work_list:
271
+ i += 1
272
+ if i > max_iterations:
273
+ err("Infinite Loop", "Infinite loop detected in the typer. Please, report this as a bug.")
274
+ break
327
275
 
328
- def propagate_reference(self, task:mm.Lookup|mm.Aggregate, references:mm.Overload|list[mm.Relation]):
276
+ node = work_list.pop(0)
277
+ next = None
278
+ if isinstance(node, mm.Field):
279
+ # this is a field loaded from a previous analysis, so all we need to do is
280
+ # propagate its type to its output edges
281
+ next = edges.get(node, [])
282
+
283
+ elif isinstance(node, (mm.Lookup, mm.Aggregate)):
284
+ # this is a lookup/aggregate that may be overloaded or a placeholder; try
285
+ # to resolve its reference, i.e. determine which specific relation or
286
+ # overload it refers to
287
+ found = self.resolve_reference(node)
288
+ if found is not None:
289
+ # if found is None it means that we need more info to resolve the
290
+ # reference; otherwise it was possible to resolve it (even if to no matches)
291
+
292
+ # keep the resolved args before propagation to see if they will change
293
+ resolved_args = [self.resolve(arg) for arg in node.args]
294
+ # propagate the reference resolution
295
+ self.propagate_reference(node, resolved_args, found)
296
+ if found:
297
+ # the next nodes to process are all the outgoing edges plus any arg that
298
+ # changed during propagation
299
+ next = OrderedSet()
300
+ next.update(edges.get(node, []))
301
+ next.update([arg for idx, arg in enumerate(node.args) if isinstance(arg, mm.Var) and not (self.resolve(arg) == resolved_args[idx])])
302
+ else:
303
+ # the reference is unresolved, so remove from the worklist the args
304
+ # because they depend on this being resolved
305
+ for arg in node.args:
306
+ if arg in work_list:
307
+ work_list.remove(arg)
308
+ else:
309
+ assert isinstance(node, (mm.Var, mm.Literal))
310
+ resolved = self.resolve(node)
311
+ # if we ended up with an invalid type, report it on a next edge (which is
312
+ # where the var or literal is going to be used)
313
+ if invalid_type(resolved):
314
+ self.invalid_type(edges.get(node, [node])[0], resolved)
315
+
316
+ # if the var is still abstract, add it to the work list to try again later
317
+ if isinstance(node, mm.Var) and bt.is_abstract(resolved):
318
+ # but if everything that is left are vars, we cannot make progress
319
+ if not all(isinstance(x, mm.Var) for x in work_list):
320
+ next = [node]
321
+ else:
322
+ # otherwise, check all outgoing edges
323
+ next = []
324
+ for out in edges.get(node, []):
325
+ if isinstance(out, mm.Update):
326
+ # this is an update to a field. We have to check that the type is
327
+ # valid for the field and if the field is abstract we can refine the
328
+ # type. But if the field is concrete we have to check that the type
329
+ # matche or can be converted.
330
+ is_population_lookup = isinstance(out.relation, mm.TypeNode)
331
+ for field in get_update_fields(node, out):
332
+ if not type_matches(resolved, field.type, accept_expected_super_types=is_population_lookup) and not conversion_allowed(resolved, field.type):
333
+ self.type_mismatch(out, field.type, resolved)
334
+ elif bt.is_abstract(field.type):
335
+ # if the field type changed, propagate further
336
+ if resolved != self.resolve(field):
337
+ next.append(field)
338
+ self.add_resolved_type(field, resolved)
339
+ else:
340
+ # this is an arg flowing into a task, so resolve the task next
341
+ next.append(out)
342
+
343
+ # add the new nodes to the work list
344
+ if next is not None:
345
+ for n in next:
346
+ if n not in work_list:
347
+ work_list.append(n)
348
+
349
+
350
+ def propagate_reference(self, task:mm.Lookup|mm.Aggregate, resolved_args:list[mm.Type], references:mm.Overload|list[mm.Relation]):
329
351
  # TODO: distinguish between overloads and placeholders better when raising errors
330
352
  if not references:
331
353
  return self.unresolved_overload(task)
332
354
 
333
- resolved_args = [self.resolve(arg) for arg in task.args]
334
-
335
- # we need to determine the final types of our args by taking all the references
336
- # and adding the type of their fields back to the args.
337
355
  relation = get_relation(task)
338
356
 
357
+ if relation == b.core.cast:
358
+ # cast is special cased: we just propagate the type to the target
359
+ cast_type, target = task.args[0], task.args[2]
360
+ assert isinstance(cast_type, mm.Type)
361
+ assert isinstance(target, mm.Var)
362
+ self.add_resolved_type(target, cast_type)
363
+ return
364
+
339
365
  if bt.is_placeholder(relation):
340
366
  assert(references and isinstance(references, list))
341
367
  # we've resolved the placeholder, so store that
@@ -357,22 +383,25 @@ class PropagationNetwork():
357
383
  # number specialization, so use that relation's field types
358
384
  types = list([self.resolve(f) for f in relation.fields])
359
385
 
360
- # if our overload preserves types, we check to see if there's a preserved
361
- # output type given the inputs and if so, shadow the field's type with the
362
- # preserved type
386
+ # if our overload preserves types, we check to see if there's a preserved output
387
+ # type given the inputs and if so, shadow the field type with the preserved type.
388
+ # this is only attempted if all input types match the field types, i.e. no
389
+ # conversions are needed
363
390
  resolved_fields = types
364
- if bt.is_function(relation) and len(set(resolved_fields)) == 1:
391
+ if bt.is_function(relation) and len(set(resolved_fields)) == 1 and not relation in self.NON_TYPE_PRESERVERS and\
392
+ all(type_matches(arg_type, field_type) for arg_type, field_type, field in zip(resolved_args, types, relation.fields) if field.input):
393
+
365
394
  input_types = set([arg_type for field, arg_type
366
395
  in zip(relation.fields, resolved_args) if field.input])
367
396
  if out_type := self.try_preserve_type(input_types):
368
397
  resolved_fields = [field_type if field.input else out_type
369
398
  for field, field_type in zip(relation.fields, types)]
370
399
 
371
- # TODO - we also need to make sure the type vars are constently resolved here
372
- # i.e. if types contain typevars, check that the args that are bound to those
373
- # typevars are consistent
374
- if b.core.Number in types or (b.core.TypeVar in types and any(bt.is_number(t) for t in resolved_args)):
375
- # this overload contains generic numbers or typevars bound to numbers, so
400
+
401
+ # eq is special cased because we don't want to specialize a number for it, as it
402
+ # can be just comparing numbers of different types.
403
+ if relation != b.core.eq and (b.core.Number in types or (b.core.TypeVar in types and any(bt.is_number(t) for t in resolved_args))):
404
+ # this relation contains generic numbers or typevars bound to numbers, so
376
405
  # find which specific type of number to use given the arguments being passed
377
406
  number, resolved_fields = self.specialize_number(relation, resolved_fields, resolved_args)
378
407
  self.resolved_number[task.id] = number
@@ -380,16 +409,62 @@ class PropagationNetwork():
380
409
  for field, field_type, arg in zip(relation.fields, resolved_fields, task.args):
381
410
  if not field.input and isinstance(arg, mm.Var):
382
411
  self.add_resolved_type(arg, field_type)
412
+
413
+ elif b.core.TypeVar in types:
414
+ # this relation contains type vars, so we have to make sure that all args
415
+ # bound to the same type var are consistent
416
+
417
+ # find which arg is bound to the type var and check that they are all consistent
418
+ typevar_type = None
419
+ for arg_type, field_type in zip(resolved_args, types):
420
+ if field_type == b.core.TypeVar:
421
+ if typevar_type is None:
422
+ typevar_type = arg_type
423
+ else:
424
+ typevar_type = merge_types(typevar_type, arg_type)
425
+ assert typevar_type is not None
426
+
427
+ # compute the final arg types by replacing type vars with the typevar_type
428
+ computed_arg_types = []
429
+ for arg_type, field_type in zip(resolved_args, types):
430
+ # check that the arg type matches the type var type
431
+ if not type_matches(typevar_type, arg_type) and not conversion_allowed(arg_type, typevar_type):
432
+ self.type_mismatch(task, typevar_type, arg_type)
433
+ if field_type == b.core.TypeVar:
434
+ computed_arg_types.append(typevar_type)
435
+ else:
436
+ computed_arg_types.append(arg_type)
437
+
438
+ # if no mismatches were found, propagate the computed arg types back to the args
439
+ if len(computed_arg_types) == len(task.args):
440
+ for computed_type, arg, arg_type in zip(computed_arg_types, task.args, resolved_args):
441
+ # TODO: we could allow for non-numeric/string literals because this means
442
+ # the literal is being used as a value type, but the backend emitters
443
+ # would have to deal with that.
444
+ if isinstance(arg, mm.Var) or (isinstance(arg, mm.Literal) and (bt.is_numeric(computed_type) or computed_type == b.core.String)):
445
+ self.add_resolved_type(arg, computed_type)
446
+
383
447
  else:
384
- for field_type, arg, arg_type in zip(resolved_fields, task.args, resolved_args):
385
- if bt.is_abstract(arg_type) and isinstance(arg, mm.Var):
448
+ # no typevar or number specialization shenanigans, just propagate field types to args
449
+ for field, field_type, arg, arg_type in zip(relation.fields, resolved_fields, task.args, resolved_args):
450
+ # add a resolved type if we learned more about the arg's type
451
+ if isinstance(arg, mm.Var) and (bt.is_abstract(arg_type) or bt.extends(field_type, arg_type)):
386
452
  self.add_resolved_type(arg, field_type)
453
+ elif not type_matches(arg_type, field_type) and not conversion_allowed(arg_type, field_type):
454
+ self.type_mismatch(task, field_type, arg_type)
455
+
387
456
 
457
+ # we try to preserve types for relations that are functions (i.e. potentially multiple
458
+ # input but a single output) and where all types are the same. However, there are some
459
+ # exceptions to this rule, e.g. range() always returns int regardless of whether the
460
+ # input type is an int or a value type that extends int.
461
+ NON_TYPE_PRESERVERS = [
462
+ b.common.range
463
+ ]
388
464
 
389
465
  def try_preserve_type(self, types:set[mm.Type]) -> Optional[mm.Type]:
390
- # we keep the input type as the output type if either all inputs
391
- # are the exact same type or there's one nominal and its base primitive
392
- # type, e.g. USD + Decimal
466
+ # we keep the input type as the output type if either all inputs are the exact same
467
+ # type or there's one nominal and its base primitive type, e.g. USD + Decimal
393
468
  if len(types) == 1:
394
469
  return next(iter(types))
395
470
  if len(types) == 2:
@@ -411,13 +486,15 @@ class PropagationNetwork():
411
486
 
412
487
  def specialize_number(self, op, field_types:list[mm.Type], arg_types:list[mm.Type]) -> Tuple[mm.NumberType, list[mm.Type]]:
413
488
  """
414
- Find the number type to use for an overload that has Number in its field_types,
415
- and which is being referred to with these arg_types.
489
+ Find the number type to use for an overload that has Number in its field_types, and
490
+ which is being referred to with these arg_types.
416
491
 
417
492
  Return a tuple where the first element is the specialized number type, and the second
418
493
  element is a new list that contains the same types as field_types but with
419
494
  Number replaced by this specialized number.
420
495
  """
496
+ # special case a few operators according to Snowflake's rules in
497
+ # https://docs.snowflake.com/en/sql-reference/operators-arithmetic#scale-and-precision-in-arithmetic-operations
421
498
  if op == b.core.div:
422
499
  # see https://docs.snowflake.com/en/sql-reference/operators-arithmetic#division
423
500
  numerator, denominator = get_number_type(arg_types[0]), get_number_type(arg_types[1])
@@ -436,29 +513,21 @@ class PropagationNetwork():
436
513
  # TODO!! - implement proper avg specialization
437
514
  pass
438
515
 
516
+ # fall back to the the current specialization policy, which is to select the number
517
+ # with largest scale and, if there multiple with the largest scale, the one with the
518
+ # largest precision. This is safe because when converting a number to the
519
+ # specialized number, we never truncate fractional digits (because we selected the
520
+ # largest scale) and, if the non-fractional digits are too large to fit the
521
+ # specialized number, we will have a runtime overflow, which should alert the user
522
+ # of the problem.
439
523
  number = None
440
524
  for arg_type in arg_types:
441
525
  x = bt.get_number_supertype(arg_type)
442
526
  if isinstance(x, mm.NumberType):
443
- # the current specialization policy is to select the number with largest
444
- # scale and, if there multiple with the largest scale, the one with the
445
- # largest precision. This is safe because when converting a number to the
446
- # specialized number, we never truncate fractional digits (because we
447
- # selected the largest scale) and, if the non-fractional digits are too
448
- # large to fit the specialized number, we will have a runtime overflow,
449
- # which should alert the user of the problem.
450
- #
451
- # In the future we can implement more complex policies. For example,
452
- # snowflake has well documented behavior for how the output of operations
453
- # behave in face of different number types, and we may use that:
454
- # https://docs.snowflake.com/en/sql-reference/operators-arithmetic#scale-and-precision-in-arithmetic-operations
455
527
  if number is None or x.scale > number.scale or (x.scale == number.scale and x.precision > number.precision):
456
528
  number = x
457
- if number is None:
458
- number = b.core.DefaultNumber
459
- # assert(isinstance(number, mm.NumberType))
460
- return number, [number if bt.is_number(field_type) else field_type
461
- for field_type in field_types]
529
+ assert(number is not None)
530
+ return number, [number if bt.is_number(t) else t for t in field_types]
462
531
 
463
532
 
464
533
  #--------------------------------------------------
@@ -467,7 +536,6 @@ class PropagationNetwork():
467
536
 
468
537
  # draw the network as a mermaid graph for the debugger
469
538
  def to_mermaid(self, max_edges=500) -> str:
470
-
471
539
  # add links for edges while collecting nodes
472
540
  nodes = OrderedSet()
473
541
  link_strs = []
@@ -481,22 +549,12 @@ class PropagationNetwork():
481
549
  link_strs.append(f"n{src.id} --> n{dst.id}")
482
550
  if len(link_strs) > max_edges:
483
551
  break
484
- # type requirements
485
- for src, dsts in self.type_requirements.items():
486
- nodes.add(src)
487
- for dst in dsts:
488
- if len(link_strs) > max_edges:
489
- break
490
- nodes.add(dst)
491
- link_strs.append(f"n{src.id} -.-> n{dst.id}")
492
- if len(link_strs) > max_edges:
493
- break
494
552
 
495
553
  def type_span(t:mm.Type) -> str:
496
554
  type_str = t.name if isinstance(t, mm.ScalarType) else str(t)
497
555
  return f"<span style='color:cyan;'>{type_str.strip()}</span>"
498
556
 
499
- def reference_span(rel:mm.Relation, arg_types:list[mm.Type], root:str) -> str:
557
+ def reference_span(rel:mm.Relation, arg_types:list[mm.Type]) -> str:
500
558
  args = []
501
559
  for field, arg_type in zip(rel.fields, arg_types):
502
560
  field_type = self.resolve(field)
@@ -506,22 +564,21 @@ class PropagationNetwork():
506
564
  args.append(type_span(field_type))
507
565
  else:
508
566
  args.append(type_span(arg_type))
509
- return f'{rel.name}{root}({", ".join(args)})'
567
+ return f'{rel.name}({", ".join(args)})'
510
568
 
511
569
  resolved = self.resolved_types
512
570
  node_strs = []
513
571
  for node in nodes:
514
572
  klass = ""
515
- root = "(*)" if node in self.roots else ""
516
573
  if isinstance(node, mm.Var):
517
574
  ir_type = resolved.get(node) or self.resolve(node)
518
575
  type_str = type_span(ir_type)
519
- label = f'(["{node.name}{root}:{type_str}"])'
576
+ label = f'(["{node.name}:{type_str}"])'
520
577
  elif isinstance(node, mm.Literal):
521
578
  ir_type = resolved.get(node) or self.resolve(node)
522
579
  type_str = type_span(ir_type)
523
580
  klass = ":::literal"
524
- label = f'[/"{node.value}{root}: {type_str}"\\]'
581
+ label = f'(["{node.value}: {type_str}"])'
525
582
  elif isinstance(node, mm.Field):
526
583
  ir_type = resolved.get(node) or self.resolve(node)
527
584
  type_str = type_span(ir_type)
@@ -529,19 +586,22 @@ class PropagationNetwork():
529
586
  rel = node._relation
530
587
  if rel is not None:
531
588
  rel = str(node._relation)
532
- label = f'{{{{"{node.name}{root}:{type_str}\nfrom {rel}"}}}}'
589
+ label = f'[/"{node.name}:{type_str}\nfrom {rel}"\\]'
533
590
  else:
534
- label = f'{{{{"{node.name}{root}:\n{type_str}"}}}}'
535
- elif isinstance(node, (mm.Lookup, mm.Update, mm.Aggregate)):
591
+ label = f'[/"{node.name}:\n{type_str}"\\]'
592
+ elif isinstance(node, (mm.Lookup, mm.Aggregate, mm.Update)):
536
593
  arg_types = [self.resolve(arg) for arg in node.args]
537
594
  if node.id in self.resolved_placeholder:
538
595
  overloads = self.resolved_placeholder[node.id]
539
- content = "<br/>".join([reference_span(o, arg_types, root) for o in overloads])
596
+ content = "<br/>".join([reference_span(o, arg_types) for o in overloads])
597
+ else:
598
+ content = reference_span(get_relation(node), arg_types)
599
+ if isinstance(node, mm.Update):
600
+ klass = ":::update"
601
+ label = f'{{{{"{content}"}}}}'
540
602
  else:
541
- content = reference_span(get_relation(node), arg_types, root)
542
- label = f'[/"{content}"/]'
543
- # elif isinstance(node, mm.Relation):
544
- # label = f'[("{node}")]'
603
+ klass = ":::reference"
604
+ label = f'[/"{content}"/]'
545
605
  else:
546
606
  raise NotImplementedError(f"Unknown node type: {type(node)}")
547
607
  if self.has_errors(node):
@@ -555,6 +615,7 @@ class PropagationNetwork():
555
615
  flowchart TD
556
616
  linkStyle default stroke:#666
557
617
  classDef field fill:#245,stroke:#478
618
+ classDef update fill:#245,stroke:#478
558
619
  classDef literal fill:#452,stroke:#784
559
620
  classDef error fill:#624,stroke:#945,color:#f9a
560
621
  classDef default stroke:#444,stroke-width:2px, font-size:12px
@@ -598,121 +659,42 @@ class Analyzer(Walker):
598
659
  rel = node.relation
599
660
  self.compute_potential_targets(rel)
600
661
 
601
- # if this is a type relation, the update is asserting that the argument is of that
602
- # type; so, it's fine to pass a super-type in to the population e.g. Employee(Person)
603
- # should be a valid way to populate that a particular Person is also an Employee.
604
- is_type_relation = isinstance(rel, mm.TypeNode)
662
+ # arg is flowing into a field
605
663
  for arg, field in zip(node.args, rel.fields):
606
- field_type = field.type
607
- arg_type = self.net.resolve(arg)
608
-
609
- # if the arg is abstract, but the field isn't, then we need to make sure that
610
- # once the arg is resolved we check that it matches the field type
611
- if isinstance(arg, mm.Var) and bt.is_abstract(arg_type) and bt.is_concrete(field_type):
612
- self.net.add_type_requirement(arg, field)
613
-
614
- if bt.is_abstract(field_type) and isinstance(arg, (mm.Var, mm.Literal)):
615
- # if the field is abstract, then eventually this arg will help determine
616
- # the field's type, so add an edge from the arg to the field
617
- self.net.add_edge(arg, field)
618
- elif not type_matches(arg_type, field_type, accept_expected_super_types=is_type_relation):
619
- if not conversion_allowed(arg_type, field_type):
620
- self.net.type_mismatch(node, field_type, arg_type)
664
+ if isinstance(arg, (mm.Var, mm.Literal)):
665
+ self.net.add_edge(arg, node)
621
666
 
622
667
  #--------------------------------------------------
623
668
  # Walk Lookups + Aggregates
624
669
  #--------------------------------------------------
625
670
 
626
- def lookup(self, node: mm.Lookup):
627
- self.compute_potential_targets(node.relation)
628
- self.visit_rel_op(node)
629
-
630
- def aggregate(self, node: mm.Aggregate):
631
- self.visit_rel_op(node)
632
-
633
- def visit_rel_op(self, node: mm.Lookup|mm.Aggregate):
634
- rel = get_relation(node)
635
-
636
- # special case eq lookups
637
- if isinstance(node, mm.Lookup) and rel == b.core.eq:
638
- # if both args for an eq are abstract, link them, otherwise do normal processing
639
- (left, right) = node.args
640
- left_type = self.net.resolve(left)
641
- right_type = self.net.resolve(right)
642
- if bt.is_abstract(left_type) and bt.is_abstract(right_type):
643
- assert isinstance(left, mm.Var) and isinstance(right, mm.Var)
644
- # if both sides are abstract, then whatever we find out about
645
- # either should propagate to the other
646
- self.net.add_edge(left, right)
647
- self.net.add_edge(right, left)
648
- return
649
-
650
- # special case when the relation needs to be resolved as there are overloads, placeholders,
651
- # type vars or it needs number specialization
652
- if self.requires_resolution(rel):
653
- return self.visit_unresolved_reference(node)
654
-
655
- # if this is a population check, then it's fine to pass a subtype in to do the check
656
- # e.g. Employee(Person) is a valid way to check if a person is an employee
657
- is_population_lookup = isinstance(rel, mm.TypeNode)
658
- for arg, field in zip(node.args, rel.fields):
659
- field_type = self.net.resolve(field)
660
- arg_type = self.net.resolve(arg)
661
- if not type_matches(arg_type, field_type, is_population_lookup):
662
- # Do not complain if we can convert the arg to the field type.
663
- if not conversion_allowed(arg_type, field_type):
664
- # if the arg is a var and it matches when allowing for super types of
665
- # the expected we can expect to refine it later; but we add a type
666
- # requirement to check at the end
667
- if isinstance(arg, mm.Var) and type_matches(arg_type, field_type, True):
668
- self.net.add_type_requirement(arg, field)
669
- else:
670
- self.net.type_mismatch(node, field_type, arg_type)
671
- # if we have an abstract var then this field will ultimately propagate to that
672
- # var's type; also, if this is a population lookup, the type of the population
673
- # being looked up will flow back to the var
674
- if isinstance(arg, mm.Var):
675
- if not field.input:
676
- self.net.add_edge(field, arg)
677
- else:
678
- self.net.add_type_requirement(arg, field)
671
+ def lookup(self, task: mm.Lookup):
672
+ self.compute_potential_targets(task.relation)
673
+ self.visit_rel_op(task)
679
674
 
675
+ def aggregate(self, task: mm.Aggregate):
676
+ self.visit_rel_op(task)
680
677
 
681
- def requires_resolution(self, rel: mm.Relation) -> bool:
682
- # has overloads or is a placeholder relation that needs replacement
683
- if rel.overloads or bt.is_placeholder(rel):
684
- return True
685
- # there are type vars or numbers in the fields that need specialization
686
- for field in rel.fields:
687
- t = self.net.resolve(field)
688
- if bt.is_type_var(t) or t == b.core.Number:
689
- return True
690
- return False
691
-
678
+ def visit_rel_op(self, task: mm.Lookup|mm.Aggregate):
679
+ relation = get_relation(task)
692
680
 
693
- def visit_unresolved_reference(self, node: mm.Lookup|mm.Aggregate):
694
- relation = get_relation(node)
695
- # functions have their outputs determined by their inputs
696
- is_function = bt.is_function(relation)
697
681
  is_placeholder = bt.is_placeholder(relation)
698
- # add edges between args and the relation based on input/output
699
- for field, arg in zip(relation.fields, node.args):
682
+ for field, arg in zip(relation.fields, task.args):
700
683
  if isinstance(arg, (mm.Var, mm.Literal)):
701
684
  if field.input:
702
- # the arg type will flow into the input
703
- self.net.add_edge(arg, node)
685
+ # we need to resolve all inputs before resolving the relation
686
+ self.net.add_edge(arg, task)
704
687
  else:
705
- if is_function:
706
- # this is an output of a function, so the field type will flow to the arg
707
- self.net.add_edge(node, arg)
688
+ # placeholders also need the output to be resolved
689
+ if is_placeholder:
690
+ self.net.add_edge(arg, task)
708
691
  else:
709
- if is_placeholder:
710
- self.net.add_edge(arg, node)
711
- if bt.is_abstract(field.type) and not is_placeholder:
712
- self.net.add_edge(field, node)
713
-
714
- if bt.is_abstract(self.net.resolve(arg)):
715
- self.net.add_edge(node, arg)
692
+ # args bound to outputs can be resolved after
693
+ self.net.add_edge(task, arg)
694
+ # if the field is abstract, it needs to be resolved before we can
695
+ # resolve the task.
696
+ if bt.is_abstract(self.net.resolve(field)):
697
+ self.net.add_edge(field, task)
716
698
 
717
699
 
718
700
  #--------------------------------------------------
@@ -750,7 +732,7 @@ class Replacer(Rewriter):
750
732
  # TODO - this is only modifying the relation in the model, but then we have a new
751
733
  # relation there, which is different than the object referenced by tasks.
752
734
  if node in self.net.resolved_types:
753
- return mm.Field(node.name, self.net.resolved_types[node], node.input)
735
+ return mm.Field(node.name, self.net.resolved_types[node], node.input, _relation = node._relation)
754
736
  return node
755
737
 
756
738
  def var(self, node: mm.Var):
@@ -773,31 +755,40 @@ class Replacer(Rewriter):
773
755
 
774
756
  args = types = None
775
757
  if node.id in self.net.resolved_placeholder:
758
+ # placeholder resolved to multiple relations
776
759
  resolved_relations = self.net.resolved_placeholder[node.id]
777
760
  args = get_lookup_args(node, resolved_relations[0])
778
761
  types = [f.type for f in resolved_relations[0].fields]
779
762
  elif node.id in self.net.resolved_overload:
763
+ # overload resolved to a specific relation
780
764
  resolved_relations = [node.relation]
781
765
  types = self.net.resolved_overload[node.id].types
782
766
  else:
767
+ # single relation
783
768
  resolved_relations = [node.relation]
784
769
 
785
770
  if len(resolved_relations) == 1:
771
+ # single relation, just convert arguments
786
772
  x = self.convert_arguments(node, resolved_relations[0], args, types)
787
773
  if isinstance(x, mm.Logical) and len(x.body) == 1:
788
774
  return x.body[0]
789
775
  else:
790
776
  return x
791
777
 
778
+ # multiple relations, create a union
792
779
  branches:list = []
793
780
  for target in resolved_relations:
794
781
  args = get_lookup_args(node, target)
795
782
  types = [f.type for f in get_relation_fields(resolved_relations[0], node.relation.name)]
796
783
  # adding this logical to avoid issues in the old backend
797
- branches.append(mm.Logical((self.convert_arguments(node, target, args, types=types),)))
784
+ branches.append(mm.Logical((self.convert_arguments(node, target, args, types=types, force_copy=True),)))
798
785
  return mm.Union(tuple(branches))
799
786
 
800
- def convert_arguments(self, node: mm.Lookup|mm.Update, relation: mm.Relation, args: Iterable[mm.Value]|None=None, types: Iterable[mm.Type]|None=None) -> mm.Logical|mm.Lookup|mm.Update:
787
+ def convert_arguments(self, node: mm.Lookup|mm.Update, relation: mm.Relation, args: Iterable[mm.Value]|None=None, types: Iterable[mm.Type]|None=None, force_copy=False) -> mm.Logical|mm.Lookup|mm.Update:
788
+ """ This node was resolved to target this relation using these args, which should
789
+ have these types. Convert any arguments as needed and return a new node with the
790
+ proper relation and converted args. If multiple conversions are needed, return a
791
+ logical that contains all the conversion tasks plus the final node. """
801
792
  args = args or node.args
802
793
  types = types or [self.net.resolve(f) for f in relation.fields]
803
794
  number_type = self.net.resolved_number.get(node.id)
@@ -817,15 +808,20 @@ class Replacer(Rewriter):
817
808
  final_args.append(arg)
818
809
  else:
819
810
  final_args.append(arg)
820
- if isinstance(node, mm.Lookup):
821
- tasks.append(node.mut(relation = relation, args = tuple(final_args)))
811
+ # add the original node with the proper target relation and converted args;
812
+ # here we want to use mut because we keep information about nodes based on their ids
813
+ # (e.g. we store resolved types based on ids), so we want to keep the same id. But
814
+ # when a lookup is being converted as part of a union (i.e. multiple targets for a
815
+ # placeholder), we need to create nodes with new ids to avoid conflicts.
816
+ if force_copy:
817
+ tasks.append(node.replace(relation = relation, args = tuple(final_args)))
822
818
  else:
823
819
  tasks.append(node.mut(relation = relation, args = tuple(final_args)))
820
+ # if we need conversion tasks, wrap in a logical
824
821
  if len(tasks) == 1:
825
822
  return tasks[0]
826
823
  return mm.Logical(tuple(tasks))
827
824
 
828
-
829
825
  def visit_eq_lookup(self, node: mm.Lookup):
830
826
  (left, right) = node.args
831
827
  left_type = to_type(left)
@@ -842,7 +838,7 @@ class Replacer(Rewriter):
842
838
  elif conversion_allowed(right_type, left_type):
843
839
  final_args = [left, convert(right, left_type, tasks)]
844
840
  else:
845
- self.net.type_mismatch(node, left_type, right_type)
841
+ # this type mismatch was reported during propagation, so just return the node
846
842
  return node
847
843
 
848
844
  tasks.append(mm.Lookup(b.core.eq, tuple(final_args)))
@@ -873,6 +869,15 @@ def get_name(type: mm.Type) -> str:
873
869
  # Type and Relation helpers
874
870
  #--------------------------------------------------
875
871
 
872
+ def get_update_fields(arg, update: mm.Update) -> Iterable[mm.Field]:
873
+ """ Get the fields of the relation being updated by this arg. Note that an arg can be
874
+ bound to multiple fields at the same time. """
875
+ if arg in update.args:
876
+ for x, field in zip(update.args, update.relation.fields):
877
+ if arg == x:
878
+ yield field
879
+ return []
880
+
876
881
  def get_relation_fields(relation: mm.Relation, name: str) -> Iterable[mm.Field]:
877
882
  """ Get the fields of this relation, potentially reordered to match the reading with the given name."""
878
883
  if name == relation.name:
@@ -915,18 +920,20 @@ def get_potential_targets(model: mm.Model, placeholder: mm.Relation) -> list[mm.
915
920
  return list(filter(lambda r: is_potential_target(placeholder, r), model.relations))
916
921
 
917
922
  def to_type(value: mm.Value|mm.Field|mm.Literal) -> mm.Type:
918
- if isinstance(value, (mm.Var, mm.Field, mm.Literal)):
923
+ if isinstance(value, (mm.Var, mm.Literal)):
919
924
  return value.type
920
925
 
921
926
  if isinstance(value, mm.Type):
922
927
  return b.core.Type
923
928
 
929
+ if isinstance(value, mm.Field):
930
+ return b.core.Field
931
+
924
932
  if isinstance(value, tuple):
925
933
  return mm.TupleType(element_types=tuple(to_type(v) for v in value))
926
934
 
927
935
  raise TypeError(f"Cannot determine IR type for value: {value} of type {type(value).__name__}")
928
936
 
929
-
930
937
  def convert(value: mm.Var|mm.Literal, to_type: mm.Type, tasks: list[mm.Task]) -> mm.Value:
931
938
  # if the arg is a literal, we can just change its type
932
939
  # TODO - we may want to check that the value is actually convertible
@@ -940,8 +947,10 @@ def convert(value: mm.Var|mm.Literal, to_type: mm.Type, tasks: list[mm.Task]) ->
940
947
  tasks.append(mm.Lookup(b.core.cast, (to_type_base, value, new_value)))
941
948
  return new_value
942
949
 
943
-
944
950
  def conversion_allowed(from_type: mm.Type, to_type: mm.Type) -> bool:
951
+ # value type conversion is allowed only if the value types are related by inheritance
952
+ if bt.is_value_type(from_type) and bt.is_value_type(to_type) and not bt.extends(from_type, to_type):
953
+ return False
945
954
  # value type conversion is allowed
946
955
  x = bt.get_primitive_supertype(from_type)
947
956
  y = bt.get_primitive_supertype(to_type)
@@ -954,7 +963,7 @@ def conversion_allowed(from_type: mm.Type, to_type: mm.Type) -> bool:
954
963
 
955
964
  # a number can be converted to another number of larger scale
956
965
  if isinstance(from_type, mm.NumberType) and isinstance(to_type, mm.NumberType):
957
- if to_type.scale > from_type.scale:
966
+ if to_type.scale >= from_type.scale:
958
967
  return True
959
968
 
960
969
  if from_type == b.core.Number and isinstance(to_type, mm.NumberType):
@@ -990,11 +999,6 @@ def type_matches(actual: mm.Type, expected: mm.Type, accept_expected_super_types
990
999
  if expected == b.core.TypeVar:
991
1000
  return True
992
1001
 
993
- # TODO - remove this once we make them singletons per precision/scale
994
- if isinstance(actual, mm.NumberType) and isinstance(expected, mm.NumberType):
995
- if actual.precision == expected.precision and actual.scale == expected.scale:
996
- return True
997
-
998
1002
  # if an entity type var or any entity is expected, it matches any actual entity type
999
1003
  if (expected == b.core.EntityTypeVar or bt.extends(expected, b.core.AnyEntity)) and not bt.is_primitive(actual):
1000
1004
  return True
@@ -1006,6 +1010,10 @@ def type_matches(actual: mm.Type, expected: mm.Type, accept_expected_super_types
1006
1010
  if (expected == b.core.Numeric) and bt.is_numeric(actual):
1007
1011
  return True
1008
1012
 
1013
+ # different value types never match
1014
+ if bt.is_value_type(actual) and bt.is_value_type(expected) and not bt.extends(actual, expected):
1015
+ return False
1016
+
1009
1017
  # if actual is scalar, any of its parents may match the expected type
1010
1018
  if isinstance(actual, mm.ScalarType) and any([type_matches(parent, expected) for parent in actual.super_types]):
1011
1019
  return True
@@ -1049,18 +1057,12 @@ def type_matches(actual: mm.Type, expected: mm.Type, accept_expected_super_types
1049
1057
 
1050
1058
  # accept tuples with a single element type to match a list with that type
1051
1059
  if isinstance(actual, mm.TupleType) and isinstance(expected, mm.ListType):
1052
- if len(set(actual.element_types)) == 1:
1053
- return type_matches(actual.element_types[0], expected.element_type)
1060
+ return type_matches(actual.element_types[0], expected.element_type)
1054
1061
 
1055
1062
  # otherwise no match
1056
1063
  return False
1057
1064
 
1058
-
1059
- def merge_types(type1: mm.Type, type2: mm.Type) -> mm.Type:
1060
- if type1 == type2:
1061
- return type1
1062
- types_to_process = [type1, type2]
1063
-
1065
+ def merge_numeric_types(type1: mm.Type, type2: mm.Type) -> Optional[mm.Type]:
1064
1066
  # if one of them is the abstract Number type, pick the other
1065
1067
  if type1 == b.core.Number and isinstance(type2, mm.NumberType):
1066
1068
  return type2
@@ -1077,6 +1079,24 @@ def merge_types(type1: mm.Type, type2: mm.Type) -> mm.Type:
1077
1079
  # if we are overriding a number with a float, pick float
1078
1080
  if isinstance(type1, mm.NumberType) and type2 == b.core.Float:
1079
1081
  return type2
1082
+ if isinstance(type2, mm.NumberType) and type1 == b.core.Float:
1083
+ return type1
1084
+
1085
+ return None
1086
+
1087
+ def merge_types(type1: mm.Type, type2: mm.Type) -> mm.Type:
1088
+ if type1 == type2:
1089
+ return type1
1090
+ if bt.is_type_var(type1):
1091
+ return type2
1092
+ if bt.is_type_var(type2):
1093
+ return type1
1094
+
1095
+ types_to_process = [type1, type2]
1096
+
1097
+ numeric_merge = merge_numeric_types(type1, type2)
1098
+ if numeric_merge is not None:
1099
+ return numeric_merge
1080
1100
 
1081
1101
  # if one extends the other, pick the most specific one
1082
1102
  if bt.extends(type1, type2):
@@ -1093,6 +1113,13 @@ def merge_types(type1: mm.Type, type2: mm.Type) -> mm.Type:
1093
1113
  elif bt.is_primitive(type2):
1094
1114
  return type1
1095
1115
 
1116
+ if base_primitive_type1 and base_primitive_type2:
1117
+ numeric_merge = merge_numeric_types(base_primitive_type1, base_primitive_type2)
1118
+ if numeric_merge == base_primitive_type1:
1119
+ return type1
1120
+ elif numeric_merge == base_primitive_type2:
1121
+ return type2
1122
+
1096
1123
  combined = OrderedSet()
1097
1124
  # Iterative flattening of union types
1098
1125
  while types_to_process: