relationalai 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import Dict, Optional, TYPE_CHECKING
5
+
6
+ from relationalai import debugging
7
+ from relationalai.clients.util import poll_with_specified_overhead
8
+ from relationalai.tools.cli_controls import create_progress
9
+ from relationalai.util.format import format_duration
10
+
11
+ if TYPE_CHECKING:
12
+ from relationalai.clients.resources.snowflake import Resources
13
+
14
+ # Polling behavior constants
15
+ POLL_OVERHEAD_RATE = 0.1 # Overhead rate for exponential backoff
16
+
17
+ # Text color constants
18
+ GREEN_COLOR = '\033[92m'
19
+ GRAY_COLOR = '\033[90m'
20
+ ENDC = '\033[0m'
21
+
22
+
23
+ class ExecTxnPoller:
24
+ """
25
+ Encapsulates the polling logic for exec_async transaction completion.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ resource: "Resources",
31
+ txn_id: str,
32
+ headers: Optional[Dict] = None,
33
+ txn_start_time: Optional[float] = None,
34
+ ):
35
+ self.res = resource
36
+ self.txn_id = txn_id
37
+ self.headers = headers or {}
38
+ self.txn_start_time = txn_start_time or time.time()
39
+
40
+ def poll(self) -> bool:
41
+ """
42
+ Poll for transaction completion with interactive progress display.
43
+
44
+ Returns:
45
+ True if transaction completed successfully, False otherwise
46
+ """
47
+
48
+ # Don't show duration summary - we handle our own completion message
49
+ with create_progress(
50
+ description="Evaluating Query...",
51
+ success_message="", # We'll handle this ourselves
52
+ leading_newline=False,
53
+ trailing_newline=False,
54
+ show_duration_summary=False,
55
+ ) as progress:
56
+ def check_status() -> bool:
57
+ """Check if transaction is complete."""
58
+ elapsed = time.time() - self.txn_start_time
59
+ # Update the main status with elapsed time
60
+ progress.update_main_status(
61
+ query_progress_message(self.txn_id, elapsed)
62
+ )
63
+ return self.res._check_exec_async_status(self.txn_id, headers=self.headers)
64
+
65
+ with debugging.span("wait", txn_id=self.txn_id):
66
+ poll_with_specified_overhead(check_status, overhead_rate=POLL_OVERHEAD_RATE)
67
+
68
+ # Calculate final duration
69
+ total_duration = time.time() - self.txn_start_time
70
+
71
+ # Update to success message with duration
72
+ progress.update_main_status(
73
+ query_complete_message(self.txn_id, total_duration)
74
+ )
75
+
76
+ return True
77
+
78
+ def query_progress_message(id: str, duration: float) -> str:
79
+ return (
80
+ # Print with whitespace to align with the end of the transaction ID
81
+ f"Evaluating Query... {format_duration(duration):>18}\n" +
82
+ f"{GRAY_COLOR}Query: {id}{ENDC}"
83
+ )
84
+
85
+ def query_complete_message(id: str, duration: float, status_header: bool = False) -> str:
86
+ return (
87
+ (f"{GREEN_COLOR}✅ " if status_header else "") +
88
+ # Print with whitespace to align with the end of the transaction ID
89
+ f"Query Complete: {format_duration(duration):>24}\n" +
90
+ f"{GRAY_COLOR}Query: {id}{ENDC}"
91
+ )
@@ -2,7 +2,7 @@
2
2
  Snowflake resources module.
3
3
  """
4
4
  # Import order matters - Resources must be imported first since other classes depend on it
5
- from .snowflake import Resources, Provider, Graph, SnowflakeClient, APP_NAME, PYREL_ROOT_DB, ExecContext, INTERNAL_ENGINE_SIZES, ENGINE_SIZES_AWS, ENGINE_SIZES_AZURE, PrimaryKey
5
+ from .snowflake import Resources, Provider, Graph, SnowflakeClient, APP_NAME, PYREL_ROOT_DB, ExecContext, INTERNAL_ENGINE_SIZES, ENGINE_SIZES_AWS, ENGINE_SIZES_AZURE, PrimaryKey, PRINT_TXN_PROGRESS_FLAG
6
6
 
7
7
  # These imports depend on Resources, so they come after
8
8
  from .cli_resources import CLIResources
@@ -14,7 +14,7 @@ __all__ = [
14
14
  'Resources', 'DirectAccessResources', 'Provider', 'Graph', 'SnowflakeClient',
15
15
  'APP_NAME', 'PYREL_ROOT_DB', 'CLIResources', 'UseIndexResources', 'ExecContext',
16
16
  'INTERNAL_ENGINE_SIZES', 'ENGINE_SIZES_AWS', 'ENGINE_SIZES_AZURE', 'PrimaryKey',
17
- 'create_resources_instance',
17
+ 'PRINT_TXN_PROGRESS_FLAG', 'create_resources_instance',
18
18
  ]
19
19
 
20
20
 
@@ -12,12 +12,13 @@ from ....environments import runtime_env, SnowbookEnvironment
12
12
  from ...config import Config, ConfigStore, ENDPOINT_FILE
13
13
  from ...direct_access_client import DirectAccessClient
14
14
  from ...types import EngineState
15
- from ...util import get_pyrel_version, poll_with_specified_overhead, safe_json_loads, ms_to_timestamp
16
- from ....errors import ResponseStatusException, QueryTimeoutExceededException
15
+ from ...util import get_pyrel_version, safe_json_loads, ms_to_timestamp
16
+ from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException
17
17
  from snowflake.snowpark import Session
18
18
 
19
19
  # Import UseIndexResources to enable use_index functionality with direct access
20
20
  from .use_index_resources import UseIndexResources
21
+ from .snowflake import TxnCreationResult
21
22
 
22
23
  # Import helper functions from util
23
24
  from .util import is_engine_issue as _is_engine_issue, is_database_issue as _is_database_issue, collect_error_messages
@@ -27,6 +28,7 @@ from typing import Iterable
27
28
 
28
29
  # Constants
29
30
  TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
31
+ TXN_ABORT_REASON_GUARD_RAILS = "guard rail violation"
30
32
 
31
33
 
32
34
  class DirectAccessResources(UseIndexResources):
@@ -217,83 +219,59 @@ class DirectAccessResources(UseIndexResources):
217
219
 
218
220
  return response
219
221
 
220
- def _exec_async_v2(
222
+ def _create_v2_txn(
221
223
  self,
222
224
  database: str,
223
- engine: Union[str, None],
225
+ engine: str | None,
224
226
  raw_code: str,
225
- inputs: Dict | None = None,
226
- readonly=True,
227
- nowait_durable=False,
228
- headers: Dict[str, str] | None = None,
229
- bypass_index=False,
230
- language: str = "rel",
231
- query_timeout_mins: int | None = None,
232
- gi_setup_skipped: bool = False,
233
- ):
234
-
235
- with debugging.span("transaction") as txn_span:
236
- with debugging.span("create_v2") as create_span:
237
-
238
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
239
-
240
- payload = {
241
- "dbname": database,
242
- "engine_name": engine,
243
- "query": raw_code,
244
- "v1_inputs": inputs,
245
- "nowait_durable": nowait_durable,
246
- "readonly": readonly,
247
- "language": language,
248
- }
249
- if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
250
- query_timeout_mins = int(timeout_value)
251
- if query_timeout_mins is not None:
252
- payload["timeout_mins"] = query_timeout_mins
253
- query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
254
-
255
- # Add gi_setup_skipped to headers
256
- if headers is None:
257
- headers = {}
258
- headers["gi_setup_skipped"] = str(gi_setup_skipped)
259
- headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
260
-
261
- response = self._txn_request_with_gi_retry(
262
- payload, headers, query_params, engine
263
- )
264
-
265
- artifact_info = {}
266
- response_content = response.json()
227
+ inputs: Dict,
228
+ headers: Dict[str, str],
229
+ readonly: bool,
230
+ nowait_durable: bool,
231
+ bypass_index: bool,
232
+ language: str,
233
+ query_timeout_mins: int | None,
234
+ ) -> TxnCreationResult:
235
+ """
236
+ Create a transaction via direct HTTP access and return the result.
267
237
 
268
- txn_id = response_content["transaction"]['id']
269
- state = response_content["transaction"]['state']
238
+ This override uses HTTP requests instead of SQL stored procedures.
239
+ """
240
+ use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
270
241
 
271
- txn_span["txn_id"] = txn_id
272
- create_span["txn_id"] = txn_id
273
- debugging.event("transaction_created", txn_span, txn_id=txn_id)
242
+ payload = {
243
+ "dbname": database,
244
+ "engine_name": engine,
245
+ "query": raw_code,
246
+ "v1_inputs": inputs,
247
+ "nowait_durable": nowait_durable,
248
+ "readonly": readonly,
249
+ "language": language,
250
+ }
251
+ if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
252
+ query_timeout_mins = int(timeout_value)
253
+ if query_timeout_mins is not None:
254
+ payload["timeout_mins"] = query_timeout_mins
255
+ query_params = {"use_graph_index": str(use_graph_index and not bypass_index)}
256
+
257
+ response = self._txn_request_with_gi_retry(
258
+ payload, headers, query_params, engine
259
+ )
274
260
 
275
- # fast path: transaction already finished
276
- if state in ["COMPLETED", "ABORTED"]:
277
- if txn_id in self._pending_transactions:
278
- self._pending_transactions.remove(txn_id)
261
+ response_content = response.json()
279
262
 
280
- # Process rows to get the rest of the artifacts
281
- for result in response_content.get("results", []):
282
- filename = result['filename']
283
- # making keys uppercase to match the old behavior
284
- artifact_info[filename] = {k.upper(): v for k, v in result.items()}
263
+ txn_id = response_content["transaction"]['id']
264
+ state = response_content["transaction"]['state']
285
265
 
286
- # Slow path: transaction not done yet; start polling
287
- else:
288
- self._pending_transactions.append(txn_id)
289
- with debugging.span("wait", txn_id=txn_id):
290
- poll_with_specified_overhead(
291
- lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
292
- )
293
- artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
266
+ # Build artifact_info if transaction completed immediately (fast path)
267
+ artifact_info: Dict[str, Dict] = {}
268
+ if state in ["COMPLETED", "ABORTED"]:
269
+ for result in response_content.get("results", []):
270
+ filename = result['filename']
271
+ # making keys uppercase to match the old behavior
272
+ artifact_info[filename] = {k.upper(): v for k, v in result.items()}
294
273
 
295
- with debugging.span("fetch"):
296
- return self._download_results(artifact_info, txn_id, state)
274
+ return TxnCreationResult(txn_id=txn_id, state=state, artifact_info=artifact_info)
297
275
 
298
276
  def _prepare_index(
299
277
  self,
@@ -355,15 +333,20 @@ class DirectAccessResources(UseIndexResources):
355
333
  if txn_id in self._pending_transactions:
356
334
  self._pending_transactions.remove(txn_id)
357
335
 
358
- if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
359
- config_file_path = getattr(self.config, 'file_path', None)
360
- timeout_ms = int(transaction.get("timeout_ms", 0))
361
- timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
362
- raise QueryTimeoutExceededException(
363
- timeout_mins=timeout_mins,
364
- query_id=txn_id,
365
- config_file_path=config_file_path,
366
- )
336
+ if status == "ABORTED":
337
+ reason = transaction.get("abort_reason", "")
338
+
339
+ if reason == TXN_ABORT_REASON_TIMEOUT:
340
+ config_file_path = getattr(self.config, 'file_path', None)
341
+ timeout_ms = int(transaction.get("timeout_ms", 0))
342
+ timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
343
+ raise QueryTimeoutExceededException(
344
+ timeout_mins=timeout_mins,
345
+ query_id=txn_id,
346
+ config_file_path=config_file_path,
347
+ )
348
+ elif reason == TXN_ABORT_REASON_GUARD_RAILS:
349
+ raise GuardRailsException(response_content.get("progress", {}))
367
350
 
368
351
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
369
352
  return status == "COMPLETED" or status == "ABORTED"
@@ -15,6 +15,7 @@ import hashlib
15
15
  from dataclasses import dataclass
16
16
 
17
17
  from ....auth.token_handler import TokenHandler
18
+ from relationalai.clients.exec_txn_poller import ExecTxnPoller, query_complete_message
18
19
  import snowflake.snowpark
19
20
 
20
21
  from ....rel_utils import sanitize_identifier, to_fqn_relation_name
@@ -54,7 +55,7 @@ from .util import (
54
55
  )
55
56
  from ....environments import runtime_env, HexEnvironment, SnowbookEnvironment
56
57
  from .... import dsl, rel, metamodel as m
57
- from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
58
+ from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, GuardRailsException, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
58
59
  from concurrent.futures import ThreadPoolExecutor
59
60
  from datetime import datetime, timedelta
60
61
  from snowflake.snowpark.types import StringType, StructField, StructType
@@ -105,6 +106,16 @@ PYREL_ROOT_DB = 'pyrel_root_db'
105
106
  TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
106
107
 
107
108
  TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
109
+ GUARDRAILS_ABORT_REASON = "guard rail violation"
110
+
111
+ PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
112
+
113
+ #--------------------------------------------------
114
+ # Helpers
115
+ #--------------------------------------------------
116
+
117
+ def should_print_txn_progress(config) -> bool:
118
+ return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
108
119
 
109
120
  #--------------------------------------------------
110
121
  # Resources
@@ -131,6 +142,19 @@ class ExecContext:
131
142
  skip_engine_db_error_retry=self.skip_engine_db_error_retry
132
143
  )
133
144
 
145
+
146
+ @dataclass
147
+ class TxnCreationResult:
148
+ """Result of creating a transaction via _create_v2_txn.
149
+
150
+ This standardizes the response format between different implementations
151
+ (SQL stored procedure vs HTTP direct access).
152
+ """
153
+ txn_id: str
154
+ state: str
155
+ artifact_info: Dict[str, Dict] # Populated if fast-path (state is COMPLETED/ABORTED)
156
+
157
+
134
158
  class Resources(ResourcesBase):
135
159
  def __init__(
136
160
  self,
@@ -1411,15 +1435,18 @@ Otherwise, remove it from your '{profile}' configuration profile.
1411
1435
  if txn_id in self._pending_transactions:
1412
1436
  self._pending_transactions.remove(txn_id)
1413
1437
 
1414
- if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1415
- config_file_path = getattr(self.config, 'file_path', None)
1416
- # todo: use the timeout returned alongside the transaction as soon as it's exposed
1417
- timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1418
- raise QueryTimeoutExceededException(
1419
- timeout_mins=timeout_mins,
1420
- query_id=txn_id,
1421
- config_file_path=config_file_path,
1422
- )
1438
+ if status == "ABORTED":
1439
+ if response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1440
+ config_file_path = getattr(self.config, 'file_path', None)
1441
+ # todo: use the timeout returned alongside the transaction as soon as it's exposed
1442
+ timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1443
+ raise QueryTimeoutExceededException(
1444
+ timeout_mins=timeout_mins,
1445
+ query_id=txn_id,
1446
+ config_file_path=config_file_path,
1447
+ )
1448
+ elif response_row.get("ABORT_REASON", "") == GUARDRAILS_ABORT_REASON:
1449
+ raise GuardRailsException()
1423
1450
 
1424
1451
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
1425
1452
  return status == "COMPLETED" or status == "ABORTED"
@@ -1654,6 +1681,72 @@ Otherwise, remove it from your '{profile}' configuration profile.
1654
1681
  raise Exception("Failed to create transaction")
1655
1682
  return response
1656
1683
 
1684
+ def _create_v2_txn(
1685
+ self,
1686
+ database: str,
1687
+ engine: str | None,
1688
+ raw_code: str,
1689
+ inputs: Dict,
1690
+ headers: Dict[str, str],
1691
+ readonly: bool,
1692
+ nowait_durable: bool,
1693
+ bypass_index: bool,
1694
+ language: str,
1695
+ query_timeout_mins: int | None,
1696
+ ) -> TxnCreationResult:
1697
+ """
1698
+ Create a transaction and return the result.
1699
+
1700
+ This method handles calling the RAI app stored procedure to create a transaction
1701
+ and parses the response into a standardized TxnCreationResult format.
1702
+
1703
+ This method can be overridden by subclasses (e.g., DirectAccessResources)
1704
+ to use different transport mechanisms (HTTP instead of SQL).
1705
+
1706
+ Args:
1707
+ database: Database/model name
1708
+ engine: Engine name (optional)
1709
+ raw_code: Code to execute (REL, LQP, or SQL)
1710
+ inputs: Input parameters for the query
1711
+ headers: HTTP headers (must be prepared by caller)
1712
+ readonly: Whether the transaction is read-only
1713
+ nowait_durable: Whether to wait for durable writes
1714
+ bypass_index: Whether to bypass graph index setup
1715
+ language: Query language ("rel" or "lqp")
1716
+ query_timeout_mins: Optional query timeout in minutes
1717
+
1718
+ Returns:
1719
+ TxnCreationResult containing txn_id, state, and artifact_info
1720
+ """
1721
+ response = self._exec_rai_app(
1722
+ database=database,
1723
+ engine=engine,
1724
+ raw_code=raw_code,
1725
+ inputs=inputs,
1726
+ readonly=readonly,
1727
+ nowait_durable=nowait_durable,
1728
+ request_headers=headers,
1729
+ bypass_index=bypass_index,
1730
+ language=language,
1731
+ query_timeout_mins=query_timeout_mins,
1732
+ )
1733
+
1734
+ rows = list(iter(response))
1735
+
1736
+ # process the first row since txn_id and state are the same for all rows
1737
+ first_row = rows[0]
1738
+ txn_id = first_row['ID']
1739
+ state = first_row['STATE']
1740
+
1741
+ # Build artifact_info if transaction completed immediately (fast path)
1742
+ artifact_info: Dict[str, Dict] = {}
1743
+ if state in ["COMPLETED", "ABORTED"]:
1744
+ for row in rows:
1745
+ filename = row['FILENAME']
1746
+ artifact_info[filename] = row
1747
+
1748
+ return TxnCreationResult(txn_id=txn_id, state=state, artifact_info=artifact_info)
1749
+
1657
1750
  def _exec_async_v2(
1658
1751
  self,
1659
1752
  database: str,
@@ -1672,15 +1765,20 @@ Otherwise, remove it from your '{profile}' configuration profile.
1672
1765
  High-level async execution method with transaction polling and artifact management.
1673
1766
 
1674
1767
  This is the core method for executing queries asynchronously. It:
1675
- 1. Creates a transaction by calling _exec_rai_app
1768
+ 1. Creates a transaction by calling _create_v2_txn
1676
1769
  2. Handles two execution paths:
1677
1770
  - Fast path: Transaction completes immediately (COMPLETED/ABORTED)
1678
1771
  - Slow path: Transaction is pending, requires polling until completion
1679
1772
  3. Manages pending transactions list
1680
1773
  4. Downloads and returns query results/artifacts
1681
1774
 
1682
- This method is called by _execute_code (base implementation) and can be
1683
- overridden by child classes (e.g., DirectAccessResources uses HTTP instead).
1775
+ This method is called by _execute_code (base implementation), and calls the
1776
+ following methods that can be overridden by child classes (e.g.,
1777
+ DirectAccessResources uses HTTP instead):
1778
+ - _create_v2_txn
1779
+ - _check_exec_async_status
1780
+ - _list_exec_async_artifacts
1781
+ - _download_results
1684
1782
 
1685
1783
  Args:
1686
1784
  database: Database/model name
@@ -1704,53 +1802,62 @@ Otherwise, remove it from your '{profile}' configuration profile.
1704
1802
  query_attrs_dict = json.loads(request_headers.get("X-Query-Attributes", "{}"))
1705
1803
 
1706
1804
  with debugging.span("transaction", **query_attrs_dict) as txn_span:
1805
+ txn_start_time = time.time()
1707
1806
  with debugging.span("create_v2", **query_attrs_dict) as create_span:
1807
+ # Prepare headers for transaction creation
1708
1808
  request_headers['user-agent'] = get_pyrel_version(self.generation)
1709
1809
  request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
1710
1810
  request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
1711
- response = self._exec_rai_app(
1811
+
1812
+ # Create the transaction
1813
+ result = self._create_v2_txn(
1712
1814
  database=database,
1713
1815
  engine=engine,
1714
1816
  raw_code=raw_code,
1715
1817
  inputs=inputs,
1818
+ headers=request_headers,
1716
1819
  readonly=readonly,
1717
1820
  nowait_durable=nowait_durable,
1718
- request_headers=request_headers,
1719
1821
  bypass_index=bypass_index,
1720
1822
  language=language,
1721
1823
  query_timeout_mins=query_timeout_mins,
1722
1824
  )
1723
1825
 
1724
- artifact_info = {}
1725
- rows = list(iter(response))
1726
-
1727
- # process the first row since txn_id and state are the same for all rows
1728
- first_row = rows[0]
1729
- txn_id = first_row['ID']
1730
- state = first_row['STATE']
1731
- filename = first_row['FILENAME']
1826
+ txn_id = result.txn_id
1827
+ state = result.state
1732
1828
 
1733
1829
  txn_span["txn_id"] = txn_id
1734
1830
  create_span["txn_id"] = txn_id
1735
1831
  debugging.event("transaction_created", txn_span, txn_id=txn_id)
1736
1832
 
1833
+ print_txn_progress = should_print_txn_progress(self.config)
1834
+
1737
1835
  # fast path: transaction already finished
1738
1836
  if state in ["COMPLETED", "ABORTED"]:
1837
+ txn_end_time = time.time()
1739
1838
  if txn_id in self._pending_transactions:
1740
1839
  self._pending_transactions.remove(txn_id)
1741
1840
 
1742
- # Process rows to get the rest of the artifacts
1743
- for row in rows:
1744
- filename = row['FILENAME']
1745
- artifact_info[filename] = row
1841
+ artifact_info = result.artifact_info
1842
+
1843
+ txn_duration = txn_end_time - txn_start_time
1844
+ if print_txn_progress:
1845
+ print(
1846
+ query_complete_message(txn_id, txn_duration, status_header=True)
1847
+ )
1746
1848
 
1747
1849
  # Slow path: transaction not done yet; start polling
1748
1850
  else:
1749
1851
  self._pending_transactions.append(txn_id)
1852
+ # Use the interactive poller for transaction status
1750
1853
  with debugging.span("wait", txn_id=txn_id):
1751
- poll_with_specified_overhead(
1752
- lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
1753
- )
1854
+ if print_txn_progress:
1855
+ poller = ExecTxnPoller(resource=self, txn_id=txn_id, headers=request_headers, txn_start_time=txn_start_time)
1856
+ poller.poll()
1857
+ else:
1858
+ poll_with_specified_overhead(
1859
+ lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
1860
+ )
1754
1861
  artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
1755
1862
 
1756
1863
  with debugging.span("fetch"):
@@ -2408,7 +2515,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2408
2515
  return None
2409
2516
  return results[0][0]
2410
2517
 
2411
- # CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
2518
+ # CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
2412
2519
  # list_databases, list_sf_schemas, list_tables) are now in CLIResources class
2413
2520
  # schema_info is kept in base Resources class since it's used by SnowflakeSchema._fetch_info()
2414
2521
 
@@ -189,6 +189,9 @@ class UseIndexPoller:
189
189
  # on every 5th iteration we reset the cdc status, so it will be checked again
190
190
  self.should_check_cdc = True
191
191
 
192
+ # Flag to only check data stream health once in the first call
193
+ self.check_data_stream_health = True
194
+
192
195
  self.wait_for_stream_sync = self.res.config.get(
193
196
  "wait_for_stream_sync", WAIT_FOR_STREAM_SYNC
194
197
  )
@@ -503,6 +506,7 @@ class UseIndexPoller:
503
506
  "init_engine_async": self.init_engine_async,
504
507
  "language": self.language,
505
508
  "data_freshness_mins": self.data_freshness,
509
+ "check_data_stream_health": self.check_data_stream_health
506
510
  })
507
511
 
508
512
  request_headers = debugging.add_current_propagation_headers(self.headers)
@@ -535,6 +539,7 @@ class UseIndexPoller:
535
539
  errors = use_index_data.get("errors", [])
536
540
  relations = use_index_data.get("relations", {})
537
541
  cdc_enabled = use_index_data.get("cdcEnabled", False)
542
+ health_checked = use_index_data.get("healthChecked", False)
538
543
  if self.check_ready_count % ERP_CHECK_FREQUENCY == 0 or not cdc_enabled:
539
544
  self.should_check_cdc = True
540
545
  else:
@@ -542,6 +547,9 @@ class UseIndexPoller:
542
547
 
543
548
  if engines and self.init_engine_async:
544
549
  self.init_engine_async = False
550
+
551
+ if self.check_data_stream_health and health_checked:
552
+ self.check_data_stream_health = False
545
553
 
546
554
  break_loop = False
547
555
  has_stream_errors = False
relationalai/errors.py CHANGED
@@ -2436,6 +2436,24 @@ class QueryTimeoutExceededException(RAIException):
2436
2436
  Consider increasing the 'query_timeout_mins' parameter in your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} to allow more time for query execution.
2437
2437
  """)
2438
2438
 
2439
+ class GuardRailsException(RAIException):
2440
+ def __init__(self, progress: dict[str, Any]={}):
2441
+ self.name = "Guard Rails Violation"
2442
+ self.message = "Transaction aborted due to guard rails violation."
2443
+ self.progress = progress
2444
+ self.content = self.format_message()
2445
+ super().__init__(self.message, self.name, self.content)
2446
+
2447
+ def format_message(self):
2448
+ messages = [] if self.progress else [self.message]
2449
+ for task in self.progress.get("tasks", {}).values():
2450
+ for warning_type, warning_data in task.get("warnings", {}).items():
2451
+ messages.append(textwrap.dedent(f"""
2452
+ Relation Name: [yellow]{task["task_name"]}[/yellow]
2453
+ Warning: {warning_type}
2454
+ Message: {warning_data["message"]}
2455
+ """))
2456
+ return "\n".join(messages)
2439
2457
 
2440
2458
  #--------------------------------------------------
2441
2459
  # Azure Exceptions
@@ -424,6 +424,12 @@ class ExtractKeysRewriter(Rewriter):
424
424
  for arg in args[:-1]:
425
425
  extended_vars.add(arg)
426
426
  there_is_progress = True
427
+ elif isinstance(task, ir.Not):
428
+ if isinstance(task.task, ir.Logical):
429
+ hoisted = helpers.hoisted_vars(task.task.hoisted)
430
+ if var in hoisted:
431
+ partitions[var].add(task)
432
+ there_is_progress = True
427
433
  else:
428
434
  assert False, f"invalid node kind {type(task)}"
429
435
 
@@ -1,8 +1,8 @@
1
- """Solver model implementation using protobuf format.
1
+ """Solver model implementation supporting protobuf and CSV formats.
2
2
 
3
3
  This module provides the SolverModelPB class for defining optimization and
4
- constraint programming problems that are serialized to protobuf format and
5
- solved by external solver engines.
4
+ constraint programming problems that are serialized and solved by external
5
+ solver engines. Supports both protobuf (default) and CSV (future) exchange formats.
6
6
 
7
7
  Note: This protobuf-based implementation will be deprecated in favor of the
8
8
  development version (solvers_dev.py) in future releases.
@@ -23,7 +23,6 @@ from relationalai.util.timeout import calc_remaining_timeout_minutes
23
23
 
24
24
  from .common import make_name
25
25
 
26
-
27
26
  # =============================================================================
28
27
  # Solver ProtoBuf Format Constants and Helpers
29
28
  # =============================================================================
@@ -191,6 +190,7 @@ class SolverModelPB:
191
190
  """
192
191
  b.define(b.RawSource("rel", textwrap.dedent(install_rel)))
193
192
 
193
+
194
194
  # -------------------------------------------------------------------------
195
195
  # Variable Handling
196
196
  # -------------------------------------------------------------------------
@@ -501,69 +501,218 @@ class SolverModelPB:
501
501
  # Solving and Result Handling
502
502
  # -------------------------------------------------------------------------
503
503
 
504
- def solve(
505
- self, solver: Solver, log_to_console: bool = False, **kwargs: Any
504
+ def _export_model_to_csv(
505
+ self,
506
+ model_id: str,
507
+ executor: RelExecutor,
508
+ prefix_lowercase: str,
509
+ query_timeout_mins: Optional[int] = None
506
510
  ) -> None:
507
- """Solve the model.
511
+ """Export model to CSV files in Snowflake stage.
508
512
 
509
513
  Args:
510
- solver: Solver instance.
511
- log_to_console: Whether to show solver output.
512
- **kwargs: Solver options and parameters.
514
+ model_id: Unique model identifier for stage paths.
515
+ executor: RelExecutor instance.
516
+ prefix_lowercase: Prefix for relation names.
517
+ query_timeout_mins: Query timeout in minutes.
513
518
  """
514
- options = {**kwargs, "version": 1}
519
+ stage_base_no_txn = f"snowflake://APP_STATE.RAI_INTERNAL_STAGE/SOLVERS/job_{model_id}"
520
+
521
+ # Export all model relations using Rel-native export_csv in a single transaction
522
+ # Transformations (uuid_string, encode_base64) are done inline in the export query
523
+ export_rel = textwrap.dedent(f"""
524
+ // Get transaction ID for folder naming - solver service validates ownership
525
+ // Use uuid_string to get proper UUID format, then replace hyphens with underscores
526
+ def txn_id_str {{string_replace[uuid_string[current_transaction_id], "-", "_"]}}
527
+
528
+ // Define base path with txn_id in folder name: model_{{txn_id}}/
529
+ def base_path {{"{stage_base_no_txn}/model"}}
530
+
531
+ // Export variable_hash.csv - single column: HASH (UUID string)
532
+ // Transformation: convert Variable UInt128 to UUID string inline
533
+ def variable_hash_data(:HASH, v, h):
534
+ {self.Variable._name}(v) and uuid_string(v, h)
535
+
536
+ def export[:variable_hash]: {{export_csv[{{
537
+ (:path, base_path ++ "/variable_hash_" ++ txn_id_str ++ ".csv");
538
+ (:data, variable_hash_data);
539
+ (:compression, "gzip")
540
+ }}]}}
541
+
542
+ // Export variable_name.csv - columns: HASH (UUID string), VALUE (name string)
543
+ // Transformation: convert Variable UInt128 to UUID string inline
544
+ def variable_name_data(:HASH, v, h):
545
+ {prefix_lowercase}variable_name(v, _) and uuid_string(v, h)
546
+ def variable_name_data(:VALUE, v, name):
547
+ {prefix_lowercase}variable_name(v, name)
548
+
549
+ def export[:variable_name]: {{export_csv[{{
550
+ (:path, base_path ++ "/variable_name_" ++ txn_id_str ++ ".csv");
551
+ (:data, variable_name_data);
552
+ (:compression, "gzip")
553
+ }}]}}
554
+
555
+ // Export constraint.csv - single column: VALUE (base64 encoded constraint)
556
+ // Transformation: encode_base64 done inline
557
+ def constraint_data(:VALUE, c, e):
558
+ exists((s) |
559
+ {self.Constraint._name}(c) and
560
+ {prefix_lowercase}constraint_serialized(c, s) and
561
+ encode_base64(s, e))
562
+
563
+ def export[:constraint]: {{export_csv[{{
564
+ (:path, base_path ++ "/constraint_" ++ txn_id_str ++ ".csv");
565
+ (:data, constraint_data);
566
+ (:compression, "gzip")
567
+ }}]}}
568
+
569
+ // Export min_objective.csv - columns: HASH (UUID string), VALUE (base64 encoded)
570
+ // Transformations: uuid_string and encode_base64 done inline
571
+ def min_objective_data(:HASH, obj, h):
572
+ {self.MinObjective._name}(obj) and uuid_string(obj, h)
573
+ def min_objective_data(:VALUE, obj, e):
574
+ exists((s) |
575
+ {self.MinObjective._name}(obj) and
576
+ {prefix_lowercase}minobjective_serialized(obj, s) and
577
+ encode_base64(s, e))
578
+
579
+ def export[:min_objective]: {{export_csv[{{
580
+ (:path, base_path ++ "/min_objective_" ++ txn_id_str ++ ".csv");
581
+ (:data, min_objective_data);
582
+ (:compression, "gzip")
583
+ }}]}}
584
+
585
+ // Export max_objective.csv - columns: HASH (UUID string), VALUE (base64 encoded)
586
+ // Transformations: uuid_string and encode_base64 done inline
587
+ def max_objective_data(:HASH, obj, h):
588
+ {self.MaxObjective._name}(obj) and uuid_string(obj, h)
589
+ def max_objective_data(:VALUE, obj, e):
590
+ exists((s) |
591
+ {self.MaxObjective._name}(obj) and
592
+ {prefix_lowercase}maxobjective_serialized(obj, s) and
593
+ encode_base64(s, e))
594
+
595
+ def export[:max_objective]: {{export_csv[{{
596
+ (:path, base_path ++ "/max_objective_" ++ txn_id_str ++ ".csv");
597
+ (:data, max_objective_data);
598
+ (:compression, "gzip")
599
+ }}]}}
600
+ """)
601
+
602
+ executor.execute_raw(export_rel, readonly=False, query_timeout_mins=query_timeout_mins)
603
+
604
+ def _import_solver_results_from_csv(
605
+ self,
606
+ model_id: str,
607
+ executor: RelExecutor,
608
+ prefix_lowercase: str,
609
+ query_timeout_mins: Optional[int] = None
610
+ ) -> None:
611
+ """Import solver results from CSV files in Snowflake stage.
515
612
 
516
- # Validate solver options
517
- for option_key, option_value in options.items():
518
- if not isinstance(option_key, str):
519
- raise TypeError(
520
- f"Solver option keys must be strings, but got {type(option_key).__name__} for key {option_key!r}."
521
- )
522
- if not isinstance(option_value, (int, float, str, bool)):
523
- raise TypeError(
524
- f"Solver option values must be int, float, str, or bool, "
525
- f"but got {type(option_value).__name__} for option {option_key!r}."
526
- )
613
+ Loads and extracts CSV files in a single transaction to minimize overhead.
527
614
 
528
- # Three-phase solve process:
529
- # 1. Export model to Snowflake as protobuf
530
- # 2. Execute solver job (external solver reads from Snowflake)
531
- # 3. Extract and load results back into the model
532
- input_id = uuid.uuid4()
533
- model_uri = f"snowflake://APP_STATE.RAI_INTERNAL_STAGE/job-inputs/solver/{input_id}/model.binpb"
534
- sf_input_uri = f"snowflake://job-inputs/solver/{input_id}/model.binpb"
535
- payload: dict[str, Any] = {"solver": solver.solver_name.lower()}
536
- payload["options"] = options
537
- payload["model_uri"] = sf_input_uri
615
+ Args:
616
+ model_id: Unique model identifier for stage paths.
617
+ executor: RelExecutor instance.
618
+ prefix_lowercase: Prefix for relation names.
619
+ query_timeout_mins: Query timeout in minutes.
620
+ """
621
+ result_stage_base = f"snowflake://APP_STATE.RAI_INTERNAL_STAGE/SOLVERS/job_{model_id}/results"
622
+ value_parse_fn = "parse_int" if self._num_type == "int" else "parse_float"
623
+
624
+ # Single transaction: Load CSV files and extract/map results
625
+ # Use inline definitions to avoid needing declared relations
626
+ load_and_extract_rel = textwrap.dedent(f"""
627
+ // Define CSV loading inline (no declare needed)
628
+ // Load ancillary.csv - contains solver metadata (NAME, VALUE columns)
629
+ def ancillary_config[:path]: "{result_stage_base}/ancillary.csv.gz"
630
+ def ancillary_config[:syntax, :header_row]: 1
631
+ def ancillary_config[:schema, :NAME]: "string"
632
+ def ancillary_config[:schema, :VALUE]: "string"
633
+ def {prefix_lowercase}solver_ancillary_raw {{load_csv[ancillary_config]}}
634
+
635
+ // Load objective_values.csv - contains objective values (SOL_INDEX, VALUE columns)
636
+ def objective_values_config[:path]: "{result_stage_base}/objective_values.csv.gz"
637
+ def objective_values_config[:syntax, :header_row]: 1
638
+ def objective_values_config[:schema, :SOL_INDEX]: "string"
639
+ def objective_values_config[:schema, :VALUE]: "string"
640
+ def {prefix_lowercase}solver_objective_values_raw {{load_csv[objective_values_config]}}
641
+
642
+ // Load points.csv.gz - contains solution points (SOL_INDEX, VAR_HASH, VALUE columns)
643
+ def points_config[:path]: "{result_stage_base}/points.csv.gz"
644
+ def points_config[:syntax, :header_row]: 1
645
+ def points_config[:schema, :SOL_INDEX]: "string"
646
+ def points_config[:schema, :VAR_HASH]: "string"
647
+ def points_config[:schema, :VALUE]: "string"
648
+ def {prefix_lowercase}solver_points_raw {{load_csv[points_config]}}
649
+
650
+ // Clear existing result data
651
+ def delete[:{self.result_info._name}]: {self.result_info._name}
652
+ def delete[:{self.point._name}]: {self.point._name}
653
+ def delete[:{self.points._name}]: {self.points._name}
538
654
 
539
- executor = self._model._to_executor()
540
- if not isinstance(executor, RelExecutor):
541
- raise ValueError(f"Expected RelExecutor, got {type(executor).__name__}.")
542
- prefix_lowercase = f"solvermodel_{self._id}_"
655
+ // Extract ancillary data (result info) - NAME and VALUE columns
656
+ def insert(:{self.result_info._name}, key, val): {{
657
+ exists((row) |
658
+ {prefix_lowercase}solver_ancillary_raw(:NAME, row, key) and
659
+ {prefix_lowercase}solver_ancillary_raw(:VALUE, row, val))
660
+ }}
543
661
 
544
- query_timeout_mins = kwargs.get("query_timeout_mins", None)
545
- config = self._model._config
546
- if (
547
- query_timeout_mins is None
548
- and (
549
- timeout_value := config.get(
550
- "query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS
551
- )
552
- )
553
- is not None
554
- ):
555
- query_timeout_mins = int(timeout_value)
556
- config_file_path = getattr(config, "file_path", None)
557
- start_time = time.monotonic()
558
- remaining_timeout_minutes = query_timeout_mins
662
+ // Extract objective value from objective_values CSV (first solution)
663
+ def insert(:{self.result_info._name}, "objective_value", val): {{
664
+ exists((row) |
665
+ {prefix_lowercase}solver_objective_values_raw(:SOL_INDEX, row, "1") and
666
+ {prefix_lowercase}solver_objective_values_raw(:VALUE, row, val))
667
+ }}
559
668
 
560
- # Step 1: Materialize the model and store it in Snowflake
561
- print("export model")
562
- # TODO(coey): Weird hack to avoid uninitialized properties error
563
- # This forces evaluation of the Variable concept before export
564
- b.select(b.count(self.Variable)).to_df()
565
- export_model_relation = f"""
566
- // TODO maybe only want to pass names if printing - like in old setup
669
+ // Extract solution points from points.csv.gz into points property
670
+ // This file has SOL_INDEX, VAR_HASH, VALUE columns
671
+ // Convert CSV string index to Int128 for points property signature
672
+ // Convert value to Int128 (for int) or Float64 (for float)
673
+ def insert(:{self.points._name}, sol_idx_int128, var, val_converted): {{
674
+ exists((row, sol_idx_str, var_hash_str, val_str, sol_idx_int, val) |
675
+ {prefix_lowercase}solver_points_raw(:SOL_INDEX, row, sol_idx_str) and
676
+ {prefix_lowercase}solver_points_raw(:VAR_HASH, row, var_hash_str) and
677
+ {prefix_lowercase}solver_points_raw(:VALUE, row, val_str) and
678
+ parse_int(sol_idx_str, sol_idx_int) and
679
+ parse_uuid(var_hash_str, var) and
680
+ {value_parse_fn}(val_str, val) and
681
+ ::std::mirror::convert(std::mirror::typeof[Int128], sol_idx_int, sol_idx_int128) and
682
+ {'::std::mirror::convert(std::mirror::typeof[Int128], val, val_converted)' if self._num_type == 'int' else '::std::mirror::convert(std::mirror::typeof[Float64], val, val_converted)'})
683
+ }}
684
+
685
+ // Extract first solution into point property (default solution)
686
+ // Filter to SOL_INDEX = 1
687
+ def insert(:{self.point._name}, var, val_converted): {{
688
+ exists((row, var_hash_str, val_str, val) |
689
+ {prefix_lowercase}solver_points_raw(:SOL_INDEX, row, "1") and
690
+ {prefix_lowercase}solver_points_raw(:VAR_HASH, row, var_hash_str) and
691
+ {prefix_lowercase}solver_points_raw(:VALUE, row, val_str) and
692
+ parse_uuid(var_hash_str, var) and
693
+ {value_parse_fn}(val_str, val) and
694
+ {'::std::mirror::convert(std::mirror::typeof[Int128], val, val_converted)' if self._num_type == 'int' else '::std::mirror::convert(std::mirror::typeof[Float64], val, val_converted)'})
695
+ }}
696
+ """)
697
+
698
+ executor.execute_raw(load_and_extract_rel, readonly=False, query_timeout_mins=query_timeout_mins)
699
+
700
+ def _export_model_to_protobuf(
701
+ self,
702
+ model_uri: str,
703
+ executor: RelExecutor,
704
+ prefix_lowercase: str,
705
+ query_timeout_mins: Optional[int] = None
706
+ ) -> None:
707
+ """Export model to protobuf format in Snowflake stage.
708
+
709
+ Args:
710
+ model_uri: Snowflake URI for the protobuf file.
711
+ executor: RelExecutor instance.
712
+ prefix_lowercase: Prefix for relation names.
713
+ query_timeout_mins: Query timeout in minutes.
714
+ """
715
+ export_rel = f"""
567
716
  // Collect all model components into a relation for serialization
568
717
  def model_relation {{
569
718
  (:variable, {self.Variable._name});
@@ -584,31 +733,24 @@ class SolverModelPB:
584
733
  def export {{ config }}
585
734
  """
586
735
  executor.execute_raw(
587
- textwrap.dedent(export_model_relation),
588
- query_timeout_mins=remaining_timeout_minutes,
736
+ textwrap.dedent(export_rel),
737
+ query_timeout_mins=query_timeout_mins
589
738
  )
590
739
 
591
- # Step 2: Execute solver job and wait for completion
592
- print("execute solver job")
593
- remaining_timeout_minutes = calc_remaining_timeout_minutes(
594
- start_time,
595
- query_timeout_mins,
596
- config_file_path=config_file_path,
597
- )
598
- job_id = solver._exec_job(
599
- payload,
600
- log_to_console=log_to_console,
601
- query_timeout_mins=remaining_timeout_minutes,
602
- )
740
+ def _import_solver_results_from_protobuf(
741
+ self,
742
+ job_id: str,
743
+ executor: RelExecutor,
744
+ query_timeout_mins: Optional[int] = None
745
+ ) -> None:
746
+ """Import solver results from protobuf format.
603
747
 
604
- # Step 3: Extract and insert solver results into the model
605
- print("extract result")
606
- remaining_timeout_minutes = calc_remaining_timeout_minutes(
607
- start_time,
608
- query_timeout_mins,
609
- config_file_path=config_file_path,
610
- )
611
- extract_results_relation = f"""
748
+ Args:
749
+ job_id: Job identifier for result location.
750
+ executor: RelExecutor instance.
751
+ query_timeout_mins: Query timeout in minutes.
752
+ """
753
+ extract_rel = f"""
612
754
  def raw_result {{
613
755
  load_binary["snowflake://APP_STATE.RAI_INTERNAL_STAGE/job-results/{job_id}/result.binpb"]
614
756
  }}
@@ -625,6 +767,7 @@ class SolverModelPB:
625
767
  def insert(:{self.result_info._name}, key, value):
626
768
  exists((original_key) | string(extracted[original_key], value) and ::std::mirror::lower(original_key, key))
627
769
  """
770
+
628
771
  if self._num_type == "int":
629
772
  insert_points_relation = f"""
630
773
  def insert(:{self.point._name}, variable, value):
@@ -645,15 +788,123 @@ class SolverModelPB:
645
788
  ::std::mirror::convert(std::mirror::typeof[Int128], float_index, point_index)
646
789
  )
647
790
  """
791
+
648
792
  executor.execute_raw(
649
- textwrap.dedent(extract_results_relation)
650
- + textwrap.dedent(insert_points_relation),
793
+ textwrap.dedent(extract_rel) + textwrap.dedent(insert_points_relation),
651
794
  readonly=False,
652
- query_timeout_mins=remaining_timeout_minutes,
795
+ query_timeout_mins=query_timeout_mins
653
796
  )
654
797
 
655
- print("finished solve")
798
+ def solve(
799
+ self, solver: Solver, log_to_console: bool = False, **kwargs: Any
800
+ ) -> None:
801
+ """Solve the model.
802
+
803
+ Args:
804
+ solver: Solver instance.
805
+ log_to_console: Whether to show solver output.
806
+ **kwargs: Solver options and parameters.
807
+ """
808
+
809
+ use_csv_store = solver.engine_settings.get("store", {})\
810
+ .get("csv", {})\
811
+ .get("enabled", False)
812
+
813
+ print(f"Using {'csv' if use_csv_store else 'protobuf'} store...")
814
+
815
+ options = {**kwargs, "version": 1}
816
+
817
+ # Validate solver options
818
+ for option_key, option_value in options.items():
819
+ if not isinstance(option_key, str):
820
+ raise TypeError(
821
+ f"Solver option keys must be strings, but got {type(option_key).__name__} for key {option_key!r}."
822
+ )
823
+ if not isinstance(option_value, (int, float, str, bool)):
824
+ raise TypeError(
825
+ f"Solver option values must be int, float, str, or bool, "
826
+ f"but got {type(option_value).__name__} for option {option_key!r}."
827
+ )
828
+
829
+ executor = self._model._to_executor()
830
+ if not isinstance(executor, RelExecutor):
831
+ raise ValueError(f"Expected RelExecutor, got {type(executor).__name__}.")
832
+ prefix_lowercase = f"solvermodel_{self._id}_"
833
+
834
+ # Initialize timeout from config
835
+ query_timeout_mins = kwargs.get("query_timeout_mins", None)
836
+ config = self._model._config
837
+ if (
838
+ query_timeout_mins is None
839
+ and (
840
+ timeout_value := config.get(
841
+ "query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS
842
+ )
843
+ )
844
+ is not None
845
+ ):
846
+ query_timeout_mins = int(timeout_value)
847
+ config_file_path = getattr(config, "file_path", None)
848
+ start_time = time.monotonic()
849
+
850
+ # Force evaluation of Variable concept before export
851
+ b.select(b.count(self.Variable)).to_df()
852
+
853
+ # Prepare payload for solver service
854
+ payload: dict[str, Any] = {"solver": solver.solver_name.lower(), "options": options}
855
+
856
+ if use_csv_store:
857
+ # CSV format: model and results are exchanged via CSV files
858
+ model_id = str(uuid.uuid4()).upper().replace('-', '_')
859
+ payload["model_uri"] = f"snowflake://SOLVERS/job_{model_id}/model"
860
+
861
+ print("Exporting model to CSV...")
862
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
863
+ start_time, query_timeout_mins, config_file_path=config_file_path
864
+ )
865
+ self._export_model_to_csv(model_id, executor, prefix_lowercase, remaining_timeout_minutes)
866
+ print("Model CSV export completed")
867
+
868
+ print("Execute solver job")
869
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
870
+ start_time, query_timeout_mins, config_file_path=config_file_path
871
+ )
872
+ solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
873
+
874
+ print("Loading and extracting solver results...")
875
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
876
+ start_time, query_timeout_mins, config_file_path=config_file_path
877
+ )
878
+ self._import_solver_results_from_csv(model_id, executor, prefix_lowercase, remaining_timeout_minutes)
879
+
880
+ else: # protobuf format
881
+ # Protobuf format: model and results are exchanged via binary protobuf
882
+ input_id = uuid.uuid4()
883
+ model_uri = f"snowflake://APP_STATE.RAI_INTERNAL_STAGE/job-inputs/solver/{input_id}/model.binpb"
884
+ sf_input_uri = f"snowflake://job-inputs/solver/{input_id}/model.binpb"
885
+ payload["model_uri"] = sf_input_uri
886
+
887
+ print("Export model...")
888
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
889
+ start_time, query_timeout_mins, config_file_path=config_file_path
890
+ )
891
+ self._export_model_to_protobuf(model_uri, executor, prefix_lowercase, remaining_timeout_minutes)
892
+
893
+ print("Execute solver job...")
894
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
895
+ start_time, query_timeout_mins, config_file_path=config_file_path
896
+ )
897
+ job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
898
+
899
+ print("Extract result...")
900
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
901
+ start_time, query_timeout_mins, config_file_path=config_file_path
902
+ )
903
+ self._import_solver_results_from_protobuf(job_id, executor, remaining_timeout_minutes)
904
+
905
+ print("Finished solve")
656
906
  print()
907
+ return None
657
908
 
658
909
  def load_point(self, point_index: int) -> None:
659
910
  """Load a solution point.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: relationalai
3
- Version: 0.13.0
3
+ Version: 0.13.1
4
4
  Summary: RelationalAI Library and CLI
5
5
  Author-email: RelationalAI <support@relational.ai>
6
6
  License-File: LICENSE
@@ -11,7 +11,7 @@ Requires-Dist: colorama
11
11
  Requires-Dist: cryptography
12
12
  Requires-Dist: gravis
13
13
  Requires-Dist: inquirerpy
14
- Requires-Dist: lqp==0.1.19
14
+ Requires-Dist: lqp==0.2.1
15
15
  Requires-Dist: nicegui==2.16.1
16
16
  Requires-Dist: numpy<2
17
17
  Requires-Dist: opentelemetry-api
@@ -4,7 +4,7 @@ relationalai/debugging.py,sha256=wqGly2Yji4ErdV9F_S1Bny35SiGqI3C7S-hap7B8yUs,119
4
4
  relationalai/dependencies.py,sha256=tL113efcISkJUiDXYHmRdU_usdD7gmee-VRHA7N4EFA,16574
5
5
  relationalai/docutils.py,sha256=1gVv9mk0ytdMB2W7_NvslJefmSQtTOg8LHTCDcGCjyE,1554
6
6
  relationalai/dsl.py,sha256=UJr93X8kwnnyUY-kjPzp_jhsp2pYBUnDfu8mhNXPNII,66116
7
- relationalai/errors.py,sha256=JeYaycLGFeT0jidSMERGkpN02zswW5TGrIV647robmw,95973
7
+ relationalai/errors.py,sha256=PaWzKsfWwjca7J0huILkiR515_PwJXP1Hlq6owOjL0Y,96828
8
8
  relationalai/metagen.py,sha256=o10PNvR_myr_61DC8g6lkB093bFo9qXGUkZKgKyfXiE,26821
9
9
  relationalai/metamodel.py,sha256=P1hliwHd1nYxbXON4LZeaYZD6T6pZm97HgmFBFrWyCk,32886
10
10
  relationalai/rel.py,sha256=ePmAXx4NxOdsPcHNHyGH3Jkp_cB3QzfKu5p_EQSHPh0,38293
@@ -22,6 +22,7 @@ relationalai/clients/__init__.py,sha256=LQ_yHsutRMpoW2mOTmOPGF8mrbP0OiV5E68t8uVw
22
22
  relationalai/clients/client.py,sha256=gk_V9KS7_MM2dLL2OCO7EPLHD9dsRwR6R-30SW8lDwU,35759
23
23
  relationalai/clients/config.py,sha256=hERaKjc3l4kd-kf0l-NUOHrWunCn8gmFWpuE0j3ScJg,24457
24
24
  relationalai/clients/direct_access_client.py,sha256=VGjQ7wzduxCo04BkxSZjlPAgqK-aBc32zIXcMfAzzSU,6436
25
+ relationalai/clients/exec_txn_poller.py,sha256=JbmrTvsWGwxM5fcZVjeJihQDqhMf-ySfHN3Jn0HGwG0,3108
25
26
  relationalai/clients/hash_util.py,sha256=pZVR1FX3q4G_19p_r6wpIR2tIM8_WUlfAR7AVZJjIYM,1495
26
27
  relationalai/clients/local.py,sha256=vo5ikSWg38l3xQAh9yL--4sMAj_T5Tn7YEZiw7TCH08,23504
27
28
  relationalai/clients/profile_polling.py,sha256=pUH7WKH4nYDD0SlQtg3wsWdj0K7qt6nZqUw8jTthCBs,2565
@@ -30,16 +31,16 @@ relationalai/clients/types.py,sha256=eNo6akcMTbnBFbBbHd5IgVeY-zuAgtXlOs8Bo1SWmVU
30
31
  relationalai/clients/util.py,sha256=NJC8fnrWHR01NydwESPSetIHRWf7jQJURYpaWJjmDyE,12311
31
32
  relationalai/clients/resources/__init__.py,sha256=pymn8gB86Q3C2bVoFei0KAL8pX_U04uDY9TE4TKzTBs,260
32
33
  relationalai/clients/resources/azure/azure.py,sha256=TDapfM5rLoHrPrXg5cUe827m3AO0gSqQjNid1VUlUFo,20631
33
- relationalai/clients/resources/snowflake/__init__.py,sha256=9VR-hSIw4ZSEWisKcWhNEcRVBmBfueXNCTOOfLt-8rs,871
34
+ relationalai/clients/resources/snowflake/__init__.py,sha256=Ofyf1RZu9GLQdvsjpHDUHEQHHVODb9vKYI4hMOxczH4,923
34
35
  relationalai/clients/resources/snowflake/cache_store.py,sha256=A-qd11wcwN3TkIqvlN0_iFUU3aEjJal3T2pqFBwkkzQ,3966
35
36
  relationalai/clients/resources/snowflake/cli_resources.py,sha256=xTIcCzvgbkxuNAEvzZoRpj0n-js0hZCK30q7IZXztbI,3252
36
- relationalai/clients/resources/snowflake/direct_access_resources.py,sha256=Xvh1e6TxUW2dTSS-9HadrfWVKrxNQ5GikqM4yjohJkM,29849
37
+ relationalai/clients/resources/snowflake/direct_access_resources.py,sha256=5LZIQCSiX_TT2tVHqbTBzJTBXZzHmQI7cxFOY6CDDBM,28871
37
38
  relationalai/clients/resources/snowflake/engine_state_handlers.py,sha256=SQBu4GfbyABU6xrEV-koivC-ubsVrfCBTF0FEQgJM5g,12054
38
39
  relationalai/clients/resources/snowflake/error_handlers.py,sha256=581G2xOihUoiPlucC_Z2FOzhKu_swdIc3uORd0yJQuA,8805
39
40
  relationalai/clients/resources/snowflake/export_procedure.py.jinja,sha256=00iLO2qmvJoqAeJUWt3bAsFDDnof7Ab2spggzelbwK4,10566
40
41
  relationalai/clients/resources/snowflake/resources_factory.py,sha256=4LGd4IQ6z8hGeGlO1TIjSFJEeUNHutaB7j9q1a9rYfQ,3385
41
- relationalai/clients/resources/snowflake/snowflake.py,sha256=kWs7W_jsf_iy2N0VO5Uqu9VHXsbH2rQghPn1s5-9GAg,133037
42
- relationalai/clients/resources/snowflake/use_index_poller.py,sha256=gMcfILcpD1-wkJ0C1rGZQSy1oZBq6Gm8x9N7jdgEkhA,48448
42
+ relationalai/clients/resources/snowflake/snowflake.py,sha256=SuFwAdhaFrd3JZFEjVJMVVFuLrC_T3dqXS27ABrXZAw,136973
43
+ relationalai/clients/resources/snowflake/use_index_poller.py,sha256=xF_9XAymO82YZcZOQqz0r9oZZTqvbWW-WjDTZm4-tFM,48859
43
44
  relationalai/clients/resources/snowflake/use_index_resources.py,sha256=69PNWHI_uf-Aw_evfwC6j8HLVdjhp84vs8hLkjnhwbg,6462
44
45
  relationalai/clients/resources/snowflake/util.py,sha256=BEnm1B1-nqqHdm41RNxblbb-zqXbtqEGGZmTdAYeN_M,13841
45
46
  relationalai/early_access/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -334,7 +335,7 @@ relationalai/semantics/lqp/rewrite/__init__.py,sha256=V9ERED9qdh4VvY9Ud_M8Zn8lhV
334
335
  relationalai/semantics/lqp/rewrite/annotate_constraints.py,sha256=b_Ly4_80dQpRzWbeLC72JVfxzhwOPBpiCdEqtBiEiwM,2310
335
336
  relationalai/semantics/lqp/rewrite/cdc.py,sha256=8By-BrsayAWYxvrOcoqQ9w0-Nj4KQnaRH8haiog0-4o,10400
336
337
  relationalai/semantics/lqp/rewrite/extract_common.py,sha256=ZRvmeYHN8JEkU-j3fRx1e0_JK-46n6NqhxtwZe6L10c,14690
337
- relationalai/semantics/lqp/rewrite/extract_keys.py,sha256=gC-AkDA4DpqpFkXYc3u-LExbK6e79SrSxs5G6-TEb58,21797
338
+ relationalai/semantics/lqp/rewrite/extract_keys.py,sha256=aw5xvsRN3556ncHWyG90jEuykw66c2cWV7n4QRsJh0Y,22150
338
339
  relationalai/semantics/lqp/rewrite/function_annotations.py,sha256=9ZzLASvXh_OgQ04eup0AyoMIh2HxWHkoRETLm1-XtWs,4660
339
340
  relationalai/semantics/lqp/rewrite/functional_dependencies.py,sha256=4oQcVQtAGDqY850B1bNszigQopf6y9Y_CaUyWx42PtM,12718
340
341
  relationalai/semantics/lqp/rewrite/quantify_vars.py,sha256=bOowgQ45zmP0HOhsTlE92WdVBCTXSkszcCYbPMeIibw,12004
@@ -370,7 +371,7 @@ relationalai/semantics/reasoners/graph/tests/README.md,sha256=XbauTzt6VA_YEOcrlZ
370
371
  relationalai/semantics/reasoners/optimization/__init__.py,sha256=lpavly1Qa3VKvLgrbpp-tsxY9hcqHL6buxuekgKPakw,2212
371
372
  relationalai/semantics/reasoners/optimization/common.py,sha256=V0c9eGHJKI-gt0X-q9o0bIkgBCdWFdjWq2NQITKwmXg,3124
372
373
  relationalai/semantics/reasoners/optimization/solvers_dev.py,sha256=lbw3c8Z6PlHRDm7TdAhICPShlGoab9uR_4uacMPvpBw,24493
373
- relationalai/semantics/reasoners/optimization/solvers_pb.py,sha256=ESwraHU9c4NCEVRZ16tnBZsUCmJg7lUhy-v0-GGq0qo,48000
374
+ relationalai/semantics/reasoners/optimization/solvers_pb.py,sha256=COJuNm5kyyAQULkc1YS-91op4y4doobZEaE_RBLZWwo,60144
374
375
  relationalai/semantics/rel/__init__.py,sha256=pMlVTC_TbQ45mP1LpzwFBBgPxpKc0H3uJDvvDXEWzvs,55
375
376
  relationalai/semantics/rel/builtins.py,sha256=kQToiELc4NnvCmXyFtu9CsGZNdTQtSzTB-nuyIfQcsM,1562
376
377
  relationalai/semantics/rel/compiler.py,sha256=pFkEbuPKVd8AI4tiklcv06LbNnK8KfoV4FwmY9Lrhqo,43044
@@ -443,7 +444,7 @@ relationalai/util/spans_file_handler.py,sha256=a0sDwDPBBvGsM6be2En3mId9sXpuJlXia
443
444
  relationalai/util/timeout.py,sha256=2o6BVNFnFc-B2j-i1pEkZcQbMRto9ps2emci0XwiA4I,783
444
445
  relationalai/util/tracing_handler.py,sha256=H919ETAxh7Z1tRz9x8m90qP51_264UunHAPw8Sr6x2g,1729
445
446
  relationalai_test_util/__init__.py,sha256=Io_9_IQXXnrUlaL7S1Ndv-4YHilNxy36LrL723MI7lw,118
446
- relationalai_test_util/fixtures.py,sha256=rNOd8HbguWQi0j63QGoh5iFHLFfM1JUBgDyf87dEsKY,9214
447
+ relationalai_test_util/fixtures.py,sha256=WhpPto6LtZTGWQhYV9Tmz9yJGoVAkYlGRrgM_7errm4,9455
447
448
  relationalai_test_util/snapshot.py,sha256=FeH2qYBzLxr2-9qs0yElPIgWUjm_SrzawB3Jgn-aSuE,9291
448
449
  relationalai_test_util/traceback.py,sha256=lD0qaEmCyO-7xg9CNf6IzwS-Q-sTS8N9YIv8RroAE50,3298
449
450
  frontend/debugger/dist/.gitignore,sha256=JAo-DTfS6GthQGP1NH6wLU-ZymwlTea4KHH_jZVTKn0,14
@@ -451,8 +452,8 @@ frontend/debugger/dist/index.html,sha256=0wIQ1Pm7BclVV1wna6Mj8OmgU73B9rSEGPVX-Wo
451
452
  frontend/debugger/dist/assets/favicon-Dy0ZgA6N.png,sha256=tPXOEhOrM4tJyZVJQVBO_yFgNAlgooY38ZsjyrFstgg,620
452
453
  frontend/debugger/dist/assets/index-Cssla-O7.js,sha256=MxgIGfdKQyBWgufck1xYggQNhW5nj6BPjCF6Wleo-f0,298886
453
454
  frontend/debugger/dist/assets/index-DlHsYx1V.css,sha256=21pZtAjKCcHLFjbjfBQTF6y7QmOic-4FYaKNmwdNZVE,60141
454
- relationalai-0.13.0.dist-info/METADATA,sha256=RsdnclgO5I9Akv8gxCW-DM5-8ohVCYowAwY-ulQmq6A,2562
455
- relationalai-0.13.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
456
- relationalai-0.13.0.dist-info/entry_points.txt,sha256=fo_oLFJih3PUgYuHXsk7RnCjBm9cqRNR--ab6DgI6-0,88
457
- relationalai-0.13.0.dist-info/licenses/LICENSE,sha256=pPyTVXFYhirkEW9VsnHIgUjT0Vg8_xsE6olrF5SIgpc,11343
458
- relationalai-0.13.0.dist-info/RECORD,,
455
+ relationalai-0.13.1.dist-info/METADATA,sha256=zO4iuv-8tLuz3rByVdd1X3MSdh5MU6XZzqzXcAvdf4k,2561
456
+ relationalai-0.13.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
457
+ relationalai-0.13.1.dist-info/entry_points.txt,sha256=fo_oLFJih3PUgYuHXsk7RnCjBm9cqRNR--ab6DgI6-0,88
458
+ relationalai-0.13.1.dist-info/licenses/LICENSE,sha256=pPyTVXFYhirkEW9VsnHIgUjT0Vg8_xsE6olrF5SIgpc,11343
459
+ relationalai-0.13.1.dist-info/RECORD,,
@@ -28,7 +28,7 @@ def graph_index_config_fixture():
28
28
  yield config
29
29
  return
30
30
 
31
- def engine_config_fixture(size, use_direct_access=False, reset_session=False, generation=rai.Generation.V0):
31
+ def engine_config_fixture(size, use_direct_access=False, reset_session=False, query_timeout_mins=None, generation=rai.Generation.V0):
32
32
  # Check for an externally provided engine name
33
33
  # It is used in GitHub Actions to run tests against a specific engine
34
34
  engine_name = os.getenv("ENGINE_NAME")
@@ -44,6 +44,8 @@ def engine_config_fixture(size, use_direct_access=False, reset_session=False, ge
44
44
 
45
45
  config.set("reuse_model", False)
46
46
  config.set("enable_otel_handler", True)
47
+ if query_timeout_mins is not None:
48
+ config.set("query_timeout_mins", query_timeout_mins)
47
49
 
48
50
  yield config
49
51
  return
@@ -54,6 +56,8 @@ def engine_config_fixture(size, use_direct_access=False, reset_session=False, ge
54
56
  if config.file_path is not None:
55
57
  # Set test defaults
56
58
  config.set("reuse_model", False)
59
+ if query_timeout_mins is not None:
60
+ config.set("query_timeout_mins", query_timeout_mins)
57
61
  # Try to reset the session instead of using active session if reset_session is true
58
62
  rai.Resources(config=config, reset_session=reset_session, generation=generation)
59
63
  yield config