relationalai 1.0.0a2__py3-none-any.whl → 1.0.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/config/shims.py +1 -0
- relationalai/semantics/__init__.py +7 -1
- relationalai/semantics/frontend/base.py +19 -13
- relationalai/semantics/frontend/core.py +30 -2
- relationalai/semantics/frontend/front_compiler.py +38 -11
- relationalai/semantics/frontend/pprint.py +1 -1
- relationalai/semantics/metamodel/rewriter.py +6 -2
- relationalai/semantics/metamodel/typer.py +70 -26
- relationalai/semantics/reasoners/__init__.py +11 -0
- relationalai/semantics/reasoners/graph/__init__.py +38 -0
- relationalai/semantics/reasoners/graph/core.py +9015 -0
- relationalai/shims/executor.py +4 -1
- relationalai/shims/hoister.py +9 -0
- relationalai/shims/mm2v0.py +47 -34
- relationalai/tools/cli/cli.py +138 -0
- relationalai/tools/cli/docs.py +394 -0
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/METADATA +5 -3
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/RECORD +57 -43
- v0/relationalai/__init__.py +69 -22
- v0/relationalai/clients/__init__.py +15 -2
- v0/relationalai/clients/client.py +4 -4
- v0/relationalai/clients/exec_txn_poller.py +91 -0
- v0/relationalai/clients/local.py +5 -5
- v0/relationalai/clients/resources/__init__.py +8 -0
- v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
- v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
- v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +717 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
- v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
- v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
- v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +642 -1399
- v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +51 -12
- v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
- v0/relationalai/clients/resources/snowflake/util.py +387 -0
- v0/relationalai/early_access/dsl/ir/executor.py +4 -4
- v0/relationalai/early_access/dsl/snow/api.py +2 -1
- v0/relationalai/errors.py +18 -0
- v0/relationalai/experimental/solvers.py +7 -7
- v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
- v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
- v0/relationalai/semantics/internal/snowflake.py +1 -1
- v0/relationalai/semantics/lqp/executor.py +7 -12
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
- v0/relationalai/semantics/metamodel/util.py +6 -5
- v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +335 -84
- v0/relationalai/semantics/rel/executor.py +14 -11
- v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
- v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
- v0/relationalai/tools/cli.py +26 -30
- v0/relationalai/tools/cli_helpers.py +10 -2
- v0/relationalai/util/otel_configuration.py +2 -1
- v0/relationalai/util/otel_handler.py +1 -1
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/WHEEL +0 -0
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/top_level.txt +0 -0
- /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
# pyright: reportUnusedExpression=false
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
import base64
|
|
4
|
-
import decimal
|
|
5
4
|
import importlib.resources
|
|
6
5
|
import io
|
|
7
|
-
from numbers import Number
|
|
8
6
|
import re
|
|
9
7
|
import json
|
|
10
8
|
import time
|
|
@@ -14,15 +12,15 @@ import uuid
|
|
|
14
12
|
import warnings
|
|
15
13
|
import atexit
|
|
16
14
|
import hashlib
|
|
15
|
+
from dataclasses import dataclass
|
|
17
16
|
|
|
18
|
-
|
|
19
|
-
from v0.relationalai.
|
|
20
|
-
from v0.relationalai.clients.use_index_poller import DirectUseIndexPoller, UseIndexPoller
|
|
17
|
+
from ....auth.token_handler import TokenHandler
|
|
18
|
+
from v0.relationalai.clients.exec_txn_poller import ExecTxnPoller, query_complete_message
|
|
21
19
|
import snowflake.snowpark
|
|
22
20
|
|
|
23
|
-
from
|
|
24
|
-
from
|
|
25
|
-
from
|
|
21
|
+
from ....rel_utils import sanitize_identifier, to_fqn_relation_name
|
|
22
|
+
from ....tools.constants import FIELD_PLACEHOLDER, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, DEFAULT_QUERY_TIMEOUT_MINS, WAIT_FOR_STREAM_SYNC, Generation
|
|
23
|
+
from .... import std
|
|
26
24
|
from collections import defaultdict
|
|
27
25
|
import requests
|
|
28
26
|
import snowflake.connector
|
|
@@ -30,33 +28,63 @@ import pyarrow as pa
|
|
|
30
28
|
|
|
31
29
|
from snowflake.snowpark import Session
|
|
32
30
|
from snowflake.snowpark.context import get_active_session
|
|
33
|
-
from
|
|
34
|
-
from
|
|
35
|
-
from typing import Any, Dict, Iterable,
|
|
31
|
+
from ... import result_helpers
|
|
32
|
+
from .... import debugging
|
|
33
|
+
from typing import Any, Dict, Iterable, Tuple, List, Literal, cast
|
|
36
34
|
|
|
37
35
|
from pandas import DataFrame
|
|
38
36
|
|
|
39
|
-
from
|
|
40
|
-
from
|
|
41
|
-
from
|
|
42
|
-
from
|
|
43
|
-
from
|
|
44
|
-
from
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
37
|
+
from ....tools.cli_controls import Spinner
|
|
38
|
+
from ...types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
|
|
39
|
+
from ...config import Config
|
|
40
|
+
from ...client import Client, ExportParams, ProviderBase, ResourcesBase
|
|
41
|
+
from ...util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, normalize_datetime
|
|
42
|
+
from .util import (
|
|
43
|
+
collect_error_messages,
|
|
44
|
+
process_jinja_template,
|
|
45
|
+
type_to_sql,
|
|
46
|
+
type_to_snowpark,
|
|
47
|
+
sanitize_user_name as _sanitize_user_name,
|
|
48
|
+
normalize_params,
|
|
49
|
+
format_sproc_name,
|
|
50
|
+
is_azure_url,
|
|
51
|
+
is_container_runtime,
|
|
52
|
+
imports_to_dicts,
|
|
53
|
+
txn_list_to_dicts,
|
|
54
|
+
decrypt_artifact,
|
|
55
|
+
)
|
|
56
|
+
from ....environments import runtime_env, HexEnvironment, SnowbookEnvironment
|
|
57
|
+
from .... import dsl, rel, metamodel as m
|
|
58
|
+
from ....errors import EngineProvisioningFailed, EngineNameValidationException, Errors, GuardRailsException, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIException, HexSessionException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, ModelNotFoundException, UnknownSourceWarning, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
|
|
48
59
|
from concurrent.futures import ThreadPoolExecutor
|
|
49
|
-
from datetime import datetime,
|
|
60
|
+
from datetime import datetime, timedelta
|
|
50
61
|
from snowflake.snowpark.types import StringType, StructField, StructType
|
|
62
|
+
# Import error handlers and constants
|
|
63
|
+
from .error_handlers import (
|
|
64
|
+
ErrorHandler,
|
|
65
|
+
DuoSecurityErrorHandler,
|
|
66
|
+
AppMissingErrorHandler,
|
|
67
|
+
DatabaseErrorsHandler,
|
|
68
|
+
EngineErrorsHandler,
|
|
69
|
+
ServiceNotStartedErrorHandler,
|
|
70
|
+
TransactionAbortedErrorHandler,
|
|
71
|
+
)
|
|
72
|
+
# Import engine state handlers
|
|
73
|
+
from .engine_state_handlers import (
|
|
74
|
+
EngineStateHandler,
|
|
75
|
+
EngineContext,
|
|
76
|
+
SyncPendingStateHandler,
|
|
77
|
+
SyncSuspendedStateHandler,
|
|
78
|
+
SyncReadyStateHandler,
|
|
79
|
+
SyncGoneStateHandler,
|
|
80
|
+
SyncMissingEngineHandler,
|
|
81
|
+
AsyncPendingStateHandler,
|
|
82
|
+
AsyncSuspendedStateHandler,
|
|
83
|
+
AsyncReadyStateHandler,
|
|
84
|
+
AsyncGoneStateHandler,
|
|
85
|
+
AsyncMissingEngineHandler,
|
|
86
|
+
)
|
|
51
87
|
|
|
52
|
-
# warehouse-based snowflake notebooks currently don't have hazmat
|
|
53
|
-
crypto_disabled = False
|
|
54
|
-
try:
|
|
55
|
-
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
56
|
-
from cryptography.hazmat.backends import default_backend
|
|
57
|
-
from cryptography.hazmat.primitives import padding
|
|
58
|
-
except (ModuleNotFoundError, ImportError):
|
|
59
|
-
crypto_disabled = True
|
|
60
88
|
|
|
61
89
|
#--------------------------------------------------
|
|
62
90
|
# Constants
|
|
@@ -66,225 +94,28 @@ VALID_POOL_STATUS = ["ACTIVE", "IDLE", "SUSPENDED"]
|
|
|
66
94
|
# transaction list and get return different fields (duration vs timings)
|
|
67
95
|
LIST_TXN_SQL_FIELDS = ["id", "database_name", "engine_name", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "duration"]
|
|
68
96
|
GET_TXN_SQL_FIELDS = ["id", "database", "engine", "state", "abort_reason", "read_only","created_by", "created_on", "finished_at", "timings"]
|
|
69
|
-
IMPORT_STREAM_FIELDS = ["ID", "CREATED_AT", "CREATED_BY", "STATUS", "REFERENCE_NAME", "REFERENCE_ALIAS", "FQ_OBJECT_NAME", "RAI_DATABASE",
|
|
70
|
-
"RAI_RELATION", "DATA_SYNC_STATUS", "PENDING_BATCHES_COUNT", "NEXT_BATCH_STATUS", "NEXT_BATCH_UNLOADED_TIMESTAMP",
|
|
71
|
-
"NEXT_BATCH_DETAILS", "LAST_BATCH_DETAILS", "LAST_BATCH_UNLOADED_TIMESTAMP", "CDC_STATUS"]
|
|
72
97
|
VALID_ENGINE_STATES = ["READY", "PENDING"]
|
|
73
98
|
|
|
74
99
|
# Cloud-specific engine sizes
|
|
75
100
|
INTERNAL_ENGINE_SIZES = ["XS", "S", "M", "L"]
|
|
76
101
|
ENGINE_SIZES_AWS = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_L"]
|
|
77
102
|
ENGINE_SIZES_AZURE = ["HIGHMEM_X64_S", "HIGHMEM_X64_M", "HIGHMEM_X64_SL"]
|
|
78
|
-
|
|
79
|
-
FIELD_MAP = {
|
|
80
|
-
"database_name": "database",
|
|
81
|
-
"engine_name": "engine",
|
|
82
|
-
}
|
|
83
|
-
VALID_IMPORT_STATES = ["PENDING", "PROCESSING", "QUARANTINED", "LOADED"]
|
|
84
|
-
ENGINE_ERRORS = ["engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted"]
|
|
85
|
-
ENGINE_NOT_READY_MSGS = ["engine is in pending", "engine is provisioning"]
|
|
86
|
-
DATABASE_ERRORS = ["database not found"]
|
|
103
|
+
# Note: ENGINE_ERRORS, ENGINE_NOT_READY_MSGS, DATABASE_ERRORS moved to util.py
|
|
87
104
|
PYREL_ROOT_DB = 'pyrel_root_db'
|
|
88
105
|
|
|
89
106
|
TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
|
|
90
107
|
|
|
91
|
-
DUO_TEXT = "duo security"
|
|
92
|
-
|
|
93
108
|
TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
109
|
+
GUARDRAILS_ABORT_REASON = "guard rail violation"
|
|
110
|
+
|
|
111
|
+
PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
|
|
94
112
|
|
|
95
113
|
#--------------------------------------------------
|
|
96
114
|
# Helpers
|
|
97
115
|
#--------------------------------------------------
|
|
98
116
|
|
|
99
|
-
def
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
Supports:
|
|
103
|
-
- Variable substitution {{ var }}
|
|
104
|
-
- Conditional blocks {% if condition %} ... {% endif %}
|
|
105
|
-
- For loops {% for item in items %} ... {% endfor %}
|
|
106
|
-
- Comments {# ... #}
|
|
107
|
-
- Whitespace control with {%- and -%}
|
|
108
|
-
|
|
109
|
-
Args:
|
|
110
|
-
template: The template string
|
|
111
|
-
indent_spaces: Number of spaces to indent the result
|
|
112
|
-
**substitutions: Variable substitutions
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
def evaluate_condition(condition: str, context: dict) -> bool:
|
|
116
|
-
"""Safely evaluate a condition string using the context."""
|
|
117
|
-
# Replace variables with their values
|
|
118
|
-
for k, v in context.items():
|
|
119
|
-
if isinstance(v, str):
|
|
120
|
-
condition = condition.replace(k, f"'{v}'")
|
|
121
|
-
else:
|
|
122
|
-
condition = condition.replace(k, str(v))
|
|
123
|
-
try:
|
|
124
|
-
return bool(eval(condition, {"__builtins__": {}}, {}))
|
|
125
|
-
except Exception:
|
|
126
|
-
return False
|
|
127
|
-
|
|
128
|
-
def process_expression(expr: str, context: dict) -> str:
|
|
129
|
-
"""Process a {{ expression }} block."""
|
|
130
|
-
expr = expr.strip()
|
|
131
|
-
if expr in context:
|
|
132
|
-
return str(context[expr])
|
|
133
|
-
return ""
|
|
134
|
-
|
|
135
|
-
def process_block(lines: List[str], context: dict, indent: int = 0) -> List[str]:
|
|
136
|
-
"""Process a block of template lines recursively."""
|
|
137
|
-
result = []
|
|
138
|
-
i = 0
|
|
139
|
-
while i < len(lines):
|
|
140
|
-
line = lines[i]
|
|
141
|
-
|
|
142
|
-
# Handle comments
|
|
143
|
-
line = re.sub(r'{#.*?#}', '', line)
|
|
144
|
-
|
|
145
|
-
# Handle if blocks
|
|
146
|
-
if_match = re.search(r'{%\s*if\s+(.+?)\s*%}', line)
|
|
147
|
-
if if_match:
|
|
148
|
-
condition = if_match.group(1)
|
|
149
|
-
if_block = []
|
|
150
|
-
else_block = []
|
|
151
|
-
i += 1
|
|
152
|
-
nesting = 1
|
|
153
|
-
in_else_block = False
|
|
154
|
-
while i < len(lines) and nesting > 0:
|
|
155
|
-
if re.search(r'{%\s*if\s+', lines[i]):
|
|
156
|
-
nesting += 1
|
|
157
|
-
elif re.search(r'{%\s*endif\s*%}', lines[i]):
|
|
158
|
-
nesting -= 1
|
|
159
|
-
elif nesting == 1 and re.search(r'{%\s*else\s*%}', lines[i]):
|
|
160
|
-
in_else_block = True
|
|
161
|
-
i += 1
|
|
162
|
-
continue
|
|
163
|
-
|
|
164
|
-
if nesting > 0:
|
|
165
|
-
if in_else_block:
|
|
166
|
-
else_block.append(lines[i])
|
|
167
|
-
else:
|
|
168
|
-
if_block.append(lines[i])
|
|
169
|
-
i += 1
|
|
170
|
-
if evaluate_condition(condition, context):
|
|
171
|
-
result.extend(process_block(if_block, context, indent))
|
|
172
|
-
else:
|
|
173
|
-
result.extend(process_block(else_block, context, indent))
|
|
174
|
-
continue
|
|
175
|
-
|
|
176
|
-
# Handle for loops
|
|
177
|
-
for_match = re.search(r'{%\s*for\s+(\w+)\s+in\s+(\w+)\s*%}', line)
|
|
178
|
-
if for_match:
|
|
179
|
-
var_name, iterable_name = for_match.groups()
|
|
180
|
-
for_block = []
|
|
181
|
-
i += 1
|
|
182
|
-
nesting = 1
|
|
183
|
-
while i < len(lines) and nesting > 0:
|
|
184
|
-
if re.search(r'{%\s*for\s+', lines[i]):
|
|
185
|
-
nesting += 1
|
|
186
|
-
elif re.search(r'{%\s*endfor\s*%}', lines[i]):
|
|
187
|
-
nesting -= 1
|
|
188
|
-
if nesting > 0:
|
|
189
|
-
for_block.append(lines[i])
|
|
190
|
-
i += 1
|
|
191
|
-
if iterable_name in context and isinstance(context[iterable_name], (list, tuple)):
|
|
192
|
-
for item in context[iterable_name]:
|
|
193
|
-
loop_context = dict(context)
|
|
194
|
-
loop_context[var_name] = item
|
|
195
|
-
result.extend(process_block(for_block, loop_context, indent))
|
|
196
|
-
continue
|
|
197
|
-
|
|
198
|
-
# Handle variable substitution
|
|
199
|
-
line = re.sub(r'{{\s*(\w+)\s*}}', lambda m: process_expression(m.group(1), context), line)
|
|
200
|
-
|
|
201
|
-
# Handle whitespace control
|
|
202
|
-
line = re.sub(r'{%-', '{%', line)
|
|
203
|
-
line = re.sub(r'-%}', '%}', line)
|
|
204
|
-
|
|
205
|
-
# Add line with proper indentation, preserving blank lines
|
|
206
|
-
if line.strip():
|
|
207
|
-
result.append(" " * (indent_spaces + indent) + line)
|
|
208
|
-
else:
|
|
209
|
-
result.append("")
|
|
210
|
-
|
|
211
|
-
i += 1
|
|
212
|
-
|
|
213
|
-
return result
|
|
214
|
-
|
|
215
|
-
# Split template into lines and process
|
|
216
|
-
lines = template.split('\n')
|
|
217
|
-
processed_lines = process_block(lines, substitutions)
|
|
218
|
-
|
|
219
|
-
return '\n'.join(processed_lines)
|
|
220
|
-
|
|
221
|
-
def type_to_sql(type) -> str:
|
|
222
|
-
if type is str:
|
|
223
|
-
return "VARCHAR"
|
|
224
|
-
if type is int:
|
|
225
|
-
return "NUMBER"
|
|
226
|
-
if type is Number:
|
|
227
|
-
return "DECIMAL(38, 15)"
|
|
228
|
-
if type is float:
|
|
229
|
-
return "FLOAT"
|
|
230
|
-
if type is decimal.Decimal:
|
|
231
|
-
return "DECIMAL(38, 15)"
|
|
232
|
-
if type is bool:
|
|
233
|
-
return "BOOLEAN"
|
|
234
|
-
if type is dict:
|
|
235
|
-
return "VARIANT"
|
|
236
|
-
if type is list:
|
|
237
|
-
return "ARRAY"
|
|
238
|
-
if type is bytes:
|
|
239
|
-
return "BINARY"
|
|
240
|
-
if type is datetime:
|
|
241
|
-
return "TIMESTAMP"
|
|
242
|
-
if type is date:
|
|
243
|
-
return "DATE"
|
|
244
|
-
if isinstance(type, dsl.Type):
|
|
245
|
-
return "VARCHAR"
|
|
246
|
-
raise ValueError(f"Unknown type {type}")
|
|
247
|
-
|
|
248
|
-
def type_to_snowpark(type) -> str:
|
|
249
|
-
if type is str:
|
|
250
|
-
return "StringType()"
|
|
251
|
-
if type is int:
|
|
252
|
-
return "IntegerType()"
|
|
253
|
-
if type is float:
|
|
254
|
-
return "FloatType()"
|
|
255
|
-
if type is Number:
|
|
256
|
-
return "DecimalType(38, 15)"
|
|
257
|
-
if type is decimal.Decimal:
|
|
258
|
-
return "DecimalType(38, 15)"
|
|
259
|
-
if type is bool:
|
|
260
|
-
return "BooleanType()"
|
|
261
|
-
if type is dict:
|
|
262
|
-
return "MapType()"
|
|
263
|
-
if type is list:
|
|
264
|
-
return "ArrayType()"
|
|
265
|
-
if type is bytes:
|
|
266
|
-
return "BinaryType()"
|
|
267
|
-
if type is datetime:
|
|
268
|
-
return "TimestampType()"
|
|
269
|
-
if type is date:
|
|
270
|
-
return "DateType()"
|
|
271
|
-
if isinstance(type, dsl.Type):
|
|
272
|
-
return "StringType()"
|
|
273
|
-
raise ValueError(f"Unknown type {type}")
|
|
274
|
-
|
|
275
|
-
def _sanitize_user_name(user: str) -> str:
|
|
276
|
-
# Extract the part before the '@'
|
|
277
|
-
sanitized_user = user.split('@')[0]
|
|
278
|
-
# Replace any character that is not a letter, number, or underscore with '_'
|
|
279
|
-
sanitized_user = re.sub(r'[^a-zA-Z0-9_]', '_', sanitized_user)
|
|
280
|
-
return sanitized_user
|
|
281
|
-
|
|
282
|
-
def _is_engine_issue(response_message: str) -> bool:
|
|
283
|
-
return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
|
|
284
|
-
|
|
285
|
-
def _is_database_issue(response_message: str) -> bool:
|
|
286
|
-
return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
|
|
287
|
-
|
|
117
|
+
def should_print_txn_progress(config) -> bool:
|
|
118
|
+
return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
|
|
288
119
|
|
|
289
120
|
#--------------------------------------------------
|
|
290
121
|
# Resources
|
|
@@ -292,6 +123,25 @@ def _is_database_issue(response_message: str) -> bool:
|
|
|
292
123
|
|
|
293
124
|
APP_NAME = "___RAI_APP___"
|
|
294
125
|
|
|
126
|
+
@dataclass
|
|
127
|
+
class ExecContext:
|
|
128
|
+
"""Execution context for SQL queries, containing all parameters needed for execution and retry."""
|
|
129
|
+
code: str
|
|
130
|
+
params: List[Any] | None = None
|
|
131
|
+
raw: bool = False
|
|
132
|
+
help: bool = True
|
|
133
|
+
skip_engine_db_error_retry: bool = False
|
|
134
|
+
|
|
135
|
+
def re_execute(self, resources: 'Resources') -> Any:
|
|
136
|
+
"""Re-execute this context's query using the provided resources instance."""
|
|
137
|
+
return resources._exec(
|
|
138
|
+
code=self.code,
|
|
139
|
+
params=self.params,
|
|
140
|
+
raw=self.raw,
|
|
141
|
+
help=self.help,
|
|
142
|
+
skip_engine_db_error_retry=self.skip_engine_db_error_retry
|
|
143
|
+
)
|
|
144
|
+
|
|
295
145
|
class Resources(ResourcesBase):
|
|
296
146
|
def __init__(
|
|
297
147
|
self,
|
|
@@ -301,7 +151,7 @@ class Resources(ResourcesBase):
|
|
|
301
151
|
dry_run: bool = False,
|
|
302
152
|
reset_session: bool = False,
|
|
303
153
|
generation: Generation | None = None,
|
|
304
|
-
language: str = "rel",
|
|
154
|
+
language: str = "rel", # Accepted for backward compatibility, but not stored in base class
|
|
305
155
|
):
|
|
306
156
|
super().__init__(profile, config=config)
|
|
307
157
|
self._token_handler: TokenHandler | None = None
|
|
@@ -319,16 +169,97 @@ class Resources(ResourcesBase):
|
|
|
319
169
|
# self.sources contains fully qualified Snowflake table/view names
|
|
320
170
|
self.sources: set[str] = set()
|
|
321
171
|
self._sproc_models = None
|
|
322
|
-
|
|
172
|
+
# Store language for backward compatibility (used by child classes for use_index polling)
|
|
323
173
|
self.language = language
|
|
174
|
+
# Register error and state handlers
|
|
175
|
+
self._register_handlers()
|
|
176
|
+
# Register atexit callback to cancel pending transactions
|
|
324
177
|
atexit.register(self.cancel_pending_transactions)
|
|
325
178
|
|
|
179
|
+
#--------------------------------------------------
|
|
180
|
+
# Initialization & Properties
|
|
181
|
+
#--------------------------------------------------
|
|
182
|
+
|
|
183
|
+
def _register_handlers(self) -> None:
|
|
184
|
+
"""Register error and engine state handlers for processing."""
|
|
185
|
+
# Register base handlers using getter methods that subclasses can override
|
|
186
|
+
# Use defensive copying to ensure each instance has its own handler lists
|
|
187
|
+
# and prevent cross-instance contamination from subclass mutations
|
|
188
|
+
self._error_handlers = list(self._get_error_handlers())
|
|
189
|
+
self._sync_engine_state_handlers = list(self._get_engine_state_handlers(is_async=False))
|
|
190
|
+
self._async_engine_state_handlers = list(self._get_engine_state_handlers(is_async=True))
|
|
191
|
+
|
|
192
|
+
def _get_error_handlers(self) -> list[ErrorHandler]:
|
|
193
|
+
"""Get list of error handlers. Subclasses can override to add custom handlers.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
List of error handlers for standard error processing using Strategy Pattern.
|
|
197
|
+
|
|
198
|
+
Example:
|
|
199
|
+
def _get_error_handlers(self) -> list[ErrorHandler]:
|
|
200
|
+
# Get base handlers
|
|
201
|
+
handlers = super()._get_error_handlers()
|
|
202
|
+
# Add custom handler
|
|
203
|
+
handlers.append(MyCustomErrorHandler())
|
|
204
|
+
return handlers
|
|
205
|
+
"""
|
|
206
|
+
return [
|
|
207
|
+
DuoSecurityErrorHandler(),
|
|
208
|
+
AppMissingErrorHandler(),
|
|
209
|
+
DatabaseErrorsHandler(),
|
|
210
|
+
EngineErrorsHandler(),
|
|
211
|
+
ServiceNotStartedErrorHandler(),
|
|
212
|
+
TransactionAbortedErrorHandler(),
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
def _get_engine_state_handlers(self, is_async: bool = False) -> list[EngineStateHandler]:
|
|
216
|
+
"""Get list of engine state handlers. Subclasses can override.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
is_async: If True, returns async handlers; if False, returns sync handlers.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
List of engine state handlers for processing engine states.
|
|
223
|
+
|
|
224
|
+
Example:
|
|
225
|
+
def _get_engine_state_handlers(self, is_async: bool = False) -> list[EngineStateHandler]:
|
|
226
|
+
# Get base handlers
|
|
227
|
+
handlers = super()._get_engine_state_handlers(is_async)
|
|
228
|
+
# Add custom handler
|
|
229
|
+
handlers.append(MyCustomStateHandler())
|
|
230
|
+
return handlers
|
|
231
|
+
"""
|
|
232
|
+
if is_async:
|
|
233
|
+
return [
|
|
234
|
+
AsyncPendingStateHandler(),
|
|
235
|
+
AsyncSuspendedStateHandler(),
|
|
236
|
+
AsyncReadyStateHandler(),
|
|
237
|
+
AsyncGoneStateHandler(),
|
|
238
|
+
AsyncMissingEngineHandler(),
|
|
239
|
+
]
|
|
240
|
+
else:
|
|
241
|
+
return [
|
|
242
|
+
SyncPendingStateHandler(),
|
|
243
|
+
SyncSuspendedStateHandler(),
|
|
244
|
+
SyncReadyStateHandler(),
|
|
245
|
+
SyncGoneStateHandler(),
|
|
246
|
+
SyncMissingEngineHandler(),
|
|
247
|
+
]
|
|
248
|
+
|
|
326
249
|
@property
|
|
327
250
|
def token_handler(self) -> TokenHandler:
|
|
328
251
|
if not self._token_handler:
|
|
329
252
|
self._token_handler = TokenHandler.from_config(self.config)
|
|
330
253
|
return self._token_handler
|
|
331
254
|
|
|
255
|
+
def reset(self):
|
|
256
|
+
"""Reset the session."""
|
|
257
|
+
self._session = None
|
|
258
|
+
|
|
259
|
+
#--------------------------------------------------
|
|
260
|
+
# Session Management
|
|
261
|
+
#--------------------------------------------------
|
|
262
|
+
|
|
332
263
|
def is_erp_running(self, app_name: str) -> bool:
|
|
333
264
|
"""Check if the ERP is running. The app.service_status() returns single row/column containing an array of JSON service status objects."""
|
|
334
265
|
query = f"CALL {app_name}.app.service_status();"
|
|
@@ -426,7 +357,28 @@ class Resources(ResourcesBase):
|
|
|
426
357
|
except Exception as e:
|
|
427
358
|
raise e
|
|
428
359
|
|
|
360
|
+
#--------------------------------------------------
|
|
361
|
+
# Core Execution Methods
|
|
362
|
+
#--------------------------------------------------
|
|
363
|
+
|
|
429
364
|
def _exec_sql(self, code: str, params: List[Any] | None, raw=False):
|
|
365
|
+
"""
|
|
366
|
+
Lowest-level SQL execution method.
|
|
367
|
+
|
|
368
|
+
Directly executes SQL via the Snowflake session. This is the foundation
|
|
369
|
+
for all other execution methods. It:
|
|
370
|
+
- Replaces APP_NAME placeholder with actual app name
|
|
371
|
+
- Executes SQL with optional parameters
|
|
372
|
+
- Returns either raw session results or collected results
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
code: SQL code to execute (may contain APP_NAME placeholder)
|
|
376
|
+
params: Optional SQL parameters
|
|
377
|
+
raw: If True, return raw session results; if False, collect results
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Raw session results if raw=True, otherwise collected results
|
|
381
|
+
"""
|
|
430
382
|
assert self._session is not None
|
|
431
383
|
sess_results = self._session.sql(
|
|
432
384
|
code.replace(APP_NAME, self.get_app_name()),
|
|
@@ -444,85 +396,88 @@ class Resources(ResourcesBase):
|
|
|
444
396
|
help: bool = True,
|
|
445
397
|
skip_engine_db_error_retry: bool = False
|
|
446
398
|
) -> Any:
|
|
399
|
+
"""
|
|
400
|
+
Mid-level SQL execution method with error handling.
|
|
401
|
+
|
|
402
|
+
This is the primary method for executing SQL queries. It wraps _exec_sql
|
|
403
|
+
with comprehensive error handling and parameter normalization. Used
|
|
404
|
+
extensively throughout the codebase for direct SQL operations like:
|
|
405
|
+
- SHOW commands (warehouses, databases, etc.)
|
|
406
|
+
- CALL statements to RAI app stored procedures
|
|
407
|
+
- Transaction management queries
|
|
408
|
+
|
|
409
|
+
The error handling flow:
|
|
410
|
+
1. Normalizes parameters and creates execution context
|
|
411
|
+
2. Calls _exec_sql to execute the query
|
|
412
|
+
3. On error, uses standard error handling (Strategy Pattern), which subclasses
|
|
413
|
+
can influence via `_get_error_handlers()` or by overriding `_handle_standard_exec_errors()`
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
code: SQL code to execute
|
|
417
|
+
params: Optional SQL parameters (normalized to list if needed)
|
|
418
|
+
raw: If True, return raw session results; if False, collect results
|
|
419
|
+
help: If True, enable error handling; if False, raise errors immediately
|
|
420
|
+
skip_engine_db_error_retry: If True, skip use_index retry logic in error handlers
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
Query results (collected or raw depending on 'raw' parameter)
|
|
424
|
+
"""
|
|
447
425
|
# print(f"\n--- sql---\n{code}\n--- end sql---\n")
|
|
426
|
+
# Ensure session is initialized
|
|
448
427
|
if not self._session:
|
|
449
428
|
self._session = self.get_sf_session()
|
|
450
429
|
|
|
430
|
+
# Normalize parameters
|
|
431
|
+
normalized_params = normalize_params(params)
|
|
432
|
+
|
|
433
|
+
# Create execution context
|
|
434
|
+
ctx = ExecContext(
|
|
435
|
+
code=code,
|
|
436
|
+
params=normalized_params,
|
|
437
|
+
raw=raw,
|
|
438
|
+
help=help,
|
|
439
|
+
skip_engine_db_error_retry=skip_engine_db_error_retry
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Execute SQL
|
|
451
443
|
try:
|
|
452
|
-
|
|
453
|
-
params = cast(List[Any], [params])
|
|
454
|
-
return self._exec_sql(code, params, raw=raw)
|
|
444
|
+
return self._exec_sql(ctx.code, ctx.params, raw=ctx.raw)
|
|
455
445
|
except Exception as e:
|
|
456
|
-
if not help:
|
|
446
|
+
if not ctx.help:
|
|
457
447
|
raise e
|
|
458
|
-
orig_message = str(e).lower()
|
|
459
|
-
rai_app = self.config.get("rai_app_name", "")
|
|
460
|
-
current_role = self.config.get("role")
|
|
461
|
-
engine = self.get_default_engine_name()
|
|
462
|
-
engine_size = self.config.get_default_engine_size()
|
|
463
|
-
assert isinstance(rai_app, str), f"rai_app_name must be a string, not {type(rai_app)}"
|
|
464
|
-
assert isinstance(engine, str), f"engine must be a string, not {type(engine)}"
|
|
465
|
-
print("\n")
|
|
466
|
-
if DUO_TEXT in orig_message:
|
|
467
|
-
raise DuoSecurityFailed(e)
|
|
468
|
-
if re.search(f"database '{rai_app}' does not exist or not authorized.".lower(), orig_message):
|
|
469
|
-
exception = SnowflakeAppMissingException(rai_app, current_role)
|
|
470
|
-
raise exception from None
|
|
471
|
-
# skip initializing the index if the query is a user transaction. exec_raw/exec_lqp will handle that case with the correct request headers.
|
|
472
|
-
if (_is_engine_issue(orig_message) or _is_database_issue(orig_message)) and not skip_engine_db_error_retry:
|
|
473
|
-
try:
|
|
474
|
-
self._poll_use_index(
|
|
475
|
-
app_name=self.get_app_name(),
|
|
476
|
-
sources=self.sources,
|
|
477
|
-
model=self.database,
|
|
478
|
-
engine_name=engine,
|
|
479
|
-
engine_size=engine_size
|
|
480
|
-
)
|
|
481
|
-
return self._exec(code, params, raw=raw, help=help)
|
|
482
|
-
except EngineNameValidationException as e:
|
|
483
|
-
raise EngineNameValidationException(engine) from e
|
|
484
|
-
except Exception as e:
|
|
485
|
-
raise EngineProvisioningFailed(engine, e) from e
|
|
486
|
-
elif re.search(r"javascript execution error", orig_message):
|
|
487
|
-
match = re.search(r"\"message\":\"(.*)\"", orig_message)
|
|
488
|
-
if match:
|
|
489
|
-
message = match.group(1)
|
|
490
|
-
if "engine is in pending" in message or "engine is provisioning" in message:
|
|
491
|
-
raise EnginePending(engine)
|
|
492
|
-
else:
|
|
493
|
-
raise RAIException(message) from None
|
|
494
|
-
|
|
495
|
-
if re.search(r"the relationalai service has not been started.", orig_message):
|
|
496
|
-
app_name = self.config.get("rai_app_name", "")
|
|
497
|
-
assert isinstance(app_name, str), f"rai_app_name must be a string, not {type(app_name)}"
|
|
498
|
-
raise SnowflakeRaiAppNotStarted(app_name)
|
|
499
|
-
|
|
500
|
-
if re.search(r"state:\s*aborted", orig_message):
|
|
501
|
-
txn_id_match = re.search(r"id:\s*([0-9a-f\-]+)", orig_message)
|
|
502
|
-
if txn_id_match:
|
|
503
|
-
txn_id = txn_id_match.group(1)
|
|
504
|
-
problems = self.get_transaction_problems(txn_id)
|
|
505
|
-
if problems:
|
|
506
|
-
for problem in problems:
|
|
507
|
-
if isinstance(problem, dict):
|
|
508
|
-
type_field = problem.get('TYPE')
|
|
509
|
-
message_field = problem.get('MESSAGE')
|
|
510
|
-
report_field = problem.get('REPORT')
|
|
511
|
-
else:
|
|
512
|
-
type_field = problem.TYPE
|
|
513
|
-
message_field = problem.MESSAGE
|
|
514
|
-
report_field = problem.REPORT
|
|
515
448
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
449
|
+
# Handle standard errors
|
|
450
|
+
result = self._handle_standard_exec_errors(e, ctx)
|
|
451
|
+
if result is not None:
|
|
452
|
+
return result
|
|
519
453
|
|
|
454
|
+
#--------------------------------------------------
|
|
455
|
+
# Error Handling
|
|
456
|
+
#--------------------------------------------------
|
|
520
457
|
|
|
521
|
-
def
|
|
522
|
-
|
|
458
|
+
def _handle_standard_exec_errors(self, e: Exception, ctx: ExecContext) -> Any | None:
|
|
459
|
+
"""
|
|
460
|
+
Handle standard Snowflake/RAI errors using Strategy Pattern.
|
|
461
|
+
|
|
462
|
+
Each error type has a dedicated handler class that encapsulates
|
|
463
|
+
the detection logic and exception creation. Handlers are processed
|
|
464
|
+
in order until one matches and handles the error.
|
|
465
|
+
"""
|
|
466
|
+
message = str(e).lower()
|
|
467
|
+
|
|
468
|
+
# Try each handler in order until one matches
|
|
469
|
+
for handler in self._error_handlers:
|
|
470
|
+
if handler.matches(e, message, ctx, self):
|
|
471
|
+
result = handler.handle(e, ctx, self)
|
|
472
|
+
if result is not None:
|
|
473
|
+
return result
|
|
474
|
+
return # Handler raised exception, we're done
|
|
475
|
+
|
|
476
|
+
# Fallback: transform to RAIException
|
|
477
|
+
raise RAIException(str(e))
|
|
523
478
|
|
|
524
479
|
#--------------------------------------------------
|
|
525
|
-
#
|
|
480
|
+
# Feature Detection & Configuration
|
|
526
481
|
#--------------------------------------------------
|
|
527
482
|
|
|
528
483
|
def is_direct_access_enabled(self) -> bool:
|
|
@@ -544,9 +499,6 @@ class Resources(ResourcesBase):
|
|
|
544
499
|
except Exception as e:
|
|
545
500
|
raise Exception(f"Unable to determine if direct access is enabled. Details error: {e}") from e
|
|
546
501
|
|
|
547
|
-
#--------------------------------------------------
|
|
548
|
-
# Snowflake Account Flags
|
|
549
|
-
#--------------------------------------------------
|
|
550
502
|
|
|
551
503
|
def is_account_flag_set(self, flag: str) -> bool:
|
|
552
504
|
results = self._exec(
|
|
@@ -566,7 +518,8 @@ class Resources(ResourcesBase):
|
|
|
566
518
|
f"call {APP_NAME}.api.get_database('{database}');"
|
|
567
519
|
)
|
|
568
520
|
except Exception as e:
|
|
569
|
-
|
|
521
|
+
messages = collect_error_messages(e)
|
|
522
|
+
if any("database does not exist" in msg for msg in messages):
|
|
570
523
|
return None
|
|
571
524
|
raise e
|
|
572
525
|
|
|
@@ -590,10 +543,11 @@ class Resources(ResourcesBase):
|
|
|
590
543
|
try:
|
|
591
544
|
results = self._exec(query)
|
|
592
545
|
except Exception as e:
|
|
593
|
-
|
|
546
|
+
messages = collect_error_messages(e)
|
|
547
|
+
if any("database does not exist" in msg for msg in messages):
|
|
594
548
|
return None
|
|
595
549
|
# fallback to None for old sql-lib versions
|
|
596
|
-
if "
|
|
550
|
+
if any("unknown user-defined function" in msg for msg in messages):
|
|
597
551
|
return None
|
|
598
552
|
raise e
|
|
599
553
|
|
|
@@ -610,6 +564,139 @@ class Resources(ResourcesBase):
|
|
|
610
564
|
# Engines
|
|
611
565
|
#--------------------------------------------------
|
|
612
566
|
|
|
567
|
+
def _prepare_engine_params(
|
|
568
|
+
self,
|
|
569
|
+
name: str | None,
|
|
570
|
+
size: str | None,
|
|
571
|
+
use_default_size: bool = False
|
|
572
|
+
) -> tuple[str, str | None]:
|
|
573
|
+
"""
|
|
574
|
+
Prepare engine parameters by resolving and validating name and size.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
name: Engine name (None to use default)
|
|
578
|
+
size: Engine size (None to use config or default)
|
|
579
|
+
use_default_size: If True and size is None, use get_default_engine_size()
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
Tuple of (engine_name, engine_size)
|
|
583
|
+
|
|
584
|
+
Raises:
|
|
585
|
+
EngineNameValidationException: If engine name is invalid
|
|
586
|
+
Exception: If engine size is invalid
|
|
587
|
+
"""
|
|
588
|
+
from v0.relationalai.tools.cli_helpers import validate_engine_name
|
|
589
|
+
|
|
590
|
+
engine_name = name or self.get_default_engine_name()
|
|
591
|
+
|
|
592
|
+
# Resolve engine size
|
|
593
|
+
if size:
|
|
594
|
+
engine_size = size
|
|
595
|
+
else:
|
|
596
|
+
if use_default_size:
|
|
597
|
+
engine_size = self.config.get_default_engine_size()
|
|
598
|
+
else:
|
|
599
|
+
engine_size = self.config.get("engine_size", None)
|
|
600
|
+
|
|
601
|
+
# Validate engine size
|
|
602
|
+
if engine_size:
|
|
603
|
+
is_size_valid, sizes = self.validate_engine_size(engine_size)
|
|
604
|
+
if not is_size_valid:
|
|
605
|
+
error_msg = f"Invalid engine size '{engine_size}'. Valid sizes are: {', '.join(sizes)}"
|
|
606
|
+
if use_default_size:
|
|
607
|
+
error_msg = f"Invalid engine size in config: '{engine_size}'. Valid sizes are: {', '.join(sizes)}"
|
|
608
|
+
raise Exception(error_msg)
|
|
609
|
+
|
|
610
|
+
# Validate engine name
|
|
611
|
+
is_name_valid, _ = validate_engine_name(engine_name)
|
|
612
|
+
if not is_name_valid:
|
|
613
|
+
raise EngineNameValidationException(engine_name)
|
|
614
|
+
|
|
615
|
+
return engine_name, engine_size
|
|
616
|
+
|
|
617
|
+
def _get_state_handler(self, state: str | None, handlers: list[EngineStateHandler]) -> EngineStateHandler:
|
|
618
|
+
"""Find the appropriate state handler for the given state."""
|
|
619
|
+
for handler in handlers:
|
|
620
|
+
if handler.handles_state(state):
|
|
621
|
+
return handler
|
|
622
|
+
# Fallback to missing engine handler if no match
|
|
623
|
+
return handlers[-1] # Last handler should be MissingEngineHandler
|
|
624
|
+
|
|
625
|
+
def _process_engine_state(
|
|
626
|
+
self,
|
|
627
|
+
engine: EngineState | Dict[str, Any] | None,
|
|
628
|
+
context: EngineContext,
|
|
629
|
+
handlers: list[EngineStateHandler],
|
|
630
|
+
set_active_on_success: bool = False
|
|
631
|
+
) -> EngineState | Dict[str, Any] | None:
|
|
632
|
+
"""
|
|
633
|
+
Process engine state using appropriate state handler.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
engine: Current engine state (or None if missing)
|
|
637
|
+
context: Engine context for state handling
|
|
638
|
+
handlers: List of state handlers to use (sync or async)
|
|
639
|
+
set_active_on_success: If True, set engine as active when handler returns engine
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
Engine state after processing, or None if engine needs to be created
|
|
643
|
+
"""
|
|
644
|
+
# Find and execute appropriate state handler
|
|
645
|
+
state = engine["state"] if engine else None
|
|
646
|
+
handler = self._get_state_handler(state, handlers)
|
|
647
|
+
engine = handler.handle(engine, context, self)
|
|
648
|
+
|
|
649
|
+
# If handler returned None and we didn't start with None state, engine needs to be created
|
|
650
|
+
# (e.g., GONE state deleted the engine, so we need to create a new one)
|
|
651
|
+
if not engine and state is not None:
|
|
652
|
+
handler = self._get_state_handler(None, handlers)
|
|
653
|
+
handler.handle(None, context, self)
|
|
654
|
+
elif set_active_on_success:
|
|
655
|
+
# Cast to EngineState for type safety (handlers return EngineDict which is compatible)
|
|
656
|
+
self._set_active_engine(cast(EngineState, engine))
|
|
657
|
+
|
|
658
|
+
return engine
|
|
659
|
+
|
|
660
|
+
def _handle_engine_creation_errors(self, error: Exception, engine_name: str, preserve_rai_exception: bool = False) -> None:
|
|
661
|
+
"""
|
|
662
|
+
Handle errors during engine creation using error handlers.
|
|
663
|
+
|
|
664
|
+
Args:
|
|
665
|
+
error: The exception that occurred
|
|
666
|
+
engine_name: Name of the engine being created
|
|
667
|
+
preserve_rai_exception: If True, re-raise RAIException without wrapping
|
|
668
|
+
|
|
669
|
+
Raises:
|
|
670
|
+
RAIException: If preserve_rai_exception is True and error is RAIException
|
|
671
|
+
EngineProvisioningFailed: If error is not handled by error handlers
|
|
672
|
+
"""
|
|
673
|
+
# Preserve RAIException passthrough if requested (for async mode)
|
|
674
|
+
if preserve_rai_exception and isinstance(error, RAIException):
|
|
675
|
+
raise error
|
|
676
|
+
|
|
677
|
+
# Check if this is a known error type that should be handled by error handlers
|
|
678
|
+
message = str(error).lower()
|
|
679
|
+
handled = False
|
|
680
|
+
# Engine creation isn't tied to a specific SQL ExecContext; pass a context that
|
|
681
|
+
# disables use_index retry behavior (and any future ctx-dependent handlers).
|
|
682
|
+
ctx = ExecContext(code="", help=True, skip_engine_db_error_retry=True)
|
|
683
|
+
for handler in self._error_handlers:
|
|
684
|
+
if handler.matches(error, message, ctx, self):
|
|
685
|
+
handler.handle(error, ctx, self)
|
|
686
|
+
handled = True
|
|
687
|
+
break # Handler raised exception, we're done
|
|
688
|
+
|
|
689
|
+
# If not handled by error handlers, wrap in EngineProvisioningFailed
|
|
690
|
+
if not handled:
|
|
691
|
+
raise EngineProvisioningFailed(engine_name, error) from error
|
|
692
|
+
|
|
693
|
+
def validate_engine_size(self, size: str) -> Tuple[bool, List[str]]:
|
|
694
|
+
if size is not None:
|
|
695
|
+
sizes = self.get_engine_sizes()
|
|
696
|
+
if size not in sizes:
|
|
697
|
+
return False, sizes
|
|
698
|
+
return True, []
|
|
699
|
+
|
|
613
700
|
def get_engine_sizes(self, cloud_provider: str|None=None):
|
|
614
701
|
sizes = []
|
|
615
702
|
if cloud_provider is None:
|
|
@@ -812,19 +899,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
812
899
|
engine_size: str | None = None,
|
|
813
900
|
program_span_id: str | None = None,
|
|
814
901
|
headers: Dict | None = None,
|
|
815
|
-
):
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
self.generation
|
|
827
|
-
).poll()
|
|
902
|
+
) -> None:
|
|
903
|
+
"""
|
|
904
|
+
Poll use_index to prepare indices for the given sources.
|
|
905
|
+
|
|
906
|
+
This is an optional interface method. Base Resources provides a no-op implementation.
|
|
907
|
+
UseIndexResources and DirectAccessResources override this to provide actual polling.
|
|
908
|
+
|
|
909
|
+
Returns:
|
|
910
|
+
None for base implementation. Child classes may return poller results.
|
|
911
|
+
"""
|
|
912
|
+
return None
|
|
828
913
|
|
|
829
914
|
def maybe_poll_use_index(
|
|
830
915
|
self,
|
|
@@ -835,36 +920,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
835
920
|
engine_size: str | None = None,
|
|
836
921
|
program_span_id: str | None = None,
|
|
837
922
|
headers: Dict | None = None,
|
|
838
|
-
):
|
|
839
|
-
"""
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
engine_size,
|
|
850
|
-
self.language,
|
|
851
|
-
program_span_id,
|
|
852
|
-
headers,
|
|
853
|
-
self.generation
|
|
854
|
-
)
|
|
855
|
-
# If cache is valid (data freshness has not expired), skip polling
|
|
856
|
-
if poller.cache.is_valid():
|
|
857
|
-
cached_sources = len(poller.cache.sources)
|
|
858
|
-
total_sources = len(sources_list)
|
|
859
|
-
cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
|
|
860
|
-
|
|
861
|
-
message = f"Using cached data for {cached_sources}/{total_sources} data streams"
|
|
862
|
-
if cached_timestamp:
|
|
863
|
-
print(f"\n{message} (cached at {cached_timestamp})\n")
|
|
864
|
-
else:
|
|
865
|
-
print(f"\n{message}\n")
|
|
866
|
-
else:
|
|
867
|
-
return poller.poll()
|
|
923
|
+
) -> None:
|
|
924
|
+
"""
|
|
925
|
+
Only call _poll_use_index if there are sources to process.
|
|
926
|
+
|
|
927
|
+
This is an optional interface method. Base Resources provides a no-op implementation.
|
|
928
|
+
UseIndexResources and DirectAccessResources override this to provide actual polling with caching.
|
|
929
|
+
|
|
930
|
+
Returns:
|
|
931
|
+
None for base implementation. Child classes may return poller results.
|
|
932
|
+
"""
|
|
933
|
+
return None
|
|
868
934
|
|
|
869
935
|
#--------------------------------------------------
|
|
870
936
|
# Models
|
|
@@ -902,11 +968,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
902
968
|
def list_exports(self, database: str, engine: str):
|
|
903
969
|
return []
|
|
904
970
|
|
|
905
|
-
def format_sproc_name(self, name: str, type:Any) -> str:
|
|
906
|
-
if type is datetime:
|
|
907
|
-
return f"{name}.astimezone(ZoneInfo('UTC')).isoformat(timespec='milliseconds')"
|
|
908
|
-
else:
|
|
909
|
-
return name
|
|
910
971
|
|
|
911
972
|
def get_export_code(self, params: ExportParams, all_installs):
|
|
912
973
|
sql_inputs = ", ".join([f"{name} {type_to_sql(type)}" for (name, _, type) in params.inputs])
|
|
@@ -928,15 +989,14 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
928
989
|
clean_inputs.append(f"{name} = '\"' + escape({name}) + '\"'")
|
|
929
990
|
# Replace `var` with `name` and keep the following non-word character unchanged
|
|
930
991
|
pattern = re.compile(re.escape(var) + r'(\W)')
|
|
931
|
-
value =
|
|
992
|
+
value = format_sproc_name(name, type)
|
|
932
993
|
safe_rel = re.sub(pattern, rf"{{{value}}}\1", safe_rel)
|
|
933
994
|
if py_inputs:
|
|
934
995
|
py_inputs = f", {py_inputs}"
|
|
935
996
|
clean_inputs = ("\n").join(clean_inputs)
|
|
936
|
-
assert __package__ is not None, "Package name must be set"
|
|
937
997
|
file = "export_procedure.py.jinja"
|
|
938
998
|
with importlib.resources.open_text(
|
|
939
|
-
|
|
999
|
+
"relationalai.clients.resources.snowflake", file
|
|
940
1000
|
) as f:
|
|
941
1001
|
template = f.read()
|
|
942
1002
|
def quote(s: str, f = False) -> str:
|
|
@@ -1094,15 +1154,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1094
1154
|
# Imports
|
|
1095
1155
|
#--------------------------------------------------
|
|
1096
1156
|
|
|
1097
|
-
def is_valid_import_state(self, state:str):
|
|
1098
|
-
return state in VALID_IMPORT_STATES
|
|
1099
|
-
|
|
1100
|
-
def imports_to_dicts(self, results):
|
|
1101
|
-
parsed_results = [
|
|
1102
|
-
{field.lower(): row[field] for field in IMPORT_STREAM_FIELDS}
|
|
1103
|
-
for row in results
|
|
1104
|
-
]
|
|
1105
|
-
return parsed_results
|
|
1106
1157
|
|
|
1107
1158
|
def change_stream_status(self, stream_id: str, model:str, suspend: bool):
|
|
1108
1159
|
if stream_id and model:
|
|
@@ -1274,7 +1325,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1274
1325
|
results = self._exec(f"CALL {APP_NAME}.api.get_data_stream('{name}', '{model}');")
|
|
1275
1326
|
if not results:
|
|
1276
1327
|
return None
|
|
1277
|
-
return
|
|
1328
|
+
return imports_to_dicts(results)
|
|
1278
1329
|
|
|
1279
1330
|
def create_import_stream(self, source:ImportSource, model:str, rate = 1, options: dict|None = None):
|
|
1280
1331
|
assert isinstance(source, ImportSourceTable), "Snowflake integration only supports loading from SF Tables. Try loading your data as a table via the Snowflake interface first."
|
|
@@ -1309,7 +1360,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1309
1360
|
try:
|
|
1310
1361
|
self._exec(command)
|
|
1311
1362
|
except Exception as e:
|
|
1312
|
-
|
|
1363
|
+
messages = collect_error_messages(e)
|
|
1364
|
+
if any("ensure that change_tracking is enabled on the source object" in msg for msg in messages):
|
|
1313
1365
|
if self.config.get("ensure_change_tracking", False) and not tracking_just_changed:
|
|
1314
1366
|
try:
|
|
1315
1367
|
self._exec(f"ALTER {kind} {object} SET CHANGE_TRACKING = TRUE;")
|
|
@@ -1320,7 +1372,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1320
1372
|
print("\n")
|
|
1321
1373
|
exception = SnowflakeChangeTrackingNotEnabledException((object, kind))
|
|
1322
1374
|
raise exception from None
|
|
1323
|
-
elif "
|
|
1375
|
+
elif any("database does not exist" in msg for msg in messages):
|
|
1324
1376
|
print("\n")
|
|
1325
1377
|
raise ModelNotFoundException(model) from None
|
|
1326
1378
|
raise e
|
|
@@ -1370,55 +1422,22 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1370
1422
|
if txn_id in self._pending_transactions:
|
|
1371
1423
|
self._pending_transactions.remove(txn_id)
|
|
1372
1424
|
|
|
1373
|
-
if status == "ABORTED"
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1425
|
+
if status == "ABORTED":
|
|
1426
|
+
if response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
1427
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
1428
|
+
# todo: use the timeout returned alongside the transaction as soon as it's exposed
|
|
1429
|
+
timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
1430
|
+
raise QueryTimeoutExceededException(
|
|
1431
|
+
timeout_mins=timeout_mins,
|
|
1432
|
+
query_id=txn_id,
|
|
1433
|
+
config_file_path=config_file_path,
|
|
1434
|
+
)
|
|
1435
|
+
elif response_row.get("ABORT_REASON", "") == GUARDRAILS_ABORT_REASON:
|
|
1436
|
+
raise GuardRailsException()
|
|
1382
1437
|
|
|
1383
1438
|
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
1384
1439
|
return status == "COMPLETED" or status == "ABORTED"
|
|
1385
1440
|
|
|
1386
|
-
def decrypt_stream(self, key: bytes, iv: bytes, src: bytes) -> bytes:
|
|
1387
|
-
"""Decrypt the provided stream with PKCS#5 padding handling."""
|
|
1388
|
-
|
|
1389
|
-
if crypto_disabled:
|
|
1390
|
-
if isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "warehouse":
|
|
1391
|
-
raise Exception("Please open the navigation-bar dropdown labeled *Packages* and select `cryptography` under the *Anaconda Packages* section, and then re-run your query.")
|
|
1392
|
-
else:
|
|
1393
|
-
raise Exception("library `cryptography.hazmat` missing; please install")
|
|
1394
|
-
|
|
1395
|
-
# `type:ignore`s are because of the conditional import, which
|
|
1396
|
-
# we have because warehouse-based snowflake notebooks don't support
|
|
1397
|
-
# the crypto library we're using.
|
|
1398
|
-
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) # type: ignore
|
|
1399
|
-
decryptor = cipher.decryptor()
|
|
1400
|
-
|
|
1401
|
-
# Decrypt the data
|
|
1402
|
-
decrypted_padded_data = decryptor.update(src) + decryptor.finalize()
|
|
1403
|
-
|
|
1404
|
-
# Unpad the decrypted data using PKCS#5
|
|
1405
|
-
unpadder = padding.PKCS7(128).unpadder() # type: ignore # Use 128 directly for AES
|
|
1406
|
-
unpadded_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
|
1407
|
-
|
|
1408
|
-
return unpadded_data
|
|
1409
|
-
|
|
1410
|
-
def _decrypt_artifact(self, data: bytes, encryption_material: str) -> bytes:
|
|
1411
|
-
"""Decrypts the artifact data using provided encryption material."""
|
|
1412
|
-
encryption_material_parts = encryption_material.split("|")
|
|
1413
|
-
assert len(encryption_material_parts) == 3, "Invalid encryption material"
|
|
1414
|
-
|
|
1415
|
-
algorithm, key_base64, iv_base64 = encryption_material_parts
|
|
1416
|
-
assert algorithm == "AES_128_CBC", f"Unsupported encryption algorithm {algorithm}"
|
|
1417
|
-
|
|
1418
|
-
key = base64.standard_b64decode(key_base64)
|
|
1419
|
-
iv = base64.standard_b64decode(iv_base64)
|
|
1420
|
-
|
|
1421
|
-
return self.decrypt_stream(key, iv, data)
|
|
1422
1441
|
|
|
1423
1442
|
def _list_exec_async_artifacts(self, txn_id: str, headers: Dict | None = None) -> Dict[str, Dict]:
|
|
1424
1443
|
"""Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
|
|
@@ -1470,7 +1489,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1470
1489
|
if len(content) == 0:
|
|
1471
1490
|
return b""
|
|
1472
1491
|
|
|
1473
|
-
return
|
|
1492
|
+
return decrypt_artifact(content, encryption_material)
|
|
1474
1493
|
|
|
1475
1494
|
# otherwise, return content directly
|
|
1476
1495
|
return content
|
|
@@ -1550,7 +1569,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1550
1569
|
def get_url_key(self, metadata) -> str:
|
|
1551
1570
|
# In Azure, there is only one type of URL, which is used for both internal and
|
|
1552
1571
|
# external access; always use that one
|
|
1553
|
-
if
|
|
1572
|
+
if is_azure_url(metadata['PRESIGNED_URL']):
|
|
1554
1573
|
return 'PRESIGNED_URL'
|
|
1555
1574
|
|
|
1556
1575
|
configured = self.config.get("download_url_type", None)
|
|
@@ -1559,17 +1578,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1559
1578
|
elif configured == "external":
|
|
1560
1579
|
return "PRESIGNED_URL"
|
|
1561
1580
|
|
|
1562
|
-
if
|
|
1581
|
+
if is_container_runtime():
|
|
1563
1582
|
return 'PRESIGNED_URL_AP'
|
|
1564
1583
|
|
|
1565
1584
|
return 'PRESIGNED_URL'
|
|
1566
1585
|
|
|
1567
|
-
def is_azure(self, url) -> bool:
|
|
1568
|
-
return "blob.core.windows.net" in url
|
|
1569
|
-
|
|
1570
|
-
def is_container_runtime(self) -> bool:
|
|
1571
|
-
return isinstance(runtime_env, SnowbookEnvironment) and runtime_env.runner == "container"
|
|
1572
|
-
|
|
1573
1586
|
def _exec_rai_app(
|
|
1574
1587
|
self,
|
|
1575
1588
|
database: str,
|
|
@@ -1583,6 +1596,40 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1583
1596
|
language: str = "rel",
|
|
1584
1597
|
query_timeout_mins: int | None = None,
|
|
1585
1598
|
):
|
|
1599
|
+
"""
|
|
1600
|
+
High-level method to execute RAI app stored procedures.
|
|
1601
|
+
|
|
1602
|
+
Builds and executes SQL to call the RAI app's exec_async_v2 stored procedure.
|
|
1603
|
+
This method handles the SQL string construction for two different formats:
|
|
1604
|
+
1. New format (with graph index): Uses object payload with parameterized query
|
|
1605
|
+
2. Legacy format: Uses positional parameters
|
|
1606
|
+
|
|
1607
|
+
The choice between formats depends on the use_graph_index configuration.
|
|
1608
|
+
The new format allows the stored procedure to hash the model and username
|
|
1609
|
+
to determine the database, while the legacy format uses the passed database directly.
|
|
1610
|
+
|
|
1611
|
+
This method is called by _exec_async_v2 to create transactions. It skips
|
|
1612
|
+
use_index retry logic (skip_engine_db_error_retry=True) because that
|
|
1613
|
+
is handled at a higher level by exec_raw/exec_lqp.
|
|
1614
|
+
|
|
1615
|
+
Args:
|
|
1616
|
+
database: Database/model name
|
|
1617
|
+
engine: Engine name (optional)
|
|
1618
|
+
raw_code: Code to execute (REL, LQP, or SQL)
|
|
1619
|
+
inputs: Input parameters for the query
|
|
1620
|
+
readonly: Whether the transaction is read-only
|
|
1621
|
+
nowait_durable: Whether to wait for durable writes
|
|
1622
|
+
request_headers: Optional HTTP headers
|
|
1623
|
+
bypass_index: Whether to bypass graph index setup
|
|
1624
|
+
language: Query language ("rel" or "lqp")
|
|
1625
|
+
query_timeout_mins: Optional query timeout in minutes
|
|
1626
|
+
|
|
1627
|
+
Returns:
|
|
1628
|
+
Response from the stored procedure call (transaction creation result)
|
|
1629
|
+
|
|
1630
|
+
Raises:
|
|
1631
|
+
Exception: If transaction creation fails
|
|
1632
|
+
"""
|
|
1586
1633
|
assert language == "rel" or language == "lqp", "Only 'rel' and 'lqp' languages are supported"
|
|
1587
1634
|
if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
|
|
1588
1635
|
query_timeout_mins = int(timeout_value)
|
|
@@ -1635,12 +1682,43 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1635
1682
|
query_timeout_mins: int | None = None,
|
|
1636
1683
|
gi_setup_skipped: bool = False,
|
|
1637
1684
|
):
|
|
1685
|
+
"""
|
|
1686
|
+
High-level async execution method with transaction polling and artifact management.
|
|
1687
|
+
|
|
1688
|
+
This is the core method for executing queries asynchronously. It:
|
|
1689
|
+
1. Creates a transaction by calling _exec_rai_app
|
|
1690
|
+
2. Handles two execution paths:
|
|
1691
|
+
- Fast path: Transaction completes immediately (COMPLETED/ABORTED)
|
|
1692
|
+
- Slow path: Transaction is pending, requires polling until completion
|
|
1693
|
+
3. Manages pending transactions list
|
|
1694
|
+
4. Downloads and returns query results/artifacts
|
|
1695
|
+
|
|
1696
|
+
This method is called by _execute_code (base implementation) and can be
|
|
1697
|
+
overridden by child classes (e.g., DirectAccessResources uses HTTP instead).
|
|
1698
|
+
|
|
1699
|
+
Args:
|
|
1700
|
+
database: Database/model name
|
|
1701
|
+
engine: Engine name (optional)
|
|
1702
|
+
raw_code: Code to execute (REL, LQP, or SQL)
|
|
1703
|
+
inputs: Input parameters for the query
|
|
1704
|
+
readonly: Whether the transaction is read-only
|
|
1705
|
+
nowait_durable: Whether to wait for durable writes
|
|
1706
|
+
headers: Optional HTTP headers
|
|
1707
|
+
bypass_index: Whether to bypass graph index setup
|
|
1708
|
+
language: Query language ("rel" or "lqp")
|
|
1709
|
+
query_timeout_mins: Optional query timeout in minutes
|
|
1710
|
+
gi_setup_skipped: Whether graph index setup was skipped (for retry logic)
|
|
1711
|
+
|
|
1712
|
+
Returns:
|
|
1713
|
+
Query results (downloaded artifacts)
|
|
1714
|
+
"""
|
|
1638
1715
|
if inputs is None:
|
|
1639
1716
|
inputs = {}
|
|
1640
1717
|
request_headers = debugging.add_current_propagation_headers(headers)
|
|
1641
1718
|
query_attrs_dict = json.loads(request_headers.get("X-Query-Attributes", "{}"))
|
|
1642
1719
|
|
|
1643
1720
|
with debugging.span("transaction", **query_attrs_dict) as txn_span:
|
|
1721
|
+
txn_start_time = time.time()
|
|
1644
1722
|
with debugging.span("create_v2", **query_attrs_dict) as create_span:
|
|
1645
1723
|
request_headers['user-agent'] = get_pyrel_version(self.generation)
|
|
1646
1724
|
request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
|
|
@@ -1671,8 +1749,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1671
1749
|
create_span["txn_id"] = txn_id
|
|
1672
1750
|
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
1673
1751
|
|
|
1752
|
+
print_txn_progress = should_print_txn_progress(self.config)
|
|
1753
|
+
|
|
1674
1754
|
# fast path: transaction already finished
|
|
1675
1755
|
if state in ["COMPLETED", "ABORTED"]:
|
|
1756
|
+
txn_end_time = time.time()
|
|
1676
1757
|
if txn_id in self._pending_transactions:
|
|
1677
1758
|
self._pending_transactions.remove(txn_id)
|
|
1678
1759
|
|
|
@@ -1681,13 +1762,24 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1681
1762
|
filename = row['FILENAME']
|
|
1682
1763
|
artifact_info[filename] = row
|
|
1683
1764
|
|
|
1765
|
+
txn_duration = txn_end_time - txn_start_time
|
|
1766
|
+
if print_txn_progress:
|
|
1767
|
+
print(
|
|
1768
|
+
query_complete_message(txn_id, txn_duration, status_header=True)
|
|
1769
|
+
)
|
|
1770
|
+
|
|
1684
1771
|
# Slow path: transaction not done yet; start polling
|
|
1685
1772
|
else:
|
|
1686
1773
|
self._pending_transactions.append(txn_id)
|
|
1774
|
+
# Use the interactive poller for transaction status
|
|
1687
1775
|
with debugging.span("wait", txn_id=txn_id):
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
|
|
1776
|
+
if print_txn_progress:
|
|
1777
|
+
poller = ExecTxnPoller(resource=self, txn_id=txn_id, headers=request_headers, txn_start_time=txn_start_time)
|
|
1778
|
+
poller.poll()
|
|
1779
|
+
else:
|
|
1780
|
+
poll_with_specified_overhead(
|
|
1781
|
+
lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
|
|
1782
|
+
)
|
|
1691
1783
|
artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
|
|
1692
1784
|
|
|
1693
1785
|
with debugging.span("fetch"):
|
|
@@ -1706,205 +1798,81 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1706
1798
|
return engine and engine["state"] == "READY"
|
|
1707
1799
|
|
|
1708
1800
|
def auto_create_engine(self, name: str | None = None, size: str | None = None, headers: Dict | None = None):
|
|
1709
|
-
|
|
1801
|
+
"""Synchronously create/ensure an engine is ready, blocking until ready."""
|
|
1710
1802
|
with debugging.span("auto_create_engine", active=self._active_engine) as span:
|
|
1711
1803
|
active = self._get_active_engine()
|
|
1712
1804
|
if active:
|
|
1713
1805
|
return active
|
|
1714
1806
|
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
# Use the provided size or fall back to the config
|
|
1718
|
-
if size:
|
|
1719
|
-
engine_size = size
|
|
1720
|
-
else:
|
|
1721
|
-
engine_size = self.config.get("engine_size", None)
|
|
1722
|
-
|
|
1723
|
-
# Validate engine size
|
|
1724
|
-
if engine_size:
|
|
1725
|
-
is_size_valid, sizes = self.validate_engine_size(engine_size)
|
|
1726
|
-
if not is_size_valid:
|
|
1727
|
-
raise Exception(f"Invalid engine size '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
|
|
1728
|
-
|
|
1729
|
-
# Validate engine name
|
|
1730
|
-
is_name_valid, _ = validate_engine_name(engine_name)
|
|
1731
|
-
if not is_name_valid:
|
|
1732
|
-
raise EngineNameValidationException(engine_name)
|
|
1807
|
+
# Resolve and validate parameters
|
|
1808
|
+
engine_name, engine_size = self._prepare_engine_params(name, size)
|
|
1733
1809
|
|
|
1734
1810
|
try:
|
|
1811
|
+
# Get current engine state
|
|
1735
1812
|
engine = self.get_engine(engine_name)
|
|
1736
1813
|
if engine:
|
|
1737
1814
|
span.update(cast(dict, engine))
|
|
1738
1815
|
|
|
1739
|
-
#
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
|
-
):
|
|
1752
|
-
poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
|
|
1753
|
-
|
|
1754
|
-
elif engine["state"] == "SUSPENDED":
|
|
1755
|
-
with Spinner(f"Resuming engine '{engine_name}'", f"Engine '{engine_name}' resumed", f"Failed to resume engine '{engine_name}'"):
|
|
1756
|
-
try:
|
|
1757
|
-
self.resume_engine_async(engine_name, headers=headers)
|
|
1758
|
-
poll_with_specified_overhead(lambda: self.is_engine_ready(engine_name), overhead_rate=0.1, max_delay=0.5, timeout=900)
|
|
1759
|
-
except Exception:
|
|
1760
|
-
raise EngineResumeFailed(engine_name)
|
|
1761
|
-
elif engine["state"] == "READY":
|
|
1762
|
-
# if the user explicitly specified a size, warn if the ready engine size doesn't match it
|
|
1763
|
-
if size is not None and engine["size"] != size:
|
|
1764
|
-
EngineSizeMismatchWarning(engine_name, engine["size"], size)
|
|
1765
|
-
self._set_active_engine(engine)
|
|
1766
|
-
return engine_name
|
|
1767
|
-
elif engine["state"] == "GONE":
|
|
1768
|
-
try:
|
|
1769
|
-
# "Gone" is abnormal condition when metadata and SF service don't match
|
|
1770
|
-
# Therefore, we have to delete the engine and create a new one
|
|
1771
|
-
# it could be case that engine is already deleted, so we have to catch the exception
|
|
1772
|
-
self.delete_engine(engine_name, headers=headers)
|
|
1773
|
-
# After deleting the engine, set it to None so that we can create a new engine
|
|
1774
|
-
engine = None
|
|
1775
|
-
except Exception as e:
|
|
1776
|
-
# if engine is already deleted, we will get an exception
|
|
1777
|
-
# we can ignore this exception and create a new engine
|
|
1778
|
-
if isinstance(e, EngineNotFoundException):
|
|
1779
|
-
engine = None
|
|
1780
|
-
pass
|
|
1781
|
-
else:
|
|
1782
|
-
raise EngineProvisioningFailed(engine_name, e) from e
|
|
1783
|
-
|
|
1784
|
-
if not engine:
|
|
1785
|
-
with Spinner(
|
|
1786
|
-
f"Auto-creating engine {engine_name}",
|
|
1787
|
-
f"Auto-created engine {engine_name}",
|
|
1788
|
-
"Engine creation failed",
|
|
1789
|
-
):
|
|
1790
|
-
self.create_engine(engine_name, size=engine_size, headers=headers)
|
|
1816
|
+
# Create context for state handling
|
|
1817
|
+
context = EngineContext(
|
|
1818
|
+
engine_name=engine_name,
|
|
1819
|
+
engine_size=engine_size,
|
|
1820
|
+
headers=headers,
|
|
1821
|
+
requested_size=size,
|
|
1822
|
+
span=span,
|
|
1823
|
+
)
|
|
1824
|
+
|
|
1825
|
+
# Process engine state using sync handlers
|
|
1826
|
+
self._process_engine_state(engine, context, self._sync_engine_state_handlers)
|
|
1827
|
+
|
|
1791
1828
|
except Exception as e:
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
raise DuoSecurityFailed(e)
|
|
1795
|
-
raise EngineProvisioningFailed(engine_name, e) from e
|
|
1829
|
+
self._handle_engine_creation_errors(e, engine_name)
|
|
1830
|
+
|
|
1796
1831
|
return engine_name
|
|
1797
1832
|
|
|
1798
1833
|
def auto_create_engine_async(self, name: str | None = None):
|
|
1834
|
+
"""Asynchronously create/ensure an engine, returns immediately."""
|
|
1799
1835
|
active = self._get_active_engine()
|
|
1800
1836
|
if active and (active == name or name is None):
|
|
1801
|
-
return
|
|
1837
|
+
return active
|
|
1802
1838
|
|
|
1803
1839
|
with Spinner(
|
|
1804
1840
|
"Checking engine status",
|
|
1805
1841
|
leading_newline=True,
|
|
1806
1842
|
) as spinner:
|
|
1807
|
-
from v0.relationalai.tools.cli_helpers import validate_engine_name
|
|
1808
1843
|
with debugging.span("auto_create_engine_async", active=self._active_engine):
|
|
1809
|
-
|
|
1810
|
-
engine_size = self.
|
|
1811
|
-
if engine_size:
|
|
1812
|
-
is_size_valid, sizes = self.validate_engine_size(engine_size)
|
|
1813
|
-
if not is_size_valid:
|
|
1814
|
-
raise Exception(f"Invalid engine size in config: '{engine_size}'. Valid sizes are: {', '.join(sizes)}")
|
|
1815
|
-
else:
|
|
1816
|
-
engine_size = self.config.get_default_engine_size()
|
|
1844
|
+
# Resolve and validate parameters (use_default_size=True for async)
|
|
1845
|
+
engine_name, engine_size = self._prepare_engine_params(name, None, use_default_size=True)
|
|
1817
1846
|
|
|
1818
|
-
is_name_valid, _ = validate_engine_name(engine_name)
|
|
1819
|
-
if not is_name_valid:
|
|
1820
|
-
raise EngineNameValidationException(engine_name)
|
|
1821
1847
|
try:
|
|
1848
|
+
# Get current engine state
|
|
1822
1849
|
engine = self.get_engine(engine_name)
|
|
1823
|
-
# if engine is gone, delete it and create new one
|
|
1824
|
-
# in case of pending state, do nothing, it is use_index responsibility to wait for engine to be ready
|
|
1825
|
-
if engine:
|
|
1826
|
-
if engine["state"] == "PENDING":
|
|
1827
|
-
spinner.update_messages(
|
|
1828
|
-
{
|
|
1829
|
-
"finished_message": f"Starting engine {engine_name}",
|
|
1830
|
-
}
|
|
1831
|
-
)
|
|
1832
|
-
pass
|
|
1833
|
-
elif engine["state"] == "SUSPENDED":
|
|
1834
|
-
spinner.update_messages(
|
|
1835
|
-
{
|
|
1836
|
-
"finished_message": f"Resuming engine {engine_name}",
|
|
1837
|
-
}
|
|
1838
|
-
)
|
|
1839
|
-
try:
|
|
1840
|
-
self.resume_engine_async(engine_name)
|
|
1841
|
-
except Exception:
|
|
1842
|
-
raise EngineResumeFailed(engine_name)
|
|
1843
|
-
elif engine["state"] == "READY":
|
|
1844
|
-
spinner.update_messages(
|
|
1845
|
-
{
|
|
1846
|
-
"finished_message": f"Engine {engine_name} initialized",
|
|
1847
|
-
}
|
|
1848
|
-
)
|
|
1849
|
-
pass
|
|
1850
|
-
elif engine["state"] == "GONE":
|
|
1851
|
-
spinner.update_messages(
|
|
1852
|
-
{
|
|
1853
|
-
"message": f"Restarting engine {engine_name}",
|
|
1854
|
-
}
|
|
1855
|
-
)
|
|
1856
|
-
try:
|
|
1857
|
-
# "Gone" is abnormal condition when metadata and SF service don't match
|
|
1858
|
-
# Therefore, we have to delete the engine and create a new one
|
|
1859
|
-
# it could be case that engine is already deleted, so we have to catch the exception
|
|
1860
|
-
# set it to None so that we can create a new engine
|
|
1861
|
-
engine = None
|
|
1862
|
-
self.delete_engine(engine_name)
|
|
1863
|
-
except Exception as e:
|
|
1864
|
-
# if engine is already deleted, we will get an exception
|
|
1865
|
-
# we can ignore this exception and create a new engine asynchronously
|
|
1866
|
-
if isinstance(e, EngineNotFoundException):
|
|
1867
|
-
engine = None
|
|
1868
|
-
pass
|
|
1869
|
-
else:
|
|
1870
|
-
print(e)
|
|
1871
|
-
raise EngineProvisioningFailed(engine_name, e) from e
|
|
1872
|
-
|
|
1873
|
-
if not engine:
|
|
1874
|
-
self.create_engine_async(engine_name, size=self.config.get("engine_size", None))
|
|
1875
|
-
spinner.update_messages(
|
|
1876
|
-
{
|
|
1877
|
-
"finished_message": f"Starting engine {engine_name}...",
|
|
1878
|
-
}
|
|
1879
|
-
)
|
|
1880
|
-
else:
|
|
1881
|
-
self._set_active_engine(engine)
|
|
1882
1850
|
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
|
|
1851
|
+
# Create context for state handling
|
|
1852
|
+
context = EngineContext(
|
|
1853
|
+
engine_name=engine_name,
|
|
1854
|
+
engine_size=engine_size,
|
|
1855
|
+
headers=None,
|
|
1856
|
+
requested_size=None,
|
|
1857
|
+
spinner=spinner,
|
|
1888
1858
|
)
|
|
1889
|
-
if DUO_TEXT in str(e).lower():
|
|
1890
|
-
raise DuoSecurityFailed(e)
|
|
1891
|
-
if isinstance(e, RAIException):
|
|
1892
|
-
raise e
|
|
1893
|
-
print(e)
|
|
1894
|
-
raise EngineProvisioningFailed(engine_name, e) from e
|
|
1895
1859
|
|
|
1896
|
-
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1860
|
+
# Process engine state using async handlers
|
|
1861
|
+
self._process_engine_state(engine, context, self._async_engine_state_handlers, set_active_on_success=True)
|
|
1862
|
+
|
|
1863
|
+
except Exception as e:
|
|
1864
|
+
spinner.update_messages({
|
|
1865
|
+
"finished_message": f"Failed to create engine {engine_name}",
|
|
1866
|
+
})
|
|
1867
|
+
self._handle_engine_creation_errors(e, engine_name, preserve_rai_exception=True)
|
|
1868
|
+
|
|
1869
|
+
return engine_name
|
|
1902
1870
|
|
|
1903
1871
|
#--------------------------------------------------
|
|
1904
1872
|
# Exec
|
|
1905
1873
|
#--------------------------------------------------
|
|
1906
1874
|
|
|
1907
|
-
def
|
|
1875
|
+
def _execute_code(
|
|
1908
1876
|
self,
|
|
1909
1877
|
database: str,
|
|
1910
1878
|
engine: str | None,
|
|
@@ -1916,39 +1884,40 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1916
1884
|
bypass_index: bool,
|
|
1917
1885
|
language: str,
|
|
1918
1886
|
query_timeout_mins: int | None,
|
|
1919
|
-
):
|
|
1920
|
-
"""Execute with graph index retry logic.
|
|
1921
|
-
|
|
1922
|
-
Attempts execution with gi_setup_skipped=True first. If an engine or database
|
|
1923
|
-
issue occurs, polls use_index and retries with gi_setup_skipped=False.
|
|
1887
|
+
) -> Any:
|
|
1924
1888
|
"""
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1889
|
+
Template method for code execution - can be overridden by child classes.
|
|
1890
|
+
|
|
1891
|
+
This is a template method that provides a hook for child classes to add
|
|
1892
|
+
execution logic (like retry mechanisms). The base implementation simply
|
|
1893
|
+
calls _exec_async_v2 directly.
|
|
1894
|
+
|
|
1895
|
+
UseIndexResources overrides this method to use _exec_with_gi_retry, which
|
|
1896
|
+
adds automatic use_index polling on engine/database errors.
|
|
1897
|
+
|
|
1898
|
+
This method is called by exec_lqp() and exec_raw() to provide a single
|
|
1899
|
+
execution point that can be customized per resource class.
|
|
1900
|
+
|
|
1901
|
+
Args:
|
|
1902
|
+
database: Database/model name
|
|
1903
|
+
engine: Engine name (optional)
|
|
1904
|
+
raw_code: Code to execute (already processed/encoded)
|
|
1905
|
+
inputs: Input parameters for the query
|
|
1906
|
+
readonly: Whether the transaction is read-only
|
|
1907
|
+
nowait_durable: Whether to wait for durable writes
|
|
1908
|
+
headers: Optional HTTP headers
|
|
1909
|
+
bypass_index: Whether to bypass graph index setup
|
|
1910
|
+
language: Query language ("rel" or "lqp")
|
|
1911
|
+
query_timeout_mins: Optional query timeout in minutes
|
|
1912
|
+
|
|
1913
|
+
Returns:
|
|
1914
|
+
Query results
|
|
1915
|
+
"""
|
|
1916
|
+
return self._exec_async_v2(
|
|
1917
|
+
database, engine, raw_code, inputs, readonly, nowait_durable,
|
|
1918
|
+
headers=headers, bypass_index=bypass_index, language=language,
|
|
1919
|
+
query_timeout_mins=query_timeout_mins, gi_setup_skipped=True,
|
|
1920
|
+
)
|
|
1952
1921
|
|
|
1953
1922
|
def exec_lqp(
|
|
1954
1923
|
self,
|
|
@@ -1963,13 +1932,13 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1963
1932
|
bypass_index=False,
|
|
1964
1933
|
query_timeout_mins: int | None = None,
|
|
1965
1934
|
):
|
|
1935
|
+
"""Execute LQP code."""
|
|
1966
1936
|
raw_code_b64 = base64.b64encode(raw_code).decode("utf-8")
|
|
1967
|
-
return self.
|
|
1937
|
+
return self._execute_code(
|
|
1968
1938
|
database, engine, raw_code_b64, inputs, readonly, nowait_durable,
|
|
1969
1939
|
headers, bypass_index, 'lqp', query_timeout_mins
|
|
1970
1940
|
)
|
|
1971
1941
|
|
|
1972
|
-
|
|
1973
1942
|
def exec_raw(
|
|
1974
1943
|
self,
|
|
1975
1944
|
database: str,
|
|
@@ -1983,8 +1952,9 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1983
1952
|
bypass_index=False,
|
|
1984
1953
|
query_timeout_mins: int | None = None,
|
|
1985
1954
|
):
|
|
1955
|
+
"""Execute raw code."""
|
|
1986
1956
|
raw_code = raw_code.replace("'", "\\'")
|
|
1987
|
-
return self.
|
|
1957
|
+
return self._execute_code(
|
|
1988
1958
|
database, engine, raw_code, inputs, readonly, nowait_durable,
|
|
1989
1959
|
headers, bypass_index, 'rel', query_timeout_mins
|
|
1990
1960
|
)
|
|
@@ -2039,7 +2009,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2039
2009
|
if query_timeout_mins is not None:
|
|
2040
2010
|
res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
|
|
2041
2011
|
else:
|
|
2042
|
-
|
|
2012
|
+
res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
|
|
2043
2013
|
txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
|
|
2044
2014
|
rejected_rows = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows", [])
|
|
2045
2015
|
rejected_rows_count = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows_count", 0)
|
|
@@ -2076,8 +2046,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2076
2046
|
if rejected_rows:
|
|
2077
2047
|
debugging.warn(RowsDroppedFromTargetTableWarning(rejected_rows, rejected_rows_count, col_names_map))
|
|
2078
2048
|
except Exception as e:
|
|
2079
|
-
|
|
2080
|
-
if "no columns returned" in msg or "columns of results could not be determined" in msg:
|
|
2049
|
+
messages = collect_error_messages(e)
|
|
2050
|
+
if any("no columns returned" in msg or "columns of results could not be determined" in msg for msg in messages):
|
|
2081
2051
|
pass
|
|
2082
2052
|
else:
|
|
2083
2053
|
raise e
|
|
@@ -2108,6 +2078,10 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2108
2078
|
self.sources.add(parser.identity)
|
|
2109
2079
|
return ns
|
|
2110
2080
|
|
|
2081
|
+
#--------------------------------------------------
|
|
2082
|
+
# Source Management
|
|
2083
|
+
#--------------------------------------------------
|
|
2084
|
+
|
|
2111
2085
|
def _check_source_updates(self, sources: Iterable[str]):
|
|
2112
2086
|
if not sources:
|
|
2113
2087
|
return {}
|
|
@@ -2370,19 +2344,6 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2370
2344
|
#--------------------------------------------------
|
|
2371
2345
|
# Transactions
|
|
2372
2346
|
#--------------------------------------------------
|
|
2373
|
-
def txn_list_to_dicts(self, transactions):
|
|
2374
|
-
dicts = []
|
|
2375
|
-
for txn in transactions:
|
|
2376
|
-
dict = {}
|
|
2377
|
-
txn_dict = txn.asDict()
|
|
2378
|
-
for key in txn_dict:
|
|
2379
|
-
mapValue = FIELD_MAP.get(key.lower())
|
|
2380
|
-
if mapValue:
|
|
2381
|
-
dict[mapValue] = txn_dict[key]
|
|
2382
|
-
else:
|
|
2383
|
-
dict[key.lower()] = txn_dict[key]
|
|
2384
|
-
dicts.append(dict)
|
|
2385
|
-
return dicts
|
|
2386
2347
|
|
|
2387
2348
|
def get_transaction(self, transaction_id):
|
|
2388
2349
|
results = self._exec(
|
|
@@ -2390,7 +2351,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2390
2351
|
if not results:
|
|
2391
2352
|
return None
|
|
2392
2353
|
|
|
2393
|
-
results =
|
|
2354
|
+
results = txn_list_to_dicts(results)
|
|
2394
2355
|
|
|
2395
2356
|
txn = {field: results[0][field] for field in GET_TXN_SQL_FIELDS}
|
|
2396
2357
|
|
|
@@ -2442,7 +2403,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2442
2403
|
results = self._exec(query, [limit])
|
|
2443
2404
|
if not results:
|
|
2444
2405
|
return []
|
|
2445
|
-
return
|
|
2406
|
+
return txn_list_to_dicts(results)
|
|
2446
2407
|
|
|
2447
2408
|
def cancel_transaction(self, transaction_id):
|
|
2448
2409
|
self._exec(f"CALL {APP_NAME}.api.cancel_own_transaction(?);", [transaction_id])
|
|
@@ -2476,63 +2437,12 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2476
2437
|
return None
|
|
2477
2438
|
return results[0][0]
|
|
2478
2439
|
|
|
2479
|
-
|
|
2480
|
-
|
|
2481
|
-
|
|
2482
|
-
return []
|
|
2483
|
-
return [{"name":name}
|
|
2484
|
-
for (name, *rest) in results]
|
|
2485
|
-
|
|
2486
|
-
def list_compute_pools(self):
|
|
2487
|
-
results = self._exec("SHOW COMPUTE POOLS")
|
|
2488
|
-
if not results:
|
|
2489
|
-
return []
|
|
2490
|
-
return [{"name":name, "status":status, "min_nodes":min_nodes, "max_nodes":max_nodes, "instance_family":instance_family}
|
|
2491
|
-
for (name, status, min_nodes, max_nodes, instance_family, *rest) in results]
|
|
2492
|
-
|
|
2493
|
-
def list_roles(self):
|
|
2494
|
-
results = self._exec("SELECT CURRENT_AVAILABLE_ROLES()")
|
|
2495
|
-
if not results:
|
|
2496
|
-
return []
|
|
2497
|
-
# the response is a single row with a single column containing
|
|
2498
|
-
# a stringified JSON array of role names:
|
|
2499
|
-
row = results[0]
|
|
2500
|
-
if not row:
|
|
2501
|
-
return []
|
|
2502
|
-
return [{"name": name} for name in json.loads(row[0])]
|
|
2503
|
-
|
|
2504
|
-
def list_apps(self):
|
|
2505
|
-
all_apps = self._exec(f"SHOW APPLICATIONS LIKE '{RAI_APP_NAME}'")
|
|
2506
|
-
if not all_apps:
|
|
2507
|
-
all_apps = self._exec("SHOW APPLICATIONS")
|
|
2508
|
-
if not all_apps:
|
|
2509
|
-
return []
|
|
2510
|
-
return [{"name":name}
|
|
2511
|
-
for (time, name, *rest) in all_apps]
|
|
2512
|
-
|
|
2513
|
-
def list_databases(self):
|
|
2514
|
-
results = self._exec("SHOW DATABASES")
|
|
2515
|
-
if not results:
|
|
2516
|
-
return []
|
|
2517
|
-
return [{"name":name}
|
|
2518
|
-
for (time, name, *rest) in results]
|
|
2519
|
-
|
|
2520
|
-
def list_sf_schemas(self, database:str):
|
|
2521
|
-
results = self._exec(f"SHOW SCHEMAS IN {database}")
|
|
2522
|
-
if not results:
|
|
2523
|
-
return []
|
|
2524
|
-
return [{"name":name}
|
|
2525
|
-
for (time, name, *rest) in results]
|
|
2526
|
-
|
|
2527
|
-
def list_tables(self, database:str, schema:str):
|
|
2528
|
-
results = self._exec(f"SHOW OBJECTS IN {database}.{schema}")
|
|
2529
|
-
items = []
|
|
2530
|
-
if results:
|
|
2531
|
-
for (time, name, db_name, schema_name, kind, *rest) in results:
|
|
2532
|
-
items.append({"name":name, "kind":kind.lower()})
|
|
2533
|
-
return items
|
|
2440
|
+
# CLI methods (list_warehouses, list_compute_pools, list_roles, list_apps,
|
|
2441
|
+
# list_databases, list_sf_schemas, list_tables) are now in CLIResources class
|
|
2442
|
+
# schema_info is kept in base Resources class since it's used by SnowflakeSchema._fetch_info()
|
|
2534
2443
|
|
|
2535
|
-
def schema_info(self, database:str, schema:str, tables:Iterable[str]):
|
|
2444
|
+
def schema_info(self, database: str, schema: str, tables: Iterable[str]):
|
|
2445
|
+
"""Get detailed schema information including primary keys, foreign keys, and columns."""
|
|
2536
2446
|
app_name = self.get_app_name()
|
|
2537
2447
|
# Only pass the db + schema as the identifier so that the resulting identity is correct
|
|
2538
2448
|
parser = IdentityParser(f"{database}.{schema}")
|
|
@@ -2550,7 +2460,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2550
2460
|
|
|
2551
2461
|
# IdentityParser will parse a single value (with no ".") and store it in this case in the db field
|
|
2552
2462
|
with debugging.span("columns") as span:
|
|
2553
|
-
|
|
2463
|
+
tables_str = ", ".join([f"'{IdentityParser(t).db}'" for t in tables])
|
|
2554
2464
|
query = textwrap.dedent(f"""
|
|
2555
2465
|
begin
|
|
2556
2466
|
SHOW COLUMNS IN SCHEMA {parser.identity};
|
|
@@ -2567,7 +2477,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2567
2477
|
ELSE FALSE
|
|
2568
2478
|
END as "supported_type"
|
|
2569
2479
|
FROM table(result_scan(-1)) as t
|
|
2570
|
-
WHERE "table_name" in ({
|
|
2480
|
+
WHERE "table_name" in ({tables_str})
|
|
2571
2481
|
);
|
|
2572
2482
|
return table(r);
|
|
2573
2483
|
end;
|
|
@@ -2578,7 +2488,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
2578
2488
|
results = defaultdict(lambda: {"pks": [], "fks": {}, "columns": {}, "invalid_columns": {}})
|
|
2579
2489
|
if pks:
|
|
2580
2490
|
for row in pks:
|
|
2581
|
-
results[row[3]]["pks"].append(row[4])
|
|
2491
|
+
results[row[3]]["pks"].append(row[4]) # type: ignore
|
|
2582
2492
|
if fks:
|
|
2583
2493
|
for row in fks:
|
|
2584
2494
|
results[row[7]]["fks"][row[8]] = row[3]
|
|
@@ -2720,7 +2630,7 @@ class SnowflakeSchema:
|
|
|
2720
2630
|
tables_with_invalid_columns[table_name] = table_info["invalid_columns"]
|
|
2721
2631
|
|
|
2722
2632
|
if tables_with_invalid_columns:
|
|
2723
|
-
from
|
|
2633
|
+
from v0.relationalai.errors import UnsupportedColumnTypesWarning
|
|
2724
2634
|
UnsupportedColumnTypesWarning(tables_with_invalid_columns)
|
|
2725
2635
|
|
|
2726
2636
|
def _add(self, name, is_imported=False):
|
|
@@ -2929,10 +2839,14 @@ class Provider(ProviderBase):
|
|
|
2929
2839
|
if resources:
|
|
2930
2840
|
self.resources = resources
|
|
2931
2841
|
else:
|
|
2932
|
-
|
|
2933
|
-
|
|
2934
|
-
|
|
2935
|
-
|
|
2842
|
+
from .resources_factory import create_resources_instance
|
|
2843
|
+
self.resources = create_resources_instance(
|
|
2844
|
+
config=config,
|
|
2845
|
+
profile=profile,
|
|
2846
|
+
generation=generation or Generation.V0,
|
|
2847
|
+
dry_run=False,
|
|
2848
|
+
language="rel",
|
|
2849
|
+
)
|
|
2936
2850
|
|
|
2937
2851
|
def list_streams(self, model:str):
|
|
2938
2852
|
return self.resources.list_imports(model=model)
|
|
@@ -3101,19 +3015,31 @@ def Graph(
|
|
|
3101
3015
|
nowait_durable: bool = True,
|
|
3102
3016
|
format: str = "default",
|
|
3103
3017
|
):
|
|
3018
|
+
from .resources_factory import create_resources_instance
|
|
3019
|
+
from .use_index_resources import UseIndexResources
|
|
3104
3020
|
|
|
3105
|
-
client_class = Client
|
|
3106
|
-
resource_class = Resources
|
|
3107
3021
|
use_graph_index = config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
3108
3022
|
use_monotype_operators = config.get("compiler.use_monotype_operators", False)
|
|
3109
|
-
use_direct_access = config.get("use_direct_access", USE_DIRECT_ACCESS)
|
|
3110
3023
|
|
|
3111
|
-
|
|
3024
|
+
# Create resources instance using factory
|
|
3025
|
+
resources = create_resources_instance(
|
|
3026
|
+
config=config,
|
|
3027
|
+
profile=profile,
|
|
3028
|
+
connection=connection,
|
|
3029
|
+
generation=Generation.V0,
|
|
3030
|
+
dry_run=False, # Resources instance dry_run is separate from client dry_run
|
|
3031
|
+
language="rel",
|
|
3032
|
+
)
|
|
3033
|
+
|
|
3034
|
+
# Determine client class based on resources type and config
|
|
3035
|
+
# SnowflakeClient is used for resources that support use_index functionality
|
|
3036
|
+
if use_graph_index or isinstance(resources, UseIndexResources):
|
|
3112
3037
|
client_class = SnowflakeClient
|
|
3113
|
-
|
|
3114
|
-
|
|
3038
|
+
else:
|
|
3039
|
+
client_class = Client
|
|
3040
|
+
|
|
3115
3041
|
client = client_class(
|
|
3116
|
-
|
|
3042
|
+
resources,
|
|
3117
3043
|
rel.Compiler(config),
|
|
3118
3044
|
name,
|
|
3119
3045
|
config,
|
|
@@ -3184,686 +3110,3 @@ def Graph(
|
|
|
3184
3110
|
debugging.set_source(pyrel_base)
|
|
3185
3111
|
client.install("pyrel_base", pyrel_base)
|
|
3186
3112
|
return dsl.Graph(client, name, format=format)
|
|
3187
|
-
|
|
3188
|
-
|
|
3189
|
-
|
|
3190
|
-
#--------------------------------------------------
|
|
3191
|
-
# Direct Access
|
|
3192
|
-
#--------------------------------------------------
|
|
3193
|
-
# Note: All direct access components should live in a separate file
|
|
3194
|
-
|
|
3195
|
-
class DirectAccessResources(Resources):
|
|
3196
|
-
"""
|
|
3197
|
-
Resources class for Direct Service Access avoiding Snowflake service functions.
|
|
3198
|
-
"""
|
|
3199
|
-
def __init__(
|
|
3200
|
-
self,
|
|
3201
|
-
profile: Union[str, None] = None,
|
|
3202
|
-
config: Union[Config, None] = None,
|
|
3203
|
-
connection: Union[Session, None] = None,
|
|
3204
|
-
dry_run: bool = False,
|
|
3205
|
-
reset_session: bool = False,
|
|
3206
|
-
generation: Optional[Generation] = None,
|
|
3207
|
-
language: str = "rel",
|
|
3208
|
-
):
|
|
3209
|
-
super().__init__(
|
|
3210
|
-
generation=generation,
|
|
3211
|
-
profile=profile,
|
|
3212
|
-
config=config,
|
|
3213
|
-
connection=connection,
|
|
3214
|
-
reset_session=reset_session,
|
|
3215
|
-
dry_run=dry_run,
|
|
3216
|
-
language=language,
|
|
3217
|
-
)
|
|
3218
|
-
self._endpoint_info = ConfigStore(ENDPOINT_FILE)
|
|
3219
|
-
self._service_endpoint = ""
|
|
3220
|
-
self._direct_access_client = None
|
|
3221
|
-
self.generation = generation
|
|
3222
|
-
self.database = ""
|
|
3223
|
-
|
|
3224
|
-
@property
|
|
3225
|
-
def service_endpoint(self) -> str:
|
|
3226
|
-
return self._retrieve_service_endpoint()
|
|
3227
|
-
|
|
3228
|
-
def _retrieve_service_endpoint(self, enforce_update=False) -> str:
|
|
3229
|
-
account = self.config.get("account")
|
|
3230
|
-
app_name = self.config.get("rai_app_name")
|
|
3231
|
-
service_endpoint_key = f"{account}.{app_name}.service_endpoint"
|
|
3232
|
-
if self._service_endpoint and not enforce_update:
|
|
3233
|
-
return self._service_endpoint
|
|
3234
|
-
if self._endpoint_info.get(service_endpoint_key, "") and not enforce_update:
|
|
3235
|
-
self._service_endpoint = str(self._endpoint_info.get(service_endpoint_key, ""))
|
|
3236
|
-
return self._service_endpoint
|
|
3237
|
-
|
|
3238
|
-
is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
|
|
3239
|
-
query = f"CALL {self.get_app_name()}.app.service_endpoint({not is_snowflake_notebook});"
|
|
3240
|
-
result = self._exec(query)
|
|
3241
|
-
assert result, f"Could not retrieve service endpoint for {self.get_app_name()}"
|
|
3242
|
-
if is_snowflake_notebook:
|
|
3243
|
-
self._service_endpoint = f"http://{result[0]['SERVICE_ENDPOINT']}"
|
|
3244
|
-
else:
|
|
3245
|
-
self._service_endpoint = f"https://{result[0]['SERVICE_ENDPOINT']}"
|
|
3246
|
-
|
|
3247
|
-
self._endpoint_info.set(service_endpoint_key, self._service_endpoint)
|
|
3248
|
-
# save the endpoint to `ENDPOINT_FILE` to avoid calling the endpoint with every
|
|
3249
|
-
# pyrel execution
|
|
3250
|
-
try:
|
|
3251
|
-
self._endpoint_info.save()
|
|
3252
|
-
except Exception:
|
|
3253
|
-
print("Failed to persist endpoints to file. This might slow down future executions.")
|
|
3254
|
-
|
|
3255
|
-
return self._service_endpoint
|
|
3256
|
-
|
|
3257
|
-
@property
|
|
3258
|
-
def direct_access_client(self) -> DirectAccessClient:
|
|
3259
|
-
if self._direct_access_client:
|
|
3260
|
-
return self._direct_access_client
|
|
3261
|
-
try:
|
|
3262
|
-
service_endpoint = self.service_endpoint
|
|
3263
|
-
self._direct_access_client = DirectAccessClient(
|
|
3264
|
-
self.config, self.token_handler, service_endpoint, self.generation,
|
|
3265
|
-
)
|
|
3266
|
-
except Exception as e:
|
|
3267
|
-
raise e
|
|
3268
|
-
return self._direct_access_client
|
|
3269
|
-
|
|
3270
|
-
def request(
|
|
3271
|
-
self,
|
|
3272
|
-
endpoint: str,
|
|
3273
|
-
payload: Dict[str, Any] | None = None,
|
|
3274
|
-
headers: Dict[str, str] | None = None,
|
|
3275
|
-
path_params: Dict[str, str] | None = None,
|
|
3276
|
-
query_params: Dict[str, str] | None = None,
|
|
3277
|
-
skip_auto_create: bool = False,
|
|
3278
|
-
skip_engine_db_error_retry: bool = False,
|
|
3279
|
-
) -> requests.Response:
|
|
3280
|
-
with debugging.span("direct_access_request"):
|
|
3281
|
-
def _send_request():
|
|
3282
|
-
return self.direct_access_client.request(
|
|
3283
|
-
endpoint=endpoint,
|
|
3284
|
-
payload=payload,
|
|
3285
|
-
headers=headers,
|
|
3286
|
-
path_params=path_params,
|
|
3287
|
-
query_params=query_params,
|
|
3288
|
-
)
|
|
3289
|
-
try:
|
|
3290
|
-
response = _send_request()
|
|
3291
|
-
if response.status_code != 200:
|
|
3292
|
-
# For 404 responses with skip_auto_create=True, return immediately to let caller handle it
|
|
3293
|
-
# (e.g., get_engine needs to check 404 and return None for auto_create_engine)
|
|
3294
|
-
# For skip_auto_create=False, continue to auto-creation logic below
|
|
3295
|
-
if response.status_code == 404 and skip_auto_create:
|
|
3296
|
-
return response
|
|
3297
|
-
|
|
3298
|
-
try:
|
|
3299
|
-
message = response.json().get("message", "")
|
|
3300
|
-
except requests.exceptions.JSONDecodeError:
|
|
3301
|
-
# Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
|
|
3302
|
-
# this should have been caught by the 404 check above, so this is an error.
|
|
3303
|
-
# For skip_auto_create=False, we explicitly check status_code below,
|
|
3304
|
-
# so we don't need to parse the message.
|
|
3305
|
-
if skip_auto_create:
|
|
3306
|
-
raise ResponseStatusException(
|
|
3307
|
-
f"Failed to parse error response from endpoint {endpoint}.", response
|
|
3308
|
-
)
|
|
3309
|
-
message = "" # Not used when we check status_code directly
|
|
3310
|
-
|
|
3311
|
-
# fix engine on engine error and retry
|
|
3312
|
-
# Skip setting up GI if skip_auto_create is True to avoid recursion or skip_engine_db_error_retry is true to let _exec_async_v2 perform the retry with the correct headers.
|
|
3313
|
-
if ((_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message)) and not skip_engine_db_error_retry:
|
|
3314
|
-
engine_name = payload.get("caller_engine_name", "") if payload else ""
|
|
3315
|
-
engine_name = engine_name or self.get_default_engine_name()
|
|
3316
|
-
engine_size = self.config.get_default_engine_size()
|
|
3317
|
-
self._poll_use_index(
|
|
3318
|
-
app_name=self.get_app_name(),
|
|
3319
|
-
sources=self.sources,
|
|
3320
|
-
model=self.database,
|
|
3321
|
-
engine_name=engine_name,
|
|
3322
|
-
engine_size=engine_size,
|
|
3323
|
-
headers=headers,
|
|
3324
|
-
)
|
|
3325
|
-
response = _send_request()
|
|
3326
|
-
except requests.exceptions.ConnectionError as e:
|
|
3327
|
-
if "NameResolutionError" in str(e):
|
|
3328
|
-
# when we can not resolve the service endpoint, we assume it is outdated
|
|
3329
|
-
# hence, we try to retrieve it again and query again.
|
|
3330
|
-
self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
|
|
3331
|
-
enforce_update=True,
|
|
3332
|
-
)
|
|
3333
|
-
return _send_request()
|
|
3334
|
-
# raise in all other cases
|
|
3335
|
-
raise e
|
|
3336
|
-
return response
|
|
3337
|
-
|
|
3338
|
-
def _txn_request_with_gi_retry(
|
|
3339
|
-
self,
|
|
3340
|
-
payload: Dict,
|
|
3341
|
-
headers: Dict[str, str],
|
|
3342
|
-
query_params: Dict,
|
|
3343
|
-
engine: Union[str, None],
|
|
3344
|
-
):
|
|
3345
|
-
"""Make request with graph index retry logic.
|
|
3346
|
-
|
|
3347
|
-
Attempts request with gi_setup_skipped=True first. If an engine or database
|
|
3348
|
-
issue occurs, polls use_index and retries with gi_setup_skipped=False.
|
|
3349
|
-
"""
|
|
3350
|
-
response = self.request(
|
|
3351
|
-
"create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
|
|
3352
|
-
)
|
|
3353
|
-
|
|
3354
|
-
if response.status_code != 200:
|
|
3355
|
-
try:
|
|
3356
|
-
message = response.json().get("message", "")
|
|
3357
|
-
except requests.exceptions.JSONDecodeError:
|
|
3358
|
-
message = ""
|
|
3359
|
-
|
|
3360
|
-
if _is_engine_issue(message) or _is_database_issue(message):
|
|
3361
|
-
engine_name = engine or self.get_default_engine_name()
|
|
3362
|
-
engine_size = self.config.get_default_engine_size()
|
|
3363
|
-
self._poll_use_index(
|
|
3364
|
-
app_name=self.get_app_name(),
|
|
3365
|
-
sources=self.sources,
|
|
3366
|
-
model=self.database,
|
|
3367
|
-
engine_name=engine_name,
|
|
3368
|
-
engine_size=engine_size,
|
|
3369
|
-
headers=headers,
|
|
3370
|
-
)
|
|
3371
|
-
headers['gi_setup_skipped'] = 'False'
|
|
3372
|
-
response = self.request(
|
|
3373
|
-
"create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
|
|
3374
|
-
)
|
|
3375
|
-
else:
|
|
3376
|
-
raise ResponseStatusException("Failed to create transaction.", response)
|
|
3377
|
-
|
|
3378
|
-
return response
|
|
3379
|
-
|
|
3380
|
-
def _exec_async_v2(
|
|
3381
|
-
self,
|
|
3382
|
-
database: str,
|
|
3383
|
-
engine: Union[str, None],
|
|
3384
|
-
raw_code: str,
|
|
3385
|
-
inputs: Dict | None = None,
|
|
3386
|
-
readonly=True,
|
|
3387
|
-
nowait_durable=False,
|
|
3388
|
-
headers: Dict[str, str] | None = None,
|
|
3389
|
-
bypass_index=False,
|
|
3390
|
-
language: str = "rel",
|
|
3391
|
-
query_timeout_mins: int | None = None,
|
|
3392
|
-
gi_setup_skipped: bool = False,
|
|
3393
|
-
):
|
|
3394
|
-
|
|
3395
|
-
with debugging.span("transaction") as txn_span:
|
|
3396
|
-
with debugging.span("create_v2") as create_span:
|
|
3397
|
-
|
|
3398
|
-
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
3399
|
-
|
|
3400
|
-
payload = {
|
|
3401
|
-
"dbname": database,
|
|
3402
|
-
"engine_name": engine,
|
|
3403
|
-
"query": raw_code,
|
|
3404
|
-
"v1_inputs": inputs,
|
|
3405
|
-
"nowait_durable": nowait_durable,
|
|
3406
|
-
"readonly": readonly,
|
|
3407
|
-
"language": language,
|
|
3408
|
-
}
|
|
3409
|
-
if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
|
|
3410
|
-
query_timeout_mins = int(timeout_value)
|
|
3411
|
-
if query_timeout_mins is not None:
|
|
3412
|
-
payload["timeout_mins"] = query_timeout_mins
|
|
3413
|
-
query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
|
|
3414
|
-
|
|
3415
|
-
# Add gi_setup_skipped to headers
|
|
3416
|
-
if headers is None:
|
|
3417
|
-
headers = {}
|
|
3418
|
-
headers["gi_setup_skipped"] = str(gi_setup_skipped)
|
|
3419
|
-
headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
|
|
3420
|
-
|
|
3421
|
-
response = self._txn_request_with_gi_retry(
|
|
3422
|
-
payload, headers, query_params, engine
|
|
3423
|
-
)
|
|
3424
|
-
|
|
3425
|
-
artifact_info = {}
|
|
3426
|
-
response_content = response.json()
|
|
3427
|
-
|
|
3428
|
-
txn_id = response_content["transaction"]['id']
|
|
3429
|
-
state = response_content["transaction"]['state']
|
|
3430
|
-
|
|
3431
|
-
txn_span["txn_id"] = txn_id
|
|
3432
|
-
create_span["txn_id"] = txn_id
|
|
3433
|
-
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
3434
|
-
|
|
3435
|
-
# fast path: transaction already finished
|
|
3436
|
-
if state in ["COMPLETED", "ABORTED"]:
|
|
3437
|
-
if txn_id in self._pending_transactions:
|
|
3438
|
-
self._pending_transactions.remove(txn_id)
|
|
3439
|
-
|
|
3440
|
-
# Process rows to get the rest of the artifacts
|
|
3441
|
-
for result in response_content.get("results", []):
|
|
3442
|
-
filename = result['filename']
|
|
3443
|
-
# making keys uppercase to match the old behavior
|
|
3444
|
-
artifact_info[filename] = {k.upper(): v for k, v in result.items()}
|
|
3445
|
-
|
|
3446
|
-
# Slow path: transaction not done yet; start polling
|
|
3447
|
-
else:
|
|
3448
|
-
self._pending_transactions.append(txn_id)
|
|
3449
|
-
with debugging.span("wait", txn_id=txn_id):
|
|
3450
|
-
poll_with_specified_overhead(
|
|
3451
|
-
lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
|
|
3452
|
-
)
|
|
3453
|
-
artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
|
|
3454
|
-
|
|
3455
|
-
with debugging.span("fetch"):
|
|
3456
|
-
return self._download_results(artifact_info, txn_id, state)
|
|
3457
|
-
|
|
3458
|
-
def _prepare_index(
|
|
3459
|
-
self,
|
|
3460
|
-
model: str,
|
|
3461
|
-
engine_name: str,
|
|
3462
|
-
engine_size: str = "",
|
|
3463
|
-
language: str = "rel",
|
|
3464
|
-
rai_relations: List[str] | None = None,
|
|
3465
|
-
pyrel_program_id: str | None = None,
|
|
3466
|
-
skip_pull_relations: bool = False,
|
|
3467
|
-
headers: Dict | None = None,
|
|
3468
|
-
):
|
|
3469
|
-
"""
|
|
3470
|
-
Prepare the index for the given engine and model.
|
|
3471
|
-
"""
|
|
3472
|
-
with debugging.span("prepare_index"):
|
|
3473
|
-
if headers is None:
|
|
3474
|
-
headers = {}
|
|
3475
|
-
|
|
3476
|
-
payload = {
|
|
3477
|
-
"model_name": model,
|
|
3478
|
-
"caller_engine_name": engine_name,
|
|
3479
|
-
"language": language,
|
|
3480
|
-
"pyrel_program_id": pyrel_program_id,
|
|
3481
|
-
"skip_pull_relations": skip_pull_relations,
|
|
3482
|
-
"rai_relations": rai_relations or [],
|
|
3483
|
-
"user_agent": get_pyrel_version(self.generation),
|
|
3484
|
-
}
|
|
3485
|
-
# Only include engine_size if it has a non-empty string value
|
|
3486
|
-
if engine_size and engine_size.strip():
|
|
3487
|
-
payload["caller_engine_size"] = engine_size
|
|
3488
|
-
|
|
3489
|
-
response = self.request(
|
|
3490
|
-
"prepare_index", payload=payload, headers=headers
|
|
3491
|
-
)
|
|
3492
|
-
|
|
3493
|
-
if response.status_code != 200:
|
|
3494
|
-
raise ResponseStatusException("Failed to prepare index.", response)
|
|
3495
|
-
|
|
3496
|
-
return response.json()
|
|
3497
|
-
|
|
3498
|
-
def _poll_use_index(
|
|
3499
|
-
self,
|
|
3500
|
-
app_name: str,
|
|
3501
|
-
sources: Iterable[str],
|
|
3502
|
-
model: str,
|
|
3503
|
-
engine_name: str,
|
|
3504
|
-
engine_size: str | None = None,
|
|
3505
|
-
program_span_id: str | None = None,
|
|
3506
|
-
headers: Dict | None = None,
|
|
3507
|
-
):
|
|
3508
|
-
return DirectUseIndexPoller(
|
|
3509
|
-
self,
|
|
3510
|
-
app_name=app_name,
|
|
3511
|
-
sources=sources,
|
|
3512
|
-
model=model,
|
|
3513
|
-
engine_name=engine_name,
|
|
3514
|
-
engine_size=engine_size,
|
|
3515
|
-
language=self.language,
|
|
3516
|
-
program_span_id=program_span_id,
|
|
3517
|
-
headers=headers,
|
|
3518
|
-
generation=self.generation,
|
|
3519
|
-
).poll()
|
|
3520
|
-
|
|
3521
|
-
def maybe_poll_use_index(
|
|
3522
|
-
self,
|
|
3523
|
-
app_name: str,
|
|
3524
|
-
sources: Iterable[str],
|
|
3525
|
-
model: str,
|
|
3526
|
-
engine_name: str,
|
|
3527
|
-
engine_size: str | None = None,
|
|
3528
|
-
program_span_id: str | None = None,
|
|
3529
|
-
headers: Dict | None = None,
|
|
3530
|
-
):
|
|
3531
|
-
"""Only call poll() if there are sources to process and cache is not valid."""
|
|
3532
|
-
sources_list = list(sources)
|
|
3533
|
-
self.database = model
|
|
3534
|
-
if sources_list:
|
|
3535
|
-
poller = DirectUseIndexPoller(
|
|
3536
|
-
self,
|
|
3537
|
-
app_name=app_name,
|
|
3538
|
-
sources=sources_list,
|
|
3539
|
-
model=model,
|
|
3540
|
-
engine_name=engine_name,
|
|
3541
|
-
engine_size=engine_size,
|
|
3542
|
-
language=self.language,
|
|
3543
|
-
program_span_id=program_span_id,
|
|
3544
|
-
headers=headers,
|
|
3545
|
-
generation=self.generation,
|
|
3546
|
-
)
|
|
3547
|
-
# If cache is valid (data freshness has not expired), skip polling
|
|
3548
|
-
if poller.cache.is_valid():
|
|
3549
|
-
cached_sources = len(poller.cache.sources)
|
|
3550
|
-
total_sources = len(sources_list)
|
|
3551
|
-
cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
|
|
3552
|
-
|
|
3553
|
-
message = f"Using cached data for {cached_sources}/{total_sources} data streams"
|
|
3554
|
-
if cached_timestamp:
|
|
3555
|
-
print(f"\n{message} (cached at {cached_timestamp})\n")
|
|
3556
|
-
else:
|
|
3557
|
-
print(f"\n{message}\n")
|
|
3558
|
-
else:
|
|
3559
|
-
return poller.poll()
|
|
3560
|
-
|
|
3561
|
-
def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
|
|
3562
|
-
"""Check whether the given transaction has completed."""
|
|
3563
|
-
|
|
3564
|
-
with debugging.span("check_status"):
|
|
3565
|
-
response = self.request(
|
|
3566
|
-
"get_txn",
|
|
3567
|
-
headers=headers,
|
|
3568
|
-
path_params={"txn_id": txn_id},
|
|
3569
|
-
)
|
|
3570
|
-
assert response, f"No results from get_transaction('{txn_id}')"
|
|
3571
|
-
|
|
3572
|
-
response_content = response.json()
|
|
3573
|
-
transaction = response_content["transaction"]
|
|
3574
|
-
status: str = transaction['state']
|
|
3575
|
-
|
|
3576
|
-
# remove the transaction from the pending list if it's completed or aborted
|
|
3577
|
-
if status in ["COMPLETED", "ABORTED"]:
|
|
3578
|
-
if txn_id in self._pending_transactions:
|
|
3579
|
-
self._pending_transactions.remove(txn_id)
|
|
3580
|
-
|
|
3581
|
-
if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
3582
|
-
config_file_path = getattr(self.config, 'file_path', None)
|
|
3583
|
-
timeout_ms = int(transaction.get("timeout_ms", 0))
|
|
3584
|
-
timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
3585
|
-
raise QueryTimeoutExceededException(
|
|
3586
|
-
timeout_mins=timeout_mins,
|
|
3587
|
-
query_id=txn_id,
|
|
3588
|
-
config_file_path=config_file_path,
|
|
3589
|
-
)
|
|
3590
|
-
|
|
3591
|
-
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
3592
|
-
return status == "COMPLETED" or status == "ABORTED"
|
|
3593
|
-
|
|
3594
|
-
def _list_exec_async_artifacts(self, txn_id: str, headers: Dict[str, str] | None = None) -> Dict[str, Dict]:
|
|
3595
|
-
"""Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
|
|
3596
|
-
with debugging.span("list_results"):
|
|
3597
|
-
response = self.request(
|
|
3598
|
-
"get_txn_artifacts",
|
|
3599
|
-
headers=headers,
|
|
3600
|
-
path_params={"txn_id": txn_id},
|
|
3601
|
-
)
|
|
3602
|
-
assert response, f"No results from get_transaction_artifacts('{txn_id}')"
|
|
3603
|
-
artifact_info = {}
|
|
3604
|
-
for result in response.json()["results"]:
|
|
3605
|
-
filename = result['filename']
|
|
3606
|
-
# making keys uppercase to match the old behavior
|
|
3607
|
-
artifact_info[filename] = {k.upper(): v for k, v in result.items()}
|
|
3608
|
-
return artifact_info
|
|
3609
|
-
|
|
3610
|
-
def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
|
|
3611
|
-
with debugging.span("get_transaction_problems"):
|
|
3612
|
-
response = self.request(
|
|
3613
|
-
"get_txn_problems",
|
|
3614
|
-
path_params={"txn_id": txn_id},
|
|
3615
|
-
)
|
|
3616
|
-
response_content = response.json()
|
|
3617
|
-
if not response_content:
|
|
3618
|
-
return []
|
|
3619
|
-
return response_content.get("problems", [])
|
|
3620
|
-
|
|
3621
|
-
def get_transaction_events(self, transaction_id: str, continuation_token: str = ''):
|
|
3622
|
-
response = self.request(
|
|
3623
|
-
"get_txn_events",
|
|
3624
|
-
path_params={"txn_id": transaction_id, "stream_name": "profiler"},
|
|
3625
|
-
query_params={"continuation_token": continuation_token},
|
|
3626
|
-
)
|
|
3627
|
-
response_content = response.json()
|
|
3628
|
-
if not response_content:
|
|
3629
|
-
return {
|
|
3630
|
-
"events": [],
|
|
3631
|
-
"continuation_token": None
|
|
3632
|
-
}
|
|
3633
|
-
return response_content
|
|
3634
|
-
|
|
3635
|
-
#--------------------------------------------------
|
|
3636
|
-
# Databases
|
|
3637
|
-
#--------------------------------------------------
|
|
3638
|
-
|
|
3639
|
-
def get_installed_packages(self, database: str) -> Union[Dict, None]:
|
|
3640
|
-
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
3641
|
-
if use_graph_index:
|
|
3642
|
-
response = self.request(
|
|
3643
|
-
"get_model_package_versions",
|
|
3644
|
-
payload={"model_name": database},
|
|
3645
|
-
)
|
|
3646
|
-
else:
|
|
3647
|
-
response = self.request(
|
|
3648
|
-
"get_package_versions",
|
|
3649
|
-
path_params={"db_name": database},
|
|
3650
|
-
)
|
|
3651
|
-
if response.status_code == 404 and response.json().get("message", "") == "database not found":
|
|
3652
|
-
return None
|
|
3653
|
-
if response.status_code != 200:
|
|
3654
|
-
raise ResponseStatusException(
|
|
3655
|
-
f"Failed to retrieve package versions for {database}.", response
|
|
3656
|
-
)
|
|
3657
|
-
|
|
3658
|
-
content = response.json()
|
|
3659
|
-
if not content:
|
|
3660
|
-
return None
|
|
3661
|
-
|
|
3662
|
-
return safe_json_loads(content["package_versions"])
|
|
3663
|
-
|
|
3664
|
-
def get_database(self, database: str):
|
|
3665
|
-
with debugging.span("get_database", dbname=database):
|
|
3666
|
-
if not database:
|
|
3667
|
-
raise ValueError("Database name must be provided to get database.")
|
|
3668
|
-
response = self.request(
|
|
3669
|
-
"get_db",
|
|
3670
|
-
path_params={},
|
|
3671
|
-
query_params={"name": database},
|
|
3672
|
-
)
|
|
3673
|
-
if response.status_code != 200:
|
|
3674
|
-
raise ResponseStatusException(f"Failed to get db. db:{database}", response)
|
|
3675
|
-
|
|
3676
|
-
response_content = response.json()
|
|
3677
|
-
|
|
3678
|
-
if (response_content.get("databases") and len(response_content["databases"]) == 1):
|
|
3679
|
-
db = response_content["databases"][0]
|
|
3680
|
-
return {
|
|
3681
|
-
"id": db["id"],
|
|
3682
|
-
"name": db["name"],
|
|
3683
|
-
"created_by": db.get("created_by"),
|
|
3684
|
-
"created_on": ms_to_timestamp(db.get("created_on")),
|
|
3685
|
-
"deleted_by": db.get("deleted_by"),
|
|
3686
|
-
"deleted_on": ms_to_timestamp(db.get("deleted_on")),
|
|
3687
|
-
"state": db["state"],
|
|
3688
|
-
}
|
|
3689
|
-
else:
|
|
3690
|
-
return None
|
|
3691
|
-
|
|
3692
|
-
def create_graph(self, name: str):
|
|
3693
|
-
with debugging.span("create_model", dbname=name):
|
|
3694
|
-
return self._create_database(name,"")
|
|
3695
|
-
|
|
3696
|
-
def delete_graph(self, name:str, force=False, language: str = "rel"):
|
|
3697
|
-
prop_hdrs = debugging.gen_current_propagation_headers()
|
|
3698
|
-
if self.config.get("use_graph_index", USE_GRAPH_INDEX):
|
|
3699
|
-
keep_database = not force and self.config.get("reuse_model", True)
|
|
3700
|
-
with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
|
|
3701
|
-
response = self.request(
|
|
3702
|
-
"release_index",
|
|
3703
|
-
payload={
|
|
3704
|
-
"model_name": name,
|
|
3705
|
-
"keep_database": keep_database,
|
|
3706
|
-
"language": language,
|
|
3707
|
-
"user_agent": get_pyrel_version(self.generation),
|
|
3708
|
-
},
|
|
3709
|
-
headers=prop_hdrs,
|
|
3710
|
-
)
|
|
3711
|
-
if (
|
|
3712
|
-
response.status_code != 200
|
|
3713
|
-
and not (
|
|
3714
|
-
response.status_code == 404
|
|
3715
|
-
and "database not found" in response.json().get("message", "")
|
|
3716
|
-
)
|
|
3717
|
-
):
|
|
3718
|
-
raise ResponseStatusException(f"Failed to release index. Model: {name} ", response)
|
|
3719
|
-
else:
|
|
3720
|
-
with debugging.span("delete_model", name=name):
|
|
3721
|
-
self._delete_database(name, headers=prop_hdrs)
|
|
3722
|
-
|
|
3723
|
-
def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
|
|
3724
|
-
if force and self.get_graph(target_name):
|
|
3725
|
-
self.delete_graph(target_name)
|
|
3726
|
-
with debugging.span("clone_model", target_name=target_name, source_name=source_name):
|
|
3727
|
-
return self._create_database(target_name,source_name)
|
|
3728
|
-
|
|
3729
|
-
def _delete_database(self, name:str, headers:Dict={}):
|
|
3730
|
-
with debugging.span("_delete_database", dbname=name):
|
|
3731
|
-
response = self.request(
|
|
3732
|
-
"delete_db",
|
|
3733
|
-
path_params={"db_name": name},
|
|
3734
|
-
query_params={},
|
|
3735
|
-
headers=headers,
|
|
3736
|
-
)
|
|
3737
|
-
if response.status_code != 200:
|
|
3738
|
-
raise ResponseStatusException(f"Failed to delete db. db:{name} ", response)
|
|
3739
|
-
|
|
3740
|
-
def _create_database(self, name:str, source_name:str):
|
|
3741
|
-
with debugging.span("_create_database", dbname=name):
|
|
3742
|
-
payload = {
|
|
3743
|
-
"name": name,
|
|
3744
|
-
"source_name": source_name,
|
|
3745
|
-
}
|
|
3746
|
-
response = self.request(
|
|
3747
|
-
"create_db", payload=payload, headers={}, query_params={},
|
|
3748
|
-
)
|
|
3749
|
-
if response.status_code != 200:
|
|
3750
|
-
raise ResponseStatusException(f"Failed to create db. db:{name}", response)
|
|
3751
|
-
|
|
3752
|
-
#--------------------------------------------------
|
|
3753
|
-
# Engines
|
|
3754
|
-
#--------------------------------------------------
|
|
3755
|
-
|
|
3756
|
-
def list_engines(self, state: str | None = None):
|
|
3757
|
-
response = self.request("list_engines")
|
|
3758
|
-
if response.status_code != 200:
|
|
3759
|
-
raise ResponseStatusException(
|
|
3760
|
-
"Failed to retrieve engines.", response
|
|
3761
|
-
)
|
|
3762
|
-
response_content = response.json()
|
|
3763
|
-
if not response_content:
|
|
3764
|
-
return []
|
|
3765
|
-
engines = [
|
|
3766
|
-
{
|
|
3767
|
-
"name": engine["name"],
|
|
3768
|
-
"id": engine["id"],
|
|
3769
|
-
"size": engine["size"],
|
|
3770
|
-
"state": engine["status"], # callers are expecting 'state'
|
|
3771
|
-
"created_by": engine["created_by"],
|
|
3772
|
-
"created_on": engine["created_on"],
|
|
3773
|
-
"updated_on": engine["updated_on"],
|
|
3774
|
-
}
|
|
3775
|
-
for engine in response_content.get("engines", [])
|
|
3776
|
-
if state is None or engine.get("status") == state
|
|
3777
|
-
]
|
|
3778
|
-
return sorted(engines, key=lambda x: x["name"])
|
|
3779
|
-
|
|
3780
|
-
def get_engine(self, name: str):
|
|
3781
|
-
response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
|
|
3782
|
-
if response.status_code == 404: # engine not found return 404
|
|
3783
|
-
return None
|
|
3784
|
-
elif response.status_code != 200:
|
|
3785
|
-
raise ResponseStatusException(
|
|
3786
|
-
f"Failed to retrieve engine {name}.", response
|
|
3787
|
-
)
|
|
3788
|
-
engine = response.json()
|
|
3789
|
-
if not engine:
|
|
3790
|
-
return None
|
|
3791
|
-
engine_state: EngineState = {
|
|
3792
|
-
"name": engine["name"],
|
|
3793
|
-
"id": engine["id"],
|
|
3794
|
-
"size": engine["size"],
|
|
3795
|
-
"state": engine["status"], # callers are expecting 'state'
|
|
3796
|
-
"created_by": engine["created_by"],
|
|
3797
|
-
"created_on": engine["created_on"],
|
|
3798
|
-
"updated_on": engine["updated_on"],
|
|
3799
|
-
"version": engine["version"],
|
|
3800
|
-
"auto_suspend": engine["auto_suspend_mins"],
|
|
3801
|
-
"suspends_at": engine["suspends_at"],
|
|
3802
|
-
}
|
|
3803
|
-
return engine_state
|
|
3804
|
-
|
|
3805
|
-
def _create_engine(
|
|
3806
|
-
self,
|
|
3807
|
-
name: str,
|
|
3808
|
-
size: str | None = None,
|
|
3809
|
-
auto_suspend_mins: int | None = None,
|
|
3810
|
-
is_async: bool = False,
|
|
3811
|
-
headers: Dict[str, str] | None = None
|
|
3812
|
-
):
|
|
3813
|
-
# only async engine creation supported via direct access
|
|
3814
|
-
if not is_async:
|
|
3815
|
-
return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
|
|
3816
|
-
payload:Dict[str, Any] = {
|
|
3817
|
-
"name": name,
|
|
3818
|
-
}
|
|
3819
|
-
if auto_suspend_mins is not None:
|
|
3820
|
-
payload["auto_suspend_mins"] = auto_suspend_mins
|
|
3821
|
-
if size is not None:
|
|
3822
|
-
payload["size"] = size
|
|
3823
|
-
response = self.request(
|
|
3824
|
-
"create_engine",
|
|
3825
|
-
payload=payload,
|
|
3826
|
-
path_params={"engine_type": "logic"},
|
|
3827
|
-
headers=headers,
|
|
3828
|
-
skip_auto_create=True,
|
|
3829
|
-
)
|
|
3830
|
-
if response.status_code != 200:
|
|
3831
|
-
raise ResponseStatusException(
|
|
3832
|
-
f"Failed to create engine {name} with size {size}.", response
|
|
3833
|
-
)
|
|
3834
|
-
|
|
3835
|
-
def delete_engine(self, name:str, force:bool = False, headers={}):
|
|
3836
|
-
response = self.request(
|
|
3837
|
-
"delete_engine",
|
|
3838
|
-
path_params={"engine_name": name, "engine_type": "logic"},
|
|
3839
|
-
headers=headers,
|
|
3840
|
-
skip_auto_create=True,
|
|
3841
|
-
)
|
|
3842
|
-
if response.status_code != 200:
|
|
3843
|
-
raise ResponseStatusException(
|
|
3844
|
-
f"Failed to delete engine {name}.", response
|
|
3845
|
-
)
|
|
3846
|
-
|
|
3847
|
-
def suspend_engine(self, name:str):
|
|
3848
|
-
response = self.request(
|
|
3849
|
-
"suspend_engine",
|
|
3850
|
-
path_params={"engine_name": name, "engine_type": "logic"},
|
|
3851
|
-
skip_auto_create=True,
|
|
3852
|
-
)
|
|
3853
|
-
if response.status_code != 200:
|
|
3854
|
-
raise ResponseStatusException(
|
|
3855
|
-
f"Failed to suspend engine {name}.", response
|
|
3856
|
-
)
|
|
3857
|
-
|
|
3858
|
-
def resume_engine_async(self, name:str, headers={}):
|
|
3859
|
-
response = self.request(
|
|
3860
|
-
"resume_engine",
|
|
3861
|
-
path_params={"engine_name": name, "engine_type": "logic"},
|
|
3862
|
-
headers=headers,
|
|
3863
|
-
skip_auto_create=True,
|
|
3864
|
-
)
|
|
3865
|
-
if response.status_code != 200:
|
|
3866
|
-
raise ResponseStatusException(
|
|
3867
|
-
f"Failed to resume engine {name}.", response
|
|
3868
|
-
)
|
|
3869
|
-
return {}
|