relationalai 0.11.4__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/config.py +7 -0
- relationalai/clients/direct_access_client.py +113 -0
- relationalai/clients/snowflake.py +35 -106
- relationalai/early_access/metamodel/rewrite/__init__.py +5 -3
- relationalai/early_access/rel/rewrite/__init__.py +1 -1
- relationalai/errors.py +24 -3
- relationalai/semantics/internal/annotations.py +1 -0
- relationalai/semantics/lqp/builtins.py +1 -0
- relationalai/semantics/lqp/passes.py +3 -4
- relationalai/semantics/{rel → lqp}/rewrite/__init__.py +6 -0
- relationalai/semantics/metamodel/builtins.py +12 -1
- relationalai/semantics/metamodel/rewrite/__init__.py +3 -9
- relationalai/semantics/reasoners/graph/core.py +221 -71
- relationalai/semantics/rel/builtins.py +5 -1
- relationalai/semantics/rel/compiler.py +3 -3
- relationalai/semantics/sql/compiler.py +2 -3
- relationalai/semantics/sql/executor/duck_db.py +8 -4
- relationalai/tools/cli.py +11 -4
- {relationalai-0.11.4.dist-info → relationalai-0.12.0.dist-info}/METADATA +5 -4
- {relationalai-0.11.4.dist-info → relationalai-0.12.0.dist-info}/RECORD +29 -30
- relationalai/semantics/metamodel/rewrite/gc_nodes.py +0 -58
- relationalai/semantics/metamodel/rewrite/list_types.py +0 -109
- /relationalai/semantics/{rel → lqp}/rewrite/cdc.py +0 -0
- /relationalai/semantics/{rel → lqp}/rewrite/extract_common.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/extract_keys.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/fd_constraints.py +0 -0
- /relationalai/semantics/{rel → lqp}/rewrite/quantify_vars.py +0 -0
- /relationalai/semantics/{metamodel → lqp}/rewrite/splinter.py +0 -0
- {relationalai-0.11.4.dist-info → relationalai-0.12.0.dist-info}/WHEEL +0 -0
- {relationalai-0.11.4.dist-info → relationalai-0.12.0.dist-info}/entry_points.txt +0 -0
- {relationalai-0.11.4.dist-info → relationalai-0.12.0.dist-info}/licenses/LICENSE +0 -0
relationalai/clients/config.py
CHANGED
|
@@ -468,6 +468,13 @@ class Config():
|
|
|
468
468
|
if not self.file_path:
|
|
469
469
|
self.file_path = "__inline__"
|
|
470
470
|
self._handle_snowflake_fallback_configurations()
|
|
471
|
+
# Check if Azure platform is being used without the legacy dependency
|
|
472
|
+
if self.get("platform", "") == "azure":
|
|
473
|
+
try:
|
|
474
|
+
import railib # noqa
|
|
475
|
+
except ImportError:
|
|
476
|
+
from relationalai.errors import AzureLegacyDependencyMissingException
|
|
477
|
+
raise AzureLegacyDependencyMissingException() from None
|
|
471
478
|
|
|
472
479
|
def fetch(self, profile:str|None=None):
|
|
473
480
|
from relationalai.environments import runtime_env, TerminalEnvironment
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from urllib.parse import urlencode, quote
|
|
6
|
+
from requests.adapters import HTTPAdapter
|
|
7
|
+
from urllib3.util.retry import Retry
|
|
8
|
+
from typing import Any, Dict, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
from relationalai.auth.token_handler import TokenHandler
|
|
11
|
+
from relationalai.clients.config import Config
|
|
12
|
+
from relationalai.clients.util import get_pyrel_version
|
|
13
|
+
from relationalai import debugging
|
|
14
|
+
from relationalai.tools.constants import Generation
|
|
15
|
+
from relationalai.environments import runtime_env, SnowbookEnvironment
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Endpoint:
|
|
19
|
+
method: str
|
|
20
|
+
endpoint: str
|
|
21
|
+
|
|
22
|
+
class DirectAccessClient:
|
|
23
|
+
"""
|
|
24
|
+
DirectAccessClient is a client for direct service access without service function calls.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: Config, token_handler: TokenHandler, service_endpoint: str, generation: Optional[Generation] = None):
|
|
28
|
+
self._config: Config = config
|
|
29
|
+
self._token_handler: TokenHandler = token_handler
|
|
30
|
+
self.service_endpoint: str = service_endpoint
|
|
31
|
+
self.generation: Optional[Generation] = generation
|
|
32
|
+
self._is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
|
|
33
|
+
self.endpoints: Dict[str, Endpoint] = {
|
|
34
|
+
"create_txn": Endpoint(method="POST", endpoint="/v1alpha1/transactions"),
|
|
35
|
+
"get_txn": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}"),
|
|
36
|
+
"get_txn_artifacts": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/artifacts"),
|
|
37
|
+
"get_txn_problems": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/problems"),
|
|
38
|
+
"get_txn_events": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/events/{stream_name}"),
|
|
39
|
+
"get_package_versions": Endpoint(method="GET", endpoint="/v1alpha1/databases/{db_name}/package_versions"),
|
|
40
|
+
"get_model_package_versions": Endpoint(method="POST", endpoint="/v1alpha1/models/get_package_versions"),
|
|
41
|
+
"create_db": Endpoint(method="POST", endpoint="/v1alpha1/databases"),
|
|
42
|
+
"get_db": Endpoint(method="GET", endpoint="/v1alpha1/databases"),
|
|
43
|
+
"delete_db": Endpoint(method="DELETE", endpoint="/v1alpha1/databases/{db_name}"),
|
|
44
|
+
"release_index": Endpoint(method="POST", endpoint="/v1alpha1/index/release"),
|
|
45
|
+
"list_engines": Endpoint(method="GET", endpoint="/v1alpha1/engines"),
|
|
46
|
+
"get_engine": Endpoint(method="GET", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
|
|
47
|
+
"create_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}"),
|
|
48
|
+
"delete_engine": Endpoint(method="DELETE", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
|
|
49
|
+
"suspend_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/suspend"),
|
|
50
|
+
"resume_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/resume_async"),
|
|
51
|
+
"prepare_index": Endpoint(method="POST", endpoint="/v1alpha1/index/prepare"),
|
|
52
|
+
}
|
|
53
|
+
self.http_session = self._create_retry_session()
|
|
54
|
+
|
|
55
|
+
def _create_retry_session(self) -> requests.Session:
|
|
56
|
+
http_session = requests.Session()
|
|
57
|
+
retries = Retry(
|
|
58
|
+
total=3,
|
|
59
|
+
backoff_factor=0.3,
|
|
60
|
+
status_forcelist=[500, 502, 503, 504],
|
|
61
|
+
allowed_methods=frozenset({"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}),
|
|
62
|
+
raise_on_status=False
|
|
63
|
+
)
|
|
64
|
+
adapter = HTTPAdapter(max_retries=retries)
|
|
65
|
+
http_session.mount("http://", adapter)
|
|
66
|
+
http_session.mount("https://", adapter)
|
|
67
|
+
http_session.headers.update({"Connection": "keep-alive"})
|
|
68
|
+
return http_session
|
|
69
|
+
|
|
70
|
+
def request(
|
|
71
|
+
self,
|
|
72
|
+
endpoint: str,
|
|
73
|
+
payload: Dict[str, Any] | None = None,
|
|
74
|
+
headers: Dict[str, str] | None = None,
|
|
75
|
+
path_params: Dict[str, str] | None = None,
|
|
76
|
+
query_params: Dict[str, str] | None = None,
|
|
77
|
+
) -> requests.Response:
|
|
78
|
+
"""
|
|
79
|
+
Send a request to the service endpoint.
|
|
80
|
+
"""
|
|
81
|
+
url, method = self._prepare_url(endpoint, path_params, query_params)
|
|
82
|
+
request_headers = self._prepare_headers(headers)
|
|
83
|
+
return self.http_session.request(method, url, json=payload, headers=request_headers)
|
|
84
|
+
|
|
85
|
+
def _prepare_url(self, endpoint: str, path_params: Dict[str, str] | None = None, query_params: Dict[str, str] | None = None) -> Tuple[str, str]:
|
|
86
|
+
try:
|
|
87
|
+
ep = self.endpoints[endpoint]
|
|
88
|
+
except KeyError:
|
|
89
|
+
raise ValueError(f"Invalid endpoint: {endpoint}. Available endpoints: {list(self.endpoints.keys())}")
|
|
90
|
+
url = f"{self.service_endpoint}{ep.endpoint}"
|
|
91
|
+
if path_params:
|
|
92
|
+
escaped_path_params = {k: quote(v, safe='') for k, v in path_params.items()}
|
|
93
|
+
url = url.format(**escaped_path_params)
|
|
94
|
+
if query_params:
|
|
95
|
+
url += '?' + urlencode(query_params)
|
|
96
|
+
return url, ep.method
|
|
97
|
+
|
|
98
|
+
def _prepare_headers(self, headers: Dict[str, str] | None) -> Dict[str, str]:
|
|
99
|
+
request_headers = {}
|
|
100
|
+
if headers:
|
|
101
|
+
request_headers.update(headers)
|
|
102
|
+
# Authorization tokens are not needed in a snowflake notebook environment
|
|
103
|
+
if not self._is_snowflake_notebook:
|
|
104
|
+
request_headers["Authorization"] = f'Snowflake Token="{self._token_handler.get_ingress_token(self.service_endpoint)}"'
|
|
105
|
+
# needed for oauth, does no harm for other authentication methods
|
|
106
|
+
request_headers["X-SF-SPCS-Authentication-Method"] = 'OAUTH'
|
|
107
|
+
request_headers["Content-Type"] = 'application/x-www-form-urlencoded'
|
|
108
|
+
request_headers["Accept"] = "application/json"
|
|
109
|
+
|
|
110
|
+
request_headers["user-agent"] = get_pyrel_version(self.generation)
|
|
111
|
+
request_headers["pyrel_program_id"] = debugging.get_program_span_id() or ""
|
|
112
|
+
|
|
113
|
+
return debugging.add_current_propagation_headers(request_headers)
|
|
@@ -28,15 +28,11 @@ import requests
|
|
|
28
28
|
import snowflake.connector
|
|
29
29
|
import pyarrow as pa
|
|
30
30
|
|
|
31
|
-
from dataclasses import dataclass
|
|
32
31
|
from snowflake.snowpark import Session
|
|
33
32
|
from snowflake.snowpark.context import get_active_session
|
|
34
33
|
from . import result_helpers
|
|
35
34
|
from .. import debugging
|
|
36
35
|
from typing import Any, Dict, Iterable, Optional, Tuple, List, Literal, Union, cast
|
|
37
|
-
from urllib.parse import urlencode, quote
|
|
38
|
-
from requests.adapters import HTTPAdapter
|
|
39
|
-
from urllib3.util.retry import Retry
|
|
40
36
|
|
|
41
37
|
from pandas import DataFrame
|
|
42
38
|
|
|
@@ -44,10 +40,11 @@ from ..tools.cli_controls import Spinner
|
|
|
44
40
|
from ..clients.types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
|
|
45
41
|
from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
|
|
46
42
|
from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
|
|
43
|
+
from ..clients.direct_access_client import DirectAccessClient
|
|
47
44
|
from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp
|
|
48
45
|
from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
|
|
49
46
|
from .. import dsl, rel, metamodel as m
|
|
50
|
-
from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning
|
|
47
|
+
from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
|
|
51
48
|
from concurrent.futures import ThreadPoolExecutor
|
|
52
49
|
from datetime import datetime, date, timedelta
|
|
53
50
|
from snowflake.snowpark.types import StringType, StructField, StructType
|
|
@@ -92,6 +89,8 @@ TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
|
|
|
92
89
|
|
|
93
90
|
DUO_TEXT = "duo security"
|
|
94
91
|
|
|
92
|
+
TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
93
|
+
|
|
95
94
|
#--------------------------------------------------
|
|
96
95
|
# Helpers
|
|
97
96
|
#--------------------------------------------------
|
|
@@ -1307,7 +1306,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1307
1306
|
response = self._exec(f"CALL {APP_NAME}.api.get_transaction('{txn_id}',{headers});")
|
|
1308
1307
|
assert response, f"No results from get_transaction('{txn_id}')"
|
|
1309
1308
|
|
|
1310
|
-
response_row = next(iter(response))
|
|
1309
|
+
response_row = next(iter(response)).asDict()
|
|
1311
1310
|
status: str = response_row['STATE']
|
|
1312
1311
|
|
|
1313
1312
|
# remove the transaction from the pending list if it's completed or aborted
|
|
@@ -1315,6 +1314,16 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1315
1314
|
if txn_id in self._pending_transactions:
|
|
1316
1315
|
self._pending_transactions.remove(txn_id)
|
|
1317
1316
|
|
|
1317
|
+
if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
1318
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
1319
|
+
# todo: use the timeout returned alongside the transaction as soon as it's exposed
|
|
1320
|
+
timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
1321
|
+
raise QueryTimeoutExceededException(
|
|
1322
|
+
timeout_mins=timeout_mins,
|
|
1323
|
+
query_id=txn_id,
|
|
1324
|
+
config_file_path=config_file_path,
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1318
1327
|
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
1319
1328
|
return status == "COMPLETED" or status == "ABORTED"
|
|
1320
1329
|
|
|
@@ -2957,104 +2966,6 @@ def Graph(
|
|
|
2957
2966
|
#--------------------------------------------------
|
|
2958
2967
|
# Note: All direct access components should live in a separate file
|
|
2959
2968
|
|
|
2960
|
-
@dataclass
|
|
2961
|
-
class Endpoint:
|
|
2962
|
-
method: str
|
|
2963
|
-
endpoint: str
|
|
2964
|
-
|
|
2965
|
-
class DirectAccessClient:
|
|
2966
|
-
"""
|
|
2967
|
-
DirectAccessClient is a client for direct service access without service function calls.
|
|
2968
|
-
"""
|
|
2969
|
-
|
|
2970
|
-
def __init__(self, config: Config, token_handler: TokenHandler, service_endpoint: str, generation: Optional[Generation] = None):
|
|
2971
|
-
self._config: Config = config
|
|
2972
|
-
self._token_handler: TokenHandler = token_handler
|
|
2973
|
-
self.service_endpoint: str = service_endpoint
|
|
2974
|
-
self.generation: Optional[Generation] = generation
|
|
2975
|
-
self._is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
|
|
2976
|
-
self.endpoints: Dict[str, Endpoint] = {
|
|
2977
|
-
"create_txn": Endpoint(method="POST", endpoint="/v1alpha1/transactions"),
|
|
2978
|
-
"get_txn": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}"),
|
|
2979
|
-
"get_txn_artifacts": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/artifacts"),
|
|
2980
|
-
"get_txn_problems": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/problems"),
|
|
2981
|
-
"get_txn_events": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/events/{stream_name}"),
|
|
2982
|
-
"get_package_versions": Endpoint(method="GET", endpoint="/v1alpha1/databases/{db_name}/package_versions"),
|
|
2983
|
-
"get_model_package_versions": Endpoint(method="POST", endpoint="/v1alpha1/models/get_package_versions"),
|
|
2984
|
-
"create_db": Endpoint(method="POST", endpoint="/v1alpha1/databases"),
|
|
2985
|
-
"get_db": Endpoint(method="GET", endpoint="/v1alpha1/databases"),
|
|
2986
|
-
"delete_db": Endpoint(method="DELETE", endpoint="/v1alpha1/databases/{db_name}"),
|
|
2987
|
-
"release_index": Endpoint(method="POST", endpoint="/v1alpha1/index/release"),
|
|
2988
|
-
"list_engines": Endpoint(method="GET", endpoint="/v1alpha1/engines"),
|
|
2989
|
-
"get_engine": Endpoint(method="GET", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
|
|
2990
|
-
"create_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}"),
|
|
2991
|
-
"delete_engine": Endpoint(method="DELETE", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
|
|
2992
|
-
"suspend_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/suspend"),
|
|
2993
|
-
"resume_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/resume_async"),
|
|
2994
|
-
"prepare_index": Endpoint(method="POST", endpoint="/v1alpha1/index/prepare"),
|
|
2995
|
-
}
|
|
2996
|
-
self.http_session = self._create_retry_session()
|
|
2997
|
-
|
|
2998
|
-
def _create_retry_session(self) -> requests.Session:
|
|
2999
|
-
http_session = requests.Session()
|
|
3000
|
-
retries = Retry(
|
|
3001
|
-
total=3,
|
|
3002
|
-
backoff_factor=0.3,
|
|
3003
|
-
status_forcelist=[500, 502, 503, 504],
|
|
3004
|
-
allowed_methods=frozenset({"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}),
|
|
3005
|
-
raise_on_status=False
|
|
3006
|
-
)
|
|
3007
|
-
adapter = HTTPAdapter(max_retries=retries)
|
|
3008
|
-
http_session.mount("http://", adapter)
|
|
3009
|
-
http_session.mount("https://", adapter)
|
|
3010
|
-
http_session.headers.update({"Connection": "keep-alive"})
|
|
3011
|
-
return http_session
|
|
3012
|
-
|
|
3013
|
-
def request(
|
|
3014
|
-
self,
|
|
3015
|
-
endpoint: str,
|
|
3016
|
-
payload: Dict[str, Any] | None = None,
|
|
3017
|
-
headers: Dict[str, str] | None = None,
|
|
3018
|
-
path_params: Dict[str, str] | None = None,
|
|
3019
|
-
query_params: Dict[str, str] | None = None,
|
|
3020
|
-
) -> requests.Response:
|
|
3021
|
-
"""
|
|
3022
|
-
Send a request to the service endpoint.
|
|
3023
|
-
"""
|
|
3024
|
-
url, method = self._prepare_url(endpoint, path_params, query_params)
|
|
3025
|
-
request_headers = self._prepare_headers(headers)
|
|
3026
|
-
return self.http_session.request(method, url, json=payload, headers=request_headers)
|
|
3027
|
-
|
|
3028
|
-
def _prepare_url(self, endpoint: str, path_params: Dict[str, str] | None = None, query_params: Dict[str, str] | None = None) -> Tuple[str, str]:
|
|
3029
|
-
try:
|
|
3030
|
-
ep = self.endpoints[endpoint]
|
|
3031
|
-
except KeyError:
|
|
3032
|
-
raise ValueError(f"Invalid endpoint: {endpoint}. Available endpoints: {list(self.endpoints.keys())}")
|
|
3033
|
-
url = f"{self.service_endpoint}{ep.endpoint}"
|
|
3034
|
-
if path_params:
|
|
3035
|
-
escaped_path_params = {k: quote(v, safe='') for k, v in path_params.items()}
|
|
3036
|
-
url = url.format(**escaped_path_params)
|
|
3037
|
-
if query_params:
|
|
3038
|
-
url += '?' + urlencode(query_params)
|
|
3039
|
-
return url, ep.method
|
|
3040
|
-
|
|
3041
|
-
def _prepare_headers(self, headers: Dict[str, str] | None) -> Dict[str, str]:
|
|
3042
|
-
request_headers = {}
|
|
3043
|
-
if headers:
|
|
3044
|
-
request_headers.update(headers)
|
|
3045
|
-
# Authorization tokens are not needed in a snowflake notebook environment
|
|
3046
|
-
if not self._is_snowflake_notebook:
|
|
3047
|
-
request_headers["Authorization"] = f'Snowflake Token="{self._token_handler.get_ingress_token(self.service_endpoint)}"'
|
|
3048
|
-
# needed for oauth, does no harm for other authentication methods
|
|
3049
|
-
request_headers["X-SF-SPCS-Authentication-Method"] = 'OAUTH'
|
|
3050
|
-
request_headers["Content-Type"] = 'application/x-www-form-urlencoded'
|
|
3051
|
-
request_headers["Accept"] = "application/json"
|
|
3052
|
-
|
|
3053
|
-
request_headers["user-agent"] = get_pyrel_version(self.generation)
|
|
3054
|
-
request_headers["pyrel_program_id"] = debugging.get_program_span_id() or ""
|
|
3055
|
-
|
|
3056
|
-
return debugging.add_current_propagation_headers(request_headers)
|
|
3057
|
-
|
|
3058
2969
|
class DirectAccessResources(Resources):
|
|
3059
2970
|
"""
|
|
3060
2971
|
Resources class for Direct Service Access avoiding Snowflake service functions.
|
|
@@ -3068,7 +2979,14 @@ class DirectAccessResources(Resources):
|
|
|
3068
2979
|
reset_session: bool = False,
|
|
3069
2980
|
generation: Optional[Generation] = None,
|
|
3070
2981
|
):
|
|
3071
|
-
super().__init__(
|
|
2982
|
+
super().__init__(
|
|
2983
|
+
generation=generation,
|
|
2984
|
+
profile=profile,
|
|
2985
|
+
config=config,
|
|
2986
|
+
connection=connection,
|
|
2987
|
+
reset_session=reset_session,
|
|
2988
|
+
dry_run=dry_run,
|
|
2989
|
+
)
|
|
3072
2990
|
self._endpoint_info = ConfigStore(ENDPOINT_FILE)
|
|
3073
2991
|
self._service_endpoint = ""
|
|
3074
2992
|
self._direct_access_client = None
|
|
@@ -3310,13 +3228,24 @@ class DirectAccessResources(Resources):
|
|
|
3310
3228
|
assert response, f"No results from get_transaction('{txn_id}')"
|
|
3311
3229
|
|
|
3312
3230
|
response_content = response.json()
|
|
3313
|
-
|
|
3231
|
+
transaction = response_content["transaction"]
|
|
3232
|
+
status: str = transaction['state']
|
|
3314
3233
|
|
|
3315
3234
|
# remove the transaction from the pending list if it's completed or aborted
|
|
3316
3235
|
if status in ["COMPLETED", "ABORTED"]:
|
|
3317
3236
|
if txn_id in self._pending_transactions:
|
|
3318
3237
|
self._pending_transactions.remove(txn_id)
|
|
3319
3238
|
|
|
3239
|
+
if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
|
|
3240
|
+
config_file_path = getattr(self.config, 'file_path', None)
|
|
3241
|
+
timeout_ms = int(transaction.get("timeout_ms", 0))
|
|
3242
|
+
timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
|
|
3243
|
+
raise QueryTimeoutExceededException(
|
|
3244
|
+
timeout_mins=timeout_mins,
|
|
3245
|
+
query_id=txn_id,
|
|
3246
|
+
config_file_path=config_file_path,
|
|
3247
|
+
)
|
|
3248
|
+
|
|
3320
3249
|
# @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
|
|
3321
3250
|
return status == "COMPLETED" or status == "ABORTED"
|
|
3322
3251
|
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
-
from relationalai.semantics.metamodel.rewrite import
|
|
2
|
-
DNFUnionSplitter,
|
|
1
|
+
from relationalai.semantics.metamodel.rewrite import Flatten, \
|
|
2
|
+
DNFUnionSplitter, ExtractNestedLogicals, flatten
|
|
3
|
+
from relationalai.semantics.lqp.rewrite import Splinter, \
|
|
4
|
+
ExtractKeys, FDConstraints
|
|
3
5
|
|
|
4
|
-
__all__ = ["Splinter", "
|
|
6
|
+
__all__ = ["Splinter", "Flatten", "DNFUnionSplitter", "ExtractKeys",
|
|
5
7
|
"ExtractNestedLogicals", "FDConstraints", "flatten"]
|
relationalai/errors.py
CHANGED
|
@@ -2397,17 +2397,18 @@ class UnsupportedColumnTypesWarning(RAIWarning):
|
|
|
2397
2397
|
""")
|
|
2398
2398
|
|
|
2399
2399
|
class QueryTimeoutExceededException(RAIException):
|
|
2400
|
-
def __init__(self, timeout_mins: int, config_file_path: str | None = None):
|
|
2400
|
+
def __init__(self, timeout_mins: int, query_id: str | None = None, config_file_path: str | None = None):
|
|
2401
2401
|
self.timeout_mins = timeout_mins
|
|
2402
|
-
self.message = f"Query execution time exceeded the specified timeout of {timeout_mins} minutes."
|
|
2403
2402
|
self.name = "Query Timeout Exceeded"
|
|
2403
|
+
self.message = f"Query execution time exceeded the specified timeout of {self.timeout_mins} minutes."
|
|
2404
|
+
self.query_id = query_id or ""
|
|
2404
2405
|
self.config_file_path = config_file_path or ""
|
|
2405
2406
|
self.content = self.format_message()
|
|
2406
2407
|
super().__init__(self.message, self.name, self.content)
|
|
2407
2408
|
|
|
2408
2409
|
def format_message(self):
|
|
2409
2410
|
return textwrap.dedent(f"""
|
|
2410
|
-
{self.
|
|
2411
|
+
Query execution time exceeded the specified timeout of {self.timeout_mins} minutes{f' for query with ID: {self.query_id}' if self.query_id else ''}.
|
|
2411
2412
|
|
|
2412
2413
|
Consider increasing the 'query_timeout_mins' parameter in your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} to allow more time for query execution.
|
|
2413
2414
|
""")
|
|
@@ -2432,3 +2433,23 @@ class AzureUnsupportedQueryTimeoutException(RAIException):
|
|
|
2432
2433
|
Please remove the 'query_timeout_mins' from your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} when running on platform Azure.
|
|
2433
2434
|
""")
|
|
2434
2435
|
|
|
2436
|
+
class AzureLegacyDependencyMissingException(RAIException):
|
|
2437
|
+
def __init__(self):
|
|
2438
|
+
self.message = "The Azure platform requires the 'legacy' extras to be installed."
|
|
2439
|
+
self.name = "Azure Legacy Dependency Missing"
|
|
2440
|
+
self.content = self.format_message()
|
|
2441
|
+
super().__init__(self.message, self.name, self.content)
|
|
2442
|
+
|
|
2443
|
+
def format_message(self):
|
|
2444
|
+
return textwrap.dedent("""
|
|
2445
|
+
The Azure platform requires the 'rai-sdk' package, which is not installed.
|
|
2446
|
+
|
|
2447
|
+
To use the Azure platform, please install the legacy extras:
|
|
2448
|
+
|
|
2449
|
+
pip install relationalai[legacy]
|
|
2450
|
+
|
|
2451
|
+
Or if upgrading an existing installation:
|
|
2452
|
+
|
|
2453
|
+
pip install --upgrade relationalai[legacy]
|
|
2454
|
+
""")
|
|
2455
|
+
|
|
@@ -2,13 +2,12 @@ from relationalai.semantics.metamodel.compiler import Pass
|
|
|
2
2
|
from relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
|
|
3
3
|
from relationalai.semantics.metamodel.typer import Checker, InferTypes, typer
|
|
4
4
|
from relationalai.semantics.metamodel import helpers, types
|
|
5
|
-
from relationalai.semantics.metamodel.rewrite import (Splinter, ExtractNestedLogicals, ExtractKeys, FDConstraints,
|
|
6
|
-
DNFUnionSplitter, DischargeConstraints)
|
|
7
5
|
from relationalai.semantics.metamodel.util import FrozenOrderedSet
|
|
8
6
|
|
|
9
7
|
from relationalai.semantics.metamodel.rewrite import Flatten
|
|
10
|
-
|
|
11
|
-
from
|
|
8
|
+
|
|
9
|
+
from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals
|
|
10
|
+
from .rewrite import CDC, ExtractCommon, ExtractKeys, FDConstraints, QuantifyVars, Splinter
|
|
12
11
|
|
|
13
12
|
from relationalai.semantics.lqp.utils import output_names
|
|
14
13
|
|
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
from .cdc import CDC
|
|
2
2
|
from .extract_common import ExtractCommon
|
|
3
|
+
from .extract_keys import ExtractKeys
|
|
4
|
+
from .fd_constraints import FDConstraints
|
|
3
5
|
from .quantify_vars import QuantifyVars
|
|
6
|
+
from .splinter import Splinter
|
|
4
7
|
|
|
5
8
|
__all__ = [
|
|
6
9
|
"CDC",
|
|
7
10
|
"ExtractCommon",
|
|
11
|
+
"ExtractKeys",
|
|
12
|
+
"FDConstraints",
|
|
8
13
|
"QuantifyVars",
|
|
14
|
+
"Splinter",
|
|
9
15
|
]
|
|
@@ -496,6 +496,17 @@ function = f.relation("function", [f.input_field("code", types.Symbol)])
|
|
|
496
496
|
function_checked_annotation = f.annotation(function, [f.lit("checked")])
|
|
497
497
|
function_annotation = f.annotation(function, [])
|
|
498
498
|
|
|
499
|
+
# Indicates this relation should be tracked in telemetry. Only supported for Relationships.
|
|
500
|
+
# `RAI_BackIR.with_relation_tracking` produces log messages at the start and end of each
|
|
501
|
+
# SCC evaluation, if any declarations bear the `track` annotation.
|
|
502
|
+
track = f.relation("track", [
|
|
503
|
+
# BackIR evaluation expects 2 parameters on the track annotation: the tracking
|
|
504
|
+
# library name and tracking relation name, which appear as log metadata fields.
|
|
505
|
+
f.input_field("library", types.Symbol),
|
|
506
|
+
f.input_field("relation", types.Symbol)
|
|
507
|
+
])
|
|
508
|
+
track_annotation = f.annotation(track, [])
|
|
509
|
+
|
|
499
510
|
# All ir nodes marked by this annotation will be removed from the final metamodel before compilation.
|
|
500
511
|
# Specifically it happens in `Flatten` pass when rewrites for `require` happen
|
|
501
512
|
discharged = f.relation("discharged", [])
|
|
@@ -672,7 +683,7 @@ def _compute_builtin_overloads() -> list[ir.Relation]:
|
|
|
672
683
|
return overloads
|
|
673
684
|
|
|
674
685
|
# manually maintain the list of relations that are actually annotations
|
|
675
|
-
builtin_annotations = [external, export, concept_population, from_cdc, from_cast]
|
|
686
|
+
builtin_annotations = [external, export, concept_population, from_cdc, from_cast, track]
|
|
676
687
|
builtin_annotations_by_name = dict((r.name, r) for r in builtin_annotations)
|
|
677
688
|
|
|
678
689
|
builtin_relations = _compute_builtin_relations()
|
|
@@ -1,12 +1,6 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .list_types import RewriteListTypes
|
|
3
|
-
from .gc_nodes import GarbageCollectNodes
|
|
4
|
-
from .flatten import Flatten
|
|
1
|
+
from .discharge_constraints import DischargeConstraints
|
|
5
2
|
from .dnf_union_splitter import DNFUnionSplitter
|
|
6
|
-
from .extract_keys import ExtractKeys
|
|
7
3
|
from .extract_nested_logicals import ExtractNestedLogicals
|
|
8
|
-
from .
|
|
9
|
-
from .discharge_constraints import DischargeConstraints
|
|
4
|
+
from .flatten import Flatten
|
|
10
5
|
|
|
11
|
-
__all__ = ["
|
|
12
|
-
"ExtractNestedLogicals", "FDConstraints", "DischargeConstraints"]
|
|
6
|
+
__all__ = ["DischargeConstraints", "DNFUnionSplitter", "ExtractNestedLogicals", "Flatten"]
|