relationalai 1.0.0a1__py3-none-any.whl → 1.0.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/semantics/frontend/base.py +3 -0
- relationalai/semantics/frontend/front_compiler.py +5 -2
- relationalai/semantics/metamodel/builtins.py +2 -1
- relationalai/semantics/metamodel/metamodel.py +32 -4
- relationalai/semantics/metamodel/pprint.py +5 -3
- relationalai/semantics/metamodel/typer.py +324 -297
- relationalai/semantics/std/aggregates.py +0 -1
- relationalai/semantics/std/datetime.py +4 -1
- relationalai/shims/executor.py +26 -5
- relationalai/shims/mm2v0.py +119 -44
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/METADATA +1 -1
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/RECORD +57 -48
- v0/relationalai/__init__.py +69 -22
- v0/relationalai/clients/__init__.py +15 -2
- v0/relationalai/clients/client.py +4 -4
- v0/relationalai/clients/local.py +5 -5
- v0/relationalai/clients/resources/__init__.py +8 -0
- v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
- v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
- v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
- v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
- v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
- v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +606 -1392
- v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
- v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
- v0/relationalai/clients/resources/snowflake/util.py +387 -0
- v0/relationalai/early_access/dsl/ir/executor.py +4 -4
- v0/relationalai/early_access/dsl/snow/api.py +2 -1
- v0/relationalai/errors.py +23 -0
- v0/relationalai/experimental/solvers.py +7 -7
- v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
- v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
- v0/relationalai/semantics/internal/internal.py +4 -4
- v0/relationalai/semantics/internal/snowflake.py +3 -2
- v0/relationalai/semantics/lqp/executor.py +20 -22
- v0/relationalai/semantics/lqp/model2lqp.py +42 -4
- v0/relationalai/semantics/lqp/passes.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/cdc.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +53 -12
- v0/relationalai/semantics/metamodel/builtins.py +8 -6
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
- v0/relationalai/semantics/metamodel/util.py +6 -5
- v0/relationalai/semantics/reasoners/graph/core.py +8 -9
- v0/relationalai/semantics/rel/executor.py +14 -11
- v0/relationalai/semantics/sql/compiler.py +2 -2
- v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
- v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
- v0/relationalai/tools/cli.py +26 -30
- v0/relationalai/tools/cli_helpers.py +10 -2
- v0/relationalai/util/otel_configuration.py +2 -1
- v0/relationalai/util/otel_handler.py +1 -1
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/WHEEL +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/top_level.txt +0 -0
- /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
|
@@ -0,0 +1,711 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Direct Access Resources - Resources class for Direct Service Access.
|
|
3
|
+
This class overrides methods to use direct HTTP requests instead of Snowflake service functions.
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
from .... import debugging
|
|
10
|
+
from ....tools.constants import USE_GRAPH_INDEX, DEFAULT_QUERY_TIMEOUT_MINS, Generation
|
|
11
|
+
from ....environments import runtime_env, SnowbookEnvironment
|
|
12
|
+
from ...config import Config, ConfigStore, ENDPOINT_FILE
|
|
13
|
+
from ...direct_access_client import DirectAccessClient
|
|
14
|
+
from ...types import EngineState
|
|
15
|
+
from ...util import get_pyrel_version, poll_with_specified_overhead, safe_json_loads, ms_to_timestamp
|
|
16
|
+
from ....errors import ResponseStatusException, QueryTimeoutExceededException
|
|
17
|
+
from snowflake.snowpark import Session
|
|
18
|
+
|
|
19
|
+
# Import UseIndexResources to enable use_index functionality with direct access
|
|
20
|
+
from .use_index_resources import UseIndexResources
|
|
21
|
+
|
|
22
|
+
# Import helper functions from util
|
|
23
|
+
from .util import is_engine_issue as _is_engine_issue, is_database_issue as _is_database_issue, collect_error_messages
|
|
24
|
+
|
|
25
|
+
from .use_index_poller import DirectUseIndexPoller
|
|
26
|
+
from typing import Iterable
|
|
27
|
+
|
|
28
|
+
# Constants
|
|
29
|
+
TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DirectAccessResources(UseIndexResources):
|
|
33
|
+
"""
|
|
34
|
+
Resources class for Direct Service Access avoiding Snowflake service functions.
|
|
35
|
+
Uses HTTP requests instead of Snowflake SQL for execution.
|
|
36
|
+
"""
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
profile: Union[str, None] = None,
|
|
40
|
+
config: Union[Config, None] = None,
|
|
41
|
+
connection: Union[Session, None] = None,
|
|
42
|
+
dry_run: bool = False,
|
|
43
|
+
reset_session: bool = False,
|
|
44
|
+
generation: Optional[Generation] = None,
|
|
45
|
+
language: str = "rel",
|
|
46
|
+
):
|
|
47
|
+
super().__init__(
|
|
48
|
+
generation=generation,
|
|
49
|
+
profile=profile,
|
|
50
|
+
config=config,
|
|
51
|
+
connection=connection,
|
|
52
|
+
reset_session=reset_session,
|
|
53
|
+
dry_run=dry_run,
|
|
54
|
+
language=language,
|
|
55
|
+
)
|
|
56
|
+
self._endpoint_info = ConfigStore(ENDPOINT_FILE)
|
|
57
|
+
self._service_endpoint = ""
|
|
58
|
+
self._direct_access_client = None
|
|
59
|
+
# database and language are already set by UseIndexResources.__init__
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def service_endpoint(self) -> str:
|
|
63
|
+
return self._retrieve_service_endpoint()
|
|
64
|
+
|
|
65
|
+
def _retrieve_service_endpoint(self, enforce_update=False) -> str:
|
|
66
|
+
account = self.config.get("account")
|
|
67
|
+
app_name = self.config.get("rai_app_name")
|
|
68
|
+
service_endpoint_key = f"{account}.{app_name}.service_endpoint"
|
|
69
|
+
if self._service_endpoint and not enforce_update:
|
|
70
|
+
return self._service_endpoint
|
|
71
|
+
if self._endpoint_info.get(service_endpoint_key, "") and not enforce_update:
|
|
72
|
+
self._service_endpoint = str(self._endpoint_info.get(service_endpoint_key, ""))
|
|
73
|
+
return self._service_endpoint
|
|
74
|
+
|
|
75
|
+
is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
|
|
76
|
+
query = f"CALL {self.get_app_name()}.app.service_endpoint({not is_snowflake_notebook});"
|
|
77
|
+
result = self._exec(query)
|
|
78
|
+
assert result, f"Could not retrieve service endpoint for {self.get_app_name()}"
|
|
79
|
+
if is_snowflake_notebook:
|
|
80
|
+
self._service_endpoint = f"http://{result[0]['SERVICE_ENDPOINT']}"
|
|
81
|
+
else:
|
|
82
|
+
self._service_endpoint = f"https://{result[0]['SERVICE_ENDPOINT']}"
|
|
83
|
+
|
|
84
|
+
self._endpoint_info.set(service_endpoint_key, self._service_endpoint)
|
|
85
|
+
# save the endpoint to `ENDPOINT_FILE` to avoid calling the endpoint with every
|
|
86
|
+
# pyrel execution
|
|
87
|
+
try:
|
|
88
|
+
self._endpoint_info.save()
|
|
89
|
+
except Exception:
|
|
90
|
+
print("Failed to persist endpoints to file. This might slow down future executions.")
|
|
91
|
+
|
|
92
|
+
return self._service_endpoint
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def direct_access_client(self) -> DirectAccessClient:
|
|
96
|
+
if self._direct_access_client:
|
|
97
|
+
return self._direct_access_client
|
|
98
|
+
try:
|
|
99
|
+
service_endpoint = self.service_endpoint
|
|
100
|
+
self._direct_access_client = DirectAccessClient(
|
|
101
|
+
self.config, self.token_handler, service_endpoint, self.generation,
|
|
102
|
+
)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise e
|
|
105
|
+
return self._direct_access_client
|
|
106
|
+
|
|
107
|
+
def request(
|
|
108
|
+
self,
|
|
109
|
+
endpoint: str,
|
|
110
|
+
payload: Dict[str, Any] | None = None,
|
|
111
|
+
headers: Dict[str, str] | None = None,
|
|
112
|
+
path_params: Dict[str, str] | None = None,
|
|
113
|
+
query_params: Dict[str, str] | None = None,
|
|
114
|
+
skip_auto_create: bool = False,
|
|
115
|
+
skip_engine_db_error_retry: bool = False,
|
|
116
|
+
) -> requests.Response:
|
|
117
|
+
with debugging.span("direct_access_request"):
|
|
118
|
+
def _send_request():
|
|
119
|
+
return self.direct_access_client.request(
|
|
120
|
+
endpoint=endpoint,
|
|
121
|
+
payload=payload,
|
|
122
|
+
headers=headers,
|
|
123
|
+
path_params=path_params,
|
|
124
|
+
query_params=query_params,
|
|
125
|
+
)
|
|
126
|
+
try:
|
|
127
|
+
response = _send_request()
|
|
128
|
+
if response.status_code != 200:
|
|
129
|
+
# For 404 responses with skip_auto_create=True, return immediately to let caller handle it
|
|
130
|
+
# (e.g., get_engine needs to check 404 and return None for auto_create_engine)
|
|
131
|
+
# For skip_auto_create=False, continue to auto-creation logic below
|
|
132
|
+
if response.status_code == 404 and skip_auto_create:
|
|
133
|
+
return response
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
message = response.json().get("message", "")
|
|
137
|
+
except requests.exceptions.JSONDecodeError:
|
|
138
|
+
# Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
|
|
139
|
+
# this should have been caught by the 404 check above, so this is an error.
|
|
140
|
+
# For skip_auto_create=False, we explicitly check status_code below,
|
|
141
|
+
# so we don't need to parse the message.
|
|
142
|
+
if skip_auto_create:
|
|
143
|
+
raise ResponseStatusException(
|
|
144
|
+
f"Failed to parse error response from endpoint {endpoint}.", response
|
|
145
|
+
)
|
|
146
|
+
message = "" # Not used when we check status_code directly
|
|
147
|
+
|
|
148
|
+
# fix engine on engine error and retry
|
|
149
|
+
# Skip setting up GI if skip_auto_create is True to avoid recursion or skip_engine_db_error_retry is true to let _exec_async_v2 perform the retry with the correct headers.
|
|
150
|
+
if ((_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message)) and not skip_engine_db_error_retry:
|
|
151
|
+
engine_name = payload.get("caller_engine_name", "") if payload else ""
|
|
152
|
+
engine_name = engine_name or self.get_default_engine_name()
|
|
153
|
+
engine_size = self.config.get_default_engine_size()
|
|
154
|
+
# Use the mixin's _poll_use_index method
|
|
155
|
+
self._poll_use_index(
|
|
156
|
+
app_name=self.get_app_name(),
|
|
157
|
+
sources=self.sources,
|
|
158
|
+
model=self.database,
|
|
159
|
+
engine_name=engine_name,
|
|
160
|
+
engine_size=engine_size,
|
|
161
|
+
headers=headers,
|
|
162
|
+
)
|
|
163
|
+
response = _send_request()
|
|
164
|
+
except requests.exceptions.ConnectionError as e:
|
|
165
|
+
messages = collect_error_messages(e)
|
|
166
|
+
if any("nameresolutionerror" in msg for msg in messages):
|
|
167
|
+
# when we can not resolve the service endpoint, we assume it is outdated
|
|
168
|
+
# hence, we try to retrieve it again and query again.
|
|
169
|
+
self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
|
|
170
|
+
enforce_update=True,
|
|
171
|
+
)
|
|
172
|
+
return _send_request()
|
|
173
|
+
# raise in all other cases
|
|
174
|
+
raise e
|
|
175
|
+
return response
|
|
176
|
+
|
|
177
|
+
def _txn_request_with_gi_retry(
|
|
178
|
+
self,
|
|
179
|
+
payload: Dict,
|
|
180
|
+
headers: Dict[str, str],
|
|
181
|
+
query_params: Dict,
|
|
182
|
+
engine: Union[str, None],
|
|
183
|
+
):
|
|
184
|
+
"""Make request with graph index retry logic.
|
|
185
|
+
|
|
186
|
+
Attempts request with gi_setup_skipped=True first. If an engine or database
|
|
187
|
+
issue occurs, polls use_index and retries with gi_setup_skipped=False.
|
|
188
|
+
"""
|
|
189
|
+
response = self.request(
|
|
190
|
+
"create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if response.status_code != 200:
|
|
194
|
+
try:
|
|
195
|
+
message = response.json().get("message", "")
|
|
196
|
+
except requests.exceptions.JSONDecodeError:
|
|
197
|
+
message = ""
|
|
198
|
+
|
|
199
|
+
if _is_engine_issue(message) or _is_database_issue(message):
|
|
200
|
+
engine_name = engine or self.get_default_engine_name()
|
|
201
|
+
engine_size = self.config.get_default_engine_size()
|
|
202
|
+
# Use the mixin's _poll_use_index method
|
|
203
|
+
self._poll_use_index(
|
|
204
|
+
app_name=self.get_app_name(),
|
|
205
|
+
sources=self.sources,
|
|
206
|
+
model=self.database,
|
|
207
|
+
engine_name=engine_name,
|
|
208
|
+
engine_size=engine_size,
|
|
209
|
+
headers=headers,
|
|
210
|
+
)
|
|
211
|
+
headers['gi_setup_skipped'] = 'False'
|
|
212
|
+
response = self.request(
|
|
213
|
+
"create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
raise ResponseStatusException("Failed to create transaction.", response)
|
|
217
|
+
|
|
218
|
+
return response
|
|
219
|
+
|
|
220
|
+
def _exec_async_v2(
|
|
221
|
+
self,
|
|
222
|
+
database: str,
|
|
223
|
+
engine: Union[str, None],
|
|
224
|
+
raw_code: str,
|
|
225
|
+
inputs: Dict | None = None,
|
|
226
|
+
readonly=True,
|
|
227
|
+
nowait_durable=False,
|
|
228
|
+
headers: Dict[str, str] | None = None,
|
|
229
|
+
bypass_index=False,
|
|
230
|
+
language: str = "rel",
|
|
231
|
+
query_timeout_mins: int | None = None,
|
|
232
|
+
gi_setup_skipped: bool = False,
|
|
233
|
+
):
|
|
234
|
+
|
|
235
|
+
with debugging.span("transaction") as txn_span:
|
|
236
|
+
with debugging.span("create_v2") as create_span:
|
|
237
|
+
|
|
238
|
+
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
239
|
+
|
|
240
|
+
payload = {
|
|
241
|
+
"dbname": database,
|
|
242
|
+
"engine_name": engine,
|
|
243
|
+
"query": raw_code,
|
|
244
|
+
"v1_inputs": inputs,
|
|
245
|
+
"nowait_durable": nowait_durable,
|
|
246
|
+
"readonly": readonly,
|
|
247
|
+
"language": language,
|
|
248
|
+
}
|
|
249
|
+
if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
|
|
250
|
+
query_timeout_mins = int(timeout_value)
|
|
251
|
+
if query_timeout_mins is not None:
|
|
252
|
+
payload["timeout_mins"] = query_timeout_mins
|
|
253
|
+
query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
|
|
254
|
+
|
|
255
|
+
# Add gi_setup_skipped to headers
|
|
256
|
+
if headers is None:
|
|
257
|
+
headers = {}
|
|
258
|
+
headers["gi_setup_skipped"] = str(gi_setup_skipped)
|
|
259
|
+
headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
|
|
260
|
+
|
|
261
|
+
response = self._txn_request_with_gi_retry(
|
|
262
|
+
payload, headers, query_params, engine
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
artifact_info = {}
|
|
266
|
+
response_content = response.json()
|
|
267
|
+
|
|
268
|
+
txn_id = response_content["transaction"]['id']
|
|
269
|
+
state = response_content["transaction"]['state']
|
|
270
|
+
|
|
271
|
+
txn_span["txn_id"] = txn_id
|
|
272
|
+
create_span["txn_id"] = txn_id
|
|
273
|
+
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
274
|
+
|
|
275
|
+
# fast path: transaction already finished
|
|
276
|
+
if state in ["COMPLETED", "ABORTED"]:
|
|
277
|
+
if txn_id in self._pending_transactions:
|
|
278
|
+
self._pending_transactions.remove(txn_id)
|
|
279
|
+
|
|
280
|
+
# Process rows to get the rest of the artifacts
|
|
281
|
+
for result in response_content.get("results", []):
|
|
282
|
+
filename = result['filename']
|
|
283
|
+
# making keys uppercase to match the old behavior
|
|
284
|
+
artifact_info[filename] = {k.upper(): v for k, v in result.items()}
|
|
285
|
+
|
|
286
|
+
# Slow path: transaction not done yet; start polling
|
|
287
|
+
else:
|
|
288
|
+
self._pending_transactions.append(txn_id)
|
|
289
|
+
with debugging.span("wait", txn_id=txn_id):
|
|
290
|
+
poll_with_specified_overhead(
|
|
291
|
+
lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
|
|
292
|
+
)
|
|
293
|
+
artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
|
|
294
|
+
|
|
295
|
+
with debugging.span("fetch"):
|
|
296
|
+
return self._download_results(artifact_info, txn_id, state)
|
|
297
|
+
|
|
298
|
+
def _prepare_index(
|
|
299
|
+
self,
|
|
300
|
+
model: str,
|
|
301
|
+
engine_name: str,
|
|
302
|
+
engine_size: str = "",
|
|
303
|
+
language: str = "rel",
|
|
304
|
+
rai_relations: List[str] | None = None,
|
|
305
|
+
pyrel_program_id: str | None = None,
|
|
306
|
+
skip_pull_relations: bool = False,
|
|
307
|
+
headers: Dict | None = None,
|
|
308
|
+
):
|
|
309
|
+
"""
|
|
310
|
+
Prepare the index for the given engine and model.
|
|
311
|
+
"""
|
|
312
|
+
with debugging.span("prepare_index"):
|
|
313
|
+
if headers is None:
|
|
314
|
+
headers = {}
|
|
315
|
+
|
|
316
|
+
payload = {
|
|
317
|
+
"model_name": model,
|
|
318
|
+
"caller_engine_name": engine_name,
|
|
319
|
+
"language": language,
|
|
320
|
+
"pyrel_program_id": pyrel_program_id,
|
|
321
|
+
"skip_pull_relations": skip_pull_relations,
|
|
322
|
+
"rai_relations": rai_relations or [],
|
|
323
|
+
"user_agent": get_pyrel_version(self.generation),
|
|
324
|
+
}
|
|
325
|
+
# Only include engine_size if it has a non-empty string value
|
|
326
|
+
if engine_size and engine_size.strip():
|
|
327
|
+
payload["caller_engine_size"] = engine_size
|
|
328
|
+
|
|
329
|
+
response = self.request(
|
|
330
|
+
"prepare_index", payload=payload, headers=headers
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
if response.status_code != 200:
|
|
334
|
+
raise ResponseStatusException("Failed to prepare index.", response)
|
|
335
|
+
|
|
336
|
+
return response.json()
|
|
337
|
+
|
|
338
|
+
def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
|
|
339
|
+
"""Check whether the given transaction has completed."""
|
|
340
|
+
|
|
341
|
+
with debugging.span("check_status"):
|
|
342
|
+
response = self.request(
|
|
343
|
+
"get_txn",
|
|
344
|
+
headers=headers,
|
|
345
|
+
path_params={"txn_id": txn_id},
|
|
346
|
+
)
|
|
347
|
+
assert response, f"No results from get_transaction('{txn_id}')"
|
|
348
|
+
|
|
349
|
+
response_content = response.json()
|
|
350
|
+
transaction = response_content["transaction"]
|
|
351
|
+
status: str = transaction['state']
|
|
352
|
+
|
|
353
|
+
# remove the transaction from the pending list if it's completed or aborted
|
|
354
|
+
if status in ["COMPLETED", "ABORTED"]:
|
|
355
|
+
if txn_id in self._pending_transactions:
|
|
356
|
+
self._pending_transactions.remove(txn_id)
|
|
357
|
+
|
|
358
|
+
if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
359
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
360
|
+
timeout_ms = int(transaction.get("timeout_ms", 0))
|
|
361
|
+
timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
362
|
+
raise QueryTimeoutExceededException(
|
|
363
|
+
timeout_mins=timeout_mins,
|
|
364
|
+
query_id=txn_id,
|
|
365
|
+
config_file_path=config_file_path,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
369
|
+
return status == "COMPLETED" or status == "ABORTED"
|
|
370
|
+
|
|
371
|
+
def _list_exec_async_artifacts(self, txn_id: str, headers: Dict[str, str] | None = None) -> Dict[str, Dict]:
|
|
372
|
+
"""Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
|
|
373
|
+
with debugging.span("list_results"):
|
|
374
|
+
response = self.request(
|
|
375
|
+
"get_txn_artifacts",
|
|
376
|
+
headers=headers,
|
|
377
|
+
path_params={"txn_id": txn_id},
|
|
378
|
+
)
|
|
379
|
+
assert response, f"No results from get_transaction_artifacts('{txn_id}')"
|
|
380
|
+
artifact_info = {}
|
|
381
|
+
for result in response.json()["results"]:
|
|
382
|
+
filename = result['filename']
|
|
383
|
+
# making keys uppercase to match the old behavior
|
|
384
|
+
artifact_info[filename] = {k.upper(): v for k, v in result.items()}
|
|
385
|
+
return artifact_info
|
|
386
|
+
|
|
387
|
+
def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
|
|
388
|
+
with debugging.span("get_transaction_problems"):
|
|
389
|
+
response = self.request(
|
|
390
|
+
"get_txn_problems",
|
|
391
|
+
path_params={"txn_id": txn_id},
|
|
392
|
+
)
|
|
393
|
+
response_content = response.json()
|
|
394
|
+
if not response_content:
|
|
395
|
+
return []
|
|
396
|
+
return response_content.get("problems", [])
|
|
397
|
+
|
|
398
|
+
def get_transaction_events(self, transaction_id: str, continuation_token: str = ''):
|
|
399
|
+
response = self.request(
|
|
400
|
+
"get_txn_events",
|
|
401
|
+
path_params={"txn_id": transaction_id, "stream_name": "profiler"},
|
|
402
|
+
query_params={"continuation_token": continuation_token},
|
|
403
|
+
)
|
|
404
|
+
response_content = response.json()
|
|
405
|
+
if not response_content:
|
|
406
|
+
return {
|
|
407
|
+
"events": [],
|
|
408
|
+
"continuation_token": None
|
|
409
|
+
}
|
|
410
|
+
return response_content
|
|
411
|
+
|
|
412
|
+
#--------------------------------------------------
|
|
413
|
+
# Databases
|
|
414
|
+
#--------------------------------------------------
|
|
415
|
+
|
|
416
|
+
def get_installed_packages(self, database: str) -> Union[Dict, None]:
|
|
417
|
+
use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
|
|
418
|
+
if use_graph_index:
|
|
419
|
+
response = self.request(
|
|
420
|
+
"get_model_package_versions",
|
|
421
|
+
payload={"model_name": database},
|
|
422
|
+
)
|
|
423
|
+
else:
|
|
424
|
+
response = self.request(
|
|
425
|
+
"get_package_versions",
|
|
426
|
+
path_params={"db_name": database},
|
|
427
|
+
)
|
|
428
|
+
if response.status_code == 404 and response.json().get("message", "") == "database not found":
|
|
429
|
+
return None
|
|
430
|
+
if response.status_code != 200:
|
|
431
|
+
raise ResponseStatusException(
|
|
432
|
+
f"Failed to retrieve package versions for {database}.", response
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
content = response.json()
|
|
436
|
+
if not content:
|
|
437
|
+
return None
|
|
438
|
+
|
|
439
|
+
return safe_json_loads(content["package_versions"])
|
|
440
|
+
|
|
441
|
+
def get_database(self, database: str):
|
|
442
|
+
with debugging.span("get_database", dbname=database):
|
|
443
|
+
if not database:
|
|
444
|
+
raise ValueError("Database name must be provided to get database.")
|
|
445
|
+
response = self.request(
|
|
446
|
+
"get_db",
|
|
447
|
+
path_params={},
|
|
448
|
+
query_params={"name": database},
|
|
449
|
+
)
|
|
450
|
+
if response.status_code != 200:
|
|
451
|
+
raise ResponseStatusException(f"Failed to get db. db:{database}", response)
|
|
452
|
+
|
|
453
|
+
response_content = response.json()
|
|
454
|
+
|
|
455
|
+
if (response_content.get("databases") and len(response_content["databases"]) == 1):
|
|
456
|
+
db = response_content["databases"][0]
|
|
457
|
+
return {
|
|
458
|
+
"id": db["id"],
|
|
459
|
+
"name": db["name"],
|
|
460
|
+
"created_by": db.get("created_by"),
|
|
461
|
+
"created_on": ms_to_timestamp(db.get("created_on")),
|
|
462
|
+
"deleted_by": db.get("deleted_by"),
|
|
463
|
+
"deleted_on": ms_to_timestamp(db.get("deleted_on")),
|
|
464
|
+
"state": db["state"],
|
|
465
|
+
}
|
|
466
|
+
else:
|
|
467
|
+
return None
|
|
468
|
+
|
|
469
|
+
def create_graph(self, name: str):
|
|
470
|
+
with debugging.span("create_model", dbname=name):
|
|
471
|
+
return self._create_database(name,"")
|
|
472
|
+
|
|
473
|
+
def delete_graph(self, name:str, force=False, language: str = "rel"):
|
|
474
|
+
prop_hdrs = debugging.gen_current_propagation_headers()
|
|
475
|
+
if self.config.get("use_graph_index", USE_GRAPH_INDEX):
|
|
476
|
+
keep_database = not force and self.config.get("reuse_model", True)
|
|
477
|
+
with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
|
|
478
|
+
response = self.request(
|
|
479
|
+
"release_index",
|
|
480
|
+
payload={
|
|
481
|
+
"model_name": name,
|
|
482
|
+
"keep_database": keep_database,
|
|
483
|
+
"language": language,
|
|
484
|
+
"user_agent": get_pyrel_version(self.generation),
|
|
485
|
+
},
|
|
486
|
+
headers=prop_hdrs,
|
|
487
|
+
)
|
|
488
|
+
if (
|
|
489
|
+
response.status_code != 200
|
|
490
|
+
and not (
|
|
491
|
+
response.status_code == 404
|
|
492
|
+
and "database not found" in response.json().get("message", "")
|
|
493
|
+
)
|
|
494
|
+
):
|
|
495
|
+
raise ResponseStatusException(f"Failed to release index. Model: {name} ", response)
|
|
496
|
+
else:
|
|
497
|
+
with debugging.span("delete_model", name=name):
|
|
498
|
+
self._delete_database(name, headers=prop_hdrs)
|
|
499
|
+
|
|
500
|
+
def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
|
|
501
|
+
if force and self.get_graph(target_name):
|
|
502
|
+
self.delete_graph(target_name)
|
|
503
|
+
with debugging.span("clone_model", target_name=target_name, source_name=source_name):
|
|
504
|
+
return self._create_database(target_name,source_name)
|
|
505
|
+
|
|
506
|
+
def _delete_database(self, name:str, headers:Dict={}):
|
|
507
|
+
with debugging.span("_delete_database", dbname=name):
|
|
508
|
+
response = self.request(
|
|
509
|
+
"delete_db",
|
|
510
|
+
path_params={"db_name": name},
|
|
511
|
+
query_params={},
|
|
512
|
+
headers=headers,
|
|
513
|
+
)
|
|
514
|
+
if response.status_code != 200:
|
|
515
|
+
raise ResponseStatusException(f"Failed to delete db. db:{name} ", response)
|
|
516
|
+
|
|
517
|
+
def _create_database(self, name:str, source_name:str):
|
|
518
|
+
with debugging.span("_create_database", dbname=name):
|
|
519
|
+
payload = {
|
|
520
|
+
"name": name,
|
|
521
|
+
"source_name": source_name,
|
|
522
|
+
}
|
|
523
|
+
response = self.request(
|
|
524
|
+
"create_db", payload=payload, headers={}, query_params={},
|
|
525
|
+
)
|
|
526
|
+
if response.status_code != 200:
|
|
527
|
+
raise ResponseStatusException(f"Failed to create db. db:{name}", response)
|
|
528
|
+
|
|
529
|
+
#--------------------------------------------------
|
|
530
|
+
# Engines
|
|
531
|
+
#--------------------------------------------------
|
|
532
|
+
|
|
533
|
+
def list_engines(self, state: str | None = None):
|
|
534
|
+
response = self.request("list_engines")
|
|
535
|
+
if response.status_code != 200:
|
|
536
|
+
raise ResponseStatusException(
|
|
537
|
+
"Failed to retrieve engines.", response
|
|
538
|
+
)
|
|
539
|
+
response_content = response.json()
|
|
540
|
+
if not response_content:
|
|
541
|
+
return []
|
|
542
|
+
engines = [
|
|
543
|
+
{
|
|
544
|
+
"name": engine["name"],
|
|
545
|
+
"id": engine["id"],
|
|
546
|
+
"size": engine["size"],
|
|
547
|
+
"state": engine["status"], # callers are expecting 'state'
|
|
548
|
+
"created_by": engine["created_by"],
|
|
549
|
+
"created_on": engine["created_on"],
|
|
550
|
+
"updated_on": engine["updated_on"],
|
|
551
|
+
}
|
|
552
|
+
for engine in response_content.get("engines", [])
|
|
553
|
+
if state is None or engine.get("status") == state
|
|
554
|
+
]
|
|
555
|
+
return sorted(engines, key=lambda x: x["name"])
|
|
556
|
+
|
|
557
|
+
def get_engine(self, name: str):
|
|
558
|
+
response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
|
|
559
|
+
if response.status_code == 404: # engine not found return 404
|
|
560
|
+
return None
|
|
561
|
+
elif response.status_code != 200:
|
|
562
|
+
raise ResponseStatusException(
|
|
563
|
+
f"Failed to retrieve engine {name}.", response
|
|
564
|
+
)
|
|
565
|
+
engine = response.json()
|
|
566
|
+
if not engine:
|
|
567
|
+
return None
|
|
568
|
+
engine_state: EngineState = {
|
|
569
|
+
"name": engine["name"],
|
|
570
|
+
"id": engine["id"],
|
|
571
|
+
"size": engine["size"],
|
|
572
|
+
"state": engine["status"], # callers are expecting 'state'
|
|
573
|
+
"created_by": engine["created_by"],
|
|
574
|
+
"created_on": engine["created_on"],
|
|
575
|
+
"updated_on": engine["updated_on"],
|
|
576
|
+
"version": engine["version"],
|
|
577
|
+
"auto_suspend": engine["auto_suspend_mins"],
|
|
578
|
+
"suspends_at": engine["suspends_at"],
|
|
579
|
+
}
|
|
580
|
+
return engine_state
|
|
581
|
+
|
|
582
|
+
def _create_engine(
|
|
583
|
+
self,
|
|
584
|
+
name: str,
|
|
585
|
+
size: str | None = None,
|
|
586
|
+
auto_suspend_mins: int | None = None,
|
|
587
|
+
is_async: bool = False,
|
|
588
|
+
headers: Dict[str, str] | None = None
|
|
589
|
+
):
|
|
590
|
+
# only async engine creation supported via direct access
|
|
591
|
+
if not is_async:
|
|
592
|
+
return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
|
|
593
|
+
payload:Dict[str, Any] = {
|
|
594
|
+
"name": name,
|
|
595
|
+
}
|
|
596
|
+
if auto_suspend_mins is not None:
|
|
597
|
+
payload["auto_suspend_mins"] = auto_suspend_mins
|
|
598
|
+
if size is not None:
|
|
599
|
+
payload["size"] = size
|
|
600
|
+
response = self.request(
|
|
601
|
+
"create_engine",
|
|
602
|
+
payload=payload,
|
|
603
|
+
path_params={"engine_type": "logic"},
|
|
604
|
+
headers=headers,
|
|
605
|
+
skip_auto_create=True,
|
|
606
|
+
)
|
|
607
|
+
if response.status_code != 200:
|
|
608
|
+
raise ResponseStatusException(
|
|
609
|
+
f"Failed to create engine {name} with size {size}.", response
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def delete_engine(self, name:str, force:bool = False, headers={}):
|
|
613
|
+
response = self.request(
|
|
614
|
+
"delete_engine",
|
|
615
|
+
path_params={"engine_name": name, "engine_type": "logic"},
|
|
616
|
+
headers=headers,
|
|
617
|
+
skip_auto_create=True,
|
|
618
|
+
)
|
|
619
|
+
if response.status_code != 200:
|
|
620
|
+
raise ResponseStatusException(
|
|
621
|
+
f"Failed to delete engine {name}.", response
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
def suspend_engine(self, name:str):
|
|
625
|
+
response = self.request(
|
|
626
|
+
"suspend_engine",
|
|
627
|
+
path_params={"engine_name": name, "engine_type": "logic"},
|
|
628
|
+
skip_auto_create=True,
|
|
629
|
+
)
|
|
630
|
+
if response.status_code != 200:
|
|
631
|
+
raise ResponseStatusException(
|
|
632
|
+
f"Failed to suspend engine {name}.", response
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
def resume_engine_async(self, name:str, headers={}):
|
|
636
|
+
response = self.request(
|
|
637
|
+
"resume_engine",
|
|
638
|
+
path_params={"engine_name": name, "engine_type": "logic"},
|
|
639
|
+
headers=headers,
|
|
640
|
+
skip_auto_create=True,
|
|
641
|
+
)
|
|
642
|
+
if response.status_code != 200:
|
|
643
|
+
raise ResponseStatusException(
|
|
644
|
+
f"Failed to resume engine {name}.", response
|
|
645
|
+
)
|
|
646
|
+
return {}
|
|
647
|
+
|
|
648
|
+
def _poll_use_index(
|
|
649
|
+
self,
|
|
650
|
+
app_name: str,
|
|
651
|
+
sources: Iterable[str],
|
|
652
|
+
model: str,
|
|
653
|
+
engine_name: str,
|
|
654
|
+
engine_size: str | None = None,
|
|
655
|
+
program_span_id: str | None = None,
|
|
656
|
+
headers: Dict | None = None,
|
|
657
|
+
):
|
|
658
|
+
"""Poll use_index to prepare indices for the given sources using DirectUseIndexPoller."""
|
|
659
|
+
return DirectUseIndexPoller(
|
|
660
|
+
self,
|
|
661
|
+
app_name=app_name,
|
|
662
|
+
sources=sources,
|
|
663
|
+
model=model,
|
|
664
|
+
engine_name=engine_name,
|
|
665
|
+
engine_size=engine_size,
|
|
666
|
+
language=self.language,
|
|
667
|
+
program_span_id=program_span_id,
|
|
668
|
+
headers=headers,
|
|
669
|
+
generation=self.generation,
|
|
670
|
+
).poll()
|
|
671
|
+
|
|
672
|
+
def maybe_poll_use_index(
|
|
673
|
+
self,
|
|
674
|
+
app_name: str,
|
|
675
|
+
sources: Iterable[str],
|
|
676
|
+
model: str,
|
|
677
|
+
engine_name: str,
|
|
678
|
+
engine_size: str | None = None,
|
|
679
|
+
program_span_id: str | None = None,
|
|
680
|
+
headers: Dict | None = None,
|
|
681
|
+
):
|
|
682
|
+
"""Only call poll() if there are sources to process and cache is not valid."""
|
|
683
|
+
sources_list = list(sources)
|
|
684
|
+
self.database = model
|
|
685
|
+
if sources_list:
|
|
686
|
+
poller = DirectUseIndexPoller(
|
|
687
|
+
self,
|
|
688
|
+
app_name=app_name,
|
|
689
|
+
sources=sources_list,
|
|
690
|
+
model=model,
|
|
691
|
+
engine_name=engine_name,
|
|
692
|
+
engine_size=engine_size,
|
|
693
|
+
language=self.language,
|
|
694
|
+
program_span_id=program_span_id,
|
|
695
|
+
headers=headers,
|
|
696
|
+
generation=self.generation,
|
|
697
|
+
)
|
|
698
|
+
# If cache is valid (data freshness has not expired), skip polling
|
|
699
|
+
if poller.cache.is_valid():
|
|
700
|
+
cached_sources = len(poller.cache.sources)
|
|
701
|
+
total_sources = len(sources_list)
|
|
702
|
+
cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
|
|
703
|
+
|
|
704
|
+
message = f"Using cached data for {cached_sources}/{total_sources} data streams"
|
|
705
|
+
if cached_timestamp:
|
|
706
|
+
print(f"\n{message} (cached at {cached_timestamp})\n")
|
|
707
|
+
else:
|
|
708
|
+
print(f"\n{message}\n")
|
|
709
|
+
else:
|
|
710
|
+
return poller.poll()
|
|
711
|
+
|