relationalai 1.0.0a2__py3-none-any.whl → 1.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. relationalai/config/shims.py +1 -0
  2. relationalai/semantics/__init__.py +7 -1
  3. relationalai/semantics/frontend/base.py +19 -13
  4. relationalai/semantics/frontend/core.py +30 -2
  5. relationalai/semantics/frontend/front_compiler.py +38 -11
  6. relationalai/semantics/frontend/pprint.py +1 -1
  7. relationalai/semantics/metamodel/rewriter.py +6 -2
  8. relationalai/semantics/metamodel/typer.py +70 -26
  9. relationalai/semantics/reasoners/__init__.py +11 -0
  10. relationalai/semantics/reasoners/graph/__init__.py +38 -0
  11. relationalai/semantics/reasoners/graph/core.py +9015 -0
  12. relationalai/shims/executor.py +4 -1
  13. relationalai/shims/hoister.py +9 -0
  14. relationalai/shims/mm2v0.py +47 -34
  15. relationalai/tools/cli/cli.py +138 -0
  16. relationalai/tools/cli/docs.py +394 -0
  17. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/METADATA +5 -3
  18. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/RECORD +57 -43
  19. v0/relationalai/__init__.py +69 -22
  20. v0/relationalai/clients/__init__.py +15 -2
  21. v0/relationalai/clients/client.py +4 -4
  22. v0/relationalai/clients/exec_txn_poller.py +91 -0
  23. v0/relationalai/clients/local.py +5 -5
  24. v0/relationalai/clients/resources/__init__.py +8 -0
  25. v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
  26. v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
  27. v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
  28. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +717 -0
  29. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
  30. v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
  31. v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
  32. v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +642 -1399
  33. v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +51 -12
  34. v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
  35. v0/relationalai/clients/resources/snowflake/util.py +387 -0
  36. v0/relationalai/early_access/dsl/ir/executor.py +4 -4
  37. v0/relationalai/early_access/dsl/snow/api.py +2 -1
  38. v0/relationalai/errors.py +18 -0
  39. v0/relationalai/experimental/solvers.py +7 -7
  40. v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
  41. v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
  42. v0/relationalai/semantics/internal/snowflake.py +1 -1
  43. v0/relationalai/semantics/lqp/executor.py +7 -12
  44. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  45. v0/relationalai/semantics/metamodel/util.py +6 -5
  46. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +335 -84
  47. v0/relationalai/semantics/rel/executor.py +14 -11
  48. v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
  49. v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
  50. v0/relationalai/tools/cli.py +26 -30
  51. v0/relationalai/tools/cli_helpers.py +10 -2
  52. v0/relationalai/util/otel_configuration.py +2 -1
  53. v0/relationalai/util/otel_handler.py +1 -1
  54. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/WHEEL +0 -0
  55. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/entry_points.txt +0 -0
  56. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/top_level.txt +0 -0
  57. /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
@@ -1,10 +1,8 @@
1
1
  # pyright: reportUnusedExpression=false
2
2
  from __future__ import annotations
3
3
  import base64
4
- import decimal
5
4
  import importlib.resources
6
5
  import io
7
- from numbers import Number
8
6
  import re
9
7
  import json
10
8
  import time
@@ -14,15 +12,15 @@ import uuid
14
12
  import warnings
15
13
  import atexit
16
14
  import hashlib
15
+ from dataclasses import dataclass
17
16
 
18
-
19
- from v0.relationalai.auth.token_handler import TokenHandler
20
- from v0.relationalai.clients.use_index_poller import DirectUseIndexPoller, UseIndexPoller
17
+ from ....auth.token_handler import TokenHandler
18
+ from v0.relationalai.clients.exec_txn_poller import ExecTxnPoller, query_complete_message
21
19
  import snowflake.snowpark
22
20
 
23
- from v0.relationalai.rel_utils import sanitize_identifier, to_fqn_relation_name
24
- from v0.relationalai.tools.constants import FIELD_PLACEHOLDER, RAI_APP_NAME, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, USE_DIRECT_ACCESS, DEFAULT_QUERY_TIMEOUT_MINS, WAIT_FOR_STREAM_SYNC, Generation
25
- from .. import std
21
+ from ....rel_utils import sanitize_identifier, to_fqn_relation_name
22
+ from ....tools.constants import FIELD_PLACEHOLDER, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, DEFAULT_QUERY_TIMEOUT_MINS, WAIT_FOR_STREAM_SYNC, Generation
23
+ from .... import std
26
24
  from collections import defaultdict
27
25
  import requests
28
26
  import snowflake.connector
@@ -30,33 +28,63 @@ import pyarrow as pa
30
28
 
31
29
  from snowflake.snowpark import Session
32
30
  from snowflake.snowpark.context import get_active_session
33
- from . import result_helpers
34
- from .. import debugging
35
- from typing import Any, Dict, Iterable, Optional, Tuple, List, Literal, Union, cast
31
+ from ... import result_helpers
32
+ from .... import debugging
33
+ from typing import Any, Dict, Iterable, Tuple, List, Literal, cast
36
34
 
37
35
  from pandas import DataFrame
38
36
 
39
- from ..tools.cli_controls import Spinner
40
- from ..clients.types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
41
- from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
42
- from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
43
- from ..clients.direct_access_client import DirectAccessClient
44
- from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp, normalize_datetime
45
- from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
46
- from .. import dsl, rel, metamodel as m
47
- from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
37
+ from ....tools.cli_controls import Spinner
38
+ from ...types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
39
+ from ...config import Config
40
+ from ...client import Client, ExportParams, ProviderBase, ResourcesBase
41
+ from ...util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, normalize_datetime
42
+ from .util import (
43
+ collect_error_messages,
44
+ process_jinja_template,
45
+ type_to_sql,
46
+ type_to_snowpark,
47
+ sanitize_user_name as _sanitize_user_name,
48
+ normalize_params,
49
+ format_sproc_name,
50
+ is_azure_url,
51
+ is_container_runtime,
52
+ imports_to_dicts,
53
+ txn_list_to_dicts,
54
+ decrypt_artifact,
55
+ )
56
+ from ....environments import runtime_env, HexEnvironment, SnowbookEnvironment
57
+ from .... import dsl, rel, metamodel as m
58
+ from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, GuardRailsException, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
48
59
  from concurrent.futures import ThreadPoolExecutor
49
- from datetime import datetime, date, timedelta
60
+ from datetime import datetime, timedelta
50
61
  from snowflake.snowpark.types import StringType, StructField, StructType
62
+ # Import error handlers and constants
63
+ from .error_handlers import (
64
+ ErrorHandler,
65
+ DuoSecurityErrorHandler,
66
+ AppMissingErrorHandler,
67
+ DatabaseErrorsHandler,
68
+ EngineErrorsHandler,
69
+ ServiceNotStartedErrorHandler,
70
+ TransactionAbortedErrorHandler,
71
+ )
72
+ # Import engine state handlers
73
+ from .engine_state_handlers import (
74
+ EngineStateHandler,
75
+ EngineContext,
76
+ SyncPendingStateHandler,
77
+ SyncSuspendedStateHandler,
78
+ SyncReadyStateHandler,
79
+ SyncGoneStateHandler,
80
+ SyncMissingEngineHandler,
81
+ AsyncPendingStateHandler,
82
+ AsyncSuspendedStateHandler,
83
+ AsyncReadyStateHandler,
84
+ AsyncGoneStateHandler,
85
+ AsyncMissingEngineHandler,
86
+ )
51
87
 
52
- # warehouse-based snowflake notebooks currently don't have hazmat
53
- crypto_disabled = False
54
- try:
55
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
56
- from cryptography.hazmat.backends import default_backend
57
- from cryptography.hazmat.primitives import padding
58
- except (ModuleNotFoundError, ImportError):
59
- crypto_disabled = True
60
88
 
61
89
  #--------------------------------------------------
62
90
  # Constants
@@ -66,225 +94,28 @@ VALID_POOL_STATUS = ["ACTIVE", "IDLE", "SUSPENDED"]
66
94
  # transaction list and get return different fields (duration vs timings)
67
95
  LIST_TXN_SQL_FIELDS = ["id", "database_name", "engine_name", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "duration"]
68
96
  GET_TXN_SQL_FIELDS = ["id", "database", "engine", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "timings"]
69
- IMPORT_STREAM_FIELDS = ["ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS", "FQ_OBJECT_NAME", "RAI_DATABASE",
70
- "RAI_RELATION", "DATA_SYNC_STATUS", "PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
71
- "NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"]
72
97
  VALID_ENGINE_STATES = ["READY", "PENDING"]
73
98
 
74
99
  # Cloud-specific engine sizes
75
100
  INTERNAL_ENGINE_SIZES = ["XS", "S", "M", "L"]
76
101
  ENGINE_SIZES_AWS = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_L"]
77
102
  ENGINE_SIZES_AZURE = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_SL"]
78
-
79
- FIELD_MAP = {
80
- "database_name": "database",
81
- "engine_name": "engine",
82
- }
83
- VALID_IMPORT_STATES = ["PENDING", "PROCESSING", "QUARANTINED", "LOADED"]
84
- ENGINE_ERRORS = ["engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted"]
85
- ENGINE_NOT_READY_MSGS = ["engine is in pending", "engine is provisioning"]
86
- DATABASE_ERRORS = ["database not found"]
103
+ # Note: ENGINE_ERRORS, ENGINE_NOT_READY_MSGS, DATABASE_ERRORS moved to util.py
87
104
  PYREL_ROOT_DB = 'pyrel_root_db'
88
105
 
89
106
  TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
90
107
 
91
- DUO_TEXT = "duo security"
92
-
93
108
  TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
109
+ GUARDRAILS_ABORT_REASON = "guard rail violation"
110
+
111
+ PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
94
112
 
95
113
  #--------------------------------------------------
96
114
  # Helpers
97
115
  #--------------------------------------------------
98
116
 
99
- def process_jinja_template(template: str, indent_spaces = 0, **substitutions) -> str:
100
- """Process a Jinja-like template.
101
-
102
- Supports:
103
- - Variable substitution {{ var }}
104
- - Conditional blocks {% if condition %} ... {% endif %}
105
- - For loops {% for item in items %} ... {% endfor %}
106
- - Comments {# ... #}
107
- - Whitespace control with {%- and -%}
108
-
109
- Args:
110
- template: The template string
111
- indent_spaces: Number of spaces to indent the result
112
- **substitutions: Variable substitutions
113
- """
114
-
115
- def evaluate_condition(condition: str, context: dict) -> bool:
116
- """Safely evaluate a condition string using the context."""
117
- # Replace variables with their values
118
- for k, v in context.items():
119
- if isinstance(v, str):
120
- condition = condition.replace(k, f"'{v}'")
121
- else:
122
- condition = condition.replace(k, str(v))
123
- try:
124
- return bool(eval(condition, {"__builtins__": {}}, {}))
125
- except Exception:
126
- return False
127
-
128
- def process_expression(expr: str, context: dict) -> str:
129
- """Process a {{ expression }} block."""
130
- expr = expr.strip()
131
- if expr in context:
132
- return str(context[expr])
133
- return ""
134
-
135
- def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
136
- """Process a block of template lines recursively."""
137
- result = []
138
- i = 0
139
- while i < len(lines):
140
- line = lines[i]
141
-
142
- # Handle comments
143
- line = re.sub(r'{#.*?#}', '', line)
144
-
145
- # Handle if blocks
146
- if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
147
- if if_match:
148
- condition = if_match.group(1)
149
- if_block = []
150
- else_block = []
151
- i += 1
152
- nesting = 1
153
- in_else_block = False
154
- while i < len(lines) and nesting > 0:
155
- if re.search(r'{%\s*if\s+', lines[i]):
156
- nesting += 1
157
- elif re.search(r'{%\s*endif\s*%}', lines[i]):
158
- nesting -= 1
159
- elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
160
- in_else_block = True
161
- i += 1
162
- continue
163
-
164
- if nesting > 0:
165
- if in_else_block:
166
- else_block.append(lines[i])
167
- else:
168
- if_block.append(lines[i])
169
- i += 1
170
- if evaluate_condition(condition, context):
171
- result.extend(process_block(if_block, context, indent))
172
- else:
173
- result.extend(process_block(else_block, context, indent))
174
- continue
175
-
176
- # Handle for loops
177
- for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
178
- if for_match:
179
- var_name, iterable_name = for_match.groups()
180
- for_block = []
181
- i += 1
182
- nesting = 1
183
- while i < len(lines) and nesting > 0:
184
- if re.search(r'{%\s*for\s+', lines[i]):
185
- nesting += 1
186
- elif re.search(r'{%\s*endfor\s*%}', lines[i]):
187
- nesting -= 1
188
- if nesting > 0:
189
- for_block.append(lines[i])
190
- i += 1
191
- if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
192
- for item in context[iterable_name]:
193
- loop_context = dict(context)
194
- loop_context[var_name] = item
195
- result.extend(process_block(for_block, loop_context, indent))
196
- continue
197
-
198
- # Handle variable substitution
199
- line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
200
-
201
- # Handle whitespace control
202
- line = re.sub(r'{%-', '{%', line)
203
- line = re.sub(r'-%}', '%}', line)
204
-
205
- # Add line with proper indentation, preserving blank lines
206
- if line.strip():
207
- result.append(" " * (indent_spaces + indent) + line)
208
- else:
209
- result.append("")
210
-
211
- i += 1
212
-
213
- return result
214
-
215
- # Split template into lines and process
216
- lines = template.split('\n')
217
- processed_lines = process_block(lines, substitutions)
218
-
219
- return '\n'.join(processed_lines)
220
-
221
- def type_to_sql(type) -> str:
222
- if type is str:
223
- return "VARCHAR"
224
- if type is int:
225
- return "NUMBER"
226
- if type is Number:
227
- return "DECIMAL(38, 15)"
228
- if type is float:
229
- return "FLOAT"
230
- if type is decimal.Decimal:
231
- return "DECIMAL(38, 15)"
232
- if type is bool:
233
- return "BOOLEAN"
234
- if type is dict:
235
- return "VARIANT"
236
- if type is list:
237
- return "ARRAY"
238
- if type is bytes:
239
- return "BINARY"
240
- if type is datetime:
241
- return "TIMESTAMP"
242
- if type is date:
243
- return "DATE"
244
- if isinstance(type, dsl.Type):
245
- return "VARCHAR"
246
- raise ValueError(f"Unknown type {type}")
247
-
248
- def type_to_snowpark(type) -> str:
249
- if type is str:
250
- return "StringType()"
251
- if type is int:
252
- return "IntegerType()"
253
- if type is float:
254
- return "FloatType()"
255
- if type is Number:
256
- return "DecimalType(38, 15)"
257
- if type is decimal.Decimal:
258
- return "DecimalType(38, 15)"
259
- if type is bool:
260
- return "BooleanType()"
261
- if type is dict:
262
- return "MapType()"
263
- if type is list:
264
- return "ArrayType()"
265
- if type is bytes:
266
- return "BinaryType()"
267
- if type is datetime:
268
- return "TimestampType()"
269
- if type is date:
270
- return "DateType()"
271
- if isinstance(type, dsl.Type):
272
- return "StringType()"
273
- raise ValueError(f"Unknown type {type}")
274
-
275
- def _sanitize_user_name(user: str) -> str:
276
- # Extract the part before the '@'
277
- sanitized_user = user.split('@')[0]
278
- # Replace any character that is not a letter, number, or underscore with '_'
279
- sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
280
- return sanitized_user
281
-
282
- def _is_engine_issue(response_message: str) -> bool:
283
- return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
284
-
285
- def _is_database_issue(response_message: str) -> bool:
286
- return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
287
-
117
+ def should_print_txn_progress(config) -> bool:
118
+ return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
288
119
 
289
120
  #--------------------------------------------------
290
121
  # Resources
@@ -292,6 +123,25 @@ def _is_database_issue(response_message: str) -> bool:
292
123
 
293
124
  APP_NAME = "___RAI_APP___"
294
125
 
126
+ @dataclass
127
+ class ExecContext:
128
+ """Execution context for SQL queries, containing all parameters needed for execution and retry."""
129
+ code: str
130
+ params: List[Any] | None = None
131
+ raw: bool = False
132
+ help: bool = True
133
+ skip_engine_db_error_retry: bool = False
134
+
135
+ def re_execute(self, resources: 'Resources') -> Any:
136
+ """Re-execute this context's query using the provided resources instance."""
137
+ return resources._exec(
138
+ code=self.code,
139
+ params=self.params,
140
+ raw=self.raw,
141
+ help=self.help,
142
+ skip_engine_db_error_retry=self.skip_engine_db_error_retry
143
+ )
144
+
295
145
  class Resources(ResourcesBase):
296
146
  def __init__(
297
147
  self,
@@ -301,7 +151,7 @@ class Resources(ResourcesBase):
301
151
  dry_run: bool = False,
302
152
  reset_session: bool = False,
303
153
  generation: Generation | None = None,
304
- language: str = "rel",
154
+ language: str = "rel", # Accepted for backward compatibility, but not stored in base class
305
155
  ):
306
156
  super().__init__(profile, config=config)
307
157
  self._token_handler: TokenHandler | None = None
@@ -319,16 +169,97 @@ class Resources(ResourcesBase):
319
169
  # self.sources contains fully qualified Snowflake table/view names
320
170
  self.sources: set[str] = set()
321
171
  self._sproc_models = None
322
- self.database = ""
172
+ # Store language for backward compatibility (used by child classes for use_index polling)
323
173
  self.language = language
174
+ # Register error and state handlers
175
+ self._register_handlers()
176
+ # Register atexit callback to cancel pending transactions
324
177
  atexit.register(self.cancel_pending_transactions)
325
178
 
179
+ #--------------------------------------------------
180
+ # Initialization & Properties
181
+ #--------------------------------------------------
182
+
183
+ def _register_handlers(self) -> None:
184
+ """Register error and engine state handlers for processing."""
185
+ # Register base handlers using getter methods that subclasses can override
186
+ # Use defensive copying to ensure each instance has its own handler lists
187
+ # and prevent cross-instance contamination from subclass mutations
188
+ self._error_handlers = list(self._get_error_handlers())
189
+ self._sync_engine_state_handlers = list(self._get_engine_state_handlers(is_async=False))
190
+ self._async_engine_state_handlers = list(self._get_engine_state_handlers(is_async=True))
191
+
192
+ def _get_error_handlers(self) -> list[ErrorHandler]:
193
+ """Get list of error handlers. Subclasses can override to add custom handlers.
194
+
195
+ Returns:
196
+ List of error handlers for standard error processing using Strategy Pattern.
197
+
198
+ Example:
199
+ def _get_error_handlers(self) -> list[ErrorHandler]:
200
+ # Get base handlers
201
+ handlers = super()._get_error_handlers()
202
+ # Add custom handler
203
+ handlers.append(MyCustomErrorHandler())
204
+ return handlers
205
+ """
206
+ return [
207
+ DuoSecurityErrorHandler(),
208
+ AppMissingErrorHandler(),
209
+ DatabaseErrorsHandler(),
210
+ EngineErrorsHandler(),
211
+ ServiceNotStartedErrorHandler(),
212
+ TransactionAbortedErrorHandler(),
213
+ ]
214
+
215
+ def _get_engine_state_handlers(self, is_async: bool = False) -> list[EngineStateHandler]:
216
+ """Get list of engine state handlers. Subclasses can override.
217
+
218
+ Args:
219
+ is_async: If True, returns async handlers; if False, returns sync handlers.
220
+
221
+ Returns:
222
+ List of engine state handlers for processing engine states.
223
+
224
+ Example:
225
+ def _get_engine_state_handlers(self, is_async: bool = False) -> list[EngineStateHandler]:
226
+ # Get base handlers
227
+ handlers = super()._get_engine_state_handlers(is_async)
228
+ # Add custom handler
229
+ handlers.append(MyCustomStateHandler())
230
+ return handlers
231
+ """
232
+ if is_async:
233
+ return [
234
+ AsyncPendingStateHandler(),
235
+ AsyncSuspendedStateHandler(),
236
+ AsyncReadyStateHandler(),
237
+ AsyncGoneStateHandler(),
238
+ AsyncMissingEngineHandler(),
239
+ ]
240
+ else:
241
+ return [
242
+ SyncPendingStateHandler(),
243
+ SyncSuspendedStateHandler(),
244
+ SyncReadyStateHandler(),
245
+ SyncGoneStateHandler(),
246
+ SyncMissingEngineHandler(),
247
+ ]
248
+
326
249
  @property
327
250
  def token_handler(self) -> TokenHandler:
328
251
  if not self._token_handler:
329
252
  self._token_handler = TokenHandler.from_config(self.config)
330
253
  return self._token_handler
331
254
 
255
+ def reset(self):
256
+ """Reset the session."""
257
+ self._session = None
258
+
259
+ #--------------------------------------------------
260
+ # Session Management
261
+ #--------------------------------------------------
262
+
332
263
  def is_erp_running(self, app_name: str) -> bool:
333
264
  """Check if the ERP is running. The app.service_status() returns single row/column containing an array of JSON service status objects."""
334
265
  query = f"CALL {app_name}.app.service_status();"
@@ -426,7 +357,28 @@ class Resources(ResourcesBase):
426
357
  except Exception as e:
427
358
  raise e
428
359
 
360
+ #--------------------------------------------------
361
+ # Core Execution Methods
362
+ #--------------------------------------------------
363
+
429
364
  def _exec_sql(self, code: str, params: List[Any] | None, raw=False):
365
+ """
366
+ Lowest-level SQL execution method.
367
+
368
+ Directly executes SQL via the Snowflake session. This is the foundation
369
+ for all other execution methods. It:
370
+ - Replaces APP_NAME placeholder with actual app name
371
+ - Executes SQL with optional parameters
372
+ - Returns either raw session results or collected results
373
+
374
+ Args:
375
+ code: SQL code to execute (may contain APP_NAME placeholder)
376
+ params: Optional SQL parameters
377
+ raw: If True, return raw session results; if False, collect results
378
+
379
+ Returns:
380
+ Raw session results if raw=True, otherwise collected results
381
+ """
430
382
  assert self._session is not None
431
383
  sess_results = self._session.sql(
432
384
  code.replace(APP_NAME, self.get_app_name()),
@@ -444,85 +396,88 @@ class Resources(ResourcesBase):
444
396
  help: bool = True,
445
397
  skip_engine_db_error_retry: bool = False
446
398
  ) -> Any:
399
+ """
400
+ Mid-level SQL execution method with error handling.
401
+
402
+ This is the primary method for executing SQL queries. It wraps _exec_sql
403
+ with comprehensive error handling and parameter normalization. Used
404
+ extensively throughout the codebase for direct SQL operations like:
405
+ - SHOW commands (warehouses, databases, etc.)
406
+ - CALL statements to RAI app stored procedures
407
+ - Transaction management queries
408
+
409
+ The error handling flow:
410
+ 1. Normalizes parameters and creates execution context
411
+ 2. Calls _exec_sql to execute the query
412
+ 3. On error, uses standard error handling (Strategy Pattern), which subclasses
413
+ can influence via `_get_error_handlers()` or by overriding `_handle_standard_exec_errors()`
414
+
415
+ Args:
416
+ code: SQL code to execute
417
+ params: Optional SQL parameters (normalized to list if needed)
418
+ raw: If True, return raw session results; if False, collect results
419
+ help: If True, enable error handling; if False, raise errors immediately
420
+ skip_engine_db_error_retry: If True, skip use_index retry logic in error handlers
421
+
422
+ Returns:
423
+ Query results (collected or raw depending on 'raw' parameter)
424
+ """
447
425
  # print(f"\n--- sql---\n{code}\n--- end sql---\n")
426
+ # Ensure session is initialized
448
427
  if not self._session:
449
428
  self._session = self.get_sf_session()
450
429
 
430
+ # Normalize parameters
431
+ normalized_params = normalize_params(params)
432
+
433
+ # Create execution context
434
+ ctx = ExecContext(
435
+ code=code,
436
+ params=normalized_params,
437
+ raw=raw,
438
+ help=help,
439
+ skip_engine_db_error_retry=skip_engine_db_error_retry
440
+ )
441
+
442
+ # Execute SQL
451
443
  try:
452
- if params is not None and not isinstance(params, list):
453
- params = cast(List[Any], [params])
454
- return self._exec_sql(code, params, raw=raw)
444
+ return self._exec_sql(ctx.code, ctx.params, raw=ctx.raw)
455
445
  except Exception as e:
456
- if not help:
446
+ if not ctx.help:
457
447
  raise e
458
- orig_message = str(e).lower()
459
- rai_app = self.config.get("rai_app_name", "")
460
- current_role = self.config.get("role")
461
- engine = self.get_default_engine_name()
462
- engine_size = self.config.get_default_engine_size()
463
- assert isinstance(rai_app, str), f"rai_app_name must be a string, not {type(rai_app)}"
464
- assert isinstance(engine, str), f"engine must be a string, not {type(engine)}"
465
- print("\n")
466
- if DUO_TEXT in orig_message:
467
- raise DuoSecurityFailed(e)
468
- if re.search(f"database '{rai_app}' does not exist or not authorized.".lower(), orig_message):
469
- exception = SnowflakeAppMissingException(rai_app, current_role)
470
- raise exception from None
471
- # skip initializing the index if the query is a user transaction. exec_raw/exec_lqp will handle that case with the correct request headers.
472
- if (_is_engine_issue(orig_message) or _is_database_issue(orig_message)) and not skip_engine_db_error_retry:
473
- try:
474
- self._poll_use_index(
475
- app_name=self.get_app_name(),
476
- sources=self.sources,
477
- model=self.database,
478
- engine_name=engine,
479
- engine_size=engine_size
480
- )
481
- return self._exec(code, params, raw=raw, help=help)
482
- except EngineNameValidationException as e:
483
- raise EngineNameValidationException(engine) from e
484
- except Exception as e:
485
- raise EngineProvisioningFailed(engine, e) from e
486
- elif re.search(r"javascript execution error", orig_message):
487
- match = re.search(r"\"message\":\"(.*)\"", orig_message)
488
- if match:
489
- message = match.group(1)
490
- if "engine is in pending" in message or "engine is provisioning" in message:
491
- raise EnginePending(engine)
492
- else:
493
- raise RAIException(message) from None
494
-
495
- if re.search(r"the relationalai service has not been started.", orig_message):
496
- app_name = self.config.get("rai_app_name", "")
497
- assert isinstance(app_name, str), f"rai_app_name must be a string, not {type(app_name)}"
498
- raise SnowflakeRaiAppNotStarted(app_name)
499
-
500
- if re.search(r"state:\s*aborted", orig_message):
501
- txn_id_match = re.search(r"id:\s*([0-9a-f\-]+)", orig_message)
502
- if txn_id_match:
503
- txn_id = txn_id_match.group(1)
504
- problems = self.get_transaction_problems(txn_id)
505
- if problems:
506
- for problem in problems:
507
- if isinstance(problem, dict):
508
- type_field = problem.get('TYPE')
509
- message_field = problem.get('MESSAGE')
510
- report_field = problem.get('REPORT')
511
- else:
512
- type_field = problem.TYPE
513
- message_field = problem.MESSAGE
514
- report_field = problem.REPORT
515
448
 
516
- raise RAIAbortedTransactionError(type_field, message_field, report_field)
517
- raise RAIException(str(e))
518
- raise RAIException(str(e))
449
+ # Handle standard errors
450
+ result = self._handle_standard_exec_errors(e, ctx)
451
+ if result is not None:
452
+ return result
519
453
 
454
+ #--------------------------------------------------
455
+ # Error Handling
456
+ #--------------------------------------------------
520
457
 
521
- def reset(self):
522
- self._session = None
458
+ def _handle_standard_exec_errors(self, e: Exception, ctx: ExecContext) -> Any | None:
459
+ """
460
+ Handle standard Snowflake/RAI errors using Strategy Pattern.
461
+
462
+ Each error type has a dedicated handler class that encapsulates
463
+ the detection logic and exception creation. Handlers are processed
464
+ in order until one matches and handles the error.
465
+ """
466
+ message = str(e).lower()
467
+
468
+ # Try each handler in order until one matches
469
+ for handler in self._error_handlers:
470
+ if handler.matches(e, message, ctx, self):
471
+ result = handler.handle(e, ctx, self)
472
+ if result is not None:
473
+ return result
474
+ return # Handler raised exception, we're done
475
+
476
+ # Fallback: transform to RAIException
477
+ raise RAIException(str(e))
523
478
 
524
479
  #--------------------------------------------------
525
- # Check direct access is enabled
480
+ # Feature Detection & Configuration
526
481
  #--------------------------------------------------
527
482
 
528
483
  def is_direct_access_enabled(self) -> bool:
@@ -544,9 +499,6 @@ class Resources(ResourcesBase):
544
499
  except Exception as e:
545
500
  raise Exception(f"Unable to determine if direct access is enabled. Details error: {e}") from e
546
501
 
547
- #--------------------------------------------------
548
- # Snowflake Account Flags
549
- #--------------------------------------------------
550
502
 
551
503
  def is_account_flag_set(self, flag: str) -> bool:
552
504
  results = self._exec(
@@ -566,7 +518,8 @@ class Resources(ResourcesBase):
566
518
  f"call {APP_NAME}.api.get_database('{database}');"
567
519
  )
568
520
  except Exception as e:
569
- if "Database does not exist" in str(e):
521
+ messages = collect_error_messages(e)
522
+ if any("database does not exist" in msg for msg in messages):
570
523
  return None
571
524
  raise e
572
525
 
@@ -590,10 +543,11 @@ class Resources(ResourcesBase):
590
543
  try:
591
544
  results = self._exec(query)
592
545
  except Exception as e:
593
- if "Database does not exist" in str(e):
546
+ messages = collect_error_messages(e)
547
+ if any("database does not exist" in msg for msg in messages):
594
548
  return None
595
549
  # fallback to None for old sql-lib versions
596
- if "Unknown user-defined function" in str(e):
550
+ if any("unknown user-defined function" in msg for msg in messages):
597
551
  return None
598
552
  raise e
599
553
 
@@ -610,6 +564,139 @@ class Resources(ResourcesBase):
610
564
  # Engines
611
565
  #--------------------------------------------------
612
566
 
567
+ def _prepare_engine_params(
568
+ self,
569
+ name: str | None,
570
+ size: str | None,
571
+ use_default_size: bool = False
572
+ ) -> tuple[str, str | None]:
573
+ """
574
+ Prepare engine parameters by resolving and validating name and size.
575
+
576
+ Args:
577
+ name: Engine name (None to use default)
578
+ size: Engine size (None to use config or default)
579
+ use_default_size: If True and size is None, use get_default_engine_size()
580
+
581
+ Returns:
582
+ Tuple of (engine_name, engine_size)
583
+
584
+ Raises:
585
+ EngineNameValidationException: If engine name is invalid
586
+ Exception: If engine size is invalid
587
+ """
588
+ from v0.relationalai.tools.cli_helpers import validate_engine_name
589
+
590
+ engine_name = name or self.get_default_engine_name()
591
+
592
+ # Resolve engine size
593
+ if size:
594
+ engine_size = size
595
+ else:
596
+ if use_default_size:
597
+ engine_size = self.config.get_default_engine_size()
598
+ else:
599
+ engine_size = self.config.get("engine_size", None)
600
+
601
+ # Validate engine size
602
+ if engine_size:
603
+ is_size_valid, sizes = self.validate_engine_size(engine_size)
604
+ if not is_size_valid:
605
+ error_msg = f"Invalid engine size '{engine_size}'. Valid sizes are: {', '.join(sizes)}"
606
+ if use_default_size:
607
+ error_msg = f"Invalid engine size in config: '{engine_size}'. Valid sizes are: {', '.join(sizes)}"
608
+ raise Exception(error_msg)
609
+
610
+ # Validate engine name
611
+ is_name_valid, _ = validate_engine_name(engine_name)
612
+ if not is_name_valid:
613
+ raise EngineNameValidationException(engine_name)
614
+
615
+ return engine_name, engine_size
616
+
617
+ def _get_state_handler(self, state: str | None, handlers: list[EngineStateHandler]) -> EngineStateHandler:
618
+ """Find the appropriate state handler for the given state."""
619
+ for handler in handlers:
620
+ if handler.handles_state(state):
621
+ return handler
622
+ # Fallback to missing engine handler if no match
623
+ return handlers[-1] # Last handler should be MissingEngineHandler
624
+
625
+ def _process_engine_state(
626
+ self,
627
+ engine: EngineState | Dict[str, Any] | None,
628
+ context: EngineContext,
629
+ handlers: list[EngineStateHandler],
630
+ set_active_on_success: bool = False
631
+ ) -> EngineState | Dict[str, Any] | None:
632
+ """
633
+ Process engine state using appropriate state handler.
634
+
635
+ Args:
636
+ engine: Current engine state (or None if missing)
637
+ context: Engine context for state handling
638
+ handlers: List of state handlers to use (sync or async)
639
+ set_active_on_success: If True, set engine as active when handler returns engine
640
+
641
+ Returns:
642
+ Engine state after processing, or None if engine needs to be created
643
+ """
644
+ # Find and execute appropriate state handler
645
+ state = engine["state"] if engine else None
646
+ handler = self._get_state_handler(state, handlers)
647
+ engine = handler.handle(engine, context, self)
648
+
649
+ # If handler returned None and we didn't start with None state, engine needs to be created
650
+ # (e.g., GONE state deleted the engine, so we need to create a new one)
651
+ if not engine and state is not None:
652
+ handler = self._get_state_handler(None, handlers)
653
+ handler.handle(None, context, self)
654
+ elif set_active_on_success:
655
+ # Cast to EngineState for type safety (handlers return EngineDict which is compatible)
656
+ self._set_active_engine(cast(EngineState, engine))
657
+
658
+ return engine
659
+
660
+ def _handle_engine_creation_errors(self, error: Exception, engine_name: str, preserve_rai_exception: bool = False) -> None:
661
+ """
662
+ Handle errors during engine creation using error handlers.
663
+
664
+ Args:
665
+ error: The exception that occurred
666
+ engine_name: Name of the engine being created
667
+ preserve_rai_exception: If True, re-raise RAIException without wrapping
668
+
669
+ Raises:
670
+ RAIException: If preserve_rai_exception is True and error is RAIException
671
+ EngineProvisioningFailed: If error is not handled by error handlers
672
+ """
673
+ # Preserve RAIException passthrough if requested (for async mode)
674
+ if preserve_rai_exception and isinstance(error, RAIException):
675
+ raise error
676
+
677
+ # Check if this is a known error type that should be handled by error handlers
678
+ message = str(error).lower()
679
+ handled = False
680
+ # Engine creation isn't tied to a specific SQL ExecContext; pass a context that
681
+ # disables use_index retry behavior (and any future ctx-dependent handlers).
682
+ ctx = ExecContext(code="", help=True, skip_engine_db_error_retry=True)
683
+ for handler in self._error_handlers:
684
+ if handler.matches(error, message, ctx, self):
685
+ handler.handle(error, ctx, self)
686
+ handled = True
687
+ break # Handler raised exception, we're done
688
+
689
+ # If not handled by error handlers, wrap in EngineProvisioningFailed
690
+ if not handled:
691
+ raise EngineProvisioningFailed(engine_name, error) from error
692
+
693
+ def validate_engine_size(self, size: str) -> Tuple[bool, List[str]]:
694
+ if size is not None:
695
+ sizes = self.get_engine_sizes()
696
+ if size not in sizes:
697
+ return False, sizes
698
+ return True, []
699
+
613
700
  def get_engine_sizes(self, cloud_provider: str|None=None):
614
701
  sizes = []
615
702
  if cloud_provider is None:
@@ -812,19 +899,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
812
899
  engine_size: str | None = None,
813
900
  program_span_id: str | None = None,
814
901
  headers: Dict | None = None,
815
- ):
816
- return UseIndexPoller(
817
- self,
818
- app_name,
819
- sources,
820
- model,
821
- engine_name,
822
- engine_size,
823
- self.language,
824
- program_span_id,
825
- headers,
826
- self.generation
827
- ).poll()
902
+ ) -> None:
903
+ """
904
+ Poll use_index to prepare indices for the given sources.
905
+
906
+ This is an optional interface method. Base Resources provides a no-op implementation.
907
+ UseIndexResources and DirectAccessResources override this to provide actual polling.
908
+
909
+ Returns:
910
+ None for base implementation. Child classes may return poller results.
911
+ """
912
+ return None
828
913
 
829
914
  def maybe_poll_use_index(
830
915
  self,
@@ -835,36 +920,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
835
920
  engine_size: str | None = None,
836
921
  program_span_id: str | None = None,
837
922
  headers: Dict | None = None,
838
- ):
839
- """Only call poll() if there are sources to process and cache is not valid."""
840
- sources_list = list(sources)
841
- self.database = model
842
- if sources_list:
843
- poller = UseIndexPoller(
844
- self,
845
- app_name,
846
- sources_list,
847
- model,
848
- engine_name,
849
- engine_size,
850
- self.language,
851
- program_span_id,
852
- headers,
853
- self.generation
854
- )
855
- # If cache is valid (data freshness has not expired), skip polling
856
- if poller.cache.is_valid():
857
- cached_sources = len(poller.cache.sources)
858
- total_sources = len(sources_list)
859
- cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
860
-
861
- message = f"Using cached data for {cached_sources}/{total_sources} data streams"
862
- if cached_timestamp:
863
- print(f"\n{message} (cached at {cached_timestamp})\n")
864
- else:
865
- print(f"\n{message}\n")
866
- else:
867
- return poller.poll()
923
+ ) -> None:
924
+ """
925
+ Only call _poll_use_index if there are sources to process.
926
+
927
+ This is an optional interface method. Base Resources provides a no-op implementation.
928
+ UseIndexResources and DirectAccessResources override this to provide actual polling with caching.
929
+
930
+ Returns:
931
+ None for base implementation. Child classes may return poller results.
932
+ """
933
+ return None
868
934
 
869
935
  #--------------------------------------------------
870
936
  # Models
@@ -902,11 +968,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
902
968
  def list_exports(self, database: str, engine: str):
903
969
  return []
904
970
 
905
- def format_sproc_name(self, name: str, type:Any) -> str:
906
- if type is datetime:
907
- return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
908
- else:
909
- return name
910
971
 
911
972
  def get_export_code(self, params: ExportParams, all_installs):
912
973
  sql_inputs = ", ".join([f"{name} {type_to_sql(type)}" for (name, _, type) in params.inputs])
@@ -928,15 +989,14 @@ Otherwise, remove it from your '{profile}' configuration profile.
928
989
  clean_inputs.append(f"{name} = '\"' + escape({name}) + '\"'")
929
990
  # Replace `var` with `name` and keep the following non-word character unchanged
930
991
  pattern = re.compile(re.escape(var) + r'(\W)')
931
- value = self.format_sproc_name(name, type)
992
+ value = format_sproc_name(name, type)
932
993
  safe_rel = re.sub(pattern, rf"{{{value}}}\1", safe_rel)
933
994
  if py_inputs:
934
995
  py_inputs = f", {py_inputs}"
935
996
  clean_inputs = ("\n").join(clean_inputs)
936
- assert __package__ is not None, "Package name must be set"
937
997
  file = "export_procedure.py.jinja"
938
998
  with importlib.resources.open_text(
939
- __package__, file
999
+ "relationalai.clients.resources.snowflake", file
940
1000
  ) as f:
941
1001
  template = f.read()
942
1002
  def quote(s: str, f = False) -> str:
@@ -1094,15 +1154,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
1094
1154
  # Imports
1095
1155
  #--------------------------------------------------
1096
1156
 
1097
- def is_valid_import_state(self, state:str):
1098
- return state in VALID_IMPORT_STATES
1099
-
1100
- def imports_to_dicts(self, results):
1101
- parsed_results = [
1102
- {field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
1103
- for row in results
1104
- ]
1105
- return parsed_results
1106
1157
 
1107
1158
  def change_stream_status(self, stream_id: str, model:str, suspend: bool):
1108
1159
  if stream_id and model:
@@ -1274,7 +1325,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1274
1325
  results = self._exec(f"CALL {APP_NAME}.api.get_data_stream('{name}', '{model}');")
1275
1326
  if not results:
1276
1327
  return None
1277
- return self.imports_to_dicts(results)
1328
+ return imports_to_dicts(results)
1278
1329
 
1279
1330
  def create_import_stream(self, source:ImportSource, model:str, rate = 1, options: dict|None = None):
1280
1331
  assert isinstance(source, ImportSourceTable), "Snowflake integration only supports loading from SF Tables. Try loading your data as a table via the Snowflake interface first."
@@ -1309,7 +1360,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
1309
1360
  try:
1310
1361
  self._exec(command)
1311
1362
  except Exception as e:
1312
- if "ensure that CHANGE_TRACKING is enabled on the source object" in str(e):
1363
+ messages = collect_error_messages(e)
1364
+ if any("ensure that change_tracking is enabled on the source object" in msg for msg in messages):
1313
1365
  if self.config.get("ensure_change_tracking", False) and not tracking_just_changed:
1314
1366
  try:
1315
1367
  self._exec(f"ALTER {kind} {object} SET CHANGE_TRACKING = TRUE;")
@@ -1320,7 +1372,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1320
1372
  print("\n")
1321
1373
  exception = SnowflakeChangeTrackingNotEnabledException((object, kind))
1322
1374
  raise exception from None
1323
- elif "Database does not exist" in str(e):
1375
+ elif any("database does not exist" in msg for msg in messages):
1324
1376
  print("\n")
1325
1377
  raise ModelNotFoundException(model) from None
1326
1378
  raise e
@@ -1370,55 +1422,22 @@ Otherwise, remove it from your '{profile}' configuration profile.
1370
1422
  if txn_id in self._pending_transactions:
1371
1423
  self._pending_transactions.remove(txn_id)
1372
1424
 
1373
- if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1374
- config_file_path = getattr(self.config, 'file_path', None)
1375
- # todo: use the timeout returned alongside the transaction as soon as it's exposed
1376
- timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1377
- raise QueryTimeoutExceededException(
1378
- timeout_mins=timeout_mins,
1379
- query_id=txn_id,
1380
- config_file_path=config_file_path,
1381
- )
1425
+ if status == "ABORTED":
1426
+ if response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1427
+ config_file_path = getattr(self.config, 'file_path', None)
1428
+ # todo: use the timeout returned alongside the transaction as soon as it's exposed
1429
+ timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1430
+ raise QueryTimeoutExceededException(
1431
+ timeout_mins=timeout_mins,
1432
+ query_id=txn_id,
1433
+ config_file_path=config_file_path,
1434
+ )
1435
+ elif response_row.get("ABORT_REASON", "") == GUARDRAILS_ABORT_REASON:
1436
+ raise GuardRailsException()
1382
1437
 
1383
1438
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
1384
1439
  return status == "COMPLETED" or status == "ABORTED"
1385
1440
 
1386
- def decrypt_stream(self, key: bytes, iv: bytes, src: bytes) -> bytes:
1387
- """Decrypt the provided stream with PKCS#5 padding handling."""
1388
-
1389
- if crypto_disabled:
1390
- if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
1391
- raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
1392
- else:
1393
- raise Exception("library `cryptography.hazmat` missing; please install")
1394
-
1395
- # `type:ignore`s are because of the conditional import, which
1396
- # we have because warehouse-based snowflake notebooks don't support
1397
- # the crypto library we're using.
1398
- cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
1399
- decryptor = cipher.decryptor()
1400
-
1401
- # Decrypt the data
1402
- decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
1403
-
1404
- # Unpad the decrypted data using PKCS#5
1405
- unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
1406
- unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
1407
-
1408
- return unpadded_data
1409
-
1410
- def _decrypt_artifact(self, data: bytes, encryption_material: str) -> bytes:
1411
- """Decrypts the artifact data using provided encryption material."""
1412
- encryption_material_parts = encryption_material.split("|")
1413
- assert len(encryption_material_parts) == 3, "Invalid encryption material"
1414
-
1415
- algorithm, key_base64, iv_base64 = encryption_material_parts
1416
- assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
1417
-
1418
- key = base64.standard_b64decode(key_base64)
1419
- iv = base64.standard_b64decode(iv_base64)
1420
-
1421
- return self.decrypt_stream(key, iv, data)
1422
1441
 
1423
1442
  def _list_exec_async_artifacts(self, txn_id: str, headers: Dict | None = None) -> Dict[str, Dict]:
1424
1443
  """Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
@@ -1470,7 +1489,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1470
1489
  if len(content) == 0:
1471
1490
  return b""
1472
1491
 
1473
- return self._decrypt_artifact(content, encryption_material)
1492
+ return decrypt_artifact(content, encryption_material)
1474
1493
 
1475
1494
  # otherwise, return content directly
1476
1495
  return content
@@ -1550,7 +1569,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1550
1569
  def get_url_key(self, metadata) -> str:
1551
1570
  # In Azure, there is only one type of URL, which is used for both internal and
1552
1571
  # external access; always use that one
1553
- if self.is_azure(metadata['PRESIGNED_URL']):
1572
+ if is_azure_url(metadata['PRESIGNED_URL']):
1554
1573
  return 'PRESIGNED_URL'
1555
1574
 
1556
1575
  configured = self.config.get("download_url_type", None)
@@ -1559,17 +1578,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
1559
1578
  elif configured == "external":
1560
1579
  return "PRESIGNED_URL"
1561
1580
 
1562
- if self.is_container_runtime():
1581
+ if is_container_runtime():
1563
1582
  return 'PRESIGNED_URL_AP'
1564
1583
 
1565
1584
  return 'PRESIGNED_URL'
1566
1585
 
1567
- def is_azure(self, url) -> bool:
1568
- return "blob.core.windows.net" in url
1569
-
1570
- def is_container_runtime(self) -> bool:
1571
- return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
1572
-
1573
1586
  def _exec_rai_app(
1574
1587
  self,
1575
1588
  database: str,
@@ -1583,6 +1596,40 @@ Otherwise, remove it from your '{profile}' configuration profile.
1583
1596
  language: str = "rel",
1584
1597
  query_timeout_mins: int | None = None,
1585
1598
  ):
1599
+ """
1600
+ High-level method to execute RAI app stored procedures.
1601
+
1602
+ Builds and executes SQL to call the RAI app's exec_async_v2 stored procedure.
1603
+ This method handles the SQL string construction for two different formats:
1604
+ 1. New format (with graph index): Uses object payload with parameterized query
1605
+ 2. Legacy format: Uses positional parameters
1606
+
1607
+ The choice between formats depends on the use_graph_index configuration.
1608
+ The new format allows the stored procedure to hash the model and username
1609
+ to determine the database, while the legacy format uses the passed database directly.
1610
+
1611
+ This method is called by _exec_async_v2 to create transactions. It skips
1612
+ use_index retry logic (skip_engine_db_error_retry=True) because that
1613
+ is handled at a higher level by exec_raw/exec_lqp.
1614
+
1615
+ Args:
1616
+ database: Database/model name
1617
+ engine: Engine name (optional)
1618
+ raw_code: Code to execute (REL, LQP, or SQL)
1619
+ inputs: Input parameters for the query
1620
+ readonly: Whether the transaction is read-only
1621
+ nowait_durable: Whether to wait for durable writes
1622
+ request_headers: Optional HTTP headers
1623
+ bypass_index: Whether to bypass graph index setup
1624
+ language: Query language ("rel" or "lqp")
1625
+ query_timeout_mins: Optional query timeout in minutes
1626
+
1627
+ Returns:
1628
+ Response from the stored procedure call (transaction creation result)
1629
+
1630
+ Raises:
1631
+ Exception: If transaction creation fails
1632
+ """
1586
1633
  assert language == "rel" or language == "lqp", "Only 'rel' and 'lqp' languages are supported"
1587
1634
  if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
1588
1635
  query_timeout_mins = int(timeout_value)
@@ -1635,12 +1682,43 @@ Otherwise, remove it from your '{profile}' configuration profile.
1635
1682
  query_timeout_mins: int | None = None,
1636
1683
  gi_setup_skipped: bool = False,
1637
1684
  ):
1685
+ """
1686
+ High-level async execution method with transaction polling and artifact management.
1687
+
1688
+ This is the core method for executing queries asynchronously. It:
1689
+ 1. Creates a transaction by calling _exec_rai_app
1690
+ 2. Handles two execution paths:
1691
+ - Fast path: Transaction completes immediately (COMPLETED/ABORTED)
1692
+ - Slow path: Transaction is pending, requires polling until completion
1693
+ 3. Manages pending transactions list
1694
+ 4. Downloads and returns query results/artifacts
1695
+
1696
+ This method is called by _execute_code (base implementation) and can be
1697
+ overridden by child classes (e.g., DirectAccessResources uses HTTP instead).
1698
+
1699
+ Args:
1700
+ database: Database/model name
1701
+ engine: Engine name (optional)
1702
+ raw_code: Code to execute (REL, LQP, or SQL)
1703
+ inputs: Input parameters for the query
1704
+ readonly: Whether the transaction is read-only
1705
+ nowait_durable: Whether to wait for durable writes
1706
+ headers: Optional HTTP headers
1707
+ bypass_index: Whether to bypass graph index setup
1708
+ language: Query language ("rel" or "lqp")
1709
+ query_timeout_mins: Optional query timeout in minutes
1710
+ gi_setup_skipped: Whether graph index setup was skipped (for retry logic)
1711
+
1712
+ Returns:
1713
+ Query results (downloaded artifacts)
1714
+ """
1638
1715
  if inputs is None:
1639
1716
  inputs = {}
1640
1717
  request_headers = debugging.add_current_propagation_headers(headers)
1641
1718
  query_attrs_dict = json.loads(request_headers.get("X-Query-Attributes", "{}"))
1642
1719
 
1643
1720
  with debugging.span("transaction", **query_attrs_dict) as txn_span:
1721
+ txn_start_time = time.time()
1644
1722
  with debugging.span("create_v2", **query_attrs_dict) as create_span:
1645
1723
  request_headers['user-agent'] = get_pyrel_version(self.generation)
1646
1724
  request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
@@ -1671,8 +1749,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
1671
1749
  create_span["txn_id"] = txn_id
1672
1750
  debugging.event("transaction_created", txn_span, txn_id=txn_id)
1673
1751
 
1752
+ print_txn_progress = should_print_txn_progress(self.config)
1753
+
1674
1754
  # fast path: transaction already finished
1675
1755
  if state in ["COMPLETED", "ABORTED"]:
1756
+ txn_end_time = time.time()
1676
1757
  if txn_id in self._pending_transactions:
1677
1758
  self._pending_transactions.remove(txn_id)
1678
1759
 
@@ -1681,13 +1762,24 @@ Otherwise, remove it from your '{profile}' configuration profile.
1681
1762
  filename = row['FILENAME']
1682
1763
  artifact_info[filename] = row
1683
1764
 
1765
+ txn_duration = txn_end_time - txn_start_time
1766
+ if print_txn_progress:
1767
+ print(
1768
+ query_complete_message(txn_id, txn_duration, status_header=True)
1769
+ )
1770
+
1684
1771
  # Slow path: transaction not done yet; start polling
1685
1772
  else:
1686
1773
  self._pending_transactions.append(txn_id)
1774
+ # Use the interactive poller for transaction status
1687
1775
  with debugging.span("wait", txn_id=txn_id):
1688
- poll_with_specified_overhead(
1689
- lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
1690
- )
1776
+ if print_txn_progress:
1777
+ poller = ExecTxnPoller(resource=self, txn_id=txn_id, headers=request_headers, txn_start_time=txn_start_time)
1778
+ poller.poll()
1779
+ else:
1780
+ poll_with_specified_overhead(
1781
+ lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
1782
+ )
1691
1783
  artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
1692
1784
 
1693
1785
  with debugging.span("fetch"):
@@ -1706,205 +1798,81 @@ Otherwise, remove it from your '{profile}' configuration profile.
1706
1798
  return engine and engine["state"] == "READY"
1707
1799
 
1708
1800
  def auto_create_engine(self, name: str | None = None, size: str | None = None, headers: Dict | None = None):
1709
- from v0.relationalai.tools.cli_helpers import validate_engine_name
1801
+ """Synchronously create/ensure an engine is ready, blocking until ready."""
1710
1802
  with debugging.span("auto_create_engine", active=self._active_engine) as span:
1711
1803
  active = self._get_active_engine()
1712
1804
  if active:
1713
1805
  return active
1714
1806
 
1715
- engine_name = name or self.get_default_engine_name()
1716
-
1717
- # Use the provided size or fall back to the config
1718
- if size:
1719
- engine_size = size
1720
- else:
1721
- engine_size = self.config.get("engine_size", None)
1722
-
1723
- # Validate engine size
1724
- if engine_size:
1725
- is_size_valid, sizes = self.validate_engine_size(engine_size)
1726
- if not is_size_valid:
1727
- raise Exception(f"Invalid engine size '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
1728
-
1729
- # Validate engine name
1730
- is_name_valid, _ = validate_engine_name(engine_name)
1731
- if not is_name_valid:
1732
- raise EngineNameValidationException(engine_name)
1807
+ # Resolve and validate parameters
1808
+ engine_name, engine_size = self._prepare_engine_params(name, size)
1733
1809
 
1734
1810
  try:
1811
+ # Get current engine state
1735
1812
  engine = self.get_engine(engine_name)
1736
1813
  if engine:
1737
1814
  span.update(cast(dict, engine))
1738
1815
 
1739
- # if engine is in the pending state, poll until its status changes
1740
- # if engine is gone, delete it and create new one
1741
- # if engine is in the ready state, return engine name
1742
- if engine:
1743
- if engine["state"] == "PENDING":
1744
- # if the user explicitly specified a size, warn if the pending engine size doesn't match it
1745
- if size is not None and engine["size"] != size:
1746
- EngineSizeMismatchWarning(engine_name, engine["size"], size)
1747
- # poll until engine is ready
1748
- with Spinner(
1749
- "Waiting for engine to be initialized",
1750
- "Engine ready",
1751
- ):
1752
- poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
1753
-
1754
- elif engine["state"] == "SUSPENDED":
1755
- with Spinner(f"Resuming engine '{engine_name}'", f"Engine '{engine_name}' resumed", f"Failed to resume engine '{engine_name}'"):
1756
- try:
1757
- self.resume_engine_async(engine_name, headers=headers)
1758
- poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
1759
- except Exception:
1760
- raise EngineResumeFailed(engine_name)
1761
- elif engine["state"] == "READY":
1762
- # if the user explicitly specified a size, warn if the ready engine size doesn't match it
1763
- if size is not None and engine["size"] != size:
1764
- EngineSizeMismatchWarning(engine_name, engine["size"], size)
1765
- self._set_active_engine(engine)
1766
- return engine_name
1767
- elif engine["state"] == "GONE":
1768
- try:
1769
- # "Gone" is abnormal condition when metadata and SF service don't match
1770
- # Therefore, we have to delete the engine and create a new one
1771
- # it could be case that engine is already deleted, so we have to catch the exception
1772
- self.delete_engine(engine_name, headers=headers)
1773
- # After deleting the engine, set it to None so that we can create a new engine
1774
- engine = None
1775
- except Exception as e:
1776
- # if engine is already deleted, we will get an exception
1777
- # we can ignore this exception and create a new engine
1778
- if isinstance(e, EngineNotFoundException):
1779
- engine = None
1780
- pass
1781
- else:
1782
- raise EngineProvisioningFailed(engine_name, e) from e
1783
-
1784
- if not engine:
1785
- with Spinner(
1786
- f"Auto-creating engine {engine_name}",
1787
- f"Auto-created engine {engine_name}",
1788
- "Engine creation failed",
1789
- ):
1790
- self.create_engine(engine_name, size=engine_size, headers=headers)
1816
+ # Create context for state handling
1817
+ context = EngineContext(
1818
+ engine_name=engine_name,
1819
+ engine_size=engine_size,
1820
+ headers=headers,
1821
+ requested_size=size,
1822
+ span=span,
1823
+ )
1824
+
1825
+ # Process engine state using sync handlers
1826
+ self._process_engine_state(engine, context, self._sync_engine_state_handlers)
1827
+
1791
1828
  except Exception as e:
1792
- print(e)
1793
- if DUO_TEXT in str(e).lower():
1794
- raise DuoSecurityFailed(e)
1795
- raise EngineProvisioningFailed(engine_name, e) from e
1829
+ self._handle_engine_creation_errors(e, engine_name)
1830
+
1796
1831
  return engine_name
1797
1832
 
1798
1833
  def auto_create_engine_async(self, name: str | None = None):
1834
+ """Asynchronously create/ensure an engine, returns immediately."""
1799
1835
  active = self._get_active_engine()
1800
1836
  if active and (active == name or name is None):
1801
- return # @NOTE: This method weirdly doesn't return engine name even though all the other ones do?
1837
+ return active
1802
1838
 
1803
1839
  with Spinner(
1804
1840
  "Checking engine status",
1805
1841
  leading_newline=True,
1806
1842
  ) as spinner:
1807
- from v0.relationalai.tools.cli_helpers import validate_engine_name
1808
1843
  with debugging.span("auto_create_engine_async", active=self._active_engine):
1809
- engine_name = name or self.get_default_engine_name()
1810
- engine_size = self.config.get("engine_size", None)
1811
- if engine_size:
1812
- is_size_valid, sizes = self.validate_engine_size(engine_size)
1813
- if not is_size_valid:
1814
- raise Exception(f"Invalid engine size in config: '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
1815
- else:
1816
- engine_size = self.config.get_default_engine_size()
1844
+ # Resolve and validate parameters (use_default_size=True for async)
1845
+ engine_name, engine_size = self._prepare_engine_params(name, None, use_default_size=True)
1817
1846
 
1818
- is_name_valid, _ = validate_engine_name(engine_name)
1819
- if not is_name_valid:
1820
- raise EngineNameValidationException(engine_name)
1821
1847
  try:
1848
+ # Get current engine state
1822
1849
  engine = self.get_engine(engine_name)
1823
- # if engine is gone, delete it and create new one
1824
- # in case of pending state, do nothing, it is use_index responsibility to wait for engine to be ready
1825
- if engine:
1826
- if engine["state"] == "PENDING":
1827
- spinner.update_messages(
1828
- {
1829
- "finished_message": f"Starting engine {engine_name}",
1830
- }
1831
- )
1832
- pass
1833
- elif engine["state"] == "SUSPENDED":
1834
- spinner.update_messages(
1835
- {
1836
- "finished_message": f"Resuming engine {engine_name}",
1837
- }
1838
- )
1839
- try:
1840
- self.resume_engine_async(engine_name)
1841
- except Exception:
1842
- raise EngineResumeFailed(engine_name)
1843
- elif engine["state"] == "READY":
1844
- spinner.update_messages(
1845
- {
1846
- "finished_message": f"Engine {engine_name} initialized",
1847
- }
1848
- )
1849
- pass
1850
- elif engine["state"] == "GONE":
1851
- spinner.update_messages(
1852
- {
1853
- "message": f"Restarting engine {engine_name}",
1854
- }
1855
- )
1856
- try:
1857
- # "Gone" is abnormal condition when metadata and SF service don't match
1858
- # Therefore, we have to delete the engine and create a new one
1859
- # it could be case that engine is already deleted, so we have to catch the exception
1860
- # set it to None so that we can create a new engine
1861
- engine = None
1862
- self.delete_engine(engine_name)
1863
- except Exception as e:
1864
- # if engine is already deleted, we will get an exception
1865
- # we can ignore this exception and create a new engine asynchronously
1866
- if isinstance(e, EngineNotFoundException):
1867
- engine = None
1868
- pass
1869
- else:
1870
- print(e)
1871
- raise EngineProvisioningFailed(engine_name, e) from e
1872
-
1873
- if not engine:
1874
- self.create_engine_async(engine_name, size=self.config.get("engine_size", None))
1875
- spinner.update_messages(
1876
- {
1877
- "finished_message": f"Starting engine {engine_name}...",
1878
- }
1879
- )
1880
- else:
1881
- self._set_active_engine(engine)
1882
1850
 
1883
- except Exception as e:
1884
- spinner.update_messages(
1885
- {
1886
- "finished_message": f"Failed to create engine {engine_name}",
1887
- }
1851
+ # Create context for state handling
1852
+ context = EngineContext(
1853
+ engine_name=engine_name,
1854
+ engine_size=engine_size,
1855
+ headers=None,
1856
+ requested_size=None,
1857
+ spinner=spinner,
1888
1858
  )
1889
- if DUO_TEXT in str(e).lower():
1890
- raise DuoSecurityFailed(e)
1891
- if isinstance(e, RAIException):
1892
- raise e
1893
- print(e)
1894
- raise EngineProvisioningFailed(engine_name, e) from e
1895
1859
 
1896
- def validate_engine_size(self, size: str) -> Tuple[bool, List[str]]:
1897
- if size is not None:
1898
- sizes = self.get_engine_sizes()
1899
- if size not in sizes:
1900
- return False, sizes
1901
- return True, []
1860
+ # Process engine state using async handlers
1861
+ self._process_engine_state(engine, context, self._async_engine_state_handlers, set_active_on_success=True)
1862
+
1863
+ except Exception as e:
1864
+ spinner.update_messages({
1865
+ "finished_message": f"Failed to create engine {engine_name}",
1866
+ })
1867
+ self._handle_engine_creation_errors(e, engine_name, preserve_rai_exception=True)
1868
+
1869
+ return engine_name
1902
1870
 
1903
1871
  #--------------------------------------------------
1904
1872
  # Exec
1905
1873
  #--------------------------------------------------
1906
1874
 
1907
- def _exec_with_gi_retry(
1875
+ def _execute_code(
1908
1876
  self,
1909
1877
  database: str,
1910
1878
  engine: str | None,
@@ -1916,39 +1884,40 @@ Otherwise, remove it from your '{profile}' configuration profile.
1916
1884
  bypass_index: bool,
1917
1885
  language: str,
1918
1886
  query_timeout_mins: int | None,
1919
- ):
1920
- """Execute with graph index retry logic.
1921
-
1922
- Attempts execution with gi_setup_skipped=True first. If an engine or database
1923
- issue occurs, polls use_index and retries with gi_setup_skipped=False.
1887
+ ) -> Any:
1924
1888
  """
1925
- try:
1926
- return self._exec_async_v2(
1927
- database, engine, raw_code, inputs, readonly, nowait_durable,
1928
- headers=headers, bypass_index=bypass_index, language=language,
1929
- query_timeout_mins=query_timeout_mins, gi_setup_skipped=True,
1930
- )
1931
- except Exception as e:
1932
- err_message = str(e).lower()
1933
- if _is_engine_issue(err_message) or _is_database_issue(err_message):
1934
- engine_name = engine or self.get_default_engine_name()
1935
- engine_size = self.config.get_default_engine_size()
1936
- self._poll_use_index(
1937
- app_name=self.get_app_name(),
1938
- sources=self.sources,
1939
- model=database,
1940
- engine_name=engine_name,
1941
- engine_size=engine_size,
1942
- headers=headers,
1943
- )
1944
-
1945
- return self._exec_async_v2(
1946
- database, engine, raw_code, inputs, readonly, nowait_durable,
1947
- headers=headers, bypass_index=bypass_index, language=language,
1948
- query_timeout_mins=query_timeout_mins, gi_setup_skipped=False,
1949
- )
1950
- else:
1951
- raise e
1889
+ Template method for code execution - can be overridden by child classes.
1890
+
1891
+ This is a template method that provides a hook for child classes to add
1892
+ execution logic (like retry mechanisms). The base implementation simply
1893
+ calls _exec_async_v2 directly.
1894
+
1895
+ UseIndexResources overrides this method to use _exec_with_gi_retry, which
1896
+ adds automatic use_index polling on engine/database errors.
1897
+
1898
+ This method is called by exec_lqp() and exec_raw() to provide a single
1899
+ execution point that can be customized per resource class.
1900
+
1901
+ Args:
1902
+ database: Database/model name
1903
+ engine: Engine name (optional)
1904
+ raw_code: Code to execute (already processed/encoded)
1905
+ inputs: Input parameters for the query
1906
+ readonly: Whether the transaction is read-only
1907
+ nowait_durable: Whether to wait for durable writes
1908
+ headers: Optional HTTP headers
1909
+ bypass_index: Whether to bypass graph index setup
1910
+ language: Query language ("rel" or "lqp")
1911
+ query_timeout_mins: Optional query timeout in minutes
1912
+
1913
+ Returns:
1914
+ Query results
1915
+ """
1916
+ return self._exec_async_v2(
1917
+ database, engine, raw_code, inputs, readonly, nowait_durable,
1918
+ headers=headers, bypass_index=bypass_index, language=language,
1919
+ query_timeout_mins=query_timeout_mins, gi_setup_skipped=True,
1920
+ )
1952
1921
 
1953
1922
  def exec_lqp(
1954
1923
  self,
@@ -1963,13 +1932,13 @@ Otherwise, remove it from your '{profile}' configuration profile.
1963
1932
  bypass_index=False,
1964
1933
  query_timeout_mins: int | None = None,
1965
1934
  ):
1935
+ """Execute LQP code."""
1966
1936
  raw_code_b64 = base64.b64encode(raw_code).decode("utf-8")
1967
- return self._exec_with_gi_retry(
1937
+ return self._execute_code(
1968
1938
  database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1969
1939
  headers, bypass_index, 'lqp', query_timeout_mins
1970
1940
  )
1971
1941
 
1972
-
1973
1942
  def exec_raw(
1974
1943
  self,
1975
1944
  database: str,
@@ -1983,8 +1952,9 @@ Otherwise, remove it from your '{profile}' configuration profile.
1983
1952
  bypass_index=False,
1984
1953
  query_timeout_mins: int | None = None,
1985
1954
  ):
1955
+ """Execute raw code."""
1986
1956
  raw_code = raw_code.replace("'", "\\'")
1987
- return self._exec_with_gi_retry(
1957
+ return self._execute_code(
1988
1958
  database, engine, raw_code, inputs, readonly, nowait_durable,
1989
1959
  headers, bypass_index, 'rel', query_timeout_mins
1990
1960
  )
@@ -2039,7 +2009,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2039
2009
  if query_timeout_mins is not None:
2040
2010
  res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
2041
2011
  else:
2042
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
2012
+ res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
2043
2013
  txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
2044
2014
  rejected_rows = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows", [])
2045
2015
  rejected_rows_count = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows_count", 0)
@@ -2076,8 +2046,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
2076
2046
  if rejected_rows:
2077
2047
  debugging.warn(RowsDroppedFromTargetTableWarning(rejected_rows, rejected_rows_count, col_names_map))
2078
2048
  except Exception as e:
2079
- msg = str(e).lower()
2080
- if "no columns returned" in msg or "columns of results could not be determined" in msg:
2049
+ messages = collect_error_messages(e)
2050
+ if any("no columns returned" in msg or "columns of results could not be determined" in msg for msg in messages):
2081
2051
  pass
2082
2052
  else:
2083
2053
  raise e
@@ -2108,6 +2078,10 @@ Otherwise, remove it from your '{profile}' configuration profile.
2108
2078
  self.sources.add(parser.identity)
2109
2079
  return ns
2110
2080
 
2081
+ #--------------------------------------------------
2082
+ # Source Management
2083
+ #--------------------------------------------------
2084
+
2111
2085
  def _check_source_updates(self, sources: Iterable[str]):
2112
2086
  if not sources:
2113
2087
  return {}
@@ -2370,19 +2344,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
2370
2344
  #--------------------------------------------------
2371
2345
  # Transactions
2372
2346
  #--------------------------------------------------
2373
- def txn_list_to_dicts(self, transactions):
2374
- dicts = []
2375
- for txn in transactions:
2376
- dict = {}
2377
- txn_dict = txn.asDict()
2378
- for key in txn_dict:
2379
- mapValue = FIELD_MAP.get(key.lower())
2380
- if mapValue:
2381
- dict[mapValue] = txn_dict[key]
2382
- else:
2383
- dict[key.lower()] = txn_dict[key]
2384
- dicts.append(dict)
2385
- return dicts
2386
2347
 
2387
2348
  def get_transaction(self, transaction_id):
2388
2349
  results = self._exec(
@@ -2390,7 +2351,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2390
2351
  if not results:
2391
2352
  return None
2392
2353
 
2393
- results = self.txn_list_to_dicts(results)
2354
+ results = txn_list_to_dicts(results)
2394
2355
 
2395
2356
  txn = {field: results[0][field] for field in GET_TXN_SQL_FIELDS}
2396
2357
 
@@ -2442,7 +2403,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2442
2403
  results = self._exec(query, [limit])
2443
2404
  if not results:
2444
2405
  return []
2445
- return self.txn_list_to_dicts(results)
2406
+ return txn_list_to_dicts(results)
2446
2407
 
2447
2408
  def cancel_transaction(self, transaction_id):
2448
2409
  self._exec(f"CALL {APP_NAME}.api.cancel_own_transaction(?);", [transaction_id])
@@ -2476,63 +2437,12 @@ Otherwise, remove it from your '{profile}' configuration profile.
2476
2437
  return None
2477
2438
  return results[0][0]
2478
2439
 
2479
- def list_warehouses(self):
2480
- results = self._exec("SHOW WAREHOUSES")
2481
- if not results:
2482
- return []
2483
- return [{"name":name}
2484
- for (name, *rest) in results]
2485
-
2486
- def list_compute_pools(self):
2487
- results = self._exec("SHOW COMPUTE POOLS")
2488
- if not results:
2489
- return []
2490
- return [{"name":name, "status":status, "min_nodes":min_nodes, "max_nodes":max_nodes, "instance_family":instance_family}
2491
- for (name, status, min_nodes, max_nodes, instance_family, *rest) in results]
2492
-
2493
- def list_roles(self):
2494
- results = self._exec("SELECT CURRENT_AVAILABLE_ROLES()")
2495
- if not results:
2496
- return []
2497
- # the response is a single row with a single column containing
2498
- # a stringified JSON array of role names:
2499
- row = results[0]
2500
- if not row:
2501
- return []
2502
- return [{"name": name} for name in json.loads(row[0])]
2503
-
2504
- def list_apps(self):
2505
- all_apps = self._exec(f"SHOW APPLICATIONS LIKE '{RAI_APP_NAME}'")
2506
- if not all_apps:
2507
- all_apps = self._exec("SHOW APPLICATIONS")
2508
- if not all_apps:
2509
- return []
2510
- return [{"name":name}
2511
- for (time, name, *rest) in all_apps]
2512
-
2513
- def list_databases(self):
2514
- results = self._exec("SHOW DATABASES")
2515
- if not results:
2516
- return []
2517
- return [{"name":name}
2518
- for (time, name, *rest) in results]
2519
-
2520
- def list_sf_schemas(self, database:str):
2521
- results = self._exec(f"SHOW SCHEMAS IN {database}")
2522
- if not results:
2523
- return []
2524
- return [{"name":name}
2525
- for (time, name, *rest) in results]
2526
-
2527
- def list_tables(self, database:str, schema:str):
2528
- results = self._exec(f"SHOW OBJECTS IN {database}.{schema}")
2529
- items = []
2530
- if results:
2531
- for (time, name, db_name, schema_name, kind, *rest) in results:
2532
- items.append({"name":name, "kind":kind.lower()})
2533
- return items
2440
+ # CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
2441
+ # list_databases, list_sf_schemas, list_tables) are now in CLIResources class
2442
+ # schema_info is kept in base Resources class since it's used by SnowflakeSchema._fetch_info()
2534
2443
 
2535
- def schema_info(self, database:str, schema:str, tables:Iterable[str]):
2444
+ def schema_info(self, database: str, schema: str, tables: Iterable[str]):
2445
+ """Get detailed schema information including primary keys, foreign keys, and columns."""
2536
2446
  app_name = self.get_app_name()
2537
2447
  # Only pass the db + schema as the identifier so that the resulting identity is correct
2538
2448
  parser = IdentityParser(f"{database}.{schema}")
@@ -2550,7 +2460,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2550
2460
 
2551
2461
  # IdentityParser will parse a single value (with no ".") and store it in this case in the db field
2552
2462
  with debugging.span("columns") as span:
2553
- tables = ", ".join([f"'{IdentityParser(t).db}'" for t in tables])
2463
+ tables_str = ", ".join([f"'{IdentityParser(t).db}'" for t in tables])
2554
2464
  query = textwrap.dedent(f"""
2555
2465
  begin
2556
2466
  SHOW COLUMNS IN SCHEMA {parser.identity};
@@ -2567,7 +2477,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2567
2477
  ELSE FALSE
2568
2478
  END as "supported_type"
2569
2479
  FROM table(result_scan(-1)) as t
2570
- WHERE "table_name" in ({tables})
2480
+ WHERE "table_name" in ({tables_str})
2571
2481
  );
2572
2482
  return table(r);
2573
2483
  end;
@@ -2578,7 +2488,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2578
2488
  results = defaultdict(lambda: {"pks": [], "fks": {}, "columns": {}, "invalid_columns": {}})
2579
2489
  if pks:
2580
2490
  for row in pks:
2581
- results[row[3]]["pks"].append(row[4]) # type: ignore
2491
+ results[row[3]]["pks"].append(row[4]) # type: ignore
2582
2492
  if fks:
2583
2493
  for row in fks:
2584
2494
  results[row[7]]["fks"][row[8]] = row[3]
@@ -2720,7 +2630,7 @@ class SnowflakeSchema:
2720
2630
  tables_with_invalid_columns[table_name] = table_info["invalid_columns"]
2721
2631
 
2722
2632
  if tables_with_invalid_columns:
2723
- from ..errors import UnsupportedColumnTypesWarning
2633
+ from v0.relationalai.errors import UnsupportedColumnTypesWarning
2724
2634
  UnsupportedColumnTypesWarning(tables_with_invalid_columns)
2725
2635
 
2726
2636
  def _add(self, name, is_imported=False):
@@ -2929,10 +2839,14 @@ class Provider(ProviderBase):
2929
2839
  if resources:
2930
2840
  self.resources = resources
2931
2841
  else:
2932
- resource_class = Resources
2933
- if config and config.get("use_direct_access", USE_DIRECT_ACCESS):
2934
- resource_class = DirectAccessResources
2935
- self.resources = resource_class(profile=profile, config=config, generation=generation)
2842
+ from .resources_factory import create_resources_instance
2843
+ self.resources = create_resources_instance(
2844
+ config=config,
2845
+ profile=profile,
2846
+ generation=generation or Generation.V0,
2847
+ dry_run=False,
2848
+ language="rel",
2849
+ )
2936
2850
 
2937
2851
  def list_streams(self, model:str):
2938
2852
  return self.resources.list_imports(model=model)
@@ -3101,19 +3015,31 @@ def Graph(
3101
3015
  nowait_durable: bool = True,
3102
3016
  format: str = "default",
3103
3017
  ):
3018
+ from .resources_factory import create_resources_instance
3019
+ from .use_index_resources import UseIndexResources
3104
3020
 
3105
- client_class = Client
3106
- resource_class = Resources
3107
3021
  use_graph_index = config.get("use_graph_index", USE_GRAPH_INDEX)
3108
3022
  use_monotype_operators = config.get("compiler.use_monotype_operators", False)
3109
- use_direct_access = config.get("use_direct_access", USE_DIRECT_ACCESS)
3110
3023
 
3111
- if use_graph_index:
3024
+ # Create resources instance using factory
3025
+ resources = create_resources_instance(
3026
+ config=config,
3027
+ profile=profile,
3028
+ connection=connection,
3029
+ generation=Generation.V0,
3030
+ dry_run=False, # Resources instance dry_run is separate from client dry_run
3031
+ language="rel",
3032
+ )
3033
+
3034
+ # Determine client class based on resources type and config
3035
+ # SnowflakeClient is used for resources that support use_index functionality
3036
+ if use_graph_index or isinstance(resources, UseIndexResources):
3112
3037
  client_class = SnowflakeClient
3113
- if use_direct_access:
3114
- resource_class = DirectAccessResources
3038
+ else:
3039
+ client_class = Client
3040
+
3115
3041
  client = client_class(
3116
- resource_class(generation=Generation.V0, profile=profile, config=config, connection=connection),
3042
+ resources,
3117
3043
  rel.Compiler(config),
3118
3044
  name,
3119
3045
  config,
@@ -3184,686 +3110,3 @@ def Graph(
3184
3110
  debugging.set_source(pyrel_base)
3185
3111
  client.install("pyrel_base", pyrel_base)
3186
3112
  return dsl.Graph(client, name, format=format)
3187
-
3188
-
3189
-
3190
- #--------------------------------------------------
3191
- # Direct Access
3192
- #--------------------------------------------------
3193
- # Note: All direct access components should live in a separate file
3194
-
3195
- class DirectAccessResources(Resources):
3196
- """
3197
- Resources class for Direct Service Access avoiding Snowflake service functions.
3198
- """
3199
- def __init__(
3200
- self,
3201
- profile: Union[str, None] = None,
3202
- config: Union[Config, None] = None,
3203
- connection: Union[Session, None] = None,
3204
- dry_run: bool = False,
3205
- reset_session: bool = False,
3206
- generation: Optional[Generation] = None,
3207
- language: str = "rel",
3208
- ):
3209
- super().__init__(
3210
- generation=generation,
3211
- profile=profile,
3212
- config=config,
3213
- connection=connection,
3214
- reset_session=reset_session,
3215
- dry_run=dry_run,
3216
- language=language,
3217
- )
3218
- self._endpoint_info = ConfigStore(ENDPOINT_FILE)
3219
- self._service_endpoint = ""
3220
- self._direct_access_client = None
3221
- self.generation = generation
3222
- self.database = ""
3223
-
3224
- @property
3225
- def service_endpoint(self) -> str:
3226
- return self._retrieve_service_endpoint()
3227
-
3228
- def _retrieve_service_endpoint(self, enforce_update=False) -> str:
3229
- account = self.config.get("account")
3230
- app_name = self.config.get("rai_app_name")
3231
- service_endpoint_key = f"{account}.{app_name}.service_endpoint"
3232
- if self._service_endpoint and not enforce_update:
3233
- return self._service_endpoint
3234
- if self._endpoint_info.get(service_endpoint_key, "") and not enforce_update:
3235
- self._service_endpoint = str(self._endpoint_info.get(service_endpoint_key, ""))
3236
- return self._service_endpoint
3237
-
3238
- is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
3239
- query = f"CALL {self.get_app_name()}.app.service_endpoint({not is_snowflake_notebook});"
3240
- result = self._exec(query)
3241
- assert result, f"Could not retrieve service endpoint for {self.get_app_name()}"
3242
- if is_snowflake_notebook:
3243
- self._service_endpoint = f"http://{result[0]['SERVICE_ENDPOINT']}"
3244
- else:
3245
- self._service_endpoint = f"https://{result[0]['SERVICE_ENDPOINT']}"
3246
-
3247
- self._endpoint_info.set(service_endpoint_key, self._service_endpoint)
3248
- # save the endpoint to `ENDPOINT_FILE` to avoid calling the endpoint with every
3249
- # pyrel execution
3250
- try:
3251
- self._endpoint_info.save()
3252
- except Exception:
3253
- print("Failed to persist endpoints to file. This might slow down future executions.")
3254
-
3255
- return self._service_endpoint
3256
-
3257
- @property
3258
- def direct_access_client(self) -> DirectAccessClient:
3259
- if self._direct_access_client:
3260
- return self._direct_access_client
3261
- try:
3262
- service_endpoint = self.service_endpoint
3263
- self._direct_access_client = DirectAccessClient(
3264
- self.config, self.token_handler, service_endpoint, self.generation,
3265
- )
3266
- except Exception as e:
3267
- raise e
3268
- return self._direct_access_client
3269
-
3270
- def request(
3271
- self,
3272
- endpoint: str,
3273
- payload: Dict[str, Any] | None = None,
3274
- headers: Dict[str, str] | None = None,
3275
- path_params: Dict[str, str] | None = None,
3276
- query_params: Dict[str, str] | None = None,
3277
- skip_auto_create: bool = False,
3278
- skip_engine_db_error_retry: bool = False,
3279
- ) -> requests.Response:
3280
- with debugging.span("direct_access_request"):
3281
- def _send_request():
3282
- return self.direct_access_client.request(
3283
- endpoint=endpoint,
3284
- payload=payload,
3285
- headers=headers,
3286
- path_params=path_params,
3287
- query_params=query_params,
3288
- )
3289
- try:
3290
- response = _send_request()
3291
- if response.status_code != 200:
3292
- # For 404 responses with skip_auto_create=True, return immediately to let caller handle it
3293
- # (e.g., get_engine needs to check 404 and return None for auto_create_engine)
3294
- # For skip_auto_create=False, continue to auto-creation logic below
3295
- if response.status_code == 404 and skip_auto_create:
3296
- return response
3297
-
3298
- try:
3299
- message = response.json().get("message", "")
3300
- except requests.exceptions.JSONDecodeError:
3301
- # Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
3302
- # this should have been caught by the 404 check above, so this is an error.
3303
- # For skip_auto_create=False, we explicitly check status_code below,
3304
- # so we don't need to parse the message.
3305
- if skip_auto_create:
3306
- raise ResponseStatusException(
3307
- f"Failed to parse error response from endpoint {endpoint}.", response
3308
- )
3309
- message = "" # Not used when we check status_code directly
3310
-
3311
- # fix engine on engine error and retry
3312
- # Skip setting up GI if skip_auto_create is True to avoid recursion or skip_engine_db_error_retry is true to let _exec_async_v2 perform the retry with the correct headers.
3313
- if ((_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message)) and not skip_engine_db_error_retry:
3314
- engine_name = payload.get("caller_engine_name", "") if payload else ""
3315
- engine_name = engine_name or self.get_default_engine_name()
3316
- engine_size = self.config.get_default_engine_size()
3317
- self._poll_use_index(
3318
- app_name=self.get_app_name(),
3319
- sources=self.sources,
3320
- model=self.database,
3321
- engine_name=engine_name,
3322
- engine_size=engine_size,
3323
- headers=headers,
3324
- )
3325
- response = _send_request()
3326
- except requests.exceptions.ConnectionError as e:
3327
- if "NameResolutionError" in str(e):
3328
- # when we can not resolve the service endpoint, we assume it is outdated
3329
- # hence, we try to retrieve it again and query again.
3330
- self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
3331
- enforce_update=True,
3332
- )
3333
- return _send_request()
3334
- # raise in all other cases
3335
- raise e
3336
- return response
3337
-
3338
- def _txn_request_with_gi_retry(
3339
- self,
3340
- payload: Dict,
3341
- headers: Dict[str, str],
3342
- query_params: Dict,
3343
- engine: Union[str, None],
3344
- ):
3345
- """Make request with graph index retry logic.
3346
-
3347
- Attempts request with gi_setup_skipped=True first. If an engine or database
3348
- issue occurs, polls use_index and retries with gi_setup_skipped=False.
3349
- """
3350
- response = self.request(
3351
- "create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
3352
- )
3353
-
3354
- if response.status_code != 200:
3355
- try:
3356
- message = response.json().get("message", "")
3357
- except requests.exceptions.JSONDecodeError:
3358
- message = ""
3359
-
3360
- if _is_engine_issue(message) or _is_database_issue(message):
3361
- engine_name = engine or self.get_default_engine_name()
3362
- engine_size = self.config.get_default_engine_size()
3363
- self._poll_use_index(
3364
- app_name=self.get_app_name(),
3365
- sources=self.sources,
3366
- model=self.database,
3367
- engine_name=engine_name,
3368
- engine_size=engine_size,
3369
- headers=headers,
3370
- )
3371
- headers['gi_setup_skipped'] = 'False'
3372
- response = self.request(
3373
- "create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
3374
- )
3375
- else:
3376
- raise ResponseStatusException("Failed to create transaction.", response)
3377
-
3378
- return response
3379
-
3380
- def _exec_async_v2(
3381
- self,
3382
- database: str,
3383
- engine: Union[str, None],
3384
- raw_code: str,
3385
- inputs: Dict | None = None,
3386
- readonly=True,
3387
- nowait_durable=False,
3388
- headers: Dict[str, str] | None = None,
3389
- bypass_index=False,
3390
- language: str = "rel",
3391
- query_timeout_mins: int | None = None,
3392
- gi_setup_skipped: bool = False,
3393
- ):
3394
-
3395
- with debugging.span("transaction") as txn_span:
3396
- with debugging.span("create_v2") as create_span:
3397
-
3398
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
3399
-
3400
- payload = {
3401
- "dbname": database,
3402
- "engine_name": engine,
3403
- "query": raw_code,
3404
- "v1_inputs": inputs,
3405
- "nowait_durable": nowait_durable,
3406
- "readonly": readonly,
3407
- "language": language,
3408
- }
3409
- if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
3410
- query_timeout_mins = int(timeout_value)
3411
- if query_timeout_mins is not None:
3412
- payload["timeout_mins"] = query_timeout_mins
3413
- query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
3414
-
3415
- # Add gi_setup_skipped to headers
3416
- if headers is None:
3417
- headers = {}
3418
- headers["gi_setup_skipped"] = str(gi_setup_skipped)
3419
- headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
3420
-
3421
- response = self._txn_request_with_gi_retry(
3422
- payload, headers, query_params, engine
3423
- )
3424
-
3425
- artifact_info = {}
3426
- response_content = response.json()
3427
-
3428
- txn_id = response_content["transaction"]['id']
3429
- state = response_content["transaction"]['state']
3430
-
3431
- txn_span["txn_id"] = txn_id
3432
- create_span["txn_id"] = txn_id
3433
- debugging.event("transaction_created", txn_span, txn_id=txn_id)
3434
-
3435
- # fast path: transaction already finished
3436
- if state in ["COMPLETED", "ABORTED"]:
3437
- if txn_id in self._pending_transactions:
3438
- self._pending_transactions.remove(txn_id)
3439
-
3440
- # Process rows to get the rest of the artifacts
3441
- for result in response_content.get("results", []):
3442
- filename = result['filename']
3443
- # making keys uppercase to match the old behavior
3444
- artifact_info[filename] = {k.upper(): v for k, v in result.items()}
3445
-
3446
- # Slow path: transaction not done yet; start polling
3447
- else:
3448
- self._pending_transactions.append(txn_id)
3449
- with debugging.span("wait", txn_id=txn_id):
3450
- poll_with_specified_overhead(
3451
- lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
3452
- )
3453
- artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
3454
-
3455
- with debugging.span("fetch"):
3456
- return self._download_results(artifact_info, txn_id, state)
3457
-
3458
- def _prepare_index(
3459
- self,
3460
- model: str,
3461
- engine_name: str,
3462
- engine_size: str = "",
3463
- language: str = "rel",
3464
- rai_relations: List[str] | None = None,
3465
- pyrel_program_id: str | None = None,
3466
- skip_pull_relations: bool = False,
3467
- headers: Dict | None = None,
3468
- ):
3469
- """
3470
- Prepare the index for the given engine and model.
3471
- """
3472
- with debugging.span("prepare_index"):
3473
- if headers is None:
3474
- headers = {}
3475
-
3476
- payload = {
3477
- "model_name": model,
3478
- "caller_engine_name": engine_name,
3479
- "language": language,
3480
- "pyrel_program_id": pyrel_program_id,
3481
- "skip_pull_relations": skip_pull_relations,
3482
- "rai_relations": rai_relations or [],
3483
- "user_agent": get_pyrel_version(self.generation),
3484
- }
3485
- # Only include engine_size if it has a non-empty string value
3486
- if engine_size and engine_size.strip():
3487
- payload["caller_engine_size"] = engine_size
3488
-
3489
- response = self.request(
3490
- "prepare_index", payload=payload, headers=headers
3491
- )
3492
-
3493
- if response.status_code != 200:
3494
- raise ResponseStatusException("Failed to prepare index.", response)
3495
-
3496
- return response.json()
3497
-
3498
- def _poll_use_index(
3499
- self,
3500
- app_name: str,
3501
- sources: Iterable[str],
3502
- model: str,
3503
- engine_name: str,
3504
- engine_size: str | None = None,
3505
- program_span_id: str | None = None,
3506
- headers: Dict | None = None,
3507
- ):
3508
- return DirectUseIndexPoller(
3509
- self,
3510
- app_name=app_name,
3511
- sources=sources,
3512
- model=model,
3513
- engine_name=engine_name,
3514
- engine_size=engine_size,
3515
- language=self.language,
3516
- program_span_id=program_span_id,
3517
- headers=headers,
3518
- generation=self.generation,
3519
- ).poll()
3520
-
3521
- def maybe_poll_use_index(
3522
- self,
3523
- app_name: str,
3524
- sources: Iterable[str],
3525
- model: str,
3526
- engine_name: str,
3527
- engine_size: str | None = None,
3528
- program_span_id: str | None = None,
3529
- headers: Dict | None = None,
3530
- ):
3531
- """Only call poll() if there are sources to process and cache is not valid."""
3532
- sources_list = list(sources)
3533
- self.database = model
3534
- if sources_list:
3535
- poller = DirectUseIndexPoller(
3536
- self,
3537
- app_name=app_name,
3538
- sources=sources_list,
3539
- model=model,
3540
- engine_name=engine_name,
3541
- engine_size=engine_size,
3542
- language=self.language,
3543
- program_span_id=program_span_id,
3544
- headers=headers,
3545
- generation=self.generation,
3546
- )
3547
- # If cache is valid (data freshness has not expired), skip polling
3548
- if poller.cache.is_valid():
3549
- cached_sources = len(poller.cache.sources)
3550
- total_sources = len(sources_list)
3551
- cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
3552
-
3553
- message = f"Using cached data for {cached_sources}/{total_sources} data streams"
3554
- if cached_timestamp:
3555
- print(f"\n{message} (cached at {cached_timestamp})\n")
3556
- else:
3557
- print(f"\n{message}\n")
3558
- else:
3559
- return poller.poll()
3560
-
3561
- def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
3562
- """Check whether the given transaction has completed."""
3563
-
3564
- with debugging.span("check_status"):
3565
- response = self.request(
3566
- "get_txn",
3567
- headers=headers,
3568
- path_params={"txn_id": txn_id},
3569
- )
3570
- assert response, f"No results from get_transaction('{txn_id}')"
3571
-
3572
- response_content = response.json()
3573
- transaction = response_content["transaction"]
3574
- status: str = transaction['state']
3575
-
3576
- # remove the transaction from the pending list if it's completed or aborted
3577
- if status in ["COMPLETED", "ABORTED"]:
3578
- if txn_id in self._pending_transactions:
3579
- self._pending_transactions.remove(txn_id)
3580
-
3581
- if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
3582
- config_file_path = getattr(self.config, 'file_path', None)
3583
- timeout_ms = int(transaction.get("timeout_ms", 0))
3584
- timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
3585
- raise QueryTimeoutExceededException(
3586
- timeout_mins=timeout_mins,
3587
- query_id=txn_id,
3588
- config_file_path=config_file_path,
3589
- )
3590
-
3591
- # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
3592
- return status == "COMPLETED" or status == "ABORTED"
3593
-
3594
- def _list_exec_async_artifacts(self, txn_id: str, headers: Dict[str, str] | None = None) -> Dict[str, Dict]:
3595
- """Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
3596
- with debugging.span("list_results"):
3597
- response = self.request(
3598
- "get_txn_artifacts",
3599
- headers=headers,
3600
- path_params={"txn_id": txn_id},
3601
- )
3602
- assert response, f"No results from get_transaction_artifacts('{txn_id}')"
3603
- artifact_info = {}
3604
- for result in response.json()["results"]:
3605
- filename = result['filename']
3606
- # making keys uppercase to match the old behavior
3607
- artifact_info[filename] = {k.upper(): v for k, v in result.items()}
3608
- return artifact_info
3609
-
3610
- def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
3611
- with debugging.span("get_transaction_problems"):
3612
- response = self.request(
3613
- "get_txn_problems",
3614
- path_params={"txn_id": txn_id},
3615
- )
3616
- response_content = response.json()
3617
- if not response_content:
3618
- return []
3619
- return response_content.get("problems", [])
3620
-
3621
- def get_transaction_events(self, transaction_id: str, continuation_token: str = ''):
3622
- response = self.request(
3623
- "get_txn_events",
3624
- path_params={"txn_id": transaction_id, "stream_name": "profiler"},
3625
- query_params={"continuation_token": continuation_token},
3626
- )
3627
- response_content = response.json()
3628
- if not response_content:
3629
- return {
3630
- "events": [],
3631
- "continuation_token": None
3632
- }
3633
- return response_content
3634
-
3635
- #--------------------------------------------------
3636
- # Databases
3637
- #--------------------------------------------------
3638
-
3639
- def get_installed_packages(self, database: str) -> Union[Dict, None]:
3640
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
3641
- if use_graph_index:
3642
- response = self.request(
3643
- "get_model_package_versions",
3644
- payload={"model_name": database},
3645
- )
3646
- else:
3647
- response = self.request(
3648
- "get_package_versions",
3649
- path_params={"db_name": database},
3650
- )
3651
- if response.status_code == 404 and response.json().get("message", "") == "database not found":
3652
- return None
3653
- if response.status_code != 200:
3654
- raise ResponseStatusException(
3655
- f"Failed to retrieve package versions for {database}.", response
3656
- )
3657
-
3658
- content = response.json()
3659
- if not content:
3660
- return None
3661
-
3662
- return safe_json_loads(content["package_versions"])
3663
-
3664
- def get_database(self, database: str):
3665
- with debugging.span("get_database", dbname=database):
3666
- if not database:
3667
- raise ValueError("Database name must be provided to get database.")
3668
- response = self.request(
3669
- "get_db",
3670
- path_params={},
3671
- query_params={"name": database},
3672
- )
3673
- if response.status_code != 200:
3674
- raise ResponseStatusException(f"Failed to get db. db:{database}", response)
3675
-
3676
- response_content = response.json()
3677
-
3678
- if (response_content.get("databases") and len(response_content["databases"]) == 1):
3679
- db = response_content["databases"][0]
3680
- return {
3681
- "id": db["id"],
3682
- "name": db["name"],
3683
- "created_by": db.get("created_by"),
3684
- "created_on": ms_to_timestamp(db.get("created_on")),
3685
- "deleted_by": db.get("deleted_by"),
3686
- "deleted_on": ms_to_timestamp(db.get("deleted_on")),
3687
- "state": db["state"],
3688
- }
3689
- else:
3690
- return None
3691
-
3692
- def create_graph(self, name: str):
3693
- with debugging.span("create_model", dbname=name):
3694
- return self._create_database(name,"")
3695
-
3696
- def delete_graph(self, name:str, force=False, language: str = "rel"):
3697
- prop_hdrs = debugging.gen_current_propagation_headers()
3698
- if self.config.get("use_graph_index", USE_GRAPH_INDEX):
3699
- keep_database = not force and self.config.get("reuse_model", True)
3700
- with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
3701
- response = self.request(
3702
- "release_index",
3703
- payload={
3704
- "model_name": name,
3705
- "keep_database": keep_database,
3706
- "language": language,
3707
- "user_agent": get_pyrel_version(self.generation),
3708
- },
3709
- headers=prop_hdrs,
3710
- )
3711
- if (
3712
- response.status_code != 200
3713
- and not (
3714
- response.status_code == 404
3715
- and "database not found" in response.json().get("message", "")
3716
- )
3717
- ):
3718
- raise ResponseStatusException(f"Failed to release index. Model: {name} ", response)
3719
- else:
3720
- with debugging.span("delete_model", name=name):
3721
- self._delete_database(name, headers=prop_hdrs)
3722
-
3723
- def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
3724
- if force and self.get_graph(target_name):
3725
- self.delete_graph(target_name)
3726
- with debugging.span("clone_model", target_name=target_name, source_name=source_name):
3727
- return self._create_database(target_name,source_name)
3728
-
3729
- def _delete_database(self, name:str, headers:Dict={}):
3730
- with debugging.span("_delete_database", dbname=name):
3731
- response = self.request(
3732
- "delete_db",
3733
- path_params={"db_name": name},
3734
- query_params={},
3735
- headers=headers,
3736
- )
3737
- if response.status_code != 200:
3738
- raise ResponseStatusException(f"Failed to delete db. db:{name} ", response)
3739
-
3740
- def _create_database(self, name:str, source_name:str):
3741
- with debugging.span("_create_database", dbname=name):
3742
- payload = {
3743
- "name": name,
3744
- "source_name": source_name,
3745
- }
3746
- response = self.request(
3747
- "create_db", payload=payload, headers={}, query_params={},
3748
- )
3749
- if response.status_code != 200:
3750
- raise ResponseStatusException(f"Failed to create db. db:{name}", response)
3751
-
3752
- #--------------------------------------------------
3753
- # Engines
3754
- #--------------------------------------------------
3755
-
3756
- def list_engines(self, state: str | None = None):
3757
- response = self.request("list_engines")
3758
- if response.status_code != 200:
3759
- raise ResponseStatusException(
3760
- "Failed to retrieve engines.", response
3761
- )
3762
- response_content = response.json()
3763
- if not response_content:
3764
- return []
3765
- engines = [
3766
- {
3767
- "name": engine["name"],
3768
- "id": engine["id"],
3769
- "size": engine["size"],
3770
- "state": engine["status"], # callers are expecting 'state'
3771
- "created_by": engine["created_by"],
3772
- "created_on": engine["created_on"],
3773
- "updated_on": engine["updated_on"],
3774
- }
3775
- for engine in response_content.get("engines", [])
3776
- if state is None or engine.get("status") == state
3777
- ]
3778
- return sorted(engines, key=lambda x: x["name"])
3779
-
3780
- def get_engine(self, name: str):
3781
- response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
3782
- if response.status_code == 404: # engine not found return 404
3783
- return None
3784
- elif response.status_code != 200:
3785
- raise ResponseStatusException(
3786
- f"Failed to retrieve engine {name}.", response
3787
- )
3788
- engine = response.json()
3789
- if not engine:
3790
- return None
3791
- engine_state: EngineState = {
3792
- "name": engine["name"],
3793
- "id": engine["id"],
3794
- "size": engine["size"],
3795
- "state": engine["status"], # callers are expecting 'state'
3796
- "created_by": engine["created_by"],
3797
- "created_on": engine["created_on"],
3798
- "updated_on": engine["updated_on"],
3799
- "version": engine["version"],
3800
- "auto_suspend": engine["auto_suspend_mins"],
3801
- "suspends_at": engine["suspends_at"],
3802
- }
3803
- return engine_state
3804
-
3805
- def _create_engine(
3806
- self,
3807
- name: str,
3808
- size: str | None = None,
3809
- auto_suspend_mins: int | None = None,
3810
- is_async: bool = False,
3811
- headers: Dict[str, str] | None = None
3812
- ):
3813
- # only async engine creation supported via direct access
3814
- if not is_async:
3815
- return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
3816
- payload:Dict[str, Any] = {
3817
- "name": name,
3818
- }
3819
- if auto_suspend_mins is not None:
3820
- payload["auto_suspend_mins"] = auto_suspend_mins
3821
- if size is not None:
3822
- payload["size"] = size
3823
- response = self.request(
3824
- "create_engine",
3825
- payload=payload,
3826
- path_params={"engine_type": "logic"},
3827
- headers=headers,
3828
- skip_auto_create=True,
3829
- )
3830
- if response.status_code != 200:
3831
- raise ResponseStatusException(
3832
- f"Failed to create engine {name} with size {size}.", response
3833
- )
3834
-
3835
- def delete_engine(self, name:str, force:bool = False, headers={}):
3836
- response = self.request(
3837
- "delete_engine",
3838
- path_params={"engine_name": name, "engine_type": "logic"},
3839
- headers=headers,
3840
- skip_auto_create=True,
3841
- )
3842
- if response.status_code != 200:
3843
- raise ResponseStatusException(
3844
- f"Failed to delete engine {name}.", response
3845
- )
3846
-
3847
- def suspend_engine(self, name:str):
3848
- response = self.request(
3849
- "suspend_engine",
3850
- path_params={"engine_name": name, "engine_type": "logic"},
3851
- skip_auto_create=True,
3852
- )
3853
- if response.status_code != 200:
3854
- raise ResponseStatusException(
3855
- f"Failed to suspend engine {name}.", response
3856
- )
3857
-
3858
- def resume_engine_async(self, name:str, headers={}):
3859
- response = self.request(
3860
- "resume_engine",
3861
- path_params={"engine_name": name, "engine_type": "logic"},
3862
- headers=headers,
3863
- skip_auto_create=True,
3864
- )
3865
- if response.status_code != 200:
3866
- raise ResponseStatusException(
3867
- f"Failed to resume engine {name}.", response
3868
- )
3869
- return {}