relationalai 0.12.12__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. relationalai/__init__.py +69 -22
  2. relationalai/clients/__init__.py +15 -2
  3. relationalai/clients/client.py +4 -4
  4. relationalai/clients/local.py +5 -5
  5. relationalai/clients/resources/__init__.py +8 -0
  6. relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
  7. relationalai/clients/resources/snowflake/__init__.py +20 -0
  8. relationalai/clients/resources/snowflake/cli_resources.py +87 -0
  9. relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
  10. relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
  11. relationalai/clients/resources/snowflake/error_handlers.py +199 -0
  12. relationalai/clients/{export_procedure.py.jinja → resources/snowflake/export_procedure.py.jinja} +1 -1
  13. relationalai/clients/resources/snowflake/resources_factory.py +99 -0
  14. relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +635 -1380
  15. relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
  16. relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
  17. relationalai/clients/resources/snowflake/util.py +387 -0
  18. relationalai/early_access/dsl/ir/executor.py +4 -4
  19. relationalai/early_access/dsl/snow/api.py +2 -1
  20. relationalai/errors.py +23 -0
  21. relationalai/experimental/solvers.py +7 -7
  22. relationalai/semantics/devtools/benchmark_lqp.py +4 -5
  23. relationalai/semantics/devtools/extract_lqp.py +1 -1
  24. relationalai/semantics/internal/internal.py +4 -4
  25. relationalai/semantics/internal/snowflake.py +3 -2
  26. relationalai/semantics/lqp/executor.py +22 -22
  27. relationalai/semantics/lqp/model2lqp.py +42 -4
  28. relationalai/semantics/lqp/passes.py +1 -1
  29. relationalai/semantics/lqp/rewrite/cdc.py +1 -1
  30. relationalai/semantics/lqp/rewrite/extract_keys.py +72 -15
  31. relationalai/semantics/metamodel/builtins.py +8 -6
  32. relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
  33. relationalai/semantics/metamodel/util.py +6 -5
  34. relationalai/semantics/reasoners/graph/core.py +8 -9
  35. relationalai/semantics/rel/executor.py +14 -11
  36. relationalai/semantics/sql/compiler.py +2 -2
  37. relationalai/semantics/sql/executor/snowflake.py +9 -5
  38. relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
  39. relationalai/tools/cli.py +26 -30
  40. relationalai/tools/cli_helpers.py +10 -2
  41. relationalai/util/otel_configuration.py +2 -1
  42. relationalai/util/otel_handler.py +1 -1
  43. {relationalai-0.12.12.dist-info → relationalai-0.13.0.dist-info}/METADATA +1 -1
  44. {relationalai-0.12.12.dist-info → relationalai-0.13.0.dist-info}/RECORD +49 -40
  45. relationalai_test_util/fixtures.py +2 -1
  46. /relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
  47. {relationalai-0.12.12.dist-info → relationalai-0.13.0.dist-info}/WHEEL +0 -0
  48. {relationalai-0.12.12.dist-info → relationalai-0.13.0.dist-info}/entry_points.txt +0 -0
  49. {relationalai-0.12.12.dist-info → relationalai-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,8 @@
1
1
  # pyright: reportUnusedExpression=false
2
2
  from __future__ import annotations
3
3
  import base64
4
- import decimal
5
4
  import importlib.resources
6
5
  import io
7
- from numbers import Number
8
6
  import re
9
7
  import json
10
8
  import time
@@ -14,15 +12,14 @@ import uuid
14
12
  import warnings
15
13
  import atexit
16
14
  import hashlib
15
+ from dataclasses import dataclass
17
16
 
18
-
19
- from relationalai.auth.token_handler import TokenHandler
20
- from relationalai.clients.use_index_poller import DirectUseIndexPoller, UseIndexPoller
17
+ from ....auth.token_handler import TokenHandler
21
18
  import snowflake.snowpark
22
19
 
23
- from relationalai.rel_utils import sanitize_identifier, to_fqn_relation_name
24
- from relationalai.tools.constants import FIELD_PLACEHOLDER, RAI_APP_NAME, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, USE_DIRECT_ACCESS, DEFAULT_QUERY_TIMEOUT_MINS, WAIT_FOR_STREAM_SYNC, Generation
25
- from .. import std
20
+ from ....rel_utils import sanitize_identifier, to_fqn_relation_name
21
+ from ....tools.constants import FIELD_PLACEHOLDER, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, DEFAULT_QUERY_TIMEOUT_MINS, WAIT_FOR_STREAM_SYNC, Generation
22
+ from .... import std
26
23
  from collections import defaultdict
27
24
  import requests
28
25
  import snowflake.connector
@@ -30,33 +27,63 @@ import pyarrow as pa
30
27
 
31
28
  from snowflake.snowpark import Session
32
29
  from snowflake.snowpark.context import get_active_session
33
- from . import result_helpers
34
- from .. import debugging
35
- from typing import Any, Dict, Iterable, Optional, Tuple, List, Literal, Union, cast
30
+ from ... import result_helpers
31
+ from .... import debugging
32
+ from typing import Any, Dict, Iterable, Tuple, List, Literal, cast
36
33
 
37
34
  from pandas import DataFrame
38
35
 
39
- from ..tools.cli_controls import Spinner
40
- from ..clients.types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
41
- from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
42
- from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
43
- from ..clients.direct_access_client import DirectAccessClient
44
- from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp, normalize_datetime
45
- from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
46
- from .. import dsl, rel, metamodel as m
47
- from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
36
+ from ....tools.cli_controls import Spinner
37
+ from ...types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
38
+ from ...config import Config
39
+ from ...client import Client, ExportParams, ProviderBase, ResourcesBase
40
+ from ...util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, normalize_datetime
41
+ from .util import (
42
+ collect_error_messages,
43
+ process_jinja_template,
44
+ type_to_sql,
45
+ type_to_snowpark,
46
+ sanitize_user_name as _sanitize_user_name,
47
+ normalize_params,
48
+ format_sproc_name,
49
+ is_azure_url,
50
+ is_container_runtime,
51
+ imports_to_dicts,
52
+ txn_list_to_dicts,
53
+ decrypt_artifact,
54
+ )
55
+ from ....environments import runtime_env, HexEnvironment, SnowbookEnvironment
56
+ from .... import dsl, rel, metamodel as m
57
+ from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
48
58
  from concurrent.futures import ThreadPoolExecutor
49
- from datetime import datetime, date, timedelta
59
+ from datetime import datetime, timedelta
50
60
  from snowflake.snowpark.types import StringType, StructField, StructType
61
+ # Import error handlers and constants
62
+ from .error_handlers import (
63
+ ErrorHandler,
64
+ DuoSecurityErrorHandler,
65
+ AppMissingErrorHandler,
66
+ DatabaseErrorsHandler,
67
+ EngineErrorsHandler,
68
+ ServiceNotStartedErrorHandler,
69
+ TransactionAbortedErrorHandler,
70
+ )
71
+ # Import engine state handlers
72
+ from .engine_state_handlers import (
73
+ EngineStateHandler,
74
+ EngineContext,
75
+ SyncPendingStateHandler,
76
+ SyncSuspendedStateHandler,
77
+ SyncReadyStateHandler,
78
+ SyncGoneStateHandler,
79
+ SyncMissingEngineHandler,
80
+ AsyncPendingStateHandler,
81
+ AsyncSuspendedStateHandler,
82
+ AsyncReadyStateHandler,
83
+ AsyncGoneStateHandler,
84
+ AsyncMissingEngineHandler,
85
+ )
51
86
 
52
- # warehouse-based snowflake notebooks currently don't have hazmat
53
- crypto_disabled = False
54
- try:
55
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
56
- from cryptography.hazmat.backends import default_backend
57
- from cryptography.hazmat.primitives import padding
58
- except (ModuleNotFoundError, ImportError):
59
- crypto_disabled = True
60
87
 
61
88
  #--------------------------------------------------
62
89
  # Constants
@@ -66,232 +93,44 @@ VALID_POOL_STATUS = ["ACTIVE", "IDLE", "SUSPENDED"]
66
93
  # transaction list and get return different fields (duration vs timings)
67
94
  LIST_TXN_SQL_FIELDS = ["id", "database_name", "engine_name", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "duration"]
68
95
  GET_TXN_SQL_FIELDS = ["id", "database", "engine", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "timings"]
69
- IMPORT_STREAM_FIELDS = ["ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS", "FQ_OBJECT_NAME", "RAI_DATABASE",
70
- "RAI_RELATION", "DATA_SYNC_STATUS", "PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
71
- "NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"]
72
96
  VALID_ENGINE_STATES = ["READY", "PENDING"]
73
97
 
74
98
  # Cloud-specific engine sizes
75
99
  INTERNAL_ENGINE_SIZES = ["XS", "S", "M", "L"]
76
100
  ENGINE_SIZES_AWS = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_L"]
77
101
  ENGINE_SIZES_AZURE = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_SL"]
78
-
79
- FIELD_MAP = {
80
- "database_name": "database",
81
- "engine_name": "engine",
82
- }
83
- VALID_IMPORT_STATES = ["PENDING", "PROCESSING", "QUARANTINED", "LOADED"]
84
- ENGINE_ERRORS = ["engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted"]
85
- ENGINE_NOT_READY_MSGS = ["engine is in pending", "engine is provisioning"]
86
- DATABASE_ERRORS = ["database not found"]
102
+ # Note: ENGINE_ERRORS, ENGINE_NOT_READY_MSGS, DATABASE_ERRORS moved to util.py
87
103
  PYREL_ROOT_DB = 'pyrel_root_db'
88
104
 
89
105
  TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
90
106
 
91
- DUO_TEXT = "duo security"
92
-
93
107
  TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
94
108
 
95
- #--------------------------------------------------
96
- # Helpers
97
- #--------------------------------------------------
98
-
99
- def process_jinja_template(template: str, indent_spaces = 0, **substitutions) -> str:
100
- """Process a Jinja-like template.
101
-
102
- Supports:
103
- - Variable substitution {{ var }}
104
- - Conditional blocks {% if condition %} ... {% endif %}
105
- - For loops {% for item in items %} ... {% endfor %}
106
- - Comments {# ... #}
107
- - Whitespace control with {%- and -%}
108
-
109
- Args:
110
- template: The template string
111
- indent_spaces: Number of spaces to indent the result
112
- **substitutions: Variable substitutions
113
- """
114
-
115
- def evaluate_condition(condition: str, context: dict) -> bool:
116
- """Safely evaluate a condition string using the context."""
117
- # Replace variables with their values
118
- for k, v in context.items():
119
- if isinstance(v, str):
120
- condition = condition.replace(k, f"'{v}'")
121
- else:
122
- condition = condition.replace(k, str(v))
123
- try:
124
- return bool(eval(condition, {"__builtins__": {}}, {}))
125
- except Exception:
126
- return False
127
-
128
- def process_expression(expr: str, context: dict) -> str:
129
- """Process a {{ expression }} block."""
130
- expr = expr.strip()
131
- if expr in context:
132
- return str(context[expr])
133
- return ""
134
-
135
- def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
136
- """Process a block of template lines recursively."""
137
- result = []
138
- i = 0
139
- while i < len(lines):
140
- line = lines[i]
141
-
142
- # Handle comments
143
- line = re.sub(r'{#.*?#}', '', line)
144
-
145
- # Handle if blocks
146
- if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
147
- if if_match:
148
- condition = if_match.group(1)
149
- if_block = []
150
- else_block = []
151
- i += 1
152
- nesting = 1
153
- in_else_block = False
154
- while i < len(lines) and nesting > 0:
155
- if re.search(r'{%\s*if\s+', lines[i]):
156
- nesting += 1
157
- elif re.search(r'{%\s*endif\s*%}', lines[i]):
158
- nesting -= 1
159
- elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
160
- in_else_block = True
161
- i += 1
162
- continue
163
-
164
- if nesting > 0:
165
- if in_else_block:
166
- else_block.append(lines[i])
167
- else:
168
- if_block.append(lines[i])
169
- i += 1
170
- if evaluate_condition(condition, context):
171
- result.extend(process_block(if_block, context, indent))
172
- else:
173
- result.extend(process_block(else_block, context, indent))
174
- continue
175
-
176
- # Handle for loops
177
- for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
178
- if for_match:
179
- var_name, iterable_name = for_match.groups()
180
- for_block = []
181
- i += 1
182
- nesting = 1
183
- while i < len(lines) and nesting > 0:
184
- if re.search(r'{%\s*for\s+', lines[i]):
185
- nesting += 1
186
- elif re.search(r'{%\s*endfor\s*%}', lines[i]):
187
- nesting -= 1
188
- if nesting > 0:
189
- for_block.append(lines[i])
190
- i += 1
191
- if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
192
- for item in context[iterable_name]:
193
- loop_context = dict(context)
194
- loop_context[var_name] = item
195
- result.extend(process_block(for_block, loop_context, indent))
196
- continue
197
-
198
- # Handle variable substitution
199
- line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
200
-
201
- # Handle whitespace control
202
- line = re.sub(r'{%-', '{%', line)
203
- line = re.sub(r'-%}', '%}', line)
204
-
205
- # Add line with proper indentation, preserving blank lines
206
- if line.strip():
207
- result.append(" " * (indent_spaces + indent) + line)
208
- else:
209
- result.append("")
210
-
211
- i += 1
212
-
213
- return result
214
-
215
- # Split template into lines and process
216
- lines = template.split('\n')
217
- processed_lines = process_block(lines, substitutions)
218
-
219
- return '\n'.join(processed_lines)
220
-
221
- def type_to_sql(type) -> str:
222
- if type is str:
223
- return "VARCHAR"
224
- if type is int:
225
- return "NUMBER"
226
- if type is Number:
227
- return "DECIMAL(38, 15)"
228
- if type is float:
229
- return "FLOAT"
230
- if type is decimal.Decimal:
231
- return "DECIMAL(38, 15)"
232
- if type is bool:
233
- return "BOOLEAN"
234
- if type is dict:
235
- return "VARIANT"
236
- if type is list:
237
- return "ARRAY"
238
- if type is bytes:
239
- return "BINARY"
240
- if type is datetime:
241
- return "TIMESTAMP"
242
- if type is date:
243
- return "DATE"
244
- if isinstance(type, dsl.Type):
245
- return "VARCHAR"
246
- raise ValueError(f"Unknown type {type}")
247
-
248
- def type_to_snowpark(type) -> str:
249
- if type is str:
250
- return "StringType()"
251
- if type is int:
252
- return "IntegerType()"
253
- if type is float:
254
- return "FloatType()"
255
- if type is Number:
256
- return "DecimalType(38, 15)"
257
- if type is decimal.Decimal:
258
- return "DecimalType(38, 15)"
259
- if type is bool:
260
- return "BooleanType()"
261
- if type is dict:
262
- return "MapType()"
263
- if type is list:
264
- return "ArrayType()"
265
- if type is bytes:
266
- return "BinaryType()"
267
- if type is datetime:
268
- return "TimestampType()"
269
- if type is date:
270
- return "DateType()"
271
- if isinstance(type, dsl.Type):
272
- return "StringType()"
273
- raise ValueError(f"Unknown type {type}")
274
-
275
- def _sanitize_user_name(user: str) -> str:
276
- # Extract the part before the '@'
277
- sanitized_user = user.split('@')[0]
278
- # Replace any character that is not a letter, number, or underscore with '_'
279
- sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
280
- return sanitized_user
281
-
282
- def _is_engine_issue(response_message: str) -> bool:
283
- return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
284
-
285
- def _is_database_issue(response_message: str) -> bool:
286
- return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
287
-
288
-
289
109
  #--------------------------------------------------
290
110
  # Resources
291
111
  #--------------------------------------------------
292
112
 
293
113
  APP_NAME = "___RAI_APP___"
294
114
 
115
+ @dataclass
116
+ class ExecContext:
117
+ """Execution context for SQL queries, containing all parameters needed for execution and retry."""
118
+ code: str
119
+ params: List[Any] | None = None
120
+ raw: bool = False
121
+ help: bool = True
122
+ skip_engine_db_error_retry: bool = False
123
+
124
+ def re_execute(self, resources: 'Resources') -> Any:
125
+ """Re-execute this context's query using the provided resources instance."""
126
+ return resources._exec(
127
+ code=self.code,
128
+ params=self.params,
129
+ raw=self.raw,
130
+ help=self.help,
131
+ skip_engine_db_error_retry=self.skip_engine_db_error_retry
132
+ )
133
+
295
134
  class Resources(ResourcesBase):
296
135
  def __init__(
297
136
  self,
@@ -301,7 +140,7 @@ class Resources(ResourcesBase):
301
140
  dry_run: bool = False,
302
141
  reset_session: bool = False,
303
142
  generation: Generation | None = None,
304
- language: str = "rel",
143
+ language: str = "rel", # Accepted for backward compatibility, but not stored in base class
305
144
  ):
306
145
  super().__init__(profile, config=config)
307
146
  self._token_handler: TokenHandler | None = None
@@ -319,16 +158,97 @@ class Resources(ResourcesBase):
319
158
  # self.sources contains fully qualified Snowflake table/view names
320
159
  self.sources: set[str] = set()
321
160
  self._sproc_models = None
322
- self.database = ""
161
+ # Store language for backward compatibility (used by child classes for use_index polling)
323
162
  self.language = language
163
+ # Register error and state handlers
164
+ self._register_handlers()
165
+ # Register atexit callback to cancel pending transactions
324
166
  atexit.register(self.cancel_pending_transactions)
325
167
 
168
+ #--------------------------------------------------
169
+ # Initialization & Properties
170
+ #--------------------------------------------------
171
+
172
+ def _register_handlers(self) -> None:
173
+ """Register error and engine state handlers for processing."""
174
+ # Register base handlers using getter methods that subclasses can override
175
+ # Use defensive copying to ensure each instance has its own handler lists
176
+ # and prevent cross-instance contamination from subclass mutations
177
+ self._error_handlers = list(self._get_error_handlers())
178
+ self._sync_engine_state_handlers = list(self._get_engine_state_handlers(is_async=False))
179
+ self._async_engine_state_handlers = list(self._get_engine_state_handlers(is_async=True))
180
+
181
+ def _get_error_handlers(self) -> list[ErrorHandler]:
182
+ """Get list of error handlers. Subclasses can override to add custom handlers.
183
+
184
+ Returns:
185
+ List of error handlers for standard error processing using Strategy Pattern.
186
+
187
+ Example:
188
+ def _get_error_handlers(self) -> list[ErrorHandler]:
189
+ # Get base handlers
190
+ handlers = super()._get_error_handlers()
191
+ # Add custom handler
192
+ handlers.append(MyCustomErrorHandler())
193
+ return handlers
194
+ """
195
+ return [
196
+ DuoSecurityErrorHandler(),
197
+ AppMissingErrorHandler(),
198
+ DatabaseErrorsHandler(),
199
+ EngineErrorsHandler(),
200
+ ServiceNotStartedErrorHandler(),
201
+ TransactionAbortedErrorHandler(),
202
+ ]
203
+
204
+ def _get_engine_state_handlers(self, is_async: bool = False) -> list[EngineStateHandler]:
205
+ """Get list of engine state handlers. Subclasses can override.
206
+
207
+ Args:
208
+ is_async: If True, returns async handlers; if False, returns sync handlers.
209
+
210
+ Returns:
211
+ List of engine state handlers for processing engine states.
212
+
213
+ Example:
214
+ def _get_engine_state_handlers(self, is_async: bool = False) -> list[EngineStateHandler]:
215
+ # Get base handlers
216
+ handlers = super()._get_engine_state_handlers(is_async)
217
+ # Add custom handler
218
+ handlers.append(MyCustomStateHandler())
219
+ return handlers
220
+ """
221
+ if is_async:
222
+ return [
223
+ AsyncPendingStateHandler(),
224
+ AsyncSuspendedStateHandler(),
225
+ AsyncReadyStateHandler(),
226
+ AsyncGoneStateHandler(),
227
+ AsyncMissingEngineHandler(),
228
+ ]
229
+ else:
230
+ return [
231
+ SyncPendingStateHandler(),
232
+ SyncSuspendedStateHandler(),
233
+ SyncReadyStateHandler(),
234
+ SyncGoneStateHandler(),
235
+ SyncMissingEngineHandler(),
236
+ ]
237
+
326
238
  @property
327
239
  def token_handler(self) -> TokenHandler:
328
240
  if not self._token_handler:
329
241
  self._token_handler = TokenHandler.from_config(self.config)
330
242
  return self._token_handler
331
243
 
244
+ def reset(self):
245
+ """Reset the session."""
246
+ self._session = None
247
+
248
+ #--------------------------------------------------
249
+ # Session Management
250
+ #--------------------------------------------------
251
+
332
252
  def is_erp_running(self, app_name: str) -> bool:
333
253
  """Check if the ERP is running. The app.service_status() returns single row/column containing an array of JSON service status objects."""
334
254
  query = f"CALL {app_name}.app.service_status();"
@@ -426,7 +346,28 @@ class Resources(ResourcesBase):
426
346
  except Exception as e:
427
347
  raise e
428
348
 
349
+ #--------------------------------------------------
350
+ # Core Execution Methods
351
+ #--------------------------------------------------
352
+
429
353
  def _exec_sql(self, code: str, params: List[Any] | None, raw=False):
354
+ """
355
+ Lowest-level SQL execution method.
356
+
357
+ Directly executes SQL via the Snowflake session. This is the foundation
358
+ for all other execution methods. It:
359
+ - Replaces APP_NAME placeholder with actual app name
360
+ - Executes SQL with optional parameters
361
+ - Returns either raw session results or collected results
362
+
363
+ Args:
364
+ code: SQL code to execute (may contain APP_NAME placeholder)
365
+ params: Optional SQL parameters
366
+ raw: If True, return raw session results; if False, collect results
367
+
368
+ Returns:
369
+ Raw session results if raw=True, otherwise collected results
370
+ """
430
371
  assert self._session is not None
431
372
  sess_results = self._session.sql(
432
373
  code.replace(APP_NAME, self.get_app_name()),
@@ -441,86 +382,91 @@ class Resources(ResourcesBase):
441
382
  code: str,
442
383
  params: List[Any] | Any | None = None,
443
384
  raw: bool = False,
444
- help: bool = True
385
+ help: bool = True,
386
+ skip_engine_db_error_retry: bool = False
445
387
  ) -> Any:
388
+ """
389
+ Mid-level SQL execution method with error handling.
390
+
391
+ This is the primary method for executing SQL queries. It wraps _exec_sql
392
+ with comprehensive error handling and parameter normalization. Used
393
+ extensively throughout the codebase for direct SQL operations like:
394
+ - SHOW commands (warehouses, databases, etc.)
395
+ - CALL statements to RAI app stored procedures
396
+ - Transaction management queries
397
+
398
+ The error handling flow:
399
+ 1. Normalizes parameters and creates execution context
400
+ 2. Calls _exec_sql to execute the query
401
+ 3. On error, uses standard error handling (Strategy Pattern), which subclasses
402
+ can influence via `_get_error_handlers()` or by overriding `_handle_standard_exec_errors()`
403
+
404
+ Args:
405
+ code: SQL code to execute
406
+ params: Optional SQL parameters (normalized to list if needed)
407
+ raw: If True, return raw session results; if False, collect results
408
+ help: If True, enable error handling; if False, raise errors immediately
409
+ skip_engine_db_error_retry: If True, skip use_index retry logic in error handlers
410
+
411
+ Returns:
412
+ Query results (collected or raw depending on 'raw' parameter)
413
+ """
446
414
  # print(f"\n--- sql---\n{code}\n--- end sql---\n")
415
+ # Ensure session is initialized
447
416
  if not self._session:
448
417
  self._session = self.get_sf_session()
449
418
 
419
+ # Normalize parameters
420
+ normalized_params = normalize_params(params)
421
+
422
+ # Create execution context
423
+ ctx = ExecContext(
424
+ code=code,
425
+ params=normalized_params,
426
+ raw=raw,
427
+ help=help,
428
+ skip_engine_db_error_retry=skip_engine_db_error_retry
429
+ )
430
+
431
+ # Execute SQL
450
432
  try:
451
- if params is not None and not isinstance(params, list):
452
- params = cast(List[Any], [params])
453
- return self._exec_sql(code, params, raw=raw)
433
+ return self._exec_sql(ctx.code, ctx.params, raw=ctx.raw)
454
434
  except Exception as e:
455
- if not help:
435
+ if not ctx.help:
456
436
  raise e
457
- orig_message = str(e).lower()
458
- rai_app = self.config.get("rai_app_name", "")
459
- current_role = self.config.get("role")
460
- engine = self.get_default_engine_name()
461
- engine_size = self.config.get_default_engine_size()
462
- assert isinstance(rai_app, str), f"rai_app_name must be a string, not {type(rai_app)}"
463
- assert isinstance(engine, str), f"engine must be a string, not {type(engine)}"
464
- print("\n")
465
- if DUO_TEXT in orig_message:
466
- raise DuoSecurityFailed(e)
467
- if re.search(f"database '{rai_app}' does not exist or not authorized.".lower(), orig_message):
468
- exception = SnowflakeAppMissingException(rai_app, current_role)
469
- raise exception from None
470
- if _is_engine_issue(orig_message) or _is_database_issue(orig_message):
471
- try:
472
- self._poll_use_index(
473
- app_name=self.get_app_name(),
474
- sources=self.sources,
475
- model=self.database,
476
- engine_name=engine,
477
- engine_size=engine_size
478
- )
479
- return self._exec(code, params, raw=raw, help=help)
480
- except EngineNameValidationException as e:
481
- raise EngineNameValidationException(engine) from e
482
- except Exception as e:
483
- raise EngineProvisioningFailed(engine, e) from e
484
- elif re.search(r"javascript execution error", orig_message):
485
- match = re.search(r"\"message\":\"(.*)\"", orig_message)
486
- if match:
487
- message = match.group(1)
488
- if "engine is in pending" in message or "engine is provisioning" in message:
489
- raise EnginePending(engine)
490
- else:
491
- raise RAIException(message) from None
492
-
493
- if re.search(r"the relationalai service has not been started.", orig_message):
494
- app_name = self.config.get("rai_app_name", "")
495
- assert isinstance(app_name, str), f"rai_app_name must be a string, not {type(app_name)}"
496
- raise SnowflakeRaiAppNotStarted(app_name)
497
-
498
- if re.search(r"state:\s*aborted", orig_message):
499
- txn_id_match = re.search(r"id:\s*([0-9a-f\-]+)", orig_message)
500
- if txn_id_match:
501
- txn_id = txn_id_match.group(1)
502
- problems = self.get_transaction_problems(txn_id)
503
- if problems:
504
- for problem in problems:
505
- if isinstance(problem, dict):
506
- type_field = problem.get('TYPE')
507
- message_field = problem.get('MESSAGE')
508
- report_field = problem.get('REPORT')
509
- else:
510
- type_field = problem.TYPE
511
- message_field = problem.MESSAGE
512
- report_field = problem.REPORT
513
437
 
514
- raise RAIAbortedTransactionError(type_field, message_field, report_field)
515
- raise RAIException(str(e))
516
- raise RAIException(str(e))
438
+ # Handle standard errors
439
+ result = self._handle_standard_exec_errors(e, ctx)
440
+ if result is not None:
441
+ return result
442
+
443
+ #--------------------------------------------------
444
+ # Error Handling
445
+ #--------------------------------------------------
517
446
 
447
+ def _handle_standard_exec_errors(self, e: Exception, ctx: ExecContext) -> Any | None:
448
+ """
449
+ Handle standard Snowflake/RAI errors using Strategy Pattern.
518
450
 
519
- def reset(self):
520
- self._session = None
451
+ Each error type has a dedicated handler class that encapsulates
452
+ the detection logic and exception creation. Handlers are processed
453
+ in order until one matches and handles the error.
454
+ """
455
+ message = str(e).lower()
456
+
457
+ # Try each handler in order until one matches
458
+ for handler in self._error_handlers:
459
+ if handler.matches(e, message, ctx, self):
460
+ result = handler.handle(e, ctx, self)
461
+ if result is not None:
462
+ return result
463
+ return # Handler raised exception, we're done
464
+
465
+ # Fallback: transform to RAIException
466
+ raise RAIException(str(e))
521
467
 
522
468
  #--------------------------------------------------
523
- # Check direct access is enabled
469
+ # Feature Detection & Configuration
524
470
  #--------------------------------------------------
525
471
 
526
472
  def is_direct_access_enabled(self) -> bool:
@@ -542,9 +488,6 @@ class Resources(ResourcesBase):
542
488
  except Exception as e:
543
489
  raise Exception(f"Unable to determine if direct access is enabled. Details error: {e}") from e
544
490
 
545
- #--------------------------------------------------
546
- # Snowflake Account Flags
547
- #--------------------------------------------------
548
491
 
549
492
  def is_account_flag_set(self, flag: str) -> bool:
550
493
  results = self._exec(
@@ -564,7 +507,8 @@ class Resources(ResourcesBase):
564
507
  f"call {APP_NAME}.api.get_database('{database}');"
565
508
  )
566
509
  except Exception as e:
567
- if "Database does not exist" in str(e):
510
+ messages = collect_error_messages(e)
511
+ if any("database does not exist" in msg for msg in messages):
568
512
  return None
569
513
  raise e
570
514
 
@@ -588,10 +532,11 @@ class Resources(ResourcesBase):
588
532
  try:
589
533
  results = self._exec(query)
590
534
  except Exception as e:
591
- if "Database does not exist" in str(e):
535
+ messages = collect_error_messages(e)
536
+ if any("database does not exist" in msg for msg in messages):
592
537
  return None
593
538
  # fallback to None for old sql-lib versions
594
- if "Unknown user-defined function" in str(e):
539
+ if any("unknown user-defined function" in msg for msg in messages):
595
540
  return None
596
541
  raise e
597
542
 
@@ -608,6 +553,139 @@ class Resources(ResourcesBase):
608
553
  # Engines
609
554
  #--------------------------------------------------
610
555
 
556
+ def _prepare_engine_params(
557
+ self,
558
+ name: str | None,
559
+ size: str | None,
560
+ use_default_size: bool = False
561
+ ) -> tuple[str, str | None]:
562
+ """
563
+ Prepare engine parameters by resolving and validating name and size.
564
+
565
+ Args:
566
+ name: Engine name (None to use default)
567
+ size: Engine size (None to use config or default)
568
+ use_default_size: If True and size is None, use get_default_engine_size()
569
+
570
+ Returns:
571
+ Tuple of (engine_name, engine_size)
572
+
573
+ Raises:
574
+ EngineNameValidationException: If engine name is invalid
575
+ Exception: If engine size is invalid
576
+ """
577
+ from relationalai.tools.cli_helpers import validate_engine_name
578
+
579
+ engine_name = name or self.get_default_engine_name()
580
+
581
+ # Resolve engine size
582
+ if size:
583
+ engine_size = size
584
+ else:
585
+ if use_default_size:
586
+ engine_size = self.config.get_default_engine_size()
587
+ else:
588
+ engine_size = self.config.get("engine_size", None)
589
+
590
+ # Validate engine size
591
+ if engine_size:
592
+ is_size_valid, sizes = self.validate_engine_size(engine_size)
593
+ if not is_size_valid:
594
+ error_msg = f"Invalid engine size '{engine_size}'. Valid sizes are: {', '.join(sizes)}"
595
+ if use_default_size:
596
+ error_msg = f"Invalid engine size in config: '{engine_size}'. Valid sizes are: {', '.join(sizes)}"
597
+ raise Exception(error_msg)
598
+
599
+ # Validate engine name
600
+ is_name_valid, _ = validate_engine_name(engine_name)
601
+ if not is_name_valid:
602
+ raise EngineNameValidationException(engine_name)
603
+
604
+ return engine_name, engine_size
605
+
606
+ def _get_state_handler(self, state: str | None, handlers: list[EngineStateHandler]) -> EngineStateHandler:
607
+ """Find the appropriate state handler for the given state."""
608
+ for handler in handlers:
609
+ if handler.handles_state(state):
610
+ return handler
611
+ # Fallback to missing engine handler if no match
612
+ return handlers[-1] # Last handler should be MissingEngineHandler
613
+
614
+ def _process_engine_state(
615
+ self,
616
+ engine: EngineState | Dict[str, Any] | None,
617
+ context: EngineContext,
618
+ handlers: list[EngineStateHandler],
619
+ set_active_on_success: bool = False
620
+ ) -> EngineState | Dict[str, Any] | None:
621
+ """
622
+ Process engine state using appropriate state handler.
623
+
624
+ Args:
625
+ engine: Current engine state (or None if missing)
626
+ context: Engine context for state handling
627
+ handlers: List of state handlers to use (sync or async)
628
+ set_active_on_success: If True, set engine as active when handler returns engine
629
+
630
+ Returns:
631
+ Engine state after processing, or None if engine needs to be created
632
+ """
633
+ # Find and execute appropriate state handler
634
+ state = engine["state"] if engine else None
635
+ handler = self._get_state_handler(state, handlers)
636
+ engine = handler.handle(engine, context, self)
637
+
638
+ # If handler returned None and we didn't start with None state, engine needs to be created
639
+ # (e.g., GONE state deleted the engine, so we need to create a new one)
640
+ if not engine and state is not None:
641
+ handler = self._get_state_handler(None, handlers)
642
+ handler.handle(None, context, self)
643
+ elif set_active_on_success:
644
+ # Cast to EngineState for type safety (handlers return EngineDict which is compatible)
645
+ self._set_active_engine(cast(EngineState, engine))
646
+
647
+ return engine
648
+
649
+ def _handle_engine_creation_errors(self, error: Exception, engine_name: str, preserve_rai_exception: bool = False) -> None:
650
+ """
651
+ Handle errors during engine creation using error handlers.
652
+
653
+ Args:
654
+ error: The exception that occurred
655
+ engine_name: Name of the engine being created
656
+ preserve_rai_exception: If True, re-raise RAIException without wrapping
657
+
658
+ Raises:
659
+ RAIException: If preserve_rai_exception is True and error is RAIException
660
+ EngineProvisioningFailed: If error is not handled by error handlers
661
+ """
662
+ # Preserve RAIException passthrough if requested (for async mode)
663
+ if preserve_rai_exception and isinstance(error, RAIException):
664
+ raise error
665
+
666
+ # Check if this is a known error type that should be handled by error handlers
667
+ message = str(error).lower()
668
+ handled = False
669
+ # Engine creation isn't tied to a specific SQL ExecContext; pass a context that
670
+ # disables use_index retry behavior (and any future ctx-dependent handlers).
671
+ ctx = ExecContext(code="", help=True, skip_engine_db_error_retry=True)
672
+ for handler in self._error_handlers:
673
+ if handler.matches(error, message, ctx, self):
674
+ handler.handle(error, ctx, self)
675
+ handled = True
676
+ break # Handler raised exception, we're done
677
+
678
+ # If not handled by error handlers, wrap in EngineProvisioningFailed
679
+ if not handled:
680
+ raise EngineProvisioningFailed(engine_name, error) from error
681
+
682
+ def validate_engine_size(self, size: str) -> Tuple[bool, List[str]]:
683
+ if size is not None:
684
+ sizes = self.get_engine_sizes()
685
+ if size not in sizes:
686
+ return False, sizes
687
+ return True, []
688
+
611
689
  def get_engine_sizes(self, cloud_provider: str|None=None):
612
690
  sizes = []
613
691
  if cloud_provider is None:
@@ -810,19 +888,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
810
888
  engine_size: str | None = None,
811
889
  program_span_id: str | None = None,
812
890
  headers: Dict | None = None,
813
- ):
814
- return UseIndexPoller(
815
- self,
816
- app_name,
817
- sources,
818
- model,
819
- engine_name,
820
- engine_size,
821
- self.language,
822
- program_span_id,
823
- headers,
824
- self.generation
825
- ).poll()
891
+ ) -> None:
892
+ """
893
+ Poll use_index to prepare indices for the given sources.
894
+
895
+ This is an optional interface method. Base Resources provides a no-op implementation.
896
+ UseIndexResources and DirectAccessResources override this to provide actual polling.
897
+
898
+ Returns:
899
+ None for base implementation. Child classes may return poller results.
900
+ """
901
+ return None
826
902
 
827
903
  def maybe_poll_use_index(
828
904
  self,
@@ -833,36 +909,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
833
909
  engine_size: str | None = None,
834
910
  program_span_id: str | None = None,
835
911
  headers: Dict | None = None,
836
- ):
837
- """Only call poll() if there are sources to process and cache is not valid."""
838
- sources_list = list(sources)
839
- self.database = model
840
- if sources_list:
841
- poller = UseIndexPoller(
842
- self,
843
- app_name,
844
- sources_list,
845
- model,
846
- engine_name,
847
- engine_size,
848
- self.language,
849
- program_span_id,
850
- headers,
851
- self.generation
852
- )
853
- # If cache is valid (data freshness has not expired), skip polling
854
- if poller.cache.is_valid():
855
- cached_sources = len(poller.cache.sources)
856
- total_sources = len(sources_list)
857
- cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
858
-
859
- message = f"Using cached data for {cached_sources}/{total_sources} data streams"
860
- if cached_timestamp:
861
- print(f"\n{message} (cached at {cached_timestamp})\n")
862
- else:
863
- print(f"\n{message}\n")
864
- else:
865
- return poller.poll()
912
+ ) -> None:
913
+ """
914
+ Only call _poll_use_index if there are sources to process.
915
+
916
+ This is an optional interface method. Base Resources provides a no-op implementation.
917
+ UseIndexResources and DirectAccessResources override this to provide actual polling with caching.
918
+
919
+ Returns:
920
+ None for base implementation. Child classes may return poller results.
921
+ """
922
+ return None
866
923
 
867
924
  #--------------------------------------------------
868
925
  # Models
@@ -900,11 +957,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
900
957
  def list_exports(self, database: str, engine: str):
901
958
  return []
902
959
 
903
- def format_sproc_name(self, name: str, type:Any) -> str:
904
- if type is datetime:
905
- return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
906
- else:
907
- return name
908
960
 
909
961
  def get_export_code(self, params: ExportParams, all_installs):
910
962
  sql_inputs = ", ".join([f"{name} {type_to_sql(type)}" for (name, _, type) in params.inputs])
@@ -926,15 +978,14 @@ Otherwise, remove it from your '{profile}' configuration profile.
926
978
  clean_inputs.append(f"{name} = '\"' + escape({name}) + '\"'")
927
979
  # Replace `var` with `name` and keep the following non-word character unchanged
928
980
  pattern = re.compile(re.escape(var) + r'(\W)')
929
- value = self.format_sproc_name(name, type)
981
+ value = format_sproc_name(name, type)
930
982
  safe_rel = re.sub(pattern, rf"{{{value}}}\1", safe_rel)
931
983
  if py_inputs:
932
984
  py_inputs = f", {py_inputs}"
933
985
  clean_inputs = ("\n").join(clean_inputs)
934
- assert __package__ is not None, "Package name must be set"
935
986
  file = "export_procedure.py.jinja"
936
987
  with importlib.resources.open_text(
937
- __package__, file
988
+ "relationalai.clients.resources.snowflake", file
938
989
  ) as f:
939
990
  template = f.read()
940
991
  def quote(s: str, f = False) -> str:
@@ -1092,15 +1143,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
1092
1143
  # Imports
1093
1144
  #--------------------------------------------------
1094
1145
 
1095
- def is_valid_import_state(self, state:str):
1096
- return state in VALID_IMPORT_STATES
1097
-
1098
- def imports_to_dicts(self, results):
1099
- parsed_results = [
1100
- {field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
1101
- for row in results
1102
- ]
1103
- return parsed_results
1104
1146
 
1105
1147
  def change_stream_status(self, stream_id: str, model:str, suspend: bool):
1106
1148
  if stream_id and model:
@@ -1272,7 +1314,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1272
1314
  results = self._exec(f"CALL {APP_NAME}.api.get_data_stream('{name}', '{model}');")
1273
1315
  if not results:
1274
1316
  return None
1275
- return self.imports_to_dicts(results)
1317
+ return imports_to_dicts(results)
1276
1318
 
1277
1319
  def create_import_stream(self, source:ImportSource, model:str, rate = 1, options: dict|None = None):
1278
1320
  assert isinstance(source, ImportSourceTable), "Snowflake integration only supports loading from SF Tables. Try loading your data as a table via the Snowflake interface first."
@@ -1307,7 +1349,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
1307
1349
  try:
1308
1350
  self._exec(command)
1309
1351
  except Exception as e:
1310
- if "ensure that CHANGE_TRACKING is enabled on the source object" in str(e):
1352
+ messages = collect_error_messages(e)
1353
+ if any("ensure that change_tracking is enabled on the source object" in msg for msg in messages):
1311
1354
  if self.config.get("ensure_change_tracking", False) and not tracking_just_changed:
1312
1355
  try:
1313
1356
  self._exec(f"ALTER {kind} {object} SET CHANGE_TRACKING = TRUE;")
@@ -1318,7 +1361,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1318
1361
  print("\n")
1319
1362
  exception = SnowflakeChangeTrackingNotEnabledException((object, kind))
1320
1363
  raise exception from None
1321
- elif "Database does not exist" in str(e):
1364
+ elif any("database does not exist" in msg for msg in messages):
1322
1365
  print("\n")
1323
1366
  raise ModelNotFoundException(model) from None
1324
1367
  raise e
@@ -1381,42 +1424,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
1381
1424
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
1382
1425
  return status == "COMPLETED" or status == "ABORTED"
1383
1426
 
1384
- def decrypt_stream(self, key: bytes, iv: bytes, src: bytes) -> bytes:
1385
- """Decrypt the provided stream with PKCS#5 padding handling."""
1386
-
1387
- if crypto_disabled:
1388
- if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
1389
- raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
1390
- else:
1391
- raise Exception("library `cryptography.hazmat` missing; please install")
1392
-
1393
- # `type:ignore`s are because of the conditional import, which
1394
- # we have because warehouse-based snowflake notebooks don't support
1395
- # the crypto library we're using.
1396
- cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
1397
- decryptor = cipher.decryptor()
1398
-
1399
- # Decrypt the data
1400
- decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
1401
-
1402
- # Unpad the decrypted data using PKCS#5
1403
- unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
1404
- unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
1405
-
1406
- return unpadded_data
1407
-
1408
- def _decrypt_artifact(self, data: bytes, encryption_material: str) -> bytes:
1409
- """Decrypts the artifact data using provided encryption material."""
1410
- encryption_material_parts = encryption_material.split("|")
1411
- assert len(encryption_material_parts) == 3, "Invalid encryption material"
1412
-
1413
- algorithm, key_base64, iv_base64 = encryption_material_parts
1414
- assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
1415
-
1416
- key = base64.standard_b64decode(key_base64)
1417
- iv = base64.standard_b64decode(iv_base64)
1418
-
1419
- return self.decrypt_stream(key, iv, data)
1420
1427
 
1421
1428
  def _list_exec_async_artifacts(self, txn_id: str, headers: Dict | None = None) -> Dict[str, Dict]:
1422
1429
  """Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
@@ -1468,7 +1475,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1468
1475
  if len(content) == 0:
1469
1476
  return b""
1470
1477
 
1471
- return self._decrypt_artifact(content, encryption_material)
1478
+ return decrypt_artifact(content, encryption_material)
1472
1479
 
1473
1480
  # otherwise, return content directly
1474
1481
  return content
@@ -1548,7 +1555,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1548
1555
  def get_url_key(self, metadata) -> str:
1549
1556
  # In Azure, there is only one type of URL, which is used for both internal and
1550
1557
  # external access; always use that one
1551
- if self.is_azure(metadata['PRESIGNED_URL']):
1558
+ if is_azure_url(metadata['PRESIGNED_URL']):
1552
1559
  return 'PRESIGNED_URL'
1553
1560
 
1554
1561
  configured = self.config.get("download_url_type", None)
@@ -1557,17 +1564,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
1557
1564
  elif configured == "external":
1558
1565
  return "PRESIGNED_URL"
1559
1566
 
1560
- if self.is_container_runtime():
1567
+ if is_container_runtime():
1561
1568
  return 'PRESIGNED_URL_AP'
1562
1569
 
1563
1570
  return 'PRESIGNED_URL'
1564
1571
 
1565
- def is_azure(self, url) -> bool:
1566
- return "blob.core.windows.net" in url
1567
-
1568
- def is_container_runtime(self) -> bool:
1569
- return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
1570
-
1571
1572
  def _exec_rai_app(
1572
1573
  self,
1573
1574
  database: str,
@@ -1581,6 +1582,40 @@ Otherwise, remove it from your '{profile}' configuration profile.
1581
1582
  language: str = "rel",
1582
1583
  query_timeout_mins: int | None = None,
1583
1584
  ):
1585
+ """
1586
+ High-level method to execute RAI app stored procedures.
1587
+
1588
+ Builds and executes SQL to call the RAI app's exec_async_v2 stored procedure.
1589
+ This method handles the SQL string construction for two different formats:
1590
+ 1. New format (with graph index): Uses object payload with parameterized query
1591
+ 2. Legacy format: Uses positional parameters
1592
+
1593
+ The choice between formats depends on the use_graph_index configuration.
1594
+ The new format allows the stored procedure to hash the model and username
1595
+ to determine the database, while the legacy format uses the passed database directly.
1596
+
1597
+ This method is called by _exec_async_v2 to create transactions. It skips
1598
+ use_index retry logic (skip_engine_db_error_retry=True) because that
1599
+ is handled at a higher level by exec_raw/exec_lqp.
1600
+
1601
+ Args:
1602
+ database: Database/model name
1603
+ engine: Engine name (optional)
1604
+ raw_code: Code to execute (REL, LQP, or SQL)
1605
+ inputs: Input parameters for the query
1606
+ readonly: Whether the transaction is read-only
1607
+ nowait_durable: Whether to wait for durable writes
1608
+ request_headers: Optional HTTP headers
1609
+ bypass_index: Whether to bypass graph index setup
1610
+ language: Query language ("rel" or "lqp")
1611
+ query_timeout_mins: Optional query timeout in minutes
1612
+
1613
+ Returns:
1614
+ Response from the stored procedure call (transaction creation result)
1615
+
1616
+ Raises:
1617
+ Exception: If transaction creation fails
1618
+ """
1584
1619
  assert language == "rel" or language == "lqp", "Only 'rel' and 'lqp' languages are supported"
1585
1620
  if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
1586
1621
  query_timeout_mins = int(timeout_value)
@@ -1609,9 +1644,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
1609
1644
  sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {query_timeout_mins}, {request_headers});"
1610
1645
  else:
1611
1646
  sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {request_headers});"
1647
+ # Don't let exec setup GI on failure, exec_raw and exec_lqp will do that and add the correct headers.
1612
1648
  response = self._exec(
1613
1649
  sql_string,
1614
1650
  raw_code,
1651
+ skip_engine_db_error_retry=True,
1615
1652
  )
1616
1653
  if not response:
1617
1654
  raise Exception("Failed to create transaction")
@@ -1629,7 +1666,38 @@ Otherwise, remove it from your '{profile}' configuration profile.
1629
1666
  bypass_index=False,
1630
1667
  language: str = "rel",
1631
1668
  query_timeout_mins: int | None = None,
1669
+ gi_setup_skipped: bool = False,
1632
1670
  ):
1671
+ """
1672
+ High-level async execution method with transaction polling and artifact management.
1673
+
1674
+ This is the core method for executing queries asynchronously. It:
1675
+ 1. Creates a transaction by calling _exec_rai_app
1676
+ 2. Handles two execution paths:
1677
+ - Fast path: Transaction completes immediately (COMPLETED/ABORTED)
1678
+ - Slow path: Transaction is pending, requires polling until completion
1679
+ 3. Manages pending transactions list
1680
+ 4. Downloads and returns query results/artifacts
1681
+
1682
+ This method is called by _execute_code (base implementation) and can be
1683
+ overridden by child classes (e.g., DirectAccessResources uses HTTP instead).
1684
+
1685
+ Args:
1686
+ database: Database/model name
1687
+ engine: Engine name (optional)
1688
+ raw_code: Code to execute (REL, LQP, or SQL)
1689
+ inputs: Input parameters for the query
1690
+ readonly: Whether the transaction is read-only
1691
+ nowait_durable: Whether to wait for durable writes
1692
+ headers: Optional HTTP headers
1693
+ bypass_index: Whether to bypass graph index setup
1694
+ language: Query language ("rel" or "lqp")
1695
+ query_timeout_mins: Optional query timeout in minutes
1696
+ gi_setup_skipped: Whether graph index setup was skipped (for retry logic)
1697
+
1698
+ Returns:
1699
+ Query results (downloaded artifacts)
1700
+ """
1633
1701
  if inputs is None:
1634
1702
  inputs = {}
1635
1703
  request_headers = debugging.add_current_propagation_headers(headers)
@@ -1638,6 +1706,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
1638
1706
  with debugging.span("transaction", **query_attrs_dict) as txn_span:
1639
1707
  with debugging.span("create_v2", **query_attrs_dict) as create_span:
1640
1708
  request_headers['user-agent'] = get_pyrel_version(self.generation)
1709
+ request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
1710
+ request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
1641
1711
  response = self._exec_rai_app(
1642
1712
  database=database,
1643
1713
  engine=engine,
@@ -1699,204 +1769,127 @@ Otherwise, remove it from your '{profile}' configuration profile.
1699
1769
  return engine and engine["state"] == "READY"
1700
1770
 
1701
1771
  def auto_create_engine(self, name: str | None = None, size: str | None = None, headers: Dict | None = None):
1702
- from relationalai.tools.cli_helpers import validate_engine_name
1772
+ """Synchronously create/ensure an engine is ready, blocking until ready."""
1703
1773
  with debugging.span("auto_create_engine", active=self._active_engine) as span:
1704
1774
  active = self._get_active_engine()
1705
1775
  if active:
1706
1776
  return active
1707
1777
 
1708
- engine_name = name or self.get_default_engine_name()
1709
-
1710
- # Use the provided size or fall back to the config
1711
- if size:
1712
- engine_size = size
1713
- else:
1714
- engine_size = self.config.get("engine_size", None)
1715
-
1716
- # Validate engine size
1717
- if engine_size:
1718
- is_size_valid, sizes = self.validate_engine_size(engine_size)
1719
- if not is_size_valid:
1720
- raise Exception(f"Invalid engine size '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
1721
-
1722
- # Validate engine name
1723
- is_name_valid, _ = validate_engine_name(engine_name)
1724
- if not is_name_valid:
1725
- raise EngineNameValidationException(engine_name)
1778
+ # Resolve and validate parameters
1779
+ engine_name, engine_size = self._prepare_engine_params(name, size)
1726
1780
 
1727
1781
  try:
1782
+ # Get current engine state
1728
1783
  engine = self.get_engine(engine_name)
1729
1784
  if engine:
1730
1785
  span.update(cast(dict, engine))
1731
1786
 
1732
- # if engine is in the pending state, poll until its status changes
1733
- # if engine is gone, delete it and create new one
1734
- # if engine is in the ready state, return engine name
1735
- if engine:
1736
- if engine["state"] == "PENDING":
1737
- # if the user explicitly specified a size, warn if the pending engine size doesn't match it
1738
- if size is not None and engine["size"] != size:
1739
- EngineSizeMismatchWarning(engine_name, engine["size"], size)
1740
- # poll until engine is ready
1741
- with Spinner(
1742
- "Waiting for engine to be initialized",
1743
- "Engine ready",
1744
- ):
1745
- poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
1746
-
1747
- elif engine["state"] == "SUSPENDED":
1748
- with Spinner(f"Resuming engine '{engine_name}'", f"Engine '{engine_name}' resumed", f"Failed to resume engine '{engine_name}'"):
1749
- try:
1750
- self.resume_engine_async(engine_name, headers=headers)
1751
- poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
1752
- except Exception:
1753
- raise EngineResumeFailed(engine_name)
1754
- elif engine["state"] == "READY":
1755
- # if the user explicitly specified a size, warn if the ready engine size doesn't match it
1756
- if size is not None and engine["size"] != size:
1757
- EngineSizeMismatchWarning(engine_name, engine["size"], size)
1758
- self._set_active_engine(engine)
1759
- return engine_name
1760
- elif engine["state"] == "GONE":
1761
- try:
1762
- # "Gone" is abnormal condition when metadata and SF service don't match
1763
- # Therefore, we have to delete the engine and create a new one
1764
- # it could be case that engine is already deleted, so we have to catch the exception
1765
- self.delete_engine(engine_name, headers=headers)
1766
- # After deleting the engine, set it to None so that we can create a new engine
1767
- engine = None
1768
- except Exception as e:
1769
- # if engine is already deleted, we will get an exception
1770
- # we can ignore this exception and create a new engine
1771
- if isinstance(e, EngineNotFoundException):
1772
- engine = None
1773
- pass
1774
- else:
1775
- raise EngineProvisioningFailed(engine_name, e) from e
1776
-
1777
- if not engine:
1778
- with Spinner(
1779
- f"Auto-creating engine {engine_name}",
1780
- f"Auto-created engine {engine_name}",
1781
- "Engine creation failed",
1782
- ):
1783
- self.create_engine(engine_name, size=engine_size, headers=headers)
1787
+ # Create context for state handling
1788
+ context = EngineContext(
1789
+ engine_name=engine_name,
1790
+ engine_size=engine_size,
1791
+ headers=headers,
1792
+ requested_size=size,
1793
+ span=span,
1794
+ )
1795
+
1796
+ # Process engine state using sync handlers
1797
+ self._process_engine_state(engine, context, self._sync_engine_state_handlers)
1798
+
1784
1799
  except Exception as e:
1785
- print(e)
1786
- if DUO_TEXT in str(e).lower():
1787
- raise DuoSecurityFailed(e)
1788
- raise EngineProvisioningFailed(engine_name, e) from e
1800
+ self._handle_engine_creation_errors(e, engine_name)
1801
+
1789
1802
  return engine_name
1790
1803
 
1791
1804
  def auto_create_engine_async(self, name: str | None = None):
1805
+ """Asynchronously create/ensure an engine, returns immediately."""
1792
1806
  active = self._get_active_engine()
1793
1807
  if active and (active == name or name is None):
1794
- return # @NOTE: This method weirdly doesn't return engine name even though all the other ones do?
1808
+ return active
1795
1809
 
1796
1810
  with Spinner(
1797
1811
  "Checking engine status",
1798
1812
  leading_newline=True,
1799
1813
  ) as spinner:
1800
- from relationalai.tools.cli_helpers import validate_engine_name
1801
1814
  with debugging.span("auto_create_engine_async", active=self._active_engine):
1802
- engine_name = name or self.get_default_engine_name()
1803
- engine_size = self.config.get("engine_size", None)
1804
- if engine_size:
1805
- is_size_valid, sizes = self.validate_engine_size(engine_size)
1806
- if not is_size_valid:
1807
- raise Exception(f"Invalid engine size in config: '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
1808
- else:
1809
- engine_size = self.config.get_default_engine_size()
1815
+ # Resolve and validate parameters (use_default_size=True for async)
1816
+ engine_name, engine_size = self._prepare_engine_params(name, None, use_default_size=True)
1810
1817
 
1811
- is_name_valid, _ = validate_engine_name(engine_name)
1812
- if not is_name_valid:
1813
- raise EngineNameValidationException(engine_name)
1814
1818
  try:
1819
+ # Get current engine state
1815
1820
  engine = self.get_engine(engine_name)
1816
- # if engine is gone, delete it and create new one
1817
- # in case of pending state, do nothing, it is use_index responsibility to wait for engine to be ready
1818
- if engine:
1819
- if engine["state"] == "PENDING":
1820
- spinner.update_messages(
1821
- {
1822
- "finished_message": f"Starting engine {engine_name}",
1823
- }
1824
- )
1825
- pass
1826
- elif engine["state"] == "SUSPENDED":
1827
- spinner.update_messages(
1828
- {
1829
- "finished_message": f"Resuming engine {engine_name}",
1830
- }
1831
- )
1832
- try:
1833
- self.resume_engine_async(engine_name)
1834
- except Exception:
1835
- raise EngineResumeFailed(engine_name)
1836
- elif engine["state"] == "READY":
1837
- spinner.update_messages(
1838
- {
1839
- "finished_message": f"Engine {engine_name} initialized",
1840
- }
1841
- )
1842
- pass
1843
- elif engine["state"] == "GONE":
1844
- spinner.update_messages(
1845
- {
1846
- "message": f"Restarting engine {engine_name}",
1847
- }
1848
- )
1849
- try:
1850
- # "Gone" is abnormal condition when metadata and SF service don't match
1851
- # Therefore, we have to delete the engine and create a new one
1852
- # it could be case that engine is already deleted, so we have to catch the exception
1853
- # set it to None so that we can create a new engine
1854
- engine = None
1855
- self.delete_engine(engine_name)
1856
- except Exception as e:
1857
- # if engine is already deleted, we will get an exception
1858
- # we can ignore this exception and create a new engine asynchronously
1859
- if isinstance(e, EngineNotFoundException):
1860
- engine = None
1861
- pass
1862
- else:
1863
- print(e)
1864
- raise EngineProvisioningFailed(engine_name, e) from e
1865
-
1866
- if not engine:
1867
- self.create_engine_async(engine_name, size=self.config.get("engine_size", None))
1868
- spinner.update_messages(
1869
- {
1870
- "finished_message": f"Starting engine {engine_name}...",
1871
- }
1872
- )
1873
- else:
1874
- self._set_active_engine(engine)
1875
1821
 
1876
- except Exception as e:
1877
- spinner.update_messages(
1878
- {
1879
- "finished_message": f"Failed to create engine {engine_name}",
1880
- }
1822
+ # Create context for state handling
1823
+ context = EngineContext(
1824
+ engine_name=engine_name,
1825
+ engine_size=engine_size,
1826
+ headers=None,
1827
+ requested_size=None,
1828
+ spinner=spinner,
1881
1829
  )
1882
- if DUO_TEXT in str(e).lower():
1883
- raise DuoSecurityFailed(e)
1884
- if isinstance(e, RAIException):
1885
- raise e
1886
- print(e)
1887
- raise EngineProvisioningFailed(engine_name, e) from e
1888
1830
 
1889
- def validate_engine_size(self, size: str) -> Tuple[bool, List[str]]:
1890
- if size is not None:
1891
- sizes = self.get_engine_sizes()
1892
- if size not in sizes:
1893
- return False, sizes
1894
- return True, []
1831
+ # Process engine state using async handlers
1832
+ self._process_engine_state(engine, context, self._async_engine_state_handlers, set_active_on_success=True)
1833
+
1834
+ except Exception as e:
1835
+ spinner.update_messages({
1836
+ "finished_message": f"Failed to create engine {engine_name}",
1837
+ })
1838
+ self._handle_engine_creation_errors(e, engine_name, preserve_rai_exception=True)
1839
+
1840
+ return engine_name
1895
1841
 
1896
1842
  #--------------------------------------------------
1897
1843
  # Exec
1898
1844
  #--------------------------------------------------
1899
1845
 
1846
+ def _execute_code(
1847
+ self,
1848
+ database: str,
1849
+ engine: str | None,
1850
+ raw_code: str,
1851
+ inputs: Dict | None,
1852
+ readonly: bool,
1853
+ nowait_durable: bool,
1854
+ headers: Dict | None,
1855
+ bypass_index: bool,
1856
+ language: str,
1857
+ query_timeout_mins: int | None,
1858
+ ) -> Any:
1859
+ """
1860
+ Template method for code execution - can be overridden by child classes.
1861
+
1862
+ This is a template method that provides a hook for child classes to add
1863
+ execution logic (like retry mechanisms). The base implementation simply
1864
+ calls _exec_async_v2 directly.
1865
+
1866
+ UseIndexResources overrides this method to use _exec_with_gi_retry, which
1867
+ adds automatic use_index polling on engine/database errors.
1868
+
1869
+ This method is called by exec_lqp() and exec_raw() to provide a single
1870
+ execution point that can be customized per resource class.
1871
+
1872
+ Args:
1873
+ database: Database/model name
1874
+ engine: Engine name (optional)
1875
+ raw_code: Code to execute (already processed/encoded)
1876
+ inputs: Input parameters for the query
1877
+ readonly: Whether the transaction is read-only
1878
+ nowait_durable: Whether to wait for durable writes
1879
+ headers: Optional HTTP headers
1880
+ bypass_index: Whether to bypass graph index setup
1881
+ language: Query language ("rel" or "lqp")
1882
+ query_timeout_mins: Optional query timeout in minutes
1883
+
1884
+ Returns:
1885
+ Query results
1886
+ """
1887
+ return self._exec_async_v2(
1888
+ database, engine, raw_code, inputs, readonly, nowait_durable,
1889
+ headers=headers, bypass_index=bypass_index, language=language,
1890
+ query_timeout_mins=query_timeout_mins, gi_setup_skipped=True,
1891
+ )
1892
+
1900
1893
  def exec_lqp(
1901
1894
  self,
1902
1895
  database: str,
@@ -1910,36 +1903,12 @@ Otherwise, remove it from your '{profile}' configuration profile.
1910
1903
  bypass_index=False,
1911
1904
  query_timeout_mins: int | None = None,
1912
1905
  ):
1906
+ """Execute LQP code."""
1913
1907
  raw_code_b64 = base64.b64encode(raw_code).decode("utf-8")
1914
-
1915
- try:
1916
- return self._exec_async_v2(
1917
- database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1918
- headers=headers, bypass_index=bypass_index, language='lqp',
1919
- query_timeout_mins=query_timeout_mins,
1920
- )
1921
- except Exception as e:
1922
- err_message = str(e).lower()
1923
- if _is_engine_issue(err_message) or _is_database_issue(err_message):
1924
- engine_name = engine or self.get_default_engine_name()
1925
- engine_size = self.config.get_default_engine_size()
1926
- self._poll_use_index(
1927
- app_name=self.get_app_name(),
1928
- sources=self.sources,
1929
- model=database,
1930
- engine_name=engine_name,
1931
- engine_size=engine_size,
1932
- headers=headers,
1933
- )
1934
-
1935
- return self._exec_async_v2(
1936
- database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1937
- headers=headers, bypass_index=bypass_index, language='lqp',
1938
- query_timeout_mins=query_timeout_mins,
1939
- )
1940
- else:
1941
- raise e
1942
-
1908
+ return self._execute_code(
1909
+ database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1910
+ headers, bypass_index, 'lqp', query_timeout_mins
1911
+ )
1943
1912
 
1944
1913
  def exec_raw(
1945
1914
  self,
@@ -1954,46 +1923,12 @@ Otherwise, remove it from your '{profile}' configuration profile.
1954
1923
  bypass_index=False,
1955
1924
  query_timeout_mins: int | None = None,
1956
1925
  ):
1926
+ """Execute raw code."""
1957
1927
  raw_code = raw_code.replace("'", "\\'")
1958
-
1959
- try:
1960
- return self._exec_async_v2(
1961
- database,
1962
- engine,
1963
- raw_code,
1964
- inputs,
1965
- readonly,
1966
- nowait_durable,
1967
- headers=headers,
1968
- bypass_index=bypass_index,
1969
- query_timeout_mins=query_timeout_mins,
1970
- )
1971
- except Exception as e:
1972
- err_message = str(e).lower()
1973
- if _is_engine_issue(err_message) or _is_database_issue(err_message):
1974
- engine_name = engine or self.get_default_engine_name()
1975
- engine_size = self.config.get_default_engine_size()
1976
- self._poll_use_index(
1977
- app_name=self.get_app_name(),
1978
- sources=self.sources,
1979
- model=database,
1980
- engine_name=engine_name,
1981
- engine_size=engine_size,
1982
- headers=headers,
1983
- )
1984
- return self._exec_async_v2(
1985
- database,
1986
- engine,
1987
- raw_code,
1988
- inputs,
1989
- readonly,
1990
- nowait_durable,
1991
- headers=headers,
1992
- bypass_index=bypass_index,
1993
- query_timeout_mins=query_timeout_mins,
1994
- )
1995
- else:
1996
- raise e
1928
+ return self._execute_code(
1929
+ database, engine, raw_code, inputs, readonly, nowait_durable,
1930
+ headers, bypass_index, 'rel', query_timeout_mins
1931
+ )
1997
1932
 
1998
1933
 
1999
1934
  def format_results(self, results, task:m.Task|None=None) -> Tuple[DataFrame, List[Any]]:
@@ -2045,7 +1980,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2045
1980
  if query_timeout_mins is not None:
2046
1981
  res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
2047
1982
  else:
2048
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1983
+ res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
2049
1984
  txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
2050
1985
  rejected_rows = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows", [])
2051
1986
  rejected_rows_count = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows_count", 0)
@@ -2082,8 +2017,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
2082
2017
  if rejected_rows:
2083
2018
  debugging.warn(RowsDroppedFromTargetTableWarning(rejected_rows, rejected_rows_count, col_names_map))
2084
2019
  except Exception as e:
2085
- msg = str(e).lower()
2086
- if "no columns returned" in msg or "columns of results could not be determined" in msg:
2020
+ messages = collect_error_messages(e)
2021
+ if any("no columns returned" in msg or "columns of results could not be determined" in msg for msg in messages):
2087
2022
  pass
2088
2023
  else:
2089
2024
  raise e
@@ -2114,6 +2049,10 @@ Otherwise, remove it from your '{profile}' configuration profile.
2114
2049
  self.sources.add(parser.identity)
2115
2050
  return ns
2116
2051
 
2052
+ #--------------------------------------------------
2053
+ # Source Management
2054
+ #--------------------------------------------------
2055
+
2117
2056
  def _check_source_updates(self, sources: Iterable[str]):
2118
2057
  if not sources:
2119
2058
  return {}
@@ -2376,19 +2315,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
2376
2315
  #--------------------------------------------------
2377
2316
  # Transactions
2378
2317
  #--------------------------------------------------
2379
- def txn_list_to_dicts(self, transactions):
2380
- dicts = []
2381
- for txn in transactions:
2382
- dict = {}
2383
- txn_dict = txn.asDict()
2384
- for key in txn_dict:
2385
- mapValue = FIELD_MAP.get(key.lower())
2386
- if mapValue:
2387
- dict[mapValue] = txn_dict[key]
2388
- else:
2389
- dict[key.lower()] = txn_dict[key]
2390
- dicts.append(dict)
2391
- return dicts
2392
2318
 
2393
2319
  def get_transaction(self, transaction_id):
2394
2320
  results = self._exec(
@@ -2396,7 +2322,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2396
2322
  if not results:
2397
2323
  return None
2398
2324
 
2399
- results = self.txn_list_to_dicts(results)
2325
+ results = txn_list_to_dicts(results)
2400
2326
 
2401
2327
  txn = {field: results[0][field] for field in GET_TXN_SQL_FIELDS}
2402
2328
 
@@ -2448,7 +2374,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2448
2374
  results = self._exec(query, [limit])
2449
2375
  if not results:
2450
2376
  return []
2451
- return self.txn_list_to_dicts(results)
2377
+ return txn_list_to_dicts(results)
2452
2378
 
2453
2379
  def cancel_transaction(self, transaction_id):
2454
2380
  self._exec(f"CALL {APP_NAME}.api.cancel_own_transaction(?);", [transaction_id])
@@ -2482,66 +2408,15 @@ Otherwise, remove it from your '{profile}' configuration profile.
2482
2408
  return None
2483
2409
  return results[0][0]
2484
2410
 
2485
- def list_warehouses(self):
2486
- results = self._exec("SHOW WAREHOUSES")
2487
- if not results:
2488
- return []
2489
- return [{"name":name}
2490
- for (name, *rest) in results]
2491
-
2492
- def list_compute_pools(self):
2493
- results = self._exec("SHOW COMPUTE POOLS")
2494
- if not results:
2495
- return []
2496
- return [{"name":name, "status":status, "min_nodes":min_nodes, "max_nodes":max_nodes, "instance_family":instance_family}
2497
- for (name, status, min_nodes, max_nodes, instance_family, *rest) in results]
2498
-
2499
- def list_roles(self):
2500
- results = self._exec("SELECT CURRENT_AVAILABLE_ROLES()")
2501
- if not results:
2502
- return []
2503
- # the response is a single row with a single column containing
2504
- # a stringified JSON array of role names:
2505
- row = results[0]
2506
- if not row:
2507
- return []
2508
- return [{"name": name} for name in json.loads(row[0])]
2411
+ # CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
2412
+ # list_databases, list_sf_schemas, list_tables) are now in CLIResources class
2413
+ # schema_info is kept in base Resources class since it's used by SnowflakeSchema._fetch_info()
2509
2414
 
2510
- def list_apps(self):
2511
- all_apps = self._exec(f"SHOW APPLICATIONS LIKE '{RAI_APP_NAME}'")
2512
- if not all_apps:
2513
- all_apps = self._exec("SHOW APPLICATIONS")
2514
- if not all_apps:
2515
- return []
2516
- return [{"name":name}
2517
- for (time, name, *rest) in all_apps]
2518
-
2519
- def list_databases(self):
2520
- results = self._exec("SHOW DATABASES")
2521
- if not results:
2522
- return []
2523
- return [{"name":name}
2524
- for (time, name, *rest) in results]
2525
-
2526
- def list_sf_schemas(self, database:str):
2527
- results = self._exec(f"SHOW SCHEMAS IN {database}")
2528
- if not results:
2529
- return []
2530
- return [{"name":name}
2531
- for (time, name, *rest) in results]
2532
-
2533
- def list_tables(self, database:str, schema:str):
2534
- results = self._exec(f"SHOW OBJECTS IN {database}.{schema}")
2535
- items = []
2536
- if results:
2537
- for (time, name, db_name, schema_name, kind, *rest) in results:
2538
- items.append({"name":name, "kind":kind.lower()})
2539
- return items
2540
-
2541
- def schema_info(self, database:str, schema:str, tables:Iterable[str]):
2542
- app_name = self.get_app_name()
2543
- # Only pass the db + schema as the identifier so that the resulting identity is correct
2544
- parser = IdentityParser(f"{database}.{schema}")
2415
+ def schema_info(self, database: str, schema: str, tables: Iterable[str]):
2416
+ """Get detailed schema information including primary keys, foreign keys, and columns."""
2417
+ app_name = self.get_app_name()
2418
+ # Only pass the db + schema as the identifier so that the resulting identity is correct
2419
+ parser = IdentityParser(f"{database}.{schema}")
2545
2420
 
2546
2421
  with debugging.span("schema_info"):
2547
2422
  with debugging.span("primary_keys") as span:
@@ -2556,7 +2431,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2556
2431
 
2557
2432
  # IdentityParser will parse a single value (with no ".") and store it in this case in the db field
2558
2433
  with debugging.span("columns") as span:
2559
- tables = ", ".join([f"'{IdentityParser(t).db}'" for t in tables])
2434
+ tables_str = ", ".join([f"'{IdentityParser(t).db}'" for t in tables])
2560
2435
  query = textwrap.dedent(f"""
2561
2436
  begin
2562
2437
  SHOW COLUMNS IN SCHEMA {parser.identity};
@@ -2573,7 +2448,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2573
2448
  ELSE FALSE
2574
2449
  END as "supported_type"
2575
2450
  FROM table(result_scan(-1)) as t
2576
- WHERE "table_name" in ({tables})
2451
+ WHERE "table_name" in ({tables_str})
2577
2452
  );
2578
2453
  return table(r);
2579
2454
  end;
@@ -2584,7 +2459,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
2584
2459
  results = defaultdict(lambda: {"pks": [], "fks": {}, "columns": {}, "invalid_columns": {}})
2585
2460
  if pks:
2586
2461
  for row in pks:
2587
- results[row[3]]["pks"].append(row[4]) # type: ignore
2462
+ results[row[3]]["pks"].append(row[4]) # type: ignore
2588
2463
  if fks:
2589
2464
  for row in fks:
2590
2465
  results[row[7]]["fks"][row[8]] = row[3]
@@ -2726,7 +2601,7 @@ class SnowflakeSchema:
2726
2601
  tables_with_invalid_columns[table_name] = table_info["invalid_columns"]
2727
2602
 
2728
2603
  if tables_with_invalid_columns:
2729
- from ..errors import UnsupportedColumnTypesWarning
2604
+ from relationalai.errors import UnsupportedColumnTypesWarning
2730
2605
  UnsupportedColumnTypesWarning(tables_with_invalid_columns)
2731
2606
 
2732
2607
  def _add(self, name, is_imported=False):
@@ -2935,10 +2810,14 @@ class Provider(ProviderBase):
2935
2810
  if resources:
2936
2811
  self.resources = resources
2937
2812
  else:
2938
- resource_class = Resources
2939
- if config and config.get("use_direct_access", USE_DIRECT_ACCESS):
2940
- resource_class = DirectAccessResources
2941
- self.resources = resource_class(profile=profile, config=config, generation=generation)
2813
+ from .resources_factory import create_resources_instance
2814
+ self.resources = create_resources_instance(
2815
+ config=config,
2816
+ profile=profile,
2817
+ generation=generation or Generation.V0,
2818
+ dry_run=False,
2819
+ language="rel",
2820
+ )
2942
2821
 
2943
2822
  def list_streams(self, model:str):
2944
2823
  return self.resources.list_imports(model=model)
@@ -3107,19 +2986,31 @@ def Graph(
3107
2986
  nowait_durable: bool = True,
3108
2987
  format: str = "default",
3109
2988
  ):
2989
+ from .resources_factory import create_resources_instance
2990
+ from .use_index_resources import UseIndexResources
3110
2991
 
3111
- client_class = Client
3112
- resource_class = Resources
3113
2992
  use_graph_index = config.get("use_graph_index", USE_GRAPH_INDEX)
3114
2993
  use_monotype_operators = config.get("compiler.use_monotype_operators", False)
3115
- use_direct_access = config.get("use_direct_access", USE_DIRECT_ACCESS)
3116
2994
 
3117
- if use_graph_index:
2995
+ # Create resources instance using factory
2996
+ resources = create_resources_instance(
2997
+ config=config,
2998
+ profile=profile,
2999
+ connection=connection,
3000
+ generation=Generation.V0,
3001
+ dry_run=False, # Resources instance dry_run is separate from client dry_run
3002
+ language="rel",
3003
+ )
3004
+
3005
+ # Determine client class based on resources type and config
3006
+ # SnowflakeClient is used for resources that support use_index functionality
3007
+ if use_graph_index or isinstance(resources, UseIndexResources):
3118
3008
  client_class = SnowflakeClient
3119
- if use_direct_access:
3120
- resource_class = DirectAccessResources
3009
+ else:
3010
+ client_class = Client
3011
+
3121
3012
  client = client_class(
3122
- resource_class(generation=Generation.V0, profile=profile, config=config, connection=connection),
3013
+ resources,
3123
3014
  rel.Compiler(config),
3124
3015
  name,
3125
3016
  config,
@@ -3190,639 +3081,3 @@ def Graph(
3190
3081
  debugging.set_source(pyrel_base)
3191
3082
  client.install("pyrel_base", pyrel_base)
3192
3083
  return dsl.Graph(client, name, format=format)
3193
-
3194
-
3195
-
3196
- #--------------------------------------------------
3197
- # Direct Access
3198
- #--------------------------------------------------
3199
- # Note: All direct access components should live in a separate file
3200
-
3201
- class DirectAccessResources(Resources):
3202
- """
3203
- Resources class for Direct Service Access avoiding Snowflake service functions.
3204
- """
3205
- def __init__(
3206
- self,
3207
- profile: Union[str, None] = None,
3208
- config: Union[Config, None] = None,
3209
- connection: Union[Session, None] = None,
3210
- dry_run: bool = False,
3211
- reset_session: bool = False,
3212
- generation: Optional[Generation] = None,
3213
- language: str = "rel",
3214
- ):
3215
- super().__init__(
3216
- generation=generation,
3217
- profile=profile,
3218
- config=config,
3219
- connection=connection,
3220
- reset_session=reset_session,
3221
- dry_run=dry_run,
3222
- language=language,
3223
- )
3224
- self._endpoint_info = ConfigStore(ENDPOINT_FILE)
3225
- self._service_endpoint = ""
3226
- self._direct_access_client = None
3227
- self.generation = generation
3228
- self.database = ""
3229
-
3230
- @property
3231
- def service_endpoint(self) -> str:
3232
- return self._retrieve_service_endpoint()
3233
-
3234
- def _retrieve_service_endpoint(self, enforce_update=False) -> str:
3235
- account = self.config.get("account")
3236
- app_name = self.config.get("rai_app_name")
3237
- service_endpoint_key = f"{account}.{app_name}.service_endpoint"
3238
- if self._service_endpoint and not enforce_update:
3239
- return self._service_endpoint
3240
- if self._endpoint_info.get(service_endpoint_key, "") and not enforce_update:
3241
- self._service_endpoint = str(self._endpoint_info.get(service_endpoint_key, ""))
3242
- return self._service_endpoint
3243
-
3244
- is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
3245
- query = f"CALL {self.get_app_name()}.app.service_endpoint({not is_snowflake_notebook});"
3246
- result = self._exec(query)
3247
- assert result, f"Could not retrieve service endpoint for {self.get_app_name()}"
3248
- if is_snowflake_notebook:
3249
- self._service_endpoint = f"http://{result[0]['SERVICE_ENDPOINT']}"
3250
- else:
3251
- self._service_endpoint = f"https://{result[0]['SERVICE_ENDPOINT']}"
3252
-
3253
- self._endpoint_info.set(service_endpoint_key, self._service_endpoint)
3254
- # save the endpoint to `ENDPOINT_FILE` to avoid calling the endpoint with every
3255
- # pyrel execution
3256
- try:
3257
- self._endpoint_info.save()
3258
- except Exception:
3259
- print("Failed to persist endpoints to file. This might slow down future executions.")
3260
-
3261
- return self._service_endpoint
3262
-
3263
- @property
3264
- def direct_access_client(self) -> DirectAccessClient:
3265
- if self._direct_access_client:
3266
- return self._direct_access_client
3267
- try:
3268
- service_endpoint = self.service_endpoint
3269
- self._direct_access_client = DirectAccessClient(
3270
- self.config, self.token_handler, service_endpoint, self.generation,
3271
- )
3272
- except Exception as e:
3273
- raise e
3274
- return self._direct_access_client
3275
-
3276
- def request(
3277
- self,
3278
- endpoint: str,
3279
- payload: Dict[str, Any] | None = None,
3280
- headers: Dict[str, str] | None = None,
3281
- path_params: Dict[str, str] | None = None,
3282
- query_params: Dict[str, str] | None = None,
3283
- skip_auto_create: bool = False,
3284
- ) -> requests.Response:
3285
- with debugging.span("direct_access_request"):
3286
- def _send_request():
3287
- return self.direct_access_client.request(
3288
- endpoint=endpoint,
3289
- payload=payload,
3290
- headers=headers,
3291
- path_params=path_params,
3292
- query_params=query_params,
3293
- )
3294
- try:
3295
- response = _send_request()
3296
- if response.status_code != 200:
3297
- # For 404 responses with skip_auto_create=True, return immediately to let caller handle it
3298
- # (e.g., get_engine needs to check 404 and return None for auto_create_engine)
3299
- # For skip_auto_create=False, continue to auto-creation logic below
3300
- if response.status_code == 404 and skip_auto_create:
3301
- return response
3302
-
3303
- try:
3304
- message = response.json().get("message", "")
3305
- except requests.exceptions.JSONDecodeError:
3306
- # Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
3307
- # this should have been caught by the 404 check above, so this is an error.
3308
- # For skip_auto_create=False, we explicitly check status_code below,
3309
- # so we don't need to parse the message.
3310
- if skip_auto_create:
3311
- raise ResponseStatusException(
3312
- f"Failed to parse error response from endpoint {endpoint}.", response
3313
- )
3314
- message = "" # Not used when we check status_code directly
3315
-
3316
- # fix engine on engine error and retry
3317
- # Skip auto-retry if skip_auto_create is True to avoid recursion
3318
- if (_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message):
3319
- engine_name = payload.get("caller_engine_name", "") if payload else ""
3320
- engine_name = engine_name or self.get_default_engine_name()
3321
- engine_size = self.config.get_default_engine_size()
3322
- self._poll_use_index(
3323
- app_name=self.get_app_name(),
3324
- sources=self.sources,
3325
- model=self.database,
3326
- engine_name=engine_name,
3327
- engine_size=engine_size,
3328
- headers=headers,
3329
- )
3330
- response = _send_request()
3331
- except requests.exceptions.ConnectionError as e:
3332
- if "NameResolutionError" in str(e):
3333
- # when we can not resolve the service endpoint, we assume it is outdated
3334
- # hence, we try to retrieve it again and query again.
3335
- self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
3336
- enforce_update=True,
3337
- )
3338
- return _send_request()
3339
- # raise in all other cases
3340
- raise e
3341
- return response
3342
-
3343
- def _exec_async_v2(
3344
- self,
3345
- database: str,
3346
- engine: Union[str, None],
3347
- raw_code: str,
3348
- inputs: Dict | None = None,
3349
- readonly=True,
3350
- nowait_durable=False,
3351
- headers: Dict[str, str] | None = None,
3352
- bypass_index=False,
3353
- language: str = "rel",
3354
- query_timeout_mins: int | None = None,
3355
- ):
3356
-
3357
- with debugging.span("transaction") as txn_span:
3358
- with debugging.span("create_v2") as create_span:
3359
-
3360
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
3361
-
3362
- payload = {
3363
- "dbname": database,
3364
- "engine_name": engine,
3365
- "query": raw_code,
3366
- "v1_inputs": inputs,
3367
- "nowait_durable": nowait_durable,
3368
- "readonly": readonly,
3369
- "language": language,
3370
- }
3371
- if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
3372
- query_timeout_mins = int(timeout_value)
3373
- if query_timeout_mins is not None:
3374
- payload["timeout_mins"] = query_timeout_mins
3375
- query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
3376
-
3377
- response = self.request(
3378
- "create_txn", payload=payload, headers=headers, query_params=query_params,
3379
- )
3380
-
3381
- if response.status_code != 200:
3382
- raise ResponseStatusException("Failed to create transaction.", response)
3383
-
3384
- artifact_info = {}
3385
- response_content = response.json()
3386
-
3387
- txn_id = response_content["transaction"]['id']
3388
- state = response_content["transaction"]['state']
3389
-
3390
- txn_span["txn_id"] = txn_id
3391
- create_span["txn_id"] = txn_id
3392
- debugging.event("transaction_created", txn_span, txn_id=txn_id)
3393
-
3394
- # fast path: transaction already finished
3395
- if state in ["COMPLETED", "ABORTED"]:
3396
- if txn_id in self._pending_transactions:
3397
- self._pending_transactions.remove(txn_id)
3398
-
3399
- # Process rows to get the rest of the artifacts
3400
- for result in response_content.get("results", []):
3401
- filename = result['filename']
3402
- # making keys uppercase to match the old behavior
3403
- artifact_info[filename] = {k.upper(): v for k, v in result.items()}
3404
-
3405
- # Slow path: transaction not done yet; start polling
3406
- else:
3407
- self._pending_transactions.append(txn_id)
3408
- with debugging.span("wait", txn_id=txn_id):
3409
- poll_with_specified_overhead(
3410
- lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
3411
- )
3412
- artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
3413
-
3414
- with debugging.span("fetch"):
3415
- return self._download_results(artifact_info, txn_id, state)
3416
-
3417
- def _prepare_index(
3418
- self,
3419
- model: str,
3420
- engine_name: str,
3421
- engine_size: str = "",
3422
- language: str = "rel",
3423
- rai_relations: List[str] | None = None,
3424
- pyrel_program_id: str | None = None,
3425
- skip_pull_relations: bool = False,
3426
- headers: Dict | None = None,
3427
- ):
3428
- """
3429
- Prepare the index for the given engine and model.
3430
- """
3431
- with debugging.span("prepare_index"):
3432
- if headers is None:
3433
- headers = {}
3434
-
3435
- payload = {
3436
- "model_name": model,
3437
- "caller_engine_name": engine_name,
3438
- "language": language,
3439
- "pyrel_program_id": pyrel_program_id,
3440
- "skip_pull_relations": skip_pull_relations,
3441
- "rai_relations": rai_relations or [],
3442
- "user_agent": get_pyrel_version(self.generation),
3443
- }
3444
- # Only include engine_size if it has a non-empty string value
3445
- if engine_size and engine_size.strip():
3446
- payload["caller_engine_size"] = engine_size
3447
-
3448
- response = self.request(
3449
- "prepare_index", payload=payload, headers=headers
3450
- )
3451
-
3452
- if response.status_code != 200:
3453
- raise ResponseStatusException("Failed to prepare index.", response)
3454
-
3455
- return response.json()
3456
-
3457
- def _poll_use_index(
3458
- self,
3459
- app_name: str,
3460
- sources: Iterable[str],
3461
- model: str,
3462
- engine_name: str,
3463
- engine_size: str | None = None,
3464
- program_span_id: str | None = None,
3465
- headers: Dict | None = None,
3466
- ):
3467
- return DirectUseIndexPoller(
3468
- self,
3469
- app_name=app_name,
3470
- sources=sources,
3471
- model=model,
3472
- engine_name=engine_name,
3473
- engine_size=engine_size,
3474
- language=self.language,
3475
- program_span_id=program_span_id,
3476
- headers=headers,
3477
- generation=self.generation,
3478
- ).poll()
3479
-
3480
- def maybe_poll_use_index(
3481
- self,
3482
- app_name: str,
3483
- sources: Iterable[str],
3484
- model: str,
3485
- engine_name: str,
3486
- engine_size: str | None = None,
3487
- program_span_id: str | None = None,
3488
- headers: Dict | None = None,
3489
- ):
3490
- """Only call poll() if there are sources to process and cache is not valid."""
3491
- sources_list = list(sources)
3492
- self.database = model
3493
- if sources_list:
3494
- poller = DirectUseIndexPoller(
3495
- self,
3496
- app_name=app_name,
3497
- sources=sources_list,
3498
- model=model,
3499
- engine_name=engine_name,
3500
- engine_size=engine_size,
3501
- language=self.language,
3502
- program_span_id=program_span_id,
3503
- headers=headers,
3504
- generation=self.generation,
3505
- )
3506
- # If cache is valid (data freshness has not expired), skip polling
3507
- if poller.cache.is_valid():
3508
- cached_sources = len(poller.cache.sources)
3509
- total_sources = len(sources_list)
3510
- cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
3511
-
3512
- message = f"Using cached data for {cached_sources}/{total_sources} data streams"
3513
- if cached_timestamp:
3514
- print(f"\n{message} (cached at {cached_timestamp})\n")
3515
- else:
3516
- print(f"\n{message}\n")
3517
- else:
3518
- return poller.poll()
3519
-
3520
- def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
3521
- """Check whether the given transaction has completed."""
3522
-
3523
- with debugging.span("check_status"):
3524
- response = self.request(
3525
- "get_txn",
3526
- headers=headers,
3527
- path_params={"txn_id": txn_id},
3528
- )
3529
- assert response, f"No results from get_transaction('{txn_id}')"
3530
-
3531
- response_content = response.json()
3532
- transaction = response_content["transaction"]
3533
- status: str = transaction['state']
3534
-
3535
- # remove the transaction from the pending list if it's completed or aborted
3536
- if status in ["COMPLETED", "ABORTED"]:
3537
- if txn_id in self._pending_transactions:
3538
- self._pending_transactions.remove(txn_id)
3539
-
3540
- if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
3541
- config_file_path = getattr(self.config, 'file_path', None)
3542
- timeout_ms = int(transaction.get("timeout_ms", 0))
3543
- timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
3544
- raise QueryTimeoutExceededException(
3545
- timeout_mins=timeout_mins,
3546
- query_id=txn_id,
3547
- config_file_path=config_file_path,
3548
- )
3549
-
3550
- # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
3551
- return status == "COMPLETED" or status == "ABORTED"
3552
-
3553
- def _list_exec_async_artifacts(self, txn_id: str, headers: Dict[str, str] | None = None) -> Dict[str, Dict]:
3554
- """Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
3555
- with debugging.span("list_results"):
3556
- response = self.request(
3557
- "get_txn_artifacts",
3558
- headers=headers,
3559
- path_params={"txn_id": txn_id},
3560
- )
3561
- assert response, f"No results from get_transaction_artifacts('{txn_id}')"
3562
- artifact_info = {}
3563
- for result in response.json()["results"]:
3564
- filename = result['filename']
3565
- # making keys uppercase to match the old behavior
3566
- artifact_info[filename] = {k.upper(): v for k, v in result.items()}
3567
- return artifact_info
3568
-
3569
- def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
3570
- with debugging.span("get_transaction_problems"):
3571
- response = self.request(
3572
- "get_txn_problems",
3573
- path_params={"txn_id": txn_id},
3574
- )
3575
- response_content = response.json()
3576
- if not response_content:
3577
- return []
3578
- return response_content.get("problems", [])
3579
-
3580
- def get_transaction_events(self, transaction_id: str, continuation_token: str = ''):
3581
- response = self.request(
3582
- "get_txn_events",
3583
- path_params={"txn_id": transaction_id, "stream_name": "profiler"},
3584
- query_params={"continuation_token": continuation_token},
3585
- )
3586
- response_content = response.json()
3587
- if not response_content:
3588
- return {
3589
- "events": [],
3590
- "continuation_token": None
3591
- }
3592
- return response_content
3593
-
3594
- #--------------------------------------------------
3595
- # Databases
3596
- #--------------------------------------------------
3597
-
3598
- def get_installed_packages(self, database: str) -> Union[Dict, None]:
3599
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
3600
- if use_graph_index:
3601
- response = self.request(
3602
- "get_model_package_versions",
3603
- payload={"model_name": database},
3604
- )
3605
- else:
3606
- response = self.request(
3607
- "get_package_versions",
3608
- path_params={"db_name": database},
3609
- )
3610
- if response.status_code == 404 and response.json().get("message", "") == "database not found":
3611
- return None
3612
- if response.status_code != 200:
3613
- raise ResponseStatusException(
3614
- f"Failed to retrieve package versions for {database}.", response
3615
- )
3616
-
3617
- content = response.json()
3618
- if not content:
3619
- return None
3620
-
3621
- return safe_json_loads(content["package_versions"])
3622
-
3623
- def get_database(self, database: str):
3624
- with debugging.span("get_database", dbname=database):
3625
- if not database:
3626
- raise ValueError("Database name must be provided to get database.")
3627
- response = self.request(
3628
- "get_db",
3629
- path_params={},
3630
- query_params={"name": database},
3631
- )
3632
- if response.status_code != 200:
3633
- raise ResponseStatusException(f"Failed to get db. db:{database}", response)
3634
-
3635
- response_content = response.json()
3636
-
3637
- if (response_content.get("databases") and len(response_content["databases"]) == 1):
3638
- db = response_content["databases"][0]
3639
- return {
3640
- "id": db["id"],
3641
- "name": db["name"],
3642
- "created_by": db.get("created_by"),
3643
- "created_on": ms_to_timestamp(db.get("created_on")),
3644
- "deleted_by": db.get("deleted_by"),
3645
- "deleted_on": ms_to_timestamp(db.get("deleted_on")),
3646
- "state": db["state"],
3647
- }
3648
- else:
3649
- return None
3650
-
3651
- def create_graph(self, name: str):
3652
- with debugging.span("create_model", dbname=name):
3653
- return self._create_database(name,"")
3654
-
3655
- def delete_graph(self, name:str, force=False, language: str = "rel"):
3656
- prop_hdrs = debugging.gen_current_propagation_headers()
3657
- if self.config.get("use_graph_index", USE_GRAPH_INDEX):
3658
- keep_database = not force and self.config.get("reuse_model", True)
3659
- with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
3660
- response = self.request(
3661
- "release_index",
3662
- payload={
3663
- "model_name": name,
3664
- "keep_database": keep_database,
3665
- "language": language,
3666
- "user_agent": get_pyrel_version(self.generation),
3667
- },
3668
- headers=prop_hdrs,
3669
- )
3670
- if (
3671
- response.status_code != 200
3672
- and not (
3673
- response.status_code == 404
3674
- and "database not found" in response.json().get("message", "")
3675
- )
3676
- ):
3677
- raise ResponseStatusException(f"Failed to release index. Model: {name} ", response)
3678
- else:
3679
- with debugging.span("delete_model", name=name):
3680
- self._delete_database(name, headers=prop_hdrs)
3681
-
3682
- def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
3683
- if force and self.get_graph(target_name):
3684
- self.delete_graph(target_name)
3685
- with debugging.span("clone_model", target_name=target_name, source_name=source_name):
3686
- return self._create_database(target_name,source_name)
3687
-
3688
- def _delete_database(self, name:str, headers:Dict={}):
3689
- with debugging.span("_delete_database", dbname=name):
3690
- response = self.request(
3691
- "delete_db",
3692
- path_params={"db_name": name},
3693
- query_params={},
3694
- headers=headers,
3695
- )
3696
- if response.status_code != 200:
3697
- raise ResponseStatusException(f"Failed to delete db. db:{name} ", response)
3698
-
3699
- def _create_database(self, name:str, source_name:str):
3700
- with debugging.span("_create_database", dbname=name):
3701
- payload = {
3702
- "name": name,
3703
- "source_name": source_name,
3704
- }
3705
- response = self.request(
3706
- "create_db", payload=payload, headers={}, query_params={},
3707
- )
3708
- if response.status_code != 200:
3709
- raise ResponseStatusException(f"Failed to create db. db:{name}", response)
3710
-
3711
- #--------------------------------------------------
3712
- # Engines
3713
- #--------------------------------------------------
3714
-
3715
- def list_engines(self, state: str | None = None):
3716
- response = self.request("list_engines")
3717
- if response.status_code != 200:
3718
- raise ResponseStatusException(
3719
- "Failed to retrieve engines.", response
3720
- )
3721
- response_content = response.json()
3722
- if not response_content:
3723
- return []
3724
- engines = [
3725
- {
3726
- "name": engine["name"],
3727
- "id": engine["id"],
3728
- "size": engine["size"],
3729
- "state": engine["status"], # callers are expecting 'state'
3730
- "created_by": engine["created_by"],
3731
- "created_on": engine["created_on"],
3732
- "updated_on": engine["updated_on"],
3733
- }
3734
- for engine in response_content.get("engines", [])
3735
- if state is None or engine.get("status") == state
3736
- ]
3737
- return sorted(engines, key=lambda x: x["name"])
3738
-
3739
- def get_engine(self, name: str):
3740
- response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
3741
- if response.status_code == 404: # engine not found return 404
3742
- return None
3743
- elif response.status_code != 200:
3744
- raise ResponseStatusException(
3745
- f"Failed to retrieve engine {name}.", response
3746
- )
3747
- engine = response.json()
3748
- if not engine:
3749
- return None
3750
- engine_state: EngineState = {
3751
- "name": engine["name"],
3752
- "id": engine["id"],
3753
- "size": engine["size"],
3754
- "state": engine["status"], # callers are expecting 'state'
3755
- "created_by": engine["created_by"],
3756
- "created_on": engine["created_on"],
3757
- "updated_on": engine["updated_on"],
3758
- "version": engine["version"],
3759
- "auto_suspend": engine["auto_suspend_mins"],
3760
- "suspends_at": engine["suspends_at"],
3761
- }
3762
- return engine_state
3763
-
3764
- def _create_engine(
3765
- self,
3766
- name: str,
3767
- size: str | None = None,
3768
- auto_suspend_mins: int | None = None,
3769
- is_async: bool = False,
3770
- headers: Dict[str, str] | None = None
3771
- ):
3772
- # only async engine creation supported via direct access
3773
- if not is_async:
3774
- return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
3775
- payload:Dict[str, Any] = {
3776
- "name": name,
3777
- }
3778
- if auto_suspend_mins is not None:
3779
- payload["auto_suspend_mins"] = auto_suspend_mins
3780
- if size is not None:
3781
- payload["size"] = size
3782
- response = self.request(
3783
- "create_engine",
3784
- payload=payload,
3785
- path_params={"engine_type": "logic"},
3786
- headers=headers,
3787
- skip_auto_create=True,
3788
- )
3789
- if response.status_code != 200:
3790
- raise ResponseStatusException(
3791
- f"Failed to create engine {name} with size {size}.", response
3792
- )
3793
-
3794
- def delete_engine(self, name:str, force:bool = False, headers={}):
3795
- response = self.request(
3796
- "delete_engine",
3797
- path_params={"engine_name": name, "engine_type": "logic"},
3798
- headers=headers,
3799
- skip_auto_create=True,
3800
- )
3801
- if response.status_code != 200:
3802
- raise ResponseStatusException(
3803
- f"Failed to delete engine {name}.", response
3804
- )
3805
-
3806
- def suspend_engine(self, name:str):
3807
- response = self.request(
3808
- "suspend_engine",
3809
- path_params={"engine_name": name, "engine_type": "logic"},
3810
- skip_auto_create=True,
3811
- )
3812
- if response.status_code != 200:
3813
- raise ResponseStatusException(
3814
- f"Failed to suspend engine {name}.", response
3815
- )
3816
-
3817
- def resume_engine_async(self, name:str, headers={}):
3818
- response = self.request(
3819
- "resume_engine",
3820
- path_params={"engine_name": name, "engine_type": "logic"},
3821
- headers=headers,
3822
- skip_auto_create=True,
3823
- )
3824
- if response.status_code != 200:
3825
- raise ResponseStatusException(
3826
- f"Failed to resume engine {name}.", response
3827
- )
3828
- return {}