relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/config/shims.py +1 -0
- relationalai/semantics/__init__.py +7 -1
- relationalai/semantics/frontend/base.py +19 -13
- relationalai/semantics/frontend/core.py +30 -2
- relationalai/semantics/frontend/front_compiler.py +38 -11
- relationalai/semantics/frontend/pprint.py +1 -1
- relationalai/semantics/metamodel/rewriter.py +6 -2
- relationalai/semantics/metamodel/typer.py +70 -26
- relationalai/semantics/reasoners/__init__.py +11 -0
- relationalai/semantics/reasoners/graph/__init__.py +38 -0
- relationalai/semantics/reasoners/graph/core.py +9015 -0
- relationalai/shims/hoister.py +9 -0
- relationalai/shims/mm2v0.py +32 -24
- relationalai/tools/cli/cli.py +138 -0
- relationalai/tools/cli/docs.py +394 -0
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/METADATA +5 -3
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/RECORD +29 -24
- v0/relationalai/clients/exec_txn_poller.py +91 -0
- v0/relationalai/clients/resources/snowflake/__init__.py +2 -2
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +16 -10
- v0/relationalai/clients/resources/snowflake/snowflake.py +43 -14
- v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
- v0/relationalai/errors.py +18 -0
- v0/relationalai/semantics/lqp/executor.py +3 -1
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
- v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +335 -84
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/WHEEL +0 -0
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/top_level.txt +0 -0
|
@@ -13,7 +13,7 @@ from ...config import Config, ConfigStore, ENDPOINT_FILE
|
|
|
13
13
|
from ...direct_access_client import DirectAccessClient
|
|
14
14
|
from ...types import EngineState
|
|
15
15
|
from ...util import get_pyrel_version, poll_with_specified_overhead, safe_json_loads, ms_to_timestamp
|
|
16
|
-
from ....errors import ResponseStatusException, QueryTimeoutExceededException
|
|
16
|
+
from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException
|
|
17
17
|
from snowflake.snowpark import Session
|
|
18
18
|
|
|
19
19
|
# Import UseIndexResources to enable use_index functionality with direct access
|
|
@@ -27,6 +27,7 @@ from typing import Iterable
|
|
|
27
27
|
|
|
28
28
|
# Constants
|
|
29
29
|
TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
30
|
+
TXN_ABORT_REASON_GUARD_RAILS = "guard rail violation"
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
class DirectAccessResources(UseIndexResources):
|
|
@@ -355,15 +356,20 @@ class DirectAccessResources(UseIndexResources):
|
|
|
355
356
|
if txn_id in self._pending_transactions:
|
|
356
357
|
self._pending_transactions.remove(txn_id)
|
|
357
358
|
|
|
358
|
-
if status == "ABORTED"
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
359
|
+
if status == "ABORTED":
|
|
360
|
+
reason = transaction.get("abort_reason", "")
|
|
361
|
+
|
|
362
|
+
if reason == TXN_ABORT_REASON_TIMEOUT:
|
|
363
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
364
|
+
timeout_ms = int(transaction.get("timeout_ms", 0))
|
|
365
|
+
timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
366
|
+
raise QueryTimeoutExceededException(
|
|
367
|
+
timeout_mins=timeout_mins,
|
|
368
|
+
query_id=txn_id,
|
|
369
|
+
config_file_path=config_file_path,
|
|
370
|
+
)
|
|
371
|
+
elif reason == TXN_ABORT_REASON_GUARD_RAILS:
|
|
372
|
+
raise GuardRailsException(response_content.get("progress", {}))
|
|
367
373
|
|
|
368
374
|
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
369
375
|
return status == "COMPLETED" or status == "ABORTED"
|
|
@@ -15,6 +15,7 @@ import hashlib
|
|
|
15
15
|
from dataclasses import dataclass
|
|
16
16
|
|
|
17
17
|
from ....auth.token_handler import TokenHandler
|
|
18
|
+
from v0.relationalai.clients.exec_txn_poller import ExecTxnPoller, query_complete_message
|
|
18
19
|
import snowflake.snowpark
|
|
19
20
|
|
|
20
21
|
from ....rel_utils import sanitize_identifier, to_fqn_relation_name
|
|
@@ -54,7 +55,7 @@ from .util import (
|
|
|
54
55
|
)
|
|
55
56
|
from ....environments import runtime_env, HexEnvironment, SnowbookEnvironment
|
|
56
57
|
from .... import dsl, rel, metamodel as m
|
|
57
|
-
from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
|
|
58
|
+
from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, GuardRailsException, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
|
|
58
59
|
from concurrent.futures import ThreadPoolExecutor
|
|
59
60
|
from datetime import datetime, timedelta
|
|
60
61
|
from snowflake.snowpark.types import StringType, StructField, StructType
|
|
@@ -105,6 +106,16 @@ PYREL_ROOT_DB = 'pyrel_root_db'
|
|
|
105
106
|
TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
|
|
106
107
|
|
|
107
108
|
TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
109
|
+
GUARDRAILS_ABORT_REASON = "guard rail violation"
|
|
110
|
+
|
|
111
|
+
PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
|
|
112
|
+
|
|
113
|
+
#--------------------------------------------------
|
|
114
|
+
# Helpers
|
|
115
|
+
#--------------------------------------------------
|
|
116
|
+
|
|
117
|
+
def should_print_txn_progress(config) -> bool:
|
|
118
|
+
return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
|
|
108
119
|
|
|
109
120
|
#--------------------------------------------------
|
|
110
121
|
# Resources
|
|
@@ -1411,15 +1422,18 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1411
1422
|
if txn_id in self._pending_transactions:
|
|
1412
1423
|
self._pending_transactions.remove(txn_id)
|
|
1413
1424
|
|
|
1414
|
-
if status == "ABORTED"
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1425
|
+
if status == "ABORTED":
|
|
1426
|
+
if response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
1427
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
1428
|
+
# todo: use the timeout returned alongside the transaction as soon as it's exposed
|
|
1429
|
+
timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
1430
|
+
raise QueryTimeoutExceededException(
|
|
1431
|
+
timeout_mins=timeout_mins,
|
|
1432
|
+
query_id=txn_id,
|
|
1433
|
+
config_file_path=config_file_path,
|
|
1434
|
+
)
|
|
1435
|
+
elif response_row.get("ABORT_REASON", "") == GUARDRAILS_ABORT_REASON:
|
|
1436
|
+
raise GuardRailsException()
|
|
1423
1437
|
|
|
1424
1438
|
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
1425
1439
|
return status == "COMPLETED" or status == "ABORTED"
|
|
@@ -1704,6 +1718,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1704
1718
|
query_attrs_dict = json.loads(request_headers.get("X-Query-Attributes", "{}"))
|
|
1705
1719
|
|
|
1706
1720
|
with debugging.span("transaction", **query_attrs_dict) as txn_span:
|
|
1721
|
+
txn_start_time = time.time()
|
|
1707
1722
|
with debugging.span("create_v2", **query_attrs_dict) as create_span:
|
|
1708
1723
|
request_headers['user-agent'] = get_pyrel_version(self.generation)
|
|
1709
1724
|
request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
|
|
@@ -1734,8 +1749,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1734
1749
|
create_span["txn_id"] = txn_id
|
|
1735
1750
|
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
1736
1751
|
|
|
1752
|
+
print_txn_progress = should_print_txn_progress(self.config)
|
|
1753
|
+
|
|
1737
1754
|
# fast path: transaction already finished
|
|
1738
1755
|
if state in ["COMPLETED", "ABORTED"]:
|
|
1756
|
+
txn_end_time = time.time()
|
|
1739
1757
|
if txn_id in self._pending_transactions:
|
|
1740
1758
|
self._pending_transactions.remove(txn_id)
|
|
1741
1759
|
|
|
@@ -1744,13 +1762,24 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1744
1762
|
filename = row['FILENAME']
|
|
1745
1763
|
artifact_info[filename] = row
|
|
1746
1764
|
|
|
1765
|
+
txn_duration = txn_end_time - txn_start_time
|
|
1766
|
+
if print_txn_progress:
|
|
1767
|
+
print(
|
|
1768
|
+
query_complete_message(txn_id, txn_duration, status_header=True)
|
|
1769
|
+
)
|
|
1770
|
+
|
|
1747
1771
|
# Slow path: transaction not done yet; start polling
|
|
1748
1772
|
else:
|
|
1749
1773
|
self._pending_transactions.append(txn_id)
|
|
1774
|
+
# Use the interactive poller for transaction status
|
|
1750
1775
|
with debugging.span("wait", txn_id=txn_id):
|
|
1751
|
-
|
|
1752
|
-
|
|
1753
|
-
|
|
1776
|
+
if print_txn_progress:
|
|
1777
|
+
poller = ExecTxnPoller(resource=self, txn_id=txn_id, headers=request_headers, txn_start_time=txn_start_time)
|
|
1778
|
+
poller.poll()
|
|
1779
|
+
else:
|
|
1780
|
+
poll_with_specified_overhead(
|
|
1781
|
+
lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
|
|
1782
|
+
)
|
|
1754
1783
|
artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
|
|
1755
1784
|
|
|
1756
1785
|
with debugging.span("fetch"):
|
|
@@ -2408,7 +2437,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2408
2437
|
return None
|
|
2409
2438
|
return results[0][0]
|
|
2410
2439
|
|
|
2411
|
-
# CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
|
|
2440
|
+
# CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
|
|
2412
2441
|
# list_databases, list_sf_schemas, list_tables) are now in CLIResources class
|
|
2413
2442
|
# schema_info is kept in base Resources class since it's used by SnowflakeSchema._fetch_info()
|
|
2414
2443
|
|
|
@@ -189,6 +189,9 @@ class UseIndexPoller:
|
|
|
189
189
|
# on every 5th iteration we reset the cdc status, so it will be checked again
|
|
190
190
|
self.should_check_cdc = True
|
|
191
191
|
|
|
192
|
+
# Flag to only check data stream health once in the first call
|
|
193
|
+
self.check_data_stream_health = True
|
|
194
|
+
|
|
192
195
|
self.wait_for_stream_sync = self.res.config.get(
|
|
193
196
|
"wait_for_stream_sync", WAIT_FOR_STREAM_SYNC
|
|
194
197
|
)
|
|
@@ -503,6 +506,7 @@ class UseIndexPoller:
|
|
|
503
506
|
"init_engine_async": self.init_engine_async,
|
|
504
507
|
"language": self.language,
|
|
505
508
|
"data_freshness_mins": self.data_freshness,
|
|
509
|
+
"check_data_stream_health": self.check_data_stream_health
|
|
506
510
|
})
|
|
507
511
|
|
|
508
512
|
request_headers = debugging.add_current_propagation_headers(self.headers)
|
|
@@ -535,6 +539,7 @@ class UseIndexPoller:
|
|
|
535
539
|
errors = use_index_data.get("errors", [])
|
|
536
540
|
relations = use_index_data.get("relations", {})
|
|
537
541
|
cdc_enabled = use_index_data.get("cdcEnabled", False)
|
|
542
|
+
health_checked = use_index_data.get("healthChecked", False)
|
|
538
543
|
if self.check_ready_count % ERP_CHECK_FREQUENCY == 0 or not cdc_enabled:
|
|
539
544
|
self.should_check_cdc = True
|
|
540
545
|
else:
|
|
@@ -542,6 +547,9 @@ class UseIndexPoller:
|
|
|
542
547
|
|
|
543
548
|
if engines and self.init_engine_async:
|
|
544
549
|
self.init_engine_async = False
|
|
550
|
+
|
|
551
|
+
if self.check_data_stream_health and health_checked:
|
|
552
|
+
self.check_data_stream_health = False
|
|
545
553
|
|
|
546
554
|
break_loop = False
|
|
547
555
|
has_stream_errors = False
|
v0/relationalai/errors.py
CHANGED
|
@@ -2436,6 +2436,24 @@ class QueryTimeoutExceededException(RAIException):
|
|
|
2436
2436
|
Consider increasing the 'query_timeout_mins' parameter in your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} to allow more time for query execution.
|
|
2437
2437
|
""")
|
|
2438
2438
|
|
|
2439
|
+
class GuardRailsException(RAIException):
|
|
2440
|
+
def __init__(self, progress: dict[str, Any]={}):
|
|
2441
|
+
self.name = "Guard Rails Violation"
|
|
2442
|
+
self.message = "Transaction aborted due to guard rails violation."
|
|
2443
|
+
self.progress = progress
|
|
2444
|
+
self.content = self.format_message()
|
|
2445
|
+
super().__init__(self.message, self.name, self.content)
|
|
2446
|
+
|
|
2447
|
+
def format_message(self):
|
|
2448
|
+
messages = [] if self.progress else [self.message]
|
|
2449
|
+
for task in self.progress.get("tasks", {}).values():
|
|
2450
|
+
for warning_type, warning_data in task.get("warnings", {}).items():
|
|
2451
|
+
messages.append(textwrap.dedent(f"""
|
|
2452
|
+
Relation Name: [yellow]{task["task_name"]}[/yellow]
|
|
2453
|
+
Warning: {warning_type}
|
|
2454
|
+
Message: {warning_data["message"]}
|
|
2455
|
+
"""))
|
|
2456
|
+
return "\n".join(messages)
|
|
2439
2457
|
|
|
2440
2458
|
#--------------------------------------------------
|
|
2441
2459
|
# Azure Exceptions
|
|
@@ -31,7 +31,9 @@ if TYPE_CHECKING:
|
|
|
31
31
|
|
|
32
32
|
# Whenever the logic engine introduces a breaking change in behaviour, we bump this version
|
|
33
33
|
# once the client is ready to handle it.
|
|
34
|
-
|
|
34
|
+
#
|
|
35
|
+
# [2026-01-09] bumping to 1 to opt-into hard validation errors from the engine
|
|
36
|
+
DEFAULT_LQP_SEMANTICS_VERSION = "1"
|
|
35
37
|
|
|
36
38
|
class LQPExecutor(e.Executor):
|
|
37
39
|
"""Executes LQP using the RAI client."""
|
|
@@ -118,6 +118,17 @@ class ExtractKeys(Pass):
|
|
|
118
118
|
the same here).
|
|
119
119
|
"""
|
|
120
120
|
class ExtractKeysRewriter(Rewriter):
|
|
121
|
+
def __init__(self):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.compound_keys: dict[Any, ir.Var] = {}
|
|
124
|
+
|
|
125
|
+
def _get_compound_key(self, orig_keys: Iterable[ir.Var]) -> ir.Var:
|
|
126
|
+
if orig_keys in self.compound_keys:
|
|
127
|
+
return self.compound_keys[orig_keys]
|
|
128
|
+
compound_key = f.var("compound_key", types.Hash)
|
|
129
|
+
self.compound_keys[orig_keys] = compound_key
|
|
130
|
+
return compound_key
|
|
131
|
+
|
|
121
132
|
def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
|
|
122
133
|
outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
|
|
123
134
|
# We are not in a logical with an output at this level.
|
|
@@ -170,7 +181,7 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
170
181
|
annos = list(output.annotations)
|
|
171
182
|
annos.append(f.annotation(builtins.output_keys, tuple(output_keys)))
|
|
172
183
|
# Create a compound key that will be used in place of the original keys.
|
|
173
|
-
compound_key =
|
|
184
|
+
compound_key = self._get_compound_key(output_keys)
|
|
174
185
|
|
|
175
186
|
for key_combination in combinations:
|
|
176
187
|
missing_keys = OrderedSet.from_iterable(output_keys)
|
|
@@ -192,8 +203,13 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
192
203
|
# handle the construct node in each clone
|
|
193
204
|
values: list[ir.Value] = [compound_key.type]
|
|
194
205
|
for key in output_keys:
|
|
195
|
-
|
|
196
|
-
|
|
206
|
+
if isinstance(key.type, ir.UnionType):
|
|
207
|
+
# the typer can derive union types when multiple distinct entities flow
|
|
208
|
+
# into a relation's field, so use AnyEntity as the type marker
|
|
209
|
+
values.append(ir.Literal(types.String, "AnyEntity"))
|
|
210
|
+
else:
|
|
211
|
+
assert isinstance(key.type, ir.ScalarType)
|
|
212
|
+
values.append(ir.Literal(types.String, key.type.name))
|
|
197
213
|
if key in key_combination:
|
|
198
214
|
values.append(key)
|
|
199
215
|
body.add(ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen()))
|
|
@@ -408,6 +424,12 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
408
424
|
for arg in args[:-1]:
|
|
409
425
|
extended_vars.add(arg)
|
|
410
426
|
there_is_progress = True
|
|
427
|
+
elif isinstance(task, ir.Not):
|
|
428
|
+
if isinstance(task.task, ir.Logical):
|
|
429
|
+
hoisted = helpers.hoisted_vars(task.task.hoisted)
|
|
430
|
+
if var in hoisted:
|
|
431
|
+
partitions[var].add(task)
|
|
432
|
+
there_is_progress = True
|
|
411
433
|
else:
|
|
412
434
|
assert False, f"invalid node kind {type(task)}"
|
|
413
435
|
|