relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. relationalai/config/shims.py +1 -0
  2. relationalai/semantics/__init__.py +7 -1
  3. relationalai/semantics/frontend/base.py +19 -13
  4. relationalai/semantics/frontend/core.py +30 -2
  5. relationalai/semantics/frontend/front_compiler.py +38 -11
  6. relationalai/semantics/frontend/pprint.py +1 -1
  7. relationalai/semantics/metamodel/rewriter.py +6 -2
  8. relationalai/semantics/metamodel/typer.py +70 -26
  9. relationalai/semantics/reasoners/__init__.py +11 -0
  10. relationalai/semantics/reasoners/graph/__init__.py +38 -0
  11. relationalai/semantics/reasoners/graph/core.py +9015 -0
  12. relationalai/shims/hoister.py +9 -0
  13. relationalai/shims/mm2v0.py +32 -24
  14. relationalai/tools/cli/cli.py +138 -0
  15. relationalai/tools/cli/docs.py +394 -0
  16. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/METADATA +5 -3
  17. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/RECORD +29 -24
  18. v0/relationalai/clients/exec_txn_poller.py +91 -0
  19. v0/relationalai/clients/resources/snowflake/__init__.py +2 -2
  20. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +16 -10
  21. v0/relationalai/clients/resources/snowflake/snowflake.py +43 -14
  22. v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
  23. v0/relationalai/errors.py +18 -0
  24. v0/relationalai/semantics/lqp/executor.py +3 -1
  25. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  26. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +335 -84
  27. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/WHEEL +0 -0
  28. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/entry_points.txt +0 -0
  29. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a4.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from ...config import Config, ConfigStore, ENDPOINT_FILE
13
13
  from ...direct_access_client import DirectAccessClient
14
14
  from ...types import EngineState
15
15
  from ...util import get_pyrel_version, poll_with_specified_overhead, safe_json_loads, ms_to_timestamp
16
- from ....errors import ResponseStatusException, QueryTimeoutExceededException
16
+ from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException
17
17
  from snowflake.snowpark import Session
18
18
 
19
19
  # Import UseIndexResources to enable use_index functionality with direct access
@@ -27,6 +27,7 @@ from typing import Iterable
27
27
 
28
28
  # Constants
29
29
  TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
30
+ TXN_ABORT_REASON_GUARD_RAILS = "guard rail violation"
30
31
 
31
32
 
32
33
  class DirectAccessResources(UseIndexResources):
@@ -355,15 +356,20 @@ class DirectAccessResources(UseIndexResources):
355
356
  if txn_id in self._pending_transactions:
356
357
  self._pending_transactions.remove(txn_id)
357
358
 
358
- if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
359
- config_file_path = getattr(self.config, 'file_path', None)
360
- timeout_ms = int(transaction.get("timeout_ms", 0))
361
- timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
362
- raise QueryTimeoutExceededException(
363
- timeout_mins=timeout_mins,
364
- query_id=txn_id,
365
- config_file_path=config_file_path,
366
- )
359
+ if status == "ABORTED":
360
+ reason = transaction.get("abort_reason", "")
361
+
362
+ if reason == TXN_ABORT_REASON_TIMEOUT:
363
+ config_file_path = getattr(self.config, 'file_path', None)
364
+ timeout_ms = int(transaction.get("timeout_ms", 0))
365
+ timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
366
+ raise QueryTimeoutExceededException(
367
+ timeout_mins=timeout_mins,
368
+ query_id=txn_id,
369
+ config_file_path=config_file_path,
370
+ )
371
+ elif reason == TXN_ABORT_REASON_GUARD_RAILS:
372
+ raise GuardRailsException(response_content.get("progress", {}))
367
373
 
368
374
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
369
375
  return status == "COMPLETED" or status == "ABORTED"
@@ -15,6 +15,7 @@ import hashlib
15
15
  from dataclasses import dataclass
16
16
 
17
17
  from ....auth.token_handler import TokenHandler
18
+ from v0.relationalai.clients.exec_txn_poller import ExecTxnPoller, query_complete_message
18
19
  import snowflake.snowpark
19
20
 
20
21
  from ....rel_utils import sanitize_identifier, to_fqn_relation_name
@@ -54,7 +55,7 @@ from .util import (
54
55
  )
55
56
  from ....environments import runtime_env, HexEnvironment, SnowbookEnvironment
56
57
  from .... import dsl, rel, metamodel as m
57
- from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
58
+ from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, GuardRailsException, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
58
59
  from concurrent.futures import ThreadPoolExecutor
59
60
  from datetime import datetime, timedelta
60
61
  from snowflake.snowpark.types import StringType, StructField, StructType
@@ -105,6 +106,16 @@ PYREL_ROOT_DB = 'pyrel_root_db'
105
106
  TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
106
107
 
107
108
  TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
109
+ GUARDRAILS_ABORT_REASON = "guard rail violation"
110
+
111
+ PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
112
+
113
+ #--------------------------------------------------
114
+ # Helpers
115
+ #--------------------------------------------------
116
+
117
+ def should_print_txn_progress(config) -> bool:
118
+ return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
108
119
 
109
120
  #--------------------------------------------------
110
121
  # Resources
@@ -1411,15 +1422,18 @@ Otherwise, remove it from your '{profile}' configuration profile.
1411
1422
  if txn_id in self._pending_transactions:
1412
1423
  self._pending_transactions.remove(txn_id)
1413
1424
 
1414
- if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1415
- config_file_path = getattr(self.config, 'file_path', None)
1416
- # todo: use the timeout returned alongside the transaction as soon as it's exposed
1417
- timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1418
- raise QueryTimeoutExceededException(
1419
- timeout_mins=timeout_mins,
1420
- query_id=txn_id,
1421
- config_file_path=config_file_path,
1422
- )
1425
+ if status == "ABORTED":
1426
+ if response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1427
+ config_file_path = getattr(self.config, 'file_path', None)
1428
+ # todo: use the timeout returned alongside the transaction as soon as it's exposed
1429
+ timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1430
+ raise QueryTimeoutExceededException(
1431
+ timeout_mins=timeout_mins,
1432
+ query_id=txn_id,
1433
+ config_file_path=config_file_path,
1434
+ )
1435
+ elif response_row.get("ABORT_REASON", "") == GUARDRAILS_ABORT_REASON:
1436
+ raise GuardRailsException()
1423
1437
 
1424
1438
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
1425
1439
  return status == "COMPLETED" or status == "ABORTED"
@@ -1704,6 +1718,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1704
1718
  query_attrs_dict = json.loads(request_headers.get("X-Query-Attributes", "{}"))
1705
1719
 
1706
1720
  with debugging.span("transaction", **query_attrs_dict) as txn_span:
1721
+ txn_start_time = time.time()
1707
1722
  with debugging.span("create_v2", **query_attrs_dict) as create_span:
1708
1723
  request_headers['user-agent'] = get_pyrel_version(self.generation)
1709
1724
  request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
@@ -1734,8 +1749,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
1734
1749
  create_span["txn_id"] = txn_id
1735
1750
  debugging.event("transaction_created", txn_span, txn_id=txn_id)
1736
1751
 
1752
+ print_txn_progress = should_print_txn_progress(self.config)
1753
+
1737
1754
  # fast path: transaction already finished
1738
1755
  if state in ["COMPLETED", "ABORTED"]:
1756
+ txn_end_time = time.time()
1739
1757
  if txn_id in self._pending_transactions:
1740
1758
  self._pending_transactions.remove(txn_id)
1741
1759
 
@@ -1744,13 +1762,24 @@ Otherwise, remove it from your '{profile}' configuration profile.
1744
1762
  filename = row['FILENAME']
1745
1763
  artifact_info[filename] = row
1746
1764
 
1765
+ txn_duration = txn_end_time - txn_start_time
1766
+ if print_txn_progress:
1767
+ print(
1768
+ query_complete_message(txn_id, txn_duration, status_header=True)
1769
+ )
1770
+
1747
1771
  # Slow path: transaction not done yet; start polling
1748
1772
  else:
1749
1773
  self._pending_transactions.append(txn_id)
1774
+ # Use the interactive poller for transaction status
1750
1775
  with debugging.span("wait", txn_id=txn_id):
1751
- poll_with_specified_overhead(
1752
- lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
1753
- )
1776
+ if print_txn_progress:
1777
+ poller = ExecTxnPoller(resource=self, txn_id=txn_id, headers=request_headers, txn_start_time=txn_start_time)
1778
+ poller.poll()
1779
+ else:
1780
+ poll_with_specified_overhead(
1781
+ lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
1782
+ )
1754
1783
  artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
1755
1784
 
1756
1785
  with debugging.span("fetch"):
@@ -2408,7 +2437,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2408
2437
  return None
2409
2438
  return results[0][0]
2410
2439
 
2411
- # CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
2440
+ # CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
2412
2441
  # list_databases, list_sf_schemas, list_tables) are now in CLIResources class
2413
2442
  # schema_info is kept in base Resources class since it's used by SnowflakeSchema._fetch_info()
2414
2443
 
@@ -189,6 +189,9 @@ class UseIndexPoller:
189
189
  # on every 5th iteration we reset the cdc status, so it will be checked again
190
190
  self.should_check_cdc = True
191
191
 
192
+ # Flag to only check data stream health once in the first call
193
+ self.check_data_stream_health = True
194
+
192
195
  self.wait_for_stream_sync = self.res.config.get(
193
196
  "wait_for_stream_sync", WAIT_FOR_STREAM_SYNC
194
197
  )
@@ -503,6 +506,7 @@ class UseIndexPoller:
503
506
  "init_engine_async": self.init_engine_async,
504
507
  "language": self.language,
505
508
  "data_freshness_mins": self.data_freshness,
509
+ "check_data_stream_health": self.check_data_stream_health
506
510
  })
507
511
 
508
512
  request_headers = debugging.add_current_propagation_headers(self.headers)
@@ -535,6 +539,7 @@ class UseIndexPoller:
535
539
  errors = use_index_data.get("errors", [])
536
540
  relations = use_index_data.get("relations", {})
537
541
  cdc_enabled = use_index_data.get("cdcEnabled", False)
542
+ health_checked = use_index_data.get("healthChecked", False)
538
543
  if self.check_ready_count % ERP_CHECK_FREQUENCY == 0 or not cdc_enabled:
539
544
  self.should_check_cdc = True
540
545
  else:
@@ -542,6 +547,9 @@ class UseIndexPoller:
542
547
 
543
548
  if engines and self.init_engine_async:
544
549
  self.init_engine_async = False
550
+
551
+ if self.check_data_stream_health and health_checked:
552
+ self.check_data_stream_health = False
545
553
 
546
554
  break_loop = False
547
555
  has_stream_errors = False
v0/relationalai/errors.py CHANGED
@@ -2436,6 +2436,24 @@ class QueryTimeoutExceededException(RAIException):
2436
2436
  Consider increasing the 'query_timeout_mins' parameter in your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} to allow more time for query execution.
2437
2437
  """)
2438
2438
 
2439
+ class GuardRailsException(RAIException):
2440
+ def __init__(self, progress: dict[str, Any]={}):
2441
+ self.name = "Guard Rails Violation"
2442
+ self.message = "Transaction aborted due to guard rails violation."
2443
+ self.progress = progress
2444
+ self.content = self.format_message()
2445
+ super().__init__(self.message, self.name, self.content)
2446
+
2447
+ def format_message(self):
2448
+ messages = [] if self.progress else [self.message]
2449
+ for task in self.progress.get("tasks", {}).values():
2450
+ for warning_type, warning_data in task.get("warnings", {}).items():
2451
+ messages.append(textwrap.dedent(f"""
2452
+ Relation Name: [yellow]{task["task_name"]}[/yellow]
2453
+ Warning: {warning_type}
2454
+ Message: {warning_data["message"]}
2455
+ """))
2456
+ return "\n".join(messages)
2439
2457
 
2440
2458
  #--------------------------------------------------
2441
2459
  # Azure Exceptions
@@ -31,7 +31,9 @@ if TYPE_CHECKING:
31
31
 
32
32
  # Whenever the logic engine introduces a breaking change in behaviour, we bump this version
33
33
  # once the client is ready to handle it.
34
- DEFAULT_LQP_SEMANTICS_VERSION = "0"
34
+ #
35
+ # [2026-01-09] bumping to 1 to opt-into hard validation errors from the engine
36
+ DEFAULT_LQP_SEMANTICS_VERSION = "1"
35
37
 
36
38
  class LQPExecutor(e.Executor):
37
39
  """Executes LQP using the RAI client."""
@@ -118,6 +118,17 @@ class ExtractKeys(Pass):
118
118
  the same here).
119
119
  """
120
120
  class ExtractKeysRewriter(Rewriter):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.compound_keys: dict[Any, ir.Var] = {}
124
+
125
+ def _get_compound_key(self, orig_keys: Iterable[ir.Var]) -> ir.Var:
126
+ if orig_keys in self.compound_keys:
127
+ return self.compound_keys[orig_keys]
128
+ compound_key = f.var("compound_key", types.Hash)
129
+ self.compound_keys[orig_keys] = compound_key
130
+ return compound_key
131
+
121
132
  def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
122
133
  outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
123
134
  # We are not in a logical with an output at this level.
@@ -170,7 +181,7 @@ class ExtractKeysRewriter(Rewriter):
170
181
  annos = list(output.annotations)
171
182
  annos.append(f.annotation(builtins.output_keys, tuple(output_keys)))
172
183
  # Create a compound key that will be used in place of the original keys.
173
- compound_key = f.var("compound_key", types.Hash)
184
+ compound_key = self._get_compound_key(output_keys)
174
185
 
175
186
  for key_combination in combinations:
176
187
  missing_keys = OrderedSet.from_iterable(output_keys)
@@ -192,8 +203,13 @@ class ExtractKeysRewriter(Rewriter):
192
203
  # handle the construct node in each clone
193
204
  values: list[ir.Value] = [compound_key.type]
194
205
  for key in output_keys:
195
- assert isinstance(key.type, ir.ScalarType)
196
- values.append(ir.Literal(types.String, key.type.name))
206
+ if isinstance(key.type, ir.UnionType):
207
+ # the typer can derive union types when multiple distinct entities flow
208
+ # into a relation's field, so use AnyEntity as the type marker
209
+ values.append(ir.Literal(types.String, "AnyEntity"))
210
+ else:
211
+ assert isinstance(key.type, ir.ScalarType)
212
+ values.append(ir.Literal(types.String, key.type.name))
197
213
  if key in key_combination:
198
214
  values.append(key)
199
215
  body.add(ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen()))
@@ -408,6 +424,12 @@ class ExtractKeysRewriter(Rewriter):
408
424
  for arg in args[:-1]:
409
425
  extended_vars.add(arg)
410
426
  there_is_progress = True
427
+ elif isinstance(task, ir.Not):
428
+ if isinstance(task.task, ir.Logical):
429
+ hoisted = helpers.hoisted_vars(task.task.hoisted)
430
+ if var in hoisted:
431
+ partitions[var].add(task)
432
+ there_is_progress = True
411
433
  else:
412
434
  assert False, f"invalid node kind {type(task)}"
413
435