relationalai 0.12.13__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/__init__.py +69 -22
- relationalai/clients/__init__.py +15 -2
- relationalai/clients/client.py +4 -4
- relationalai/clients/local.py +5 -5
- relationalai/clients/resources/__init__.py +8 -0
- relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
- relationalai/clients/resources/snowflake/__init__.py +20 -0
- relationalai/clients/resources/snowflake/cli_resources.py +87 -0
- relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
- relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
- relationalai/clients/resources/snowflake/error_handlers.py +199 -0
- relationalai/clients/{export_procedure.py.jinja → resources/snowflake/export_procedure.py.jinja} +1 -1
- relationalai/clients/resources/snowflake/resources_factory.py +99 -0
- relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +606 -1392
- relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
- relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
- relationalai/clients/resources/snowflake/util.py +387 -0
- relationalai/early_access/dsl/ir/executor.py +4 -4
- relationalai/early_access/dsl/snow/api.py +2 -1
- relationalai/errors.py +23 -0
- relationalai/experimental/solvers.py +7 -7
- relationalai/semantics/devtools/benchmark_lqp.py +4 -5
- relationalai/semantics/devtools/extract_lqp.py +1 -1
- relationalai/semantics/internal/internal.py +4 -4
- relationalai/semantics/internal/snowflake.py +3 -2
- relationalai/semantics/lqp/executor.py +22 -22
- relationalai/semantics/lqp/model2lqp.py +42 -4
- relationalai/semantics/lqp/passes.py +1 -1
- relationalai/semantics/lqp/rewrite/cdc.py +1 -1
- relationalai/semantics/lqp/rewrite/extract_keys.py +72 -15
- relationalai/semantics/metamodel/builtins.py +8 -6
- relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
- relationalai/semantics/metamodel/util.py +6 -5
- relationalai/semantics/reasoners/graph/core.py +8 -9
- relationalai/semantics/rel/executor.py +14 -11
- relationalai/semantics/sql/compiler.py +2 -2
- relationalai/semantics/sql/executor/snowflake.py +9 -5
- relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
- relationalai/tools/cli.py +26 -30
- relationalai/tools/cli_helpers.py +10 -2
- relationalai/util/otel_configuration.py +2 -1
- relationalai/util/otel_handler.py +1 -1
- {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/METADATA +1 -1
- {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/RECORD +49 -40
- relationalai_test_util/fixtures.py +2 -1
- /relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
- {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/WHEEL +0 -0
- {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/entry_points.txt +0 -0
- {relationalai-0.12.13.dist-info → relationalai-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -396,6 +396,41 @@ def _translate_ascending_rank(ctx: TranslationCtx, limit: int, result_var: lqp.V
|
|
|
396
396
|
terms=terms,
|
|
397
397
|
)
|
|
398
398
|
|
|
399
|
+
def _rename_shadowed_abstraction_vars(
|
|
400
|
+
ctx: TranslationCtx,
|
|
401
|
+
aggr: ir.Aggregate,
|
|
402
|
+
abstr_args: list[Tuple[lqp.Var, lqp.Type]],
|
|
403
|
+
body_conjs: list[lqp.Formula]
|
|
404
|
+
) -> list[Tuple[lqp.Var, lqp.Type]]:
|
|
405
|
+
"""
|
|
406
|
+
Rename abstraction variables that shadow group-by variables.
|
|
407
|
+
|
|
408
|
+
This can happen when the same variable appears in both aggr.group and as an input
|
|
409
|
+
to the aggregation, e.g., min(Person.age).per(Person.age). The group-by variables
|
|
410
|
+
are in the outer scope, while the abstraction parameters are in the inner scope,
|
|
411
|
+
so we need different names to avoid shadowing.
|
|
412
|
+
"""
|
|
413
|
+
# Get the LQP names of group-by variables
|
|
414
|
+
group_var_names = set()
|
|
415
|
+
for group_var in aggr.group:
|
|
416
|
+
lqp_var = _translate_var(ctx, group_var)
|
|
417
|
+
group_var_names.add(lqp_var.name)
|
|
418
|
+
|
|
419
|
+
# Rename any abstraction parameters that conflict with group-by variables
|
|
420
|
+
renamed_abstr_args = []
|
|
421
|
+
for var, typ in abstr_args:
|
|
422
|
+
if var.name in group_var_names:
|
|
423
|
+
# This variable shadows a group-by variable, so rename it
|
|
424
|
+
fresh_var = gen_unique_var(ctx, var.name)
|
|
425
|
+
# Add an equality constraint: fresh_var == var
|
|
426
|
+
# var is a free variable referring to the outer scope group-by variable
|
|
427
|
+
body_conjs.append(mk_primitive("rel_primitive_eq", [fresh_var, var]))
|
|
428
|
+
renamed_abstr_args.append((fresh_var, typ))
|
|
429
|
+
else:
|
|
430
|
+
renamed_abstr_args.append((var, typ))
|
|
431
|
+
|
|
432
|
+
return renamed_abstr_args
|
|
433
|
+
|
|
399
434
|
def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Formula) -> Union[lqp.Reduce, lqp.Formula]:
|
|
400
435
|
# TODO: handle this properly
|
|
401
436
|
aggr_name = aggr.aggregation.name
|
|
@@ -432,6 +467,9 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
432
467
|
body_conjs.extend(projected_eqs)
|
|
433
468
|
abstr_args: list[Tuple[lqp.Var, lqp.Type]] = projected_args + input_args
|
|
434
469
|
|
|
470
|
+
# Rename abstraction variables that shadow group-by variables
|
|
471
|
+
abstr_args = _rename_shadowed_abstraction_vars(ctx, aggr, abstr_args, body_conjs)
|
|
472
|
+
|
|
435
473
|
if aggr_name == "count":
|
|
436
474
|
assert len(output_terms) == 1, "Count and avg expect a single output variable"
|
|
437
475
|
assert isinstance(meta_output_terms[0], ir.Var)
|
|
@@ -441,7 +479,6 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
441
479
|
one_var, eq = constant_to_var(ctx, to_lqp_value(1, meta_output_terms[0].type), "one")
|
|
442
480
|
body_conjs.append(eq)
|
|
443
481
|
abstr_args.append((one_var, typ))
|
|
444
|
-
body = mk_and(body_conjs)
|
|
445
482
|
|
|
446
483
|
# Average needs to wrap the reduce in Exists(Conjunction(Reduce, div))
|
|
447
484
|
if aggr_name == "avg":
|
|
@@ -454,7 +491,6 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
454
491
|
one_var, eq = constant_to_var(ctx, to_lqp_value(1, types.Int64), "one")
|
|
455
492
|
body_conjs.append(eq)
|
|
456
493
|
abstr_args.append((one_var, count_type))
|
|
457
|
-
body = mk_and(body_conjs)
|
|
458
494
|
|
|
459
495
|
# The average will produce two output variables: sum and count.
|
|
460
496
|
sum_result = gen_unique_var(ctx, "sum")
|
|
@@ -462,6 +498,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
462
498
|
|
|
463
499
|
# Second to last is the variable we're summing over.
|
|
464
500
|
(sum_var, sum_type) = abstr_args[-2]
|
|
501
|
+
body = mk_and(body_conjs)
|
|
465
502
|
|
|
466
503
|
result = lqp.Reduce(
|
|
467
504
|
op=lqp_avg_op(ctx, aggr.aggregation.name, sum_var.name, sum_type),
|
|
@@ -494,6 +531,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
494
531
|
# `input_args` hold the types of the input arguments, but they may have been modified
|
|
495
532
|
# if we're dealing with a count, so we use `abstr_args` to find the type.
|
|
496
533
|
(aggr_arg, aggr_arg_type) = abstr_args[-1]
|
|
534
|
+
body = mk_and(body_conjs)
|
|
497
535
|
|
|
498
536
|
# Group-bys do not need to be handled at all, since they are introduced outside already
|
|
499
537
|
reduce = lqp.Reduce(
|
|
@@ -668,11 +706,11 @@ def to_lqp_value(value: ir.PyValue, value_type: ir.Type) -> lqp.Value:
|
|
|
668
706
|
val = value
|
|
669
707
|
elif typ.type_name == lqp.TypeName.STRING and isinstance(value, str):
|
|
670
708
|
val = value
|
|
671
|
-
elif typ.type_name == lqp.TypeName.DECIMAL and isinstance(value, PyDecimal):
|
|
709
|
+
elif typ.type_name == lqp.TypeName.DECIMAL and isinstance(value, (int, float, PyDecimal)):
|
|
672
710
|
precision = typ.parameters[0].value
|
|
673
711
|
scale = typ.parameters[1].value
|
|
674
712
|
assert isinstance(precision, int) and isinstance(scale, int)
|
|
675
|
-
val = lqp.DecimalValue(precision=precision, scale=scale, value=value, meta=None)
|
|
713
|
+
val = lqp.DecimalValue(precision=precision, scale=scale, value=PyDecimal(value), meta=None)
|
|
676
714
|
elif typ.type_name == lqp.TypeName.DATE and isinstance(value, date):
|
|
677
715
|
val = lqp.DateValue(value=value, meta=None)
|
|
678
716
|
elif typ.type_name == lqp.TypeName.DATETIME and isinstance(value, datetime):
|
|
@@ -200,7 +200,7 @@ class CDC(Pass):
|
|
|
200
200
|
Get the relation that represents this property var in this wide_cdc_relation. If the
|
|
201
201
|
relation is not yet available in the context, this method will create and register it.
|
|
202
202
|
"""
|
|
203
|
-
relation_name = wide_cdc_relation.name
|
|
203
|
+
relation_name = helpers.sanitize(wide_cdc_relation.name).replace("-", "_")
|
|
204
204
|
key = (relation_name, property.name)
|
|
205
205
|
if key not in ctx.cdc_relations:
|
|
206
206
|
# the property relation is overloaded for all properties of the same wide cdc relation, so they have
|
|
@@ -118,6 +118,17 @@ class ExtractKeys(Pass):
|
|
|
118
118
|
the same here).
|
|
119
119
|
"""
|
|
120
120
|
class ExtractKeysRewriter(Rewriter):
|
|
121
|
+
def __init__(self):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.compound_keys: dict[Any, ir.Var] = {}
|
|
124
|
+
|
|
125
|
+
def _get_compound_key(self, orig_keys: Iterable[ir.Var]) -> ir.Var:
|
|
126
|
+
if orig_keys in self.compound_keys:
|
|
127
|
+
return self.compound_keys[orig_keys]
|
|
128
|
+
compound_key = f.var("compound_key", types.Hash)
|
|
129
|
+
self.compound_keys[orig_keys] = compound_key
|
|
130
|
+
return compound_key
|
|
131
|
+
|
|
121
132
|
def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
|
|
122
133
|
outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
|
|
123
134
|
# We are not in a logical with an output at this level.
|
|
@@ -170,7 +181,7 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
170
181
|
annos = list(output.annotations)
|
|
171
182
|
annos.append(f.annotation(builtins.output_keys, tuple(output_keys)))
|
|
172
183
|
# Create a compound key that will be used in place of the original keys.
|
|
173
|
-
compound_key =
|
|
184
|
+
compound_key = self._get_compound_key(output_keys)
|
|
174
185
|
|
|
175
186
|
for key_combination in combinations:
|
|
176
187
|
missing_keys = OrderedSet.from_iterable(output_keys)
|
|
@@ -192,8 +203,13 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
192
203
|
# handle the construct node in each clone
|
|
193
204
|
values: list[ir.Value] = [compound_key.type]
|
|
194
205
|
for key in output_keys:
|
|
195
|
-
|
|
196
|
-
|
|
206
|
+
if isinstance(key.type, ir.UnionType):
|
|
207
|
+
# the typer can derive union types when multiple distinct entities flow
|
|
208
|
+
# into a relation's field, so use AnyEntity as the type marker
|
|
209
|
+
values.append(ir.Literal(types.String, "AnyEntity"))
|
|
210
|
+
else:
|
|
211
|
+
assert isinstance(key.type, ir.ScalarType)
|
|
212
|
+
values.append(ir.Literal(types.String, key.type.name))
|
|
197
213
|
if key in key_combination:
|
|
198
214
|
values.append(key)
|
|
199
215
|
body.add(ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen()))
|
|
@@ -335,9 +351,45 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
335
351
|
partitions:dict[ir.Var, OrderedSet[ir.Task]] = defaultdict(OrderedSet)
|
|
336
352
|
dependencies:dict[ir.Var, OrderedSet[ir.Var]] = defaultdict(OrderedSet)
|
|
337
353
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
354
|
+
def dfs_collect_deps(task, deps):
|
|
355
|
+
if isinstance(task, ir.Lookup):
|
|
356
|
+
args = helpers.vars(task.args)
|
|
357
|
+
for i, v in enumerate(args):
|
|
358
|
+
# v depends on all previous vars
|
|
359
|
+
for j in range(i):
|
|
360
|
+
deps[v].add(args[j])
|
|
361
|
+
# for ternary+ lookups, a var also depends on the next vars
|
|
362
|
+
if i > 0 and len(args) >= 3:
|
|
363
|
+
for j in range(i+1, len(args)):
|
|
364
|
+
deps[v].add(args[j])
|
|
365
|
+
elif isinstance(task, ir.Construct):
|
|
366
|
+
vars = helpers.vars(task.values)
|
|
367
|
+
for val_var in vars:
|
|
368
|
+
deps[task.id_var].add(val_var)
|
|
369
|
+
elif isinstance(task, ir.Logical):
|
|
370
|
+
for child in task.body:
|
|
371
|
+
dfs_collect_deps(child, deps)
|
|
372
|
+
elif isinstance(task, (ir.Match, ir.Union)):
|
|
373
|
+
for child in task.tasks:
|
|
374
|
+
dfs_collect_deps(child, deps)
|
|
375
|
+
|
|
376
|
+
for task in tasks:
|
|
377
|
+
dfs_collect_deps(task, dependencies)
|
|
378
|
+
|
|
379
|
+
def dfs_transitive_deps(var, visited):
|
|
380
|
+
for dep_var in dependencies[var]:
|
|
381
|
+
if dep_var not in visited:
|
|
382
|
+
visited.add(dep_var)
|
|
383
|
+
dfs_transitive_deps(dep_var, visited)
|
|
384
|
+
|
|
385
|
+
transitive_deps = defaultdict(OrderedSet)
|
|
386
|
+
for var in list(dependencies.keys()):
|
|
387
|
+
visited = OrderedSet()
|
|
388
|
+
dfs_transitive_deps(var, visited)
|
|
389
|
+
transitive_deps[var] = visited
|
|
390
|
+
dependencies = transitive_deps
|
|
391
|
+
|
|
392
|
+
for var in vars:
|
|
341
393
|
extended_vars = OrderedSet[ir.Var]()
|
|
342
394
|
extended_vars.add(var)
|
|
343
395
|
|
|
@@ -347,28 +399,33 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
347
399
|
for task in tasks:
|
|
348
400
|
if task in partitions[var]:
|
|
349
401
|
continue
|
|
350
|
-
|
|
402
|
+
|
|
403
|
+
if isinstance(task, (ir.Logical, ir.Match, ir.Union)):
|
|
404
|
+
hoisted = helpers.hoisted_vars(task.hoisted)
|
|
405
|
+
if var in hoisted:
|
|
406
|
+
partitions[var].add(task)
|
|
407
|
+
there_is_progress = True
|
|
408
|
+
elif isinstance(task, ir.Construct):
|
|
409
|
+
if task.id_var == var:
|
|
410
|
+
partitions[var].add(task)
|
|
411
|
+
there_is_progress = True
|
|
351
412
|
elif isinstance(task, ir.Lookup):
|
|
352
413
|
args = helpers.vars(task.args)
|
|
353
414
|
if len(args) == 1 and args[0] in extended_vars:
|
|
354
415
|
partitions[var].add(task)
|
|
355
|
-
|
|
416
|
+
there_is_progress = True
|
|
417
|
+
# NOTE: heuristics to have dot_joins work
|
|
356
418
|
elif len(args) >= 3 and args[-2] in extended_vars:
|
|
357
419
|
partitions[var].add(task)
|
|
358
420
|
extended_vars.add(args[-1])
|
|
359
|
-
dependencies[var].add(args[-1])
|
|
360
421
|
there_is_progress = True
|
|
361
422
|
elif len(args) > 1 and args[-1] in extended_vars:
|
|
362
423
|
partitions[var].add(task)
|
|
363
424
|
for arg in args[:-1]:
|
|
364
425
|
extended_vars.add(arg)
|
|
365
|
-
dependencies[var].add(arg)
|
|
366
|
-
there_is_progress = True
|
|
367
|
-
elif isinstance(task, ir.Logical):
|
|
368
|
-
hoisted = helpers.hoisted_vars(task.hoisted)
|
|
369
|
-
if var in hoisted:
|
|
370
|
-
partitions[var].add(task)
|
|
371
426
|
there_is_progress = True
|
|
427
|
+
else:
|
|
428
|
+
assert False, f"invalid node kind {type(task)}"
|
|
372
429
|
|
|
373
430
|
return partitions, dependencies
|
|
374
431
|
|
|
@@ -443,12 +443,14 @@ datetime_second = f.relation("datetime_second", [f.input_field("a", types.DateTi
|
|
|
443
443
|
datetime_weekday = f.relation("datetime_weekday", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
444
444
|
|
|
445
445
|
# Other
|
|
446
|
-
range = f.relation(
|
|
447
|
-
|
|
448
|
-
f.input_field("stop", types.
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
])
|
|
446
|
+
range = f.relation(
|
|
447
|
+
"range",
|
|
448
|
+
[f.input_field("start", types.Number), f.input_field("stop", types.Number), f.input_field("step", types.Number), f.field("result", types.Number)],
|
|
449
|
+
overloads=[
|
|
450
|
+
f.relation("range", [f.input_field("start", types.Int64), f.input_field("stop", types.Int64), f.input_field("step", types.Int64), f.field("result", types.Int64)]),
|
|
451
|
+
f.relation("range", [f.input_field("start", types.Int128), f.input_field("stop", types.Int128), f.input_field("step", types.Int128), f.field("result", types.Int128)]),
|
|
452
|
+
],
|
|
453
|
+
)
|
|
452
454
|
|
|
453
455
|
hash = f.relation("hash", [f.input_field("args", types.AnyList), f.field("hash", types.Hash)])
|
|
454
456
|
|
|
@@ -124,10 +124,10 @@ class Flatten(Pass):
|
|
|
124
124
|
output
|
|
125
125
|
"""
|
|
126
126
|
|
|
127
|
-
def __init__(self,
|
|
127
|
+
def __init__(self, use_sql: bool=False):
|
|
128
128
|
super().__init__()
|
|
129
129
|
self.name_cache = NameCache(start_from_one=True)
|
|
130
|
-
self.
|
|
130
|
+
self._use_sql = use_sql
|
|
131
131
|
|
|
132
132
|
|
|
133
133
|
#--------------------------------------------------
|
|
@@ -181,7 +181,12 @@ class Flatten(Pass):
|
|
|
181
181
|
def handle(self, task: ir.Task, ctx: Context) -> Flatten.HandleResult:
|
|
182
182
|
if isinstance(task, ir.Logical):
|
|
183
183
|
return self.handle_logical(task, ctx)
|
|
184
|
-
elif isinstance(task, ir.Union):
|
|
184
|
+
elif isinstance(task, ir.Union) and (task.hoisted or self._use_sql):
|
|
185
|
+
# Only flatten Unions which hoist variables. If there are no hoisted variables,
|
|
186
|
+
# then the Union acts as a filter, and it can be inefficient to flatten it.
|
|
187
|
+
#
|
|
188
|
+
# However, for the SQL backend, we always need to flatten Unions for correct SQL
|
|
189
|
+
# generation.
|
|
185
190
|
return self.handle_union(task, ctx)
|
|
186
191
|
elif isinstance(task, ir.Match):
|
|
187
192
|
return self.handle_match(task, ctx)
|
|
@@ -238,7 +243,7 @@ class Flatten(Pass):
|
|
|
238
243
|
# If there are outputs, flatten each into its own top-level rule, along with its
|
|
239
244
|
# dependencies.
|
|
240
245
|
if groups["outputs"]:
|
|
241
|
-
if
|
|
246
|
+
if self._use_sql:
|
|
242
247
|
ctx.rewrite_ctx.top_level.append(ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations))
|
|
243
248
|
return Flatten.HandleResult(None)
|
|
244
249
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import Callable, Generator, Generic, IO, Iterable, Optional, Sequence, Tuple, TypeVar, cast, Hashable
|
|
2
|
+
from typing import Callable, Generator, Generic, IO, Iterable, Optional, Sequence, Tuple, TypeVar, cast, Hashable, Union
|
|
3
|
+
import types
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
|
|
5
6
|
#--------------------------------------------------
|
|
@@ -345,7 +346,7 @@ def rewrite_set(t: type[T], f: Callable[[T], T], items: FrozenOrderedSet[T]) ->
|
|
|
345
346
|
return items
|
|
346
347
|
return ordered_set(*new_items).frozen()
|
|
347
348
|
|
|
348
|
-
def rewrite_list(t: type[T], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple[T, ...]:
|
|
349
|
+
def rewrite_list(t: Union[type[T], types.UnionType], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple[T, ...]:
|
|
349
350
|
""" Map a function over a list, returning a new list with the results. Avoid allocating a new list if the function is the identity. """
|
|
350
351
|
new_items: Optional[list[T]] = None
|
|
351
352
|
for i in range(len(items)):
|
|
@@ -359,15 +360,15 @@ def rewrite_list(t: type[T], f: Callable[[T], T], items: Tuple[T, ...]) -> Tuple
|
|
|
359
360
|
return items
|
|
360
361
|
return tuple(new_items)
|
|
361
362
|
|
|
362
|
-
def flatten_iter(items: Iterable[object], t: type[T]) -> Generator[T, None, None]:
|
|
363
|
+
def flatten_iter(items: Iterable[object], t: Union[type[T], types.UnionType]) -> Generator[T, None, None]:
|
|
363
364
|
"""Yield items from a nested iterable structure one at a time."""
|
|
364
365
|
for item in items:
|
|
365
366
|
if isinstance(item, (list, tuple, OrderedSet)):
|
|
366
367
|
yield from flatten_iter(item, t)
|
|
367
368
|
elif isinstance(item, t):
|
|
368
|
-
yield item
|
|
369
|
+
yield cast(T, item)
|
|
369
370
|
|
|
370
|
-
def flatten_tuple(items: Iterable[object], t: type[T]) -> tuple[T, ...]:
|
|
371
|
+
def flatten_tuple(items: Iterable[object], t: Union[type[T], types.UnionType]) -> tuple[T, ...]:
|
|
371
372
|
""" Flatten the nested iterable structure into a tuple."""
|
|
372
373
|
return tuple(flatten_iter(items, t))
|
|
373
374
|
|
|
@@ -6222,9 +6222,9 @@ class Graph():
|
|
|
6222
6222
|
def _distance_reversed_non_weighted(self):
|
|
6223
6223
|
"""Lazily define and cache the self._distance_reversed_non_weighted relationship, a non-public helper."""
|
|
6224
6224
|
_distance_reversed_non_weighted_rel = self._model.Relationship(f"{{node_u:{self._NodeConceptStr}}} and {{node_v:{self._NodeConceptStr}}} have a reversed distance of {{d:Integer}}")
|
|
6225
|
-
node_u, node_v, node_n
|
|
6225
|
+
node_u, node_v, node_n = self.Node.ref(), self.Node.ref(), self.Node.ref()
|
|
6226
6226
|
node_u, node_v, d = union(
|
|
6227
|
-
where(node_u == node_v, d1
|
|
6227
|
+
where(node_u == node_v, d1 := 0).select(node_u, node_v, d1), # Base case.
|
|
6228
6228
|
where(self._edge(node_v, node_n),
|
|
6229
6229
|
d2 := _distance_reversed_non_weighted_rel(node_u, node_n, Integer) + 1).select(node_u, node_v, d2) # Recursive case.
|
|
6230
6230
|
)
|
|
@@ -6326,13 +6326,12 @@ class Graph():
|
|
|
6326
6326
|
_is_connected_rel.annotate(annotations.track("graphs", "is_connected"))
|
|
6327
6327
|
|
|
6328
6328
|
where(
|
|
6329
|
-
|
|
6330
|
-
|
|
6331
|
-
|
|
6332
|
-
|
|
6333
|
-
|
|
6334
|
-
|
|
6335
|
-
).define(_is_connected_rel(False))
|
|
6329
|
+
union(
|
|
6330
|
+
self._num_nodes(0),
|
|
6331
|
+
count(self._reachable_from_min_node(self.Node.ref())) == self._num_nodes(Integer.ref())
|
|
6332
|
+
)
|
|
6333
|
+
).define(_is_connected_rel(True)) \
|
|
6334
|
+
| define(_is_connected_rel(False))
|
|
6336
6335
|
|
|
6337
6336
|
return _is_connected_rel
|
|
6338
6337
|
|
|
@@ -9,16 +9,15 @@ import uuid
|
|
|
9
9
|
from pandas import DataFrame
|
|
10
10
|
from typing import Any, Optional, Literal, TYPE_CHECKING
|
|
11
11
|
from snowflake.snowpark import Session
|
|
12
|
-
import relationalai as rai
|
|
13
12
|
|
|
14
13
|
from relationalai import debugging
|
|
15
14
|
from relationalai.clients import result_helpers
|
|
16
15
|
from relationalai.clients.util import IdentityParser, escape_for_f_string
|
|
17
|
-
from relationalai.clients.snowflake import APP_NAME
|
|
16
|
+
from relationalai.clients.resources.snowflake import APP_NAME, create_resources_instance
|
|
18
17
|
from relationalai.semantics.metamodel import ir, executor as e, factory as f
|
|
19
18
|
from relationalai.semantics.rel import Compiler
|
|
20
19
|
from relationalai.clients.config import Config
|
|
21
|
-
from relationalai.tools.constants import
|
|
20
|
+
from relationalai.tools.constants import Generation, QUERY_ATTRIBUTES_HEADER
|
|
22
21
|
from relationalai.tools.query_utils import prepare_metadata_for_headers
|
|
23
22
|
|
|
24
23
|
if TYPE_CHECKING:
|
|
@@ -53,15 +52,11 @@ class RelExecutor(e.Executor):
|
|
|
53
52
|
if not self._resources:
|
|
54
53
|
with debugging.span("create_session"):
|
|
55
54
|
self.dry_run |= bool(self.config.get("compiler.dry_run", False))
|
|
56
|
-
resource_class = rai.clients.snowflake.Resources
|
|
57
|
-
if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
|
|
58
|
-
resource_class = rai.clients.snowflake.DirectAccessResources
|
|
59
55
|
# NOTE: language="rel" is required for Rel execution. It is the default, but
|
|
60
56
|
# we set it explicitly here to be sure.
|
|
61
|
-
self._resources =
|
|
62
|
-
dry_run=self.dry_run,
|
|
57
|
+
self._resources = create_resources_instance(
|
|
63
58
|
config=self.config,
|
|
64
|
-
|
|
59
|
+
dry_run=self.dry_run,
|
|
65
60
|
connection=self.connection,
|
|
66
61
|
language="rel",
|
|
67
62
|
)
|
|
@@ -163,13 +158,20 @@ class RelExecutor(e.Executor):
|
|
|
163
158
|
raise errors.RAIExceptionSet(all_errors)
|
|
164
159
|
|
|
165
160
|
def _export(self, raw_code: str, dest: Table, actual_cols: list[str], declared_cols: list[str], update:bool, headers: dict[str, Any] | None = None):
|
|
161
|
+
# _export is Snowflake-specific and requires Snowflake Resources
|
|
162
|
+
# It calls Snowflake stored procedures (APP_NAME.api.exec_into_table, etc.)
|
|
163
|
+
# LocalResources doesn't support this functionality
|
|
164
|
+
from relationalai.clients.local import LocalResources
|
|
165
|
+
if isinstance(self.resources, LocalResources):
|
|
166
|
+
raise NotImplementedError("Export functionality is not supported in local mode. Use Snowflake Resources instead.")
|
|
167
|
+
|
|
166
168
|
_exec = self.resources._exec
|
|
167
169
|
output_table = "out" + str(uuid.uuid4()).replace("-", "_")
|
|
168
170
|
txn_id = None
|
|
169
171
|
artifacts = None
|
|
170
172
|
dest_database, dest_schema, dest_table, _ = IdentityParser(dest._fqn, require_all_parts=True).to_list()
|
|
171
173
|
dest_fqn = dest._fqn
|
|
172
|
-
assert self.resources._session
|
|
174
|
+
assert self.resources._session # All Snowflake Resources have _session
|
|
173
175
|
with debugging.span("transaction"):
|
|
174
176
|
try:
|
|
175
177
|
with debugging.span("exec_format") as span:
|
|
@@ -258,7 +260,7 @@ class RelExecutor(e.Executor):
|
|
|
258
260
|
SELECT 1
|
|
259
261
|
FROM {dest_database}.INFORMATION_SCHEMA.TABLES
|
|
260
262
|
WHERE table_schema = '{dest_schema}'
|
|
261
|
-
|
|
263
|
+
AND table_name = '{dest_table}'
|
|
262
264
|
)) THEN
|
|
263
265
|
EXECUTE IMMEDIATE 'TRUNCATE TABLE {dest_fqn}';
|
|
264
266
|
END IF;
|
|
@@ -267,6 +269,7 @@ class RelExecutor(e.Executor):
|
|
|
267
269
|
else:
|
|
268
270
|
raise e
|
|
269
271
|
if txn_id:
|
|
272
|
+
# These methods are available on all Snowflake Resources
|
|
270
273
|
artifact_info = self.resources._list_exec_async_artifacts(txn_id, headers=headers)
|
|
271
274
|
with debugging.span("fetch"):
|
|
272
275
|
artifacts = self.resources._download_results(artifact_info, txn_id, "ABORTED")
|
|
@@ -33,7 +33,7 @@ class Compiler(c.Compiler):
|
|
|
33
33
|
ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
|
|
34
34
|
InferTypes(),
|
|
35
35
|
DNFUnionSplitter(),
|
|
36
|
-
Flatten(
|
|
36
|
+
Flatten(use_sql=True),
|
|
37
37
|
rewrite.RecursiveUnion(),
|
|
38
38
|
rewrite.DoubleNegation(),
|
|
39
39
|
rewrite.SortOutputQuery()
|
|
@@ -1264,7 +1264,7 @@ class ModelToSQL:
|
|
|
1264
1264
|
assert isinstance(index, ir.Var) and isinstance(part, ir.Var), "Third and fourth arguments (index, part) must be variables"
|
|
1265
1265
|
builtin_vars[part] = part_expr
|
|
1266
1266
|
builtin_vars[index] = index_expr
|
|
1267
|
-
elif relation == builtins.range:
|
|
1267
|
+
elif relation == builtins.range or relation in builtins.range.overloads:
|
|
1268
1268
|
assert len(args) == 4, f"Expected 4 args for `range`, got {len(args)}: {args}"
|
|
1269
1269
|
start_raw, stop_raw, step_raw, result = args
|
|
1270
1270
|
start = self._var_to_expr(start_raw, reference, resolve_builtin_var, var_to_construct)
|
|
@@ -17,6 +17,7 @@ from relationalai.semantics.sql.executor.result_helpers import format_columns
|
|
|
17
17
|
from relationalai.semantics.metamodel.visitor import collect_by_type
|
|
18
18
|
from relationalai.semantics.metamodel.typer import typer
|
|
19
19
|
from relationalai.tools.constants import USE_DIRECT_ACCESS
|
|
20
|
+
from relationalai.clients.resources.snowflake import Resources, DirectAccessResources, Provider
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
22
23
|
from relationalai.semantics.snowflake import Table
|
|
@@ -51,17 +52,20 @@ class SnowflakeExecutor(e.Executor):
|
|
|
51
52
|
if not self._resources:
|
|
52
53
|
with debugging.span("create_session"):
|
|
53
54
|
self.dry_run |= bool(self.config.get("compiler.dry_run", False))
|
|
54
|
-
resource_class =
|
|
55
|
+
resource_class: type = Resources
|
|
55
56
|
if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
|
|
56
|
-
resource_class =
|
|
57
|
-
self._resources = resource_class(
|
|
58
|
-
|
|
57
|
+
resource_class = DirectAccessResources
|
|
58
|
+
self._resources = resource_class(
|
|
59
|
+
dry_run=self.dry_run, config=self.config, generation=rai.Generation.QB,
|
|
60
|
+
connection=self.connection,
|
|
61
|
+
language="sql",
|
|
62
|
+
)
|
|
59
63
|
return self._resources
|
|
60
64
|
|
|
61
65
|
@property
|
|
62
66
|
def provider(self):
|
|
63
67
|
if not self._provider:
|
|
64
|
-
self._provider =
|
|
68
|
+
self._provider = Provider(resources=self.resources)
|
|
65
69
|
return self._provider
|
|
66
70
|
|
|
67
71
|
def execute(self, model: ir.Model, task: ir.Task, format:Literal["pandas", "snowpark"]="pandas",
|
|
@@ -10,7 +10,7 @@ from relationalai.semantics.tests.utils import reset_state
|
|
|
10
10
|
from relationalai.semantics.internal import internal
|
|
11
11
|
from relationalai.clients.result_helpers import sort_data_frame_result
|
|
12
12
|
from relationalai.clients.util import IdentityParser
|
|
13
|
-
from relationalai.clients.snowflake import Provider as SFProvider
|
|
13
|
+
from relationalai.clients.resources.snowflake import Provider as SFProvider
|
|
14
14
|
from relationalai import Provider
|
|
15
15
|
from typing import cast, Dict, Union
|
|
16
16
|
from pathlib import Path
|