relationalai 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/config.py +7 -0
- relationalai/clients/direct_access_client.py +113 -0
- relationalai/clients/snowflake.py +41 -107
- relationalai/clients/use_index_poller.py +349 -188
- relationalai/early_access/dsl/bindings/csv.py +2 -2
- relationalai/early_access/metamodel/rewrite/__init__.py +5 -3
- relationalai/early_access/rel/rewrite/__init__.py +1 -1
- relationalai/errors.py +24 -3
- relationalai/semantics/internal/annotations.py +1 -0
- relationalai/semantics/internal/internal.py +22 -4
- relationalai/semantics/lqp/builtins.py +1 -0
- relationalai/semantics/lqp/executor.py +61 -12
- relationalai/semantics/lqp/intrinsics.py +23 -0
- relationalai/semantics/lqp/model2lqp.py +13 -4
- relationalai/semantics/lqp/passes.py +4 -6
- relationalai/semantics/lqp/primitives.py +12 -1
- relationalai/semantics/{rel → lqp}/rewrite/__init__.py +6 -0
- relationalai/semantics/lqp/rewrite/extract_common.py +362 -0
- relationalai/semantics/metamodel/builtins.py +20 -2
- relationalai/semantics/metamodel/factory.py +3 -2
- relationalai/semantics/metamodel/rewrite/__init__.py +3 -9
- relationalai/semantics/reasoners/graph/core.py +273 -71
- relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
- relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
- relationalai/semantics/rel/builtins.py +5 -1
- relationalai/semantics/rel/compiler.py +7 -19
- relationalai/semantics/rel/executor.py +2 -2
- relationalai/semantics/rel/rel.py +6 -0
- relationalai/semantics/rel/rel_utils.py +8 -1
- relationalai/semantics/sql/compiler.py +122 -42
- relationalai/semantics/sql/executor/duck_db.py +28 -3
- relationalai/semantics/sql/rewrite/denormalize.py +4 -6
- relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
- relationalai/semantics/sql/sql.py +27 -0
- relationalai/semantics/std/__init__.py +2 -1
- relationalai/semantics/std/datetime.py +4 -0
- relationalai/semantics/std/re.py +83 -0
- relationalai/semantics/std/strings.py +1 -1
- relationalai/tools/cli.py +11 -4
- relationalai/tools/cli_controls.py +445 -60
- relationalai/util/format.py +78 -1
- {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/METADATA +7 -5
- {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/RECORD +51 -50
- relationalai/semantics/metamodel/rewrite/gc_nodes.py +0 -58
- relationalai/semantics/metamodel/rewrite/list_types.py +0 -109
- relationalai/semantics/rel/rewrite/extract_common.py +0 -451
- /relationalai/semantics/{rel → lqp}/rewrite/cdc.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/extract_keys.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/fd_constraints.py +0 -0
- /relationalai/semantics/{rel → lqp}/rewrite/quantify_vars.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/splinter.py +0 -0
- {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/WHEEL +0 -0
- {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/entry_points.txt +0 -0
- {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/licenses/LICENSE +0 -0
relationalai/clients/config.py
CHANGED
|
@@ -468,6 +468,13 @@ class Config():
|
|
|
468
468
|
if not self.file_path:
|
|
469
469
|
self.file_path = "__inline__"
|
|
470
470
|
self._handle_snowflake_fallback_configurations()
|
|
471
|
+
# Check if Azure platform is being used without the legacy dependency
|
|
472
|
+
if self.get("platform", "") == "azure":
|
|
473
|
+
try:
|
|
474
|
+
import railib # noqa
|
|
475
|
+
except ImportError:
|
|
476
|
+
from relationalai.errors import AzureLegacyDependencyMissingException
|
|
477
|
+
raise AzureLegacyDependencyMissingException() from None
|
|
471
478
|
|
|
472
479
|
def fetch(self, profile:str|None=None):
|
|
473
480
|
from relationalai.environments import runtime_env, TerminalEnvironment
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from urllib.parse import urlencode, quote
|
|
6
|
+
from requests.adapters import HTTPAdapter
|
|
7
|
+
from urllib3.util.retry import Retry
|
|
8
|
+
from typing import Any, Dict, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
from relationalai.auth.token_handler import TokenHandler
|
|
11
|
+
from relationalai.clients.config import Config
|
|
12
|
+
from relationalai.clients.util import get_pyrel_version
|
|
13
|
+
from relationalai import debugging
|
|
14
|
+
from relationalai.tools.constants import Generation
|
|
15
|
+
from relationalai.environments import runtime_env, SnowbookEnvironment
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Endpoint:
|
|
19
|
+
method: str
|
|
20
|
+
endpoint: str
|
|
21
|
+
|
|
22
|
+
class DirectAccessClient:
|
|
23
|
+
"""
|
|
24
|
+
DirectAccessClient is a client for direct service access without service function calls.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: Config, token_handler: TokenHandler, service_endpoint: str, generation: Optional[Generation] = None):
|
|
28
|
+
self._config: Config = config
|
|
29
|
+
self._token_handler: TokenHandler = token_handler
|
|
30
|
+
self.service_endpoint: str = service_endpoint
|
|
31
|
+
self.generation: Optional[Generation] = generation
|
|
32
|
+
self._is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
|
|
33
|
+
self.endpoints: Dict[str, Endpoint] = {
|
|
34
|
+
"create_txn": Endpoint(method="POST", endpoint="/v1alpha1/transactions"),
|
|
35
|
+
"get_txn": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}"),
|
|
36
|
+
"get_txn_artifacts": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/artifacts"),
|
|
37
|
+
"get_txn_problems": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/problems"),
|
|
38
|
+
"get_txn_events": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/events/{stream_name}"),
|
|
39
|
+
"get_package_versions": Endpoint(method="GET", endpoint="/v1alpha1/databases/{db_name}/package_versions"),
|
|
40
|
+
"get_model_package_versions": Endpoint(method="POST", endpoint="/v1alpha1/models/get_package_versions"),
|
|
41
|
+
"create_db": Endpoint(method="POST", endpoint="/v1alpha1/databases"),
|
|
42
|
+
"get_db": Endpoint(method="GET", endpoint="/v1alpha1/databases"),
|
|
43
|
+
"delete_db": Endpoint(method="DELETE", endpoint="/v1alpha1/databases/{db_name}"),
|
|
44
|
+
"release_index": Endpoint(method="POST", endpoint="/v1alpha1/index/release"),
|
|
45
|
+
"list_engines": Endpoint(method="GET", endpoint="/v1alpha1/engines"),
|
|
46
|
+
"get_engine": Endpoint(method="GET", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
|
|
47
|
+
"create_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}"),
|
|
48
|
+
"delete_engine": Endpoint(method="DELETE", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
|
|
49
|
+
"suspend_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/suspend"),
|
|
50
|
+
"resume_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/resume_async"),
|
|
51
|
+
"prepare_index": Endpoint(method="POST", endpoint="/v1alpha1/index/prepare"),
|
|
52
|
+
}
|
|
53
|
+
self.http_session = self._create_retry_session()
|
|
54
|
+
|
|
55
|
+
def _create_retry_session(self) -> requests.Session:
|
|
56
|
+
http_session = requests.Session()
|
|
57
|
+
retries = Retry(
|
|
58
|
+
total=3,
|
|
59
|
+
backoff_factor=0.3,
|
|
60
|
+
status_forcelist=[500, 502, 503, 504],
|
|
61
|
+
allowed_methods=frozenset({"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}),
|
|
62
|
+
raise_on_status=False
|
|
63
|
+
)
|
|
64
|
+
adapter = HTTPAdapter(max_retries=retries)
|
|
65
|
+
http_session.mount("http://", adapter)
|
|
66
|
+
http_session.mount("https://", adapter)
|
|
67
|
+
http_session.headers.update({"Connection": "keep-alive"})
|
|
68
|
+
return http_session
|
|
69
|
+
|
|
70
|
+
def request(
|
|
71
|
+
self,
|
|
72
|
+
endpoint: str,
|
|
73
|
+
payload: Dict[str, Any] | None = None,
|
|
74
|
+
headers: Dict[str, str] | None = None,
|
|
75
|
+
path_params: Dict[str, str] | None = None,
|
|
76
|
+
query_params: Dict[str, str] | None = None,
|
|
77
|
+
) -> requests.Response:
|
|
78
|
+
"""
|
|
79
|
+
Send a request to the service endpoint.
|
|
80
|
+
"""
|
|
81
|
+
url, method = self._prepare_url(endpoint, path_params, query_params)
|
|
82
|
+
request_headers = self._prepare_headers(headers)
|
|
83
|
+
return self.http_session.request(method, url, json=payload, headers=request_headers)
|
|
84
|
+
|
|
85
|
+
def _prepare_url(self, endpoint: str, path_params: Dict[str, str] | None = None, query_params: Dict[str, str] | None = None) -> Tuple[str, str]:
|
|
86
|
+
try:
|
|
87
|
+
ep = self.endpoints[endpoint]
|
|
88
|
+
except KeyError:
|
|
89
|
+
raise ValueError(f"Invalid endpoint: {endpoint}. Available endpoints: {list(self.endpoints.keys())}")
|
|
90
|
+
url = f"{self.service_endpoint}{ep.endpoint}"
|
|
91
|
+
if path_params:
|
|
92
|
+
escaped_path_params = {k: quote(v, safe='') for k, v in path_params.items()}
|
|
93
|
+
url = url.format(**escaped_path_params)
|
|
94
|
+
if query_params:
|
|
95
|
+
url += '?' + urlencode(query_params)
|
|
96
|
+
return url, ep.method
|
|
97
|
+
|
|
98
|
+
def _prepare_headers(self, headers: Dict[str, str] | None) -> Dict[str, str]:
|
|
99
|
+
request_headers = {}
|
|
100
|
+
if headers:
|
|
101
|
+
request_headers.update(headers)
|
|
102
|
+
# Authorization tokens are not needed in a snowflake notebook environment
|
|
103
|
+
if not self._is_snowflake_notebook:
|
|
104
|
+
request_headers["Authorization"] = f'Snowflake Token="{self._token_handler.get_ingress_token(self.service_endpoint)}"'
|
|
105
|
+
# needed for oauth, does no harm for other authentication methods
|
|
106
|
+
request_headers["X-SF-SPCS-Authentication-Method"] = 'OAUTH'
|
|
107
|
+
request_headers["Content-Type"] = 'application/x-www-form-urlencoded'
|
|
108
|
+
request_headers["Accept"] = "application/json"
|
|
109
|
+
|
|
110
|
+
request_headers["user-agent"] = get_pyrel_version(self.generation)
|
|
111
|
+
request_headers["pyrel_program_id"] = debugging.get_program_span_id() or ""
|
|
112
|
+
|
|
113
|
+
return debugging.add_current_propagation_headers(request_headers)
|
|
@@ -28,15 +28,11 @@ import requests
|
|
|
28
28
|
import snowflake.connector
|
|
29
29
|
import pyarrow as pa
|
|
30
30
|
|
|
31
|
-
from dataclasses import dataclass
|
|
32
31
|
from snowflake.snowpark import Session
|
|
33
32
|
from snowflake.snowpark.context import get_active_session
|
|
34
33
|
from . import result_helpers
|
|
35
34
|
from .. import debugging
|
|
36
35
|
from typing import Any, Dict, Iterable, Optional, Tuple, List, Literal, Union, cast
|
|
37
|
-
from urllib.parse import urlencode, quote
|
|
38
|
-
from requests.adapters import HTTPAdapter
|
|
39
|
-
from urllib3.util.retry import Retry
|
|
40
36
|
|
|
41
37
|
from pandas import DataFrame
|
|
42
38
|
|
|
@@ -44,10 +40,11 @@ from ..tools.cli_controls import Spinner
|
|
|
44
40
|
from ..clients.types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
|
|
45
41
|
from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
|
|
46
42
|
from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
|
|
43
|
+
from ..clients.direct_access_client import DirectAccessClient
|
|
47
44
|
from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp
|
|
48
45
|
from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
|
|
49
46
|
from .. import dsl, rel, metamodel as m
|
|
50
|
-
from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning
|
|
47
|
+
from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
|
|
51
48
|
from concurrent.futures import ThreadPoolExecutor
|
|
52
49
|
from datetime import datetime, date, timedelta
|
|
53
50
|
from snowflake.snowpark.types import StringType, StructField, StructType
|
|
@@ -92,6 +89,8 @@ TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
|
|
|
92
89
|
|
|
93
90
|
DUO_TEXT = "duo security"
|
|
94
91
|
|
|
92
|
+
TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
93
|
+
|
|
95
94
|
#--------------------------------------------------
|
|
96
95
|
# Helpers
|
|
97
96
|
#--------------------------------------------------
|
|
@@ -1307,7 +1306,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1307
1306
|
response = self._exec(f"CALL {APP_NAME}.api.get_transaction('{txn_id}',{headers});")
|
|
1308
1307
|
assert response, f"No results from get_transaction('{txn_id}')"
|
|
1309
1308
|
|
|
1310
|
-
response_row = next(iter(response))
|
|
1309
|
+
response_row = next(iter(response)).asDict()
|
|
1311
1310
|
status: str = response_row['STATE']
|
|
1312
1311
|
|
|
1313
1312
|
# remove the transaction from the pending list if it's completed or aborted
|
|
@@ -1315,6 +1314,16 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1315
1314
|
if txn_id in self._pending_transactions:
|
|
1316
1315
|
self._pending_transactions.remove(txn_id)
|
|
1317
1316
|
|
|
1317
|
+
if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
1318
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
1319
|
+
# todo: use the timeout returned alongside the transaction as soon as it's exposed
|
|
1320
|
+
timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
1321
|
+
raise QueryTimeoutExceededException(
|
|
1322
|
+
timeout_mins=timeout_mins,
|
|
1323
|
+
query_id=txn_id,
|
|
1324
|
+
config_file_path=config_file_path,
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1318
1327
|
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
1319
1328
|
return status == "COMPLETED" or status == "ABORTED"
|
|
1320
1329
|
|
|
@@ -2957,104 +2966,6 @@ def Graph(
|
|
|
2957
2966
|
#--------------------------------------------------
|
|
2958
2967
|
# Note: All direct access components should live in a separate file
|
|
2959
2968
|
|
|
2960
|
-
@dataclass
|
|
2961
|
-
class Endpoint:
|
|
2962
|
-
method: str
|
|
2963
|
-
endpoint: str
|
|
2964
|
-
|
|
2965
|
-
class DirectAccessClient:
|
|
2966
|
-
"""
|
|
2967
|
-
DirectAccessClient is a client for direct service access without service function calls.
|
|
2968
|
-
"""
|
|
2969
|
-
|
|
2970
|
-
def __init__(self, config: Config, token_handler: TokenHandler, service_endpoint: str, generation: Optional[Generation] = None):
|
|
2971
|
-
self._config: Config = config
|
|
2972
|
-
self._token_handler: TokenHandler = token_handler
|
|
2973
|
-
self.service_endpoint: str = service_endpoint
|
|
2974
|
-
self.generation: Optional[Generation] = generation
|
|
2975
|
-
self._is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
|
|
2976
|
-
self.endpoints: Dict[str, Endpoint] = {
|
|
2977
|
-
"create_txn": Endpoint(method="POST", endpoint="/v1alpha1/transactions"),
|
|
2978
|
-
"get_txn": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}"),
|
|
2979
|
-
"get_txn_artifacts": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/artifacts"),
|
|
2980
|
-
"get_txn_problems": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/problems"),
|
|
2981
|
-
"get_txn_events": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/events/{stream_name}"),
|
|
2982
|
-
"get_package_versions": Endpoint(method="GET", endpoint="/v1alpha1/databases/{db_name}/package_versions"),
|
|
2983
|
-
"get_model_package_versions": Endpoint(method="POST", endpoint="/v1alpha1/models/get_package_versions"),
|
|
2984
|
-
"create_db": Endpoint(method="POST", endpoint="/v1alpha1/databases"),
|
|
2985
|
-
"get_db": Endpoint(method="GET", endpoint="/v1alpha1/databases"),
|
|
2986
|
-
"delete_db": Endpoint(method="DELETE", endpoint="/v1alpha1/databases/{db_name}"),
|
|
2987
|
-
"release_index": Endpoint(method="POST", endpoint="/v1alpha1/index/release"),
|
|
2988
|
-
"list_engines": Endpoint(method="GET", endpoint="/v1alpha1/engines"),
|
|
2989
|
-
"get_engine": Endpoint(method="GET", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
|
|
2990
|
-
"create_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}"),
|
|
2991
|
-
"delete_engine": Endpoint(method="DELETE", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
|
|
2992
|
-
"suspend_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/suspend"),
|
|
2993
|
-
"resume_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/resume_async"),
|
|
2994
|
-
"prepare_index": Endpoint(method="POST", endpoint="/v1alpha1/index/prepare"),
|
|
2995
|
-
}
|
|
2996
|
-
self.http_session = self._create_retry_session()
|
|
2997
|
-
|
|
2998
|
-
def _create_retry_session(self) -> requests.Session:
|
|
2999
|
-
http_session = requests.Session()
|
|
3000
|
-
retries = Retry(
|
|
3001
|
-
total=3,
|
|
3002
|
-
backoff_factor=0.3,
|
|
3003
|
-
status_forcelist=[500, 502, 503, 504],
|
|
3004
|
-
allowed_methods=frozenset({"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}),
|
|
3005
|
-
raise_on_status=False
|
|
3006
|
-
)
|
|
3007
|
-
adapter = HTTPAdapter(max_retries=retries)
|
|
3008
|
-
http_session.mount("http://", adapter)
|
|
3009
|
-
http_session.mount("https://", adapter)
|
|
3010
|
-
http_session.headers.update({"Connection": "keep-alive"})
|
|
3011
|
-
return http_session
|
|
3012
|
-
|
|
3013
|
-
def request(
|
|
3014
|
-
self,
|
|
3015
|
-
endpoint: str,
|
|
3016
|
-
payload: Dict[str, Any] | None = None,
|
|
3017
|
-
headers: Dict[str, str] | None = None,
|
|
3018
|
-
path_params: Dict[str, str] | None = None,
|
|
3019
|
-
query_params: Dict[str, str] | None = None,
|
|
3020
|
-
) -> requests.Response:
|
|
3021
|
-
"""
|
|
3022
|
-
Send a request to the service endpoint.
|
|
3023
|
-
"""
|
|
3024
|
-
url, method = self._prepare_url(endpoint, path_params, query_params)
|
|
3025
|
-
request_headers = self._prepare_headers(headers)
|
|
3026
|
-
return self.http_session.request(method, url, json=payload, headers=request_headers)
|
|
3027
|
-
|
|
3028
|
-
def _prepare_url(self, endpoint: str, path_params: Dict[str, str] | None = None, query_params: Dict[str, str] | None = None) -> Tuple[str, str]:
|
|
3029
|
-
try:
|
|
3030
|
-
ep = self.endpoints[endpoint]
|
|
3031
|
-
except KeyError:
|
|
3032
|
-
raise ValueError(f"Invalid endpoint: {endpoint}. Available endpoints: {list(self.endpoints.keys())}")
|
|
3033
|
-
url = f"{self.service_endpoint}{ep.endpoint}"
|
|
3034
|
-
if path_params:
|
|
3035
|
-
escaped_path_params = {k: quote(v, safe='') for k, v in path_params.items()}
|
|
3036
|
-
url = url.format(**escaped_path_params)
|
|
3037
|
-
if query_params:
|
|
3038
|
-
url += '?' + urlencode(query_params)
|
|
3039
|
-
return url, ep.method
|
|
3040
|
-
|
|
3041
|
-
def _prepare_headers(self, headers: Dict[str, str] | None) -> Dict[str, str]:
|
|
3042
|
-
request_headers = {}
|
|
3043
|
-
if headers:
|
|
3044
|
-
request_headers.update(headers)
|
|
3045
|
-
# Authorization tokens are not needed in a snowflake notebook environment
|
|
3046
|
-
if not self._is_snowflake_notebook:
|
|
3047
|
-
request_headers["Authorization"] = f'Snowflake Token="{self._token_handler.get_ingress_token(self.service_endpoint)}"'
|
|
3048
|
-
# needed for oauth, does no harm for other authentication methods
|
|
3049
|
-
request_headers["X-SF-SPCS-Authentication-Method"] = 'OAUTH'
|
|
3050
|
-
request_headers["Content-Type"] = 'application/x-www-form-urlencoded'
|
|
3051
|
-
request_headers["Accept"] = "application/json"
|
|
3052
|
-
|
|
3053
|
-
request_headers["user-agent"] = get_pyrel_version(self.generation)
|
|
3054
|
-
request_headers["pyrel_program_id"] = debugging.get_program_span_id() or ""
|
|
3055
|
-
|
|
3056
|
-
return debugging.add_current_propagation_headers(request_headers)
|
|
3057
|
-
|
|
3058
2969
|
class DirectAccessResources(Resources):
|
|
3059
2970
|
"""
|
|
3060
2971
|
Resources class for Direct Service Access avoiding Snowflake service functions.
|
|
@@ -3068,7 +2979,14 @@ class DirectAccessResources(Resources):
|
|
|
3068
2979
|
reset_session: bool = False,
|
|
3069
2980
|
generation: Optional[Generation] = None,
|
|
3070
2981
|
):
|
|
3071
|
-
super().__init__(
|
|
2982
|
+
super().__init__(
|
|
2983
|
+
generation=generation,
|
|
2984
|
+
profile=profile,
|
|
2985
|
+
config=config,
|
|
2986
|
+
connection=connection,
|
|
2987
|
+
reset_session=reset_session,
|
|
2988
|
+
dry_run=dry_run,
|
|
2989
|
+
)
|
|
3072
2990
|
self._endpoint_info = ConfigStore(ENDPOINT_FILE)
|
|
3073
2991
|
self._service_endpoint = ""
|
|
3074
2992
|
self._direct_access_client = None
|
|
@@ -3140,7 +3058,12 @@ class DirectAccessResources(Resources):
|
|
|
3140
3058
|
try:
|
|
3141
3059
|
response = _send_request()
|
|
3142
3060
|
if response.status_code != 200:
|
|
3143
|
-
|
|
3061
|
+
try:
|
|
3062
|
+
message = response.json().get("message", "")
|
|
3063
|
+
except requests.exceptions.JSONDecodeError:
|
|
3064
|
+
raise ResponseStatusException(
|
|
3065
|
+
f"Failed to parse error response from endpoint {endpoint}.", response
|
|
3066
|
+
)
|
|
3144
3067
|
|
|
3145
3068
|
# fix engine on engine error and retry
|
|
3146
3069
|
if _is_engine_issue(message):
|
|
@@ -3305,13 +3228,24 @@ class DirectAccessResources(Resources):
|
|
|
3305
3228
|
assert response, f"No results from get_transaction('{txn_id}')"
|
|
3306
3229
|
|
|
3307
3230
|
response_content = response.json()
|
|
3308
|
-
|
|
3231
|
+
transaction = response_content["transaction"]
|
|
3232
|
+
status: str = transaction['state']
|
|
3309
3233
|
|
|
3310
3234
|
# remove the transaction from the pending list if it's completed or aborted
|
|
3311
3235
|
if status in ["COMPLETED", "ABORTED"]:
|
|
3312
3236
|
if txn_id in self._pending_transactions:
|
|
3313
3237
|
self._pending_transactions.remove(txn_id)
|
|
3314
3238
|
|
|
3239
|
+
if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
3240
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
3241
|
+
timeout_ms = int(transaction.get("timeout_ms", 0))
|
|
3242
|
+
timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
3243
|
+
raise QueryTimeoutExceededException(
|
|
3244
|
+
timeout_mins=timeout_mins,
|
|
3245
|
+
query_id=txn_id,
|
|
3246
|
+
config_file_path=config_file_path,
|
|
3247
|
+
)
|
|
3248
|
+
|
|
3315
3249
|
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
3316
3250
|
return status == "COMPLETED" or status == "ABORTED"
|
|
3317
3251
|
|