relationalai 0.12.13__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. relationalai/__init__.py +69 -22
  2. relationalai/clients/__init__.py +15 -2
  3. relationalai/clients/client.py +4 -4
  4. relationalai/clients/local.py +5 -5
  5. relationalai/clients/resources/__init__.py +8 -0
  6. relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
  7. relationalai/clients/resources/snowflake/__init__.py +20 -0
  8. relationalai/clients/resources/snowflake/cli_resources.py +87 -0
  9. relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
  10. relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
  11. relationalai/clients/resources/snowflake/error_handlers.py +199 -0
  12. relationalai/clients/{export_procedure.py.jinja → resources/snowflake/export_procedure.py.jinja} +1 -1
  13. relationalai/clients/resources/snowflake/resources_factory.py +99 -0
  14. relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +606 -1392
  15. relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
  16. relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
  17. relationalai/clients/resources/snowflake/util.py +387 -0
  18. relationalai/early_access/dsl/ir/executor.py +4 -4
  19. relationalai/early_access/dsl/snow/api.py +2 -1
  20. relationalai/errors.py +23 -0
  21. relationalai/experimental/solvers.py +7 -7
  22. relationalai/semantics/devtools/benchmark_lqp.py +4 -5
  23. relationalai/semantics/devtools/extract_lqp.py +1 -1
  24. relationalai/semantics/internal/internal.py +4 -4
  25. relationalai/semantics/internal/snowflake.py +3 -2
  26. relationalai/semantics/lqp/executor.py +22 -22
  27. relationalai/semantics/lqp/model2lqp.py +42 -4
  28. relationalai/semantics/lqp/passes.py +1 -1
  29. relationalai/semantics/lqp/rewrite/cdc.py +1 -1
  30. relationalai/semantics/lqp/rewrite/extract_keys.py +72 -15
  31. relationalai/semantics/metamodel/builtins.py +8 -6
  32. relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
  33. relationalai/semantics/metamodel/util.py +6 -5
  34. relationalai/semantics/reasoners/graph/core.py +8 -9
  35. relationalai/semantics/rel/executor.py +14 -11
  36. relationalai/semantics/sql/compiler.py +2 -2
  37. relationalai/semantics/sql/executor/snowflake.py +9 -5
  38. relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
  39. relationalai/tools/cli.py +26 -30
  40. relationalai/tools/cli_helpers.py +10 -2
  41. relationalai/util/otel_configuration.py +2 -1
  42. relationalai/util/otel_handler.py +1 -1
  43. {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/METADATA +1 -1
  44. {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/RECORD +49 -40
  45. relationalai_test_util/fixtures.py +2 -1
  46. /relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
  47. {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/WHEEL +0 -0
  48. {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/entry_points.txt +0 -0
  49. {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -396,6 +396,41 @@ def _translate_ascending_rank(ctx: TranslationCtx, limit: int, result_var: lqp.V
396
396
  terms=terms,
397
397
  )
398
398
 
399
+ def _rename_shadowed_abstraction_vars(
400
+ ctx: TranslationCtx,
401
+ aggr: ir.Aggregate,
402
+ abstr_args: list[Tuple[lqp.Var, lqp.Type]],
403
+ body_conjs: list[lqp.Formula]
404
+ ) -> list[Tuple[lqp.Var, lqp.Type]]:
405
+ """
406
+ Rename abstraction variables that shadow group-by variables.
407
+
408
+ This can happen when the same variable appears in both aggr.group and as an input
409
+ to the aggregation, e.g., min(Person.age).per(Person.age). The group-by variables
410
+ are in the outer scope, while the abstraction parameters are in the inner scope,
411
+ so we need different names to avoid shadowing.
412
+ """
413
+ # Get the LQP names of group-by variables
414
+ group_var_names = set()
415
+ for group_var in aggr.group:
416
+ lqp_var = _translate_var(ctx, group_var)
417
+ group_var_names.add(lqp_var.name)
418
+
419
+ # Rename any abstraction parameters that conflict with group-by variables
420
+ renamed_abstr_args = []
421
+ for var, typ in abstr_args:
422
+ if var.name in group_var_names:
423
+ # This variable shadows a group-by variable, so rename it
424
+ fresh_var = gen_unique_var(ctx, var.name)
425
+ # Add an equality constraint: fresh_var == var
426
+ # var is a free variable referring to the outer scope group-by variable
427
+ body_conjs.append(mk_primitive("rel_primitive_eq", [fresh_var, var]))
428
+ renamed_abstr_args.append((fresh_var, typ))
429
+ else:
430
+ renamed_abstr_args.append((var, typ))
431
+
432
+ return renamed_abstr_args
433
+
399
434
  def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Formula) -> Union[lqp.Reduce, lqp.Formula]:
400
435
  # TODO: handle this properly
401
436
  aggr_name = aggr.aggregation.name
@@ -432,6 +467,9 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
432
467
  body_conjs.extend(projected_eqs)
433
468
  abstr_args: list[Tuple[lqp.Var, lqp.Type]] = projected_args + input_args
434
469
 
470
+ # Rename abstraction variables that shadow group-by variables
471
+ abstr_args = _rename_shadowed_abstraction_vars(ctx, aggr, abstr_args, body_conjs)
472
+
435
473
  if aggr_name == "count":
436
474
  assert len(output_terms) == 1, "Count and avg expect a single output variable"
437
475
  assert isinstance(meta_output_terms[0], ir.Var)
@@ -441,7 +479,6 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
441
479
  one_var, eq = constant_to_var(ctx, to_lqp_value(1, meta_output_terms[0].type), "one")
442
480
  body_conjs.append(eq)
443
481
  abstr_args.append((one_var, typ))
444
- body = mk_and(body_conjs)
445
482
 
446
483
  # Average needs to wrap the reduce in Exists(Conjunction(Reduce, div))
447
484
  if aggr_name == "avg":
@@ -454,7 +491,6 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
454
491
  one_var, eq = constant_to_var(ctx, to_lqp_value(1, types.Int64), "one")
455
492
  body_conjs.append(eq)
456
493
  abstr_args.append((one_var, count_type))
457
- body = mk_and(body_conjs)
458
494
 
459
495
  # The average will produce two output variables: sum and count.
460
496
  sum_result = gen_unique_var(ctx, "sum")
@@ -462,6 +498,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
462
498
 
463
499
  # Second to last is the variable we're summing over.
464
500
  (sum_var, sum_type) = abstr_args[-2]
501
+ body = mk_and(body_conjs)
465
502
 
466
503
  result = lqp.Reduce(
467
504
  op=lqp_avg_op(ctx, aggr.aggregation.name, sum_var.name, sum_type),
@@ -494,6 +531,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
494
531
  # `input_args` hold the types of the input arguments, but they may have been modified
495
532
  # if we're dealing with a count, so we use `abstr_args` to find the type.
496
533
  (aggr_arg, aggr_arg_type) = abstr_args[-1]
534
+ body = mk_and(body_conjs)
497
535
 
498
536
  # Group-bys do not need to be handled at all, since they are introduced outside already
499
537
  reduce = lqp.Reduce(
@@ -668,11 +706,11 @@ def to_lqp_value(value: ir.PyValue, value_type: ir.Type) -> lqp.Value:
668
706
  val = value
669
707
  elif typ.type_name == lqp.TypeName.STRING and isinstance(value, str):
670
708
  val = value
671
- elif typ.type_name == lqp.TypeName.DECIMAL and isinstance(value, PyDecimal):
709
+ elif typ.type_name == lqp.TypeName.DECIMAL and isinstance(value, (int, float, PyDecimal)):
672
710
  precision = typ.parameters[0].value
673
711
  scale = typ.parameters[1].value
674
712
  assert isinstance(precision, int) and isinstance(scale, int)
675
- val = lqp.DecimalValue(precision=precision, scale=scale, value=value, meta=None)
713
+ val = lqp.DecimalValue(precision=precision, scale=scale, value=PyDecimal(value), meta=None)
676
714
  elif typ.type_name == lqp.TypeName.DATE and isinstance(value, date):
677
715
  val = lqp.DateValue(value=value, meta=None)
678
716
  elif typ.type_name == lqp.TypeName.DATETIME and isinstance(value, datetime):
@@ -390,7 +390,7 @@ class EliminateData(Pass):
390
390
  [
391
391
  f.logical(
392
392
  [
393
- f.lookup(rel_builtins.eq, [f.literal(val), var])
393
+ f.lookup(rel_builtins.eq, [f.literal(val, var.type), var])
394
394
  for (val, var) in zip(row, node.vars)
395
395
  ],
396
396
  )
@@ -200,7 +200,7 @@ class CDC(Pass):
200
200
  Get the relation that represents this property var in this wide_cdc_relation. If the
201
201
  relation is not yet available in the context, this method will create and register it.
202
202
  """
203
- relation_name = wide_cdc_relation.name.lower().replace(".", "_")
203
+ relation_name = helpers.sanitize(wide_cdc_relation.name).replace("-", "_")
204
204
  key = (relation_name, property.name)
205
205
  if key not in ctx.cdc_relations:
206
206
  # the property relation is overloaded for all properties of the same wide cdc relation, so they have
@@ -118,6 +118,17 @@ class ExtractKeys(Pass):
118
118
  the same here).
119
119
  """
120
120
  class ExtractKeysRewriter(Rewriter):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.compound_keys: dict[Any, ir.Var] = {}
124
+
125
+ def _get_compound_key(self, orig_keys: Iterable[ir.Var]) -> ir.Var:
126
+ if orig_keys in self.compound_keys:
127
+ return self.compound_keys[orig_keys]
128
+ compound_key = f.var("compound_key", types.Hash)
129
+ self.compound_keys[orig_keys] = compound_key
130
+ return compound_key
131
+
121
132
  def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
122
133
  outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
123
134
  # We are not in a logical with an output at this level.
@@ -170,7 +181,7 @@ class ExtractKeysRewriter(Rewriter):
170
181
  annos = list(output.annotations)
171
182
  annos.append(f.annotation(builtins.output_keys, tuple(output_keys)))
172
183
  # Create a compound key that will be used in place of the original keys.
173
- compound_key = f.var("compound_key", types.Hash)
184
+ compound_key = self._get_compound_key(output_keys)
174
185
 
175
186
  for key_combination in combinations:
176
187
  missing_keys = OrderedSet.from_iterable(output_keys)
@@ -192,8 +203,13 @@ class ExtractKeysRewriter(Rewriter):
192
203
  # handle the construct node in each clone
193
204
  values: list[ir.Value] = [compound_key.type]
194
205
  for key in output_keys:
195
- assert isinstance(key.type, ir.ScalarType)
196
- values.append(ir.Literal(types.String, key.type.name))
206
+ if isinstance(key.type, ir.UnionType):
207
+ # the typer can derive union types when multiple distinct entities flow
208
+ # into a relation's field, so use AnyEntity as the type marker
209
+ values.append(ir.Literal(types.String, "AnyEntity"))
210
+ else:
211
+ assert isinstance(key.type, ir.ScalarType)
212
+ values.append(ir.Literal(types.String, key.type.name))
197
213
  if key in key_combination:
198
214
  values.append(key)
199
215
  body.add(ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen()))
@@ -335,9 +351,45 @@ class ExtractKeysRewriter(Rewriter):
335
351
  partitions:dict[ir.Var, OrderedSet[ir.Task]] = defaultdict(OrderedSet)
336
352
  dependencies:dict[ir.Var, OrderedSet[ir.Var]] = defaultdict(OrderedSet)
337
353
 
338
- worklist = list(vars)
339
- while worklist:
340
- var = worklist.pop()
354
+ def dfs_collect_deps(task, deps):
355
+ if isinstance(task, ir.Lookup):
356
+ args = helpers.vars(task.args)
357
+ for i, v in enumerate(args):
358
+ # v depends on all previous vars
359
+ for j in range(i):
360
+ deps[v].add(args[j])
361
+ # for ternary+ lookups, a var also depends on the next vars
362
+ if i > 0 and len(args) >= 3:
363
+ for j in range(i+1, len(args)):
364
+ deps[v].add(args[j])
365
+ elif isinstance(task, ir.Construct):
366
+ vars = helpers.vars(task.values)
367
+ for val_var in vars:
368
+ deps[task.id_var].add(val_var)
369
+ elif isinstance(task, ir.Logical):
370
+ for child in task.body:
371
+ dfs_collect_deps(child, deps)
372
+ elif isinstance(task, (ir.Match, ir.Union)):
373
+ for child in task.tasks:
374
+ dfs_collect_deps(child, deps)
375
+
376
+ for task in tasks:
377
+ dfs_collect_deps(task, dependencies)
378
+
379
+ def dfs_transitive_deps(var, visited):
380
+ for dep_var in dependencies[var]:
381
+ if dep_var not in visited:
382
+ visited.add(dep_var)
383
+ dfs_transitive_deps(dep_var, visited)
384
+
385
+ transitive_deps = defaultdict(OrderedSet)
386
+ for var in list(dependencies.keys()):
387
+ visited = OrderedSet()
388
+ dfs_transitive_deps(var, visited)
389
+ transitive_deps[var] = visited
390
+ dependencies = transitive_deps
391
+
392
+ for var in vars:
341
393
  extended_vars = OrderedSet[ir.Var]()
342
394
  extended_vars.add(var)
343
395
 
@@ -347,28 +399,33 @@ class ExtractKeysRewriter(Rewriter):
347
399
  for task in tasks:
348
400
  if task in partitions[var]:
349
401
  continue
350
- # Already added this task to this partition
402
+
403
+ if isinstance(task, (ir.Logical, ir.Match, ir.Union)):
404
+ hoisted = helpers.hoisted_vars(task.hoisted)
405
+ if var in hoisted:
406
+ partitions[var].add(task)
407
+ there_is_progress = True
408
+ elif isinstance(task, ir.Construct):
409
+ if task.id_var == var:
410
+ partitions[var].add(task)
411
+ there_is_progress = True
351
412
  elif isinstance(task, ir.Lookup):
352
413
  args = helpers.vars(task.args)
353
414
  if len(args) == 1 and args[0] in extended_vars:
354
415
  partitions[var].add(task)
355
- # TODO: hack to have dot_joins work
416
+ there_is_progress = True
417
+ # NOTE: heuristics to have dot_joins work
356
418
  elif len(args) >= 3 and args[-2] in extended_vars:
357
419
  partitions[var].add(task)
358
420
  extended_vars.add(args[-1])
359
- dependencies[var].add(args[-1])
360
421
  there_is_progress = True
361
422
  elif len(args) > 1 and args[-1] in extended_vars:
362
423
  partitions[var].add(task)
363
424
  for arg in args[:-1]:
364
425
  extended_vars.add(arg)
365
- dependencies[var].add(arg)
366
- there_is_progress = True
367
- elif isinstance(task, ir.Logical):
368
- hoisted = helpers.hoisted_vars(task.hoisted)
369
- if var in hoisted:
370
- partitions[var].add(task)
371
426
  there_is_progress = True
427
+ else:
428
+ assert False, f"invalid node kind {type(task)}"
372
429
 
373
430
  return partitions, dependencies
374
431
 
@@ -443,12 +443,14 @@ datetime_second = f.relation("datetime_second", [f.input_field("a", types.DateTi
443
443
  datetime_weekday = f.relation("datetime_weekday", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
444
444
 
445
445
  # Other
446
- range = f.relation("range", [
447
- f.input_field("start", types.Int64),
448
- f.input_field("stop", types.Int64),
449
- f.input_field("step", types.Int64),
450
- f.field("result", types.Int64),
451
- ])
446
+ range = f.relation(
447
+ "range",
448
+ [f.input_field("start", types.Number), f.input_field("stop", types.Number), f.input_field("step", types.Number), f.field("result", types.Number)],
449
+ overloads=[
450
+ f.relation("range", [f.input_field("start", types.Int64), f.input_field("stop", types.Int64), f.input_field("step", types.Int64), f.field("result", types.Int64)]),
451
+ f.relation("range", [f.input_field("start", types.Int128), f.input_field("stop", types.Int128), f.input_field("step", types.Int128), f.field("result", types.Int128)]),
452
+ ],
453
+ )
452
454
 
453
455
  hash = f.relation("hash", [f.input_field("args", types.AnyList), f.field("hash", types.Hash)])
454
456
 
@@ -124,10 +124,10 @@ class Flatten(Pass):
124
124
  output
125
125
  """
126
126
 
127
- def __init__(self, handle_outputs: bool=True):
127
+ def __init__(self, use_sql: bool=False):
128
128
  super().__init__()
129
129
  self.name_cache = NameCache(start_from_one=True)
130
- self._handle_outputs = handle_outputs
130
+ self._use_sql = use_sql
131
131
 
132
132
 
133
133
  #--------------------------------------------------
@@ -181,7 +181,12 @@ class Flatten(Pass):
181
181
  def handle(self, task: ir.Task, ctx: Context) -> Flatten.HandleResult:
182
182
  if isinstance(task, ir.Logical):
183
183
  return self.handle_logical(task, ctx)
184
- elif isinstance(task, ir.Union):
184
+ elif isinstance(task, ir.Union) and (task.hoisted or self._use_sql):
185
+ # Only flatten Unions which hoist variables. If there are no hoisted variables,
186
+ # then the Union acts as a filter, and it can be inefficient to flatten it.
187
+ #
188
+ # However, for the SQL backend, we always need to flatten Unions for correct SQL
189
+ # generation.
185
190
  return self.handle_union(task, ctx)
186
191
  elif isinstance(task, ir.Match):
187
192
  return self.handle_match(task, ctx)
@@ -238,7 +243,7 @@ class Flatten(Pass):
238
243
  # If there are outputs, flatten each into its own top-level rule, along with its
239
244
  # dependencies.
240
245
  if groups["outputs"]:
241
- if not self._handle_outputs:
246
+ if self._use_sql:
242
247
  ctx.rewrite_ctx.top_level.append(ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations))
243
248
  return Flatten.HandleResult(None)
244
249
 
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
- from typing import Callable, Generator, Generic, IO, Iterable, Optional, Sequence, Tuple, TypeVar, cast, Hashable
2
+ from typing import Callable, Generator, Generic, IO, Iterable, Optional, Sequence, Tuple, TypeVar, cast, Hashable, Union
3
+ import types
3
4
  from dataclasses import dataclass, field
4
5
 
5
6
  #--------------------------------------------------
@@ -345,7 +346,7 @@ def rewrite_set(t: type[T], f: Callable[[T], T], items: FrozenOrderedSet[T]) ->
345
346
  return items
346
347
  return ordered_set(*new_items).frozen()
347
348
 
348
- def rewrite_list(t: type[T], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple[T, ...]:
349
+ def rewrite_list(t: Union[type[T], types.UnionType], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple[T, ...]:
349
350
  """ Map a function over a list, returning a new list with the results. Avoid allocating a new list if the function is the identity. """
350
351
  new_items: Optional[list[T]] = None
351
352
  for i in range(len(items)):
@@ -359,15 +360,15 @@ def rewrite_list(t: type[T], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple
359
360
  return items
360
361
  return tuple(new_items)
361
362
 
362
- def flatten_iter(items: Iterable[object], t: type[T]) -> Generator[T, None, None]:
363
+ def flatten_iter(items: Iterable[object], t: Union[type[T], types.UnionType]) -> Generator[T, None, None]:
363
364
  """Yield items from a nested iterable structure one at a time."""
364
365
  for item in items:
365
366
  if isinstance(item, (list, tuple, OrderedSet)):
366
367
  yield from flatten_iter(item, t)
367
368
  elif isinstance(item, t):
368
- yield item
369
+ yield cast(T, item)
369
370
 
370
- def flatten_tuple(items: Iterable[object], t: type[T]) -> tuple[T, ...]:
371
+ def flatten_tuple(items: Iterable[object], t: Union[type[T], types.UnionType]) -> tuple[T, ...]:
371
372
  """ Flatten the nested iterable structure into a tuple."""
372
373
  return tuple(flatten_iter(items, t))
373
374
 
@@ -6222,9 +6222,9 @@ class Graph():
6222
6222
  def _distance_reversed_non_weighted(self):
6223
6223
  """Lazily define and cache the self._distance_reversed_non_weighted relationship, a non-public helper."""
6224
6224
  _distance_reversed_non_weighted_rel = self._model.Relationship(f"{{node_u:{self._NodeConceptStr}}} and {{node_v:{self._NodeConceptStr}}} have a reversed distance of {{d:Integer}}")
6225
- node_u, node_v, node_n, d1 = self.Node.ref(), self.Node.ref(), self.Node.ref(), Integer.ref()
6225
+ node_u, node_v, node_n = self.Node.ref(), self.Node.ref(), self.Node.ref()
6226
6226
  node_u, node_v, d = union(
6227
- where(node_u == node_v, d1 == 0).select(node_u, node_v, d1), # Base case.
6227
+ where(node_u == node_v, d1 := 0).select(node_u, node_v, d1), # Base case.
6228
6228
  where(self._edge(node_v, node_n),
6229
6229
  d2 := _distance_reversed_non_weighted_rel(node_u, node_n, Integer) + 1).select(node_u, node_v, d2) # Recursive case.
6230
6230
  )
@@ -6326,13 +6326,12 @@ class Graph():
6326
6326
  _is_connected_rel.annotate(annotations.track("graphs", "is_connected"))
6327
6327
 
6328
6328
  where(
6329
- self._num_nodes(0) |
6330
- count(self._reachable_from_min_node(self.Node.ref())) == self._num_nodes(Integer.ref())
6331
- ).define(_is_connected_rel(True))
6332
-
6333
- where(
6334
- not_(_is_connected_rel(True))
6335
- ).define(_is_connected_rel(False))
6329
+ union(
6330
+ self._num_nodes(0),
6331
+ count(self._reachable_from_min_node(self.Node.ref())) == self._num_nodes(Integer.ref())
6332
+ )
6333
+ ).define(_is_connected_rel(True)) \
6334
+ | define(_is_connected_rel(False))
6336
6335
 
6337
6336
  return _is_connected_rel
6338
6337
 
@@ -9,16 +9,15 @@ import uuid
9
9
  from pandas import DataFrame
10
10
  from typing import Any, Optional, Literal, TYPE_CHECKING
11
11
  from snowflake.snowpark import Session
12
- import relationalai as rai
13
12
 
14
13
  from relationalai import debugging
15
14
  from relationalai.clients import result_helpers
16
15
  from relationalai.clients.util import IdentityParser, escape_for_f_string
17
- from relationalai.clients.snowflake import APP_NAME
16
+ from relationalai.clients.resources.snowflake import APP_NAME, create_resources_instance
18
17
  from relationalai.semantics.metamodel import ir, executor as e, factory as f
19
18
  from relationalai.semantics.rel import Compiler
20
19
  from relationalai.clients.config import Config
21
- from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation, QUERY_ATTRIBUTES_HEADER
20
+ from relationalai.tools.constants import Generation, QUERY_ATTRIBUTES_HEADER
22
21
  from relationalai.tools.query_utils import prepare_metadata_for_headers
23
22
 
24
23
  if TYPE_CHECKING:
@@ -53,15 +52,11 @@ class RelExecutor(e.Executor):
53
52
  if not self._resources:
54
53
  with debugging.span("create_session"):
55
54
  self.dry_run |= bool(self.config.get("compiler.dry_run", False))
56
- resource_class = rai.clients.snowflake.Resources
57
- if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
58
- resource_class = rai.clients.snowflake.DirectAccessResources
59
55
  # NOTE: language="rel" is required for Rel execution. It is the default, but
60
56
  # we set it explicitly here to be sure.
61
- self._resources = resource_class(
62
- dry_run=self.dry_run,
57
+ self._resources = create_resources_instance(
63
58
  config=self.config,
64
- generation=rai.Generation.QB,
59
+ dry_run=self.dry_run,
65
60
  connection=self.connection,
66
61
  language="rel",
67
62
  )
@@ -163,13 +158,20 @@ class RelExecutor(e.Executor):
163
158
  raise errors.RAIExceptionSet(all_errors)
164
159
 
165
160
  def _export(self, raw_code: str, dest: Table, actual_cols: list[str], declared_cols: list[str], update:bool, headers: dict[str, Any] | None = None):
161
+ # _export is Snowflake-specific and requires Snowflake Resources
162
+ # It calls Snowflake stored procedures (APP_NAME.api.exec_into_table, etc.)
163
+ # LocalResources doesn't support this functionality
164
+ from relationalai.clients.local import LocalResources
165
+ if isinstance(self.resources, LocalResources):
166
+ raise NotImplementedError("Export functionality is not supported in local mode. Use Snowflake Resources instead.")
167
+
166
168
  _exec = self.resources._exec
167
169
  output_table = "out" + str(uuid.uuid4()).replace("-", "_")
168
170
  txn_id = None
169
171
  artifacts = None
170
172
  dest_database, dest_schema, dest_table, _ = IdentityParser(dest._fqn, require_all_parts=True).to_list()
171
173
  dest_fqn = dest._fqn
172
- assert self.resources._session
174
+ assert self.resources._session # All Snowflake Resources have _session
173
175
  with debugging.span("transaction"):
174
176
  try:
175
177
  with debugging.span("exec_format") as span:
@@ -258,7 +260,7 @@ class RelExecutor(e.Executor):
258
260
  SELECT 1
259
261
  FROM {dest_database}.INFORMATION_SCHEMA.TABLES
260
262
  WHERE table_schema = '{dest_schema}'
261
- AND table_name = '{dest_table}'
263
+ AND table_name = '{dest_table}'
262
264
  )) THEN
263
265
  EXECUTE IMMEDIATE 'TRUNCATE TABLE {dest_fqn}';
264
266
  END IF;
@@ -267,6 +269,7 @@ class RelExecutor(e.Executor):
267
269
  else:
268
270
  raise e
269
271
  if txn_id:
272
+ # These methods are available on all Snowflake Resources
270
273
  artifact_info = self.resources._list_exec_async_artifacts(txn_id, headers=headers)
271
274
  with debugging.span("fetch"):
272
275
  artifacts = self.resources._download_results(artifact_info, txn_id, "ABORTED")
@@ -33,7 +33,7 @@ class Compiler(c.Compiler):
33
33
  ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
34
34
  InferTypes(),
35
35
  DNFUnionSplitter(),
36
- Flatten(handle_outputs=False),
36
+ Flatten(use_sql=True),
37
37
  rewrite.RecursiveUnion(),
38
38
  rewrite.DoubleNegation(),
39
39
  rewrite.SortOutputQuery()
@@ -1264,7 +1264,7 @@ class ModelToSQL:
1264
1264
  assert isinstance(index, ir.Var) and isinstance(part, ir.Var), "Third and fourth arguments (index, part) must be variables"
1265
1265
  builtin_vars[part] = part_expr
1266
1266
  builtin_vars[index] = index_expr
1267
- elif relation == builtins.range:
1267
+ elif relation == builtins.range or relation in builtins.range.overloads:
1268
1268
  assert len(args) == 4, f"Expected 4 args for `range`, got {len(args)}: {args}"
1269
1269
  start_raw, stop_raw, step_raw, result = args
1270
1270
  start = self._var_to_expr(start_raw, reference, resolve_builtin_var, var_to_construct)
@@ -17,6 +17,7 @@ from relationalai.semantics.sql.executor.result_helpers import format_columns
17
17
  from relationalai.semantics.metamodel.visitor import collect_by_type
18
18
  from relationalai.semantics.metamodel.typer import typer
19
19
  from relationalai.tools.constants import USE_DIRECT_ACCESS
20
+ from relationalai.clients.resources.snowflake import Resources, DirectAccessResources, Provider
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from relationalai.semantics.snowflake import Table
@@ -51,17 +52,20 @@ class SnowflakeExecutor(e.Executor):
51
52
  if not self._resources:
52
53
  with debugging.span("create_session"):
53
54
  self.dry_run |= bool(self.config.get("compiler.dry_run", False))
54
- resource_class = rai.clients.snowflake.Resources
55
+ resource_class: type = Resources
55
56
  if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
56
- resource_class = rai.clients.snowflake.DirectAccessResources
57
- self._resources = resource_class(dry_run=self.dry_run, config=self.config, generation=rai.Generation.QB,
58
- connection=self.connection)
57
+ resource_class = DirectAccessResources
58
+ self._resources = resource_class(
59
+ dry_run=self.dry_run, config=self.config, generation=rai.Generation.QB,
60
+ connection=self.connection,
61
+ language="sql",
62
+ )
59
63
  return self._resources
60
64
 
61
65
  @property
62
66
  def provider(self):
63
67
  if not self._provider:
64
- self._provider = rai.clients.snowflake.Provider(resources=self.resources)
68
+ self._provider = Provider(resources=self.resources)
65
69
  return self._provider
66
70
 
67
71
  def execute(self, model: ir.Model, task: ir.Task, format:Literal["pandas", "snowpark"]="pandas",
@@ -10,7 +10,7 @@ from relationalai.semantics.tests.utils import reset_state
10
10
  from relationalai.semantics.internal import internal
11
11
  from relationalai.clients.result_helpers import sort_data_frame_result
12
12
  from relationalai.clients.util import IdentityParser
13
- from relationalai.clients.snowflake import Provider as SFProvider
13
+ from relationalai.clients.resources.snowflake import Provider as SFProvider
14
14
  from relationalai import Provider
15
15
  from typing import cast, Dict, Union
16
16
  from pathlib import Path