relationalai 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. relationalai/clients/config.py +7 -0
  2. relationalai/clients/direct_access_client.py +113 -0
  3. relationalai/clients/snowflake.py +41 -107
  4. relationalai/clients/use_index_poller.py +349 -188
  5. relationalai/early_access/dsl/bindings/csv.py +2 -2
  6. relationalai/early_access/metamodel/rewrite/__init__.py +5 -3
  7. relationalai/early_access/rel/rewrite/__init__.py +1 -1
  8. relationalai/errors.py +24 -3
  9. relationalai/semantics/internal/annotations.py +1 -0
  10. relationalai/semantics/internal/internal.py +22 -4
  11. relationalai/semantics/lqp/builtins.py +1 -0
  12. relationalai/semantics/lqp/executor.py +61 -12
  13. relationalai/semantics/lqp/intrinsics.py +23 -0
  14. relationalai/semantics/lqp/model2lqp.py +13 -4
  15. relationalai/semantics/lqp/passes.py +4 -6
  16. relationalai/semantics/lqp/primitives.py +12 -1
  17. relationalai/semantics/{rel → lqp}/rewrite/__init__.py +6 -0
  18. relationalai/semantics/lqp/rewrite/extract_common.py +362 -0
  19. relationalai/semantics/metamodel/builtins.py +20 -2
  20. relationalai/semantics/metamodel/factory.py +3 -2
  21. relationalai/semantics/metamodel/rewrite/__init__.py +3 -9
  22. relationalai/semantics/reasoners/graph/core.py +273 -71
  23. relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
  24. relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
  25. relationalai/semantics/rel/builtins.py +5 -1
  26. relationalai/semantics/rel/compiler.py +7 -19
  27. relationalai/semantics/rel/executor.py +2 -2
  28. relationalai/semantics/rel/rel.py +6 -0
  29. relationalai/semantics/rel/rel_utils.py +8 -1
  30. relationalai/semantics/sql/compiler.py +122 -42
  31. relationalai/semantics/sql/executor/duck_db.py +28 -3
  32. relationalai/semantics/sql/rewrite/denormalize.py +4 -6
  33. relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
  34. relationalai/semantics/sql/sql.py +27 -0
  35. relationalai/semantics/std/__init__.py +2 -1
  36. relationalai/semantics/std/datetime.py +4 -0
  37. relationalai/semantics/std/re.py +83 -0
  38. relationalai/semantics/std/strings.py +1 -1
  39. relationalai/tools/cli.py +11 -4
  40. relationalai/tools/cli_controls.py +445 -60
  41. relationalai/util/format.py +78 -1
  42. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/METADATA +7 -5
  43. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/RECORD +51 -50
  44. relationalai/semantics/metamodel/rewrite/gc_nodes.py +0 -58
  45. relationalai/semantics/metamodel/rewrite/list_types.py +0 -109
  46. relationalai/semantics/rel/rewrite/extract_common.py +0 -451
  47. /relationalai/semantics/{rel → lqp}/rewrite/cdc.py +0 -0
  48. /relationalai/semantics/{metamodel → lqp}/rewrite/extract_keys.py +0 -0
  49. /relationalai/semantics/{metamodel → lqp}/rewrite/fd_constraints.py +0 -0
  50. /relationalai/semantics/{rel → lqp}/rewrite/quantify_vars.py +0 -0
  51. /relationalai/semantics/{metamodel → lqp}/rewrite/splinter.py +0 -0
  52. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/WHEEL +0 -0
  53. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/entry_points.txt +0 -0
  54. {relationalai-0.11.3.dist-info → relationalai-0.12.0.dist-info}/licenses/LICENSE +0 -0
@@ -468,6 +468,13 @@ class Config():
468
468
  if not self.file_path:
469
469
  self.file_path = "__inline__"
470
470
  self._handle_snowflake_fallback_configurations()
471
+ # Check if Azure platform is being used without the legacy dependency
472
+ if self.get("platform", "") == "azure":
473
+ try:
474
+ import railib # noqa
475
+ except ImportError:
476
+ from relationalai.errors import AzureLegacyDependencyMissingException
477
+ raise AzureLegacyDependencyMissingException() from None
471
478
 
472
479
  def fetch(self, profile:str|None=None):
473
480
  from relationalai.environments import runtime_env, TerminalEnvironment
@@ -0,0 +1,113 @@
1
+ from __future__ import annotations
2
+
3
+ import requests
4
+ from dataclasses import dataclass
5
+ from urllib.parse import urlencode, quote
6
+ from requests.adapters import HTTPAdapter
7
+ from urllib3.util.retry import Retry
8
+ from typing import Any, Dict, Optional, Tuple
9
+
10
+ from relationalai.auth.token_handler import TokenHandler
11
+ from relationalai.clients.config import Config
12
+ from relationalai.clients.util import get_pyrel_version
13
+ from relationalai import debugging
14
+ from relationalai.tools.constants import Generation
15
+ from relationalai.environments import runtime_env, SnowbookEnvironment
16
+
17
+ @dataclass
18
+ class Endpoint:
19
+ method: str
20
+ endpoint: str
21
+
22
+ class DirectAccessClient:
23
+ """
24
+ DirectAccessClient is a client for direct service access without service function calls.
25
+ """
26
+
27
+ def __init__(self, config: Config, token_handler: TokenHandler, service_endpoint: str, generation: Optional[Generation] = None):
28
+ self._config: Config = config
29
+ self._token_handler: TokenHandler = token_handler
30
+ self.service_endpoint: str = service_endpoint
31
+ self.generation: Optional[Generation] = generation
32
+ self._is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
33
+ self.endpoints: Dict[str, Endpoint] = {
34
+ "create_txn": Endpoint(method="POST", endpoint="/v1alpha1/transactions"),
35
+ "get_txn": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}"),
36
+ "get_txn_artifacts": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/artifacts"),
37
+ "get_txn_problems": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/problems"),
38
+ "get_txn_events": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/events/{stream_name}"),
39
+ "get_package_versions": Endpoint(method="GET", endpoint="/v1alpha1/databases/{db_name}/package_versions"),
40
+ "get_model_package_versions": Endpoint(method="POST", endpoint="/v1alpha1/models/get_package_versions"),
41
+ "create_db": Endpoint(method="POST", endpoint="/v1alpha1/databases"),
42
+ "get_db": Endpoint(method="GET", endpoint="/v1alpha1/databases"),
43
+ "delete_db": Endpoint(method="DELETE", endpoint="/v1alpha1/databases/{db_name}"),
44
+ "release_index": Endpoint(method="POST", endpoint="/v1alpha1/index/release"),
45
+ "list_engines": Endpoint(method="GET", endpoint="/v1alpha1/engines"),
46
+ "get_engine": Endpoint(method="GET", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
47
+ "create_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}"),
48
+ "delete_engine": Endpoint(method="DELETE", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
49
+ "suspend_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/suspend"),
50
+ "resume_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/resume_async"),
51
+ "prepare_index": Endpoint(method="POST", endpoint="/v1alpha1/index/prepare"),
52
+ }
53
+ self.http_session = self._create_retry_session()
54
+
55
+ def _create_retry_session(self) -> requests.Session:
56
+ http_session = requests.Session()
57
+ retries = Retry(
58
+ total=3,
59
+ backoff_factor=0.3,
60
+ status_forcelist=[500, 502, 503, 504],
61
+ allowed_methods=frozenset({"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}),
62
+ raise_on_status=False
63
+ )
64
+ adapter = HTTPAdapter(max_retries=retries)
65
+ http_session.mount("http://", adapter)
66
+ http_session.mount("https://", adapter)
67
+ http_session.headers.update({"Connection": "keep-alive"})
68
+ return http_session
69
+
70
+ def request(
71
+ self,
72
+ endpoint: str,
73
+ payload: Dict[str, Any] | None = None,
74
+ headers: Dict[str, str] | None = None,
75
+ path_params: Dict[str, str] | None = None,
76
+ query_params: Dict[str, str] | None = None,
77
+ ) -> requests.Response:
78
+ """
79
+ Send a request to the service endpoint.
80
+ """
81
+ url, method = self._prepare_url(endpoint, path_params, query_params)
82
+ request_headers = self._prepare_headers(headers)
83
+ return self.http_session.request(method, url, json=payload, headers=request_headers)
84
+
85
+ def _prepare_url(self, endpoint: str, path_params: Dict[str, str] | None = None, query_params: Dict[str, str] | None = None) -> Tuple[str, str]:
86
+ try:
87
+ ep = self.endpoints[endpoint]
88
+ except KeyError:
89
+ raise ValueError(f"Invalid endpoint: {endpoint}. Available endpoints: {list(self.endpoints.keys())}")
90
+ url = f"{self.service_endpoint}{ep.endpoint}"
91
+ if path_params:
92
+ escaped_path_params = {k: quote(v, safe='') for k, v in path_params.items()}
93
+ url = url.format(**escaped_path_params)
94
+ if query_params:
95
+ url += '?' + urlencode(query_params)
96
+ return url, ep.method
97
+
98
+ def _prepare_headers(self, headers: Dict[str, str] | None) -> Dict[str, str]:
99
+ request_headers = {}
100
+ if headers:
101
+ request_headers.update(headers)
102
+ # Authorization tokens are not needed in a snowflake notebook environment
103
+ if not self._is_snowflake_notebook:
104
+ request_headers["Authorization"] = f'Snowflake Token="{self._token_handler.get_ingress_token(self.service_endpoint)}"'
105
+ # needed for oauth, does no harm for other authentication methods
106
+ request_headers["X-SF-SPCS-Authentication-Method"] = 'OAUTH'
107
+ request_headers["Content-Type"] = 'application/x-www-form-urlencoded'
108
+ request_headers["Accept"] = "application/json"
109
+
110
+ request_headers["user-agent"] = get_pyrel_version(self.generation)
111
+ request_headers["pyrel_program_id"] = debugging.get_program_span_id() or ""
112
+
113
+ return debugging.add_current_propagation_headers(request_headers)
@@ -28,15 +28,11 @@ import requests
28
28
  import snowflake.connector
29
29
  import pyarrow as pa
30
30
 
31
- from dataclasses import dataclass
32
31
  from snowflake.snowpark import Session
33
32
  from snowflake.snowpark.context import get_active_session
34
33
  from . import result_helpers
35
34
  from .. import debugging
36
35
  from typing import Any, Dict, Iterable, Optional, Tuple, List, Literal, Union, cast
37
- from urllib.parse import urlencode, quote
38
- from requests.adapters import HTTPAdapter
39
- from urllib3.util.retry import Retry
40
36
 
41
37
  from pandas import DataFrame
42
38
 
@@ -44,10 +40,11 @@ from ..tools.cli_controls import Spinner
44
40
  from ..clients.types import AvailableModel, EngineState, Import, ImportSource, ImportSourceTable, ImportsStatus, SourceInfo, TransactionAsyncResponse
45
41
  from ..clients.config import Config, ConfigStore, ENDPOINT_FILE
46
42
  from ..clients.client import Client, ExportParams, ProviderBase, ResourcesBase
43
+ from ..clients.direct_access_client import DirectAccessClient
47
44
  from ..clients.util import IdentityParser, escape_for_f_string, get_pyrel_version, get_with_retries, poll_with_specified_overhead, safe_json_loads, sanitize_module_name, scrub_exception, wrap_with_request_id, ms_to_timestamp
48
45
  from ..environments import runtime_env, HexEnvironment, SnowbookEnvironment
49
46
  from .. import dsl, rel, metamodel as m
50
- from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning
47
+ from ..errors import DuoSecurityFailed, EngineProvisioningFailed, EngineNameValidationException, EngineNotFoundException, EnginePending, EngineSizeMismatchWarning, EngineResumeFailed, Errors, InvalidAliasError, InvalidEngineSizeError, InvalidSourceTypeWarning, RAIAbortedTransactionError, RAIException, HexSessionException, SnowflakeAppMissingException, SnowflakeChangeTrackingNotEnabledException, SnowflakeDatabaseException, SnowflakeImportMissingException, SnowflakeInvalidSource, SnowflakeMissingConfigValuesException, SnowflakeProxyAPIDeprecationWarning, SnowflakeProxySourceError, SnowflakeRaiAppNotStarted, ModelNotFoundException, UnknownSourceWarning, ResponseStatusException, RowsDroppedFromTargetTableWarning, QueryTimeoutExceededException
51
48
  from concurrent.futures import ThreadPoolExecutor
52
49
  from datetime import datetime, date, timedelta
53
50
  from snowflake.snowpark.types import StringType, StructField, StructType
@@ -92,6 +89,8 @@ TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
92
89
 
93
90
  DUO_TEXT = "duo security"
94
91
 
92
+ TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
93
+
95
94
  #--------------------------------------------------
96
95
  # Helpers
97
96
  #--------------------------------------------------
@@ -1307,7 +1306,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1307
1306
  response = self._exec(f"CALL {APP_NAME}.api.get_transaction('{txn_id}',{headers});")
1308
1307
  assert response, f"No results from get_transaction('{txn_id}')"
1309
1308
 
1310
- response_row = next(iter(response))
1309
+ response_row = next(iter(response)).asDict()
1311
1310
  status: str = response_row['STATE']
1312
1311
 
1313
1312
  # remove the transaction from the pending list if it's completed or aborted
@@ -1315,6 +1314,16 @@ Otherwise, remove it from your '{profile}' configuration profile.
1315
1314
  if txn_id in self._pending_transactions:
1316
1315
  self._pending_transactions.remove(txn_id)
1317
1316
 
1317
+ if status == "ABORTED" and response_row.get("ABORT_REASON", "") == TXN_ABORT_REASON_TIMEOUT:
1318
+ config_file_path = getattr(self.config, 'file_path', None)
1319
+ # todo: use the timeout returned alongside the transaction as soon as it's exposed
1320
+ timeout_mins = int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
1321
+ raise QueryTimeoutExceededException(
1322
+ timeout_mins=timeout_mins,
1323
+ query_id=txn_id,
1324
+ config_file_path=config_file_path,
1325
+ )
1326
+
1318
1327
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
1319
1328
  return status == "COMPLETED" or status == "ABORTED"
1320
1329
 
@@ -2957,104 +2966,6 @@ def Graph(
2957
2966
  #--------------------------------------------------
2958
2967
  # Note: All direct access components should live in a separate file
2959
2968
 
2960
- @dataclass
2961
- class Endpoint:
2962
- method: str
2963
- endpoint: str
2964
-
2965
- class DirectAccessClient:
2966
- """
2967
- DirectAccessClient is a client for direct service access without service function calls.
2968
- """
2969
-
2970
- def __init__(self, config: Config, token_handler: TokenHandler, service_endpoint: str, generation: Optional[Generation] = None):
2971
- self._config: Config = config
2972
- self._token_handler: TokenHandler = token_handler
2973
- self.service_endpoint: str = service_endpoint
2974
- self.generation: Optional[Generation] = generation
2975
- self._is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
2976
- self.endpoints: Dict[str, Endpoint] = {
2977
- "create_txn": Endpoint(method="POST", endpoint="/v1alpha1/transactions"),
2978
- "get_txn": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}"),
2979
- "get_txn_artifacts": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/artifacts"),
2980
- "get_txn_problems": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/problems"),
2981
- "get_txn_events": Endpoint(method="GET", endpoint="/v1alpha1/transactions/{txn_id}/events/{stream_name}"),
2982
- "get_package_versions": Endpoint(method="GET", endpoint="/v1alpha1/databases/{db_name}/package_versions"),
2983
- "get_model_package_versions": Endpoint(method="POST", endpoint="/v1alpha1/models/get_package_versions"),
2984
- "create_db": Endpoint(method="POST", endpoint="/v1alpha1/databases"),
2985
- "get_db": Endpoint(method="GET", endpoint="/v1alpha1/databases"),
2986
- "delete_db": Endpoint(method="DELETE", endpoint="/v1alpha1/databases/{db_name}"),
2987
- "release_index": Endpoint(method="POST", endpoint="/v1alpha1/index/release"),
2988
- "list_engines": Endpoint(method="GET", endpoint="/v1alpha1/engines"),
2989
- "get_engine": Endpoint(method="GET", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
2990
- "create_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}"),
2991
- "delete_engine": Endpoint(method="DELETE", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}"),
2992
- "suspend_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/suspend"),
2993
- "resume_engine": Endpoint(method="POST", endpoint="/v1alpha1/engines/{engine_type}/{engine_name}/resume_async"),
2994
- "prepare_index": Endpoint(method="POST", endpoint="/v1alpha1/index/prepare"),
2995
- }
2996
- self.http_session = self._create_retry_session()
2997
-
2998
- def _create_retry_session(self) -> requests.Session:
2999
- http_session = requests.Session()
3000
- retries = Retry(
3001
- total=3,
3002
- backoff_factor=0.3,
3003
- status_forcelist=[500, 502, 503, 504],
3004
- allowed_methods=frozenset({"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}),
3005
- raise_on_status=False
3006
- )
3007
- adapter = HTTPAdapter(max_retries=retries)
3008
- http_session.mount("http://", adapter)
3009
- http_session.mount("https://", adapter)
3010
- http_session.headers.update({"Connection": "keep-alive"})
3011
- return http_session
3012
-
3013
- def request(
3014
- self,
3015
- endpoint: str,
3016
- payload: Dict[str, Any] | None = None,
3017
- headers: Dict[str, str] | None = None,
3018
- path_params: Dict[str, str] | None = None,
3019
- query_params: Dict[str, str] | None = None,
3020
- ) -> requests.Response:
3021
- """
3022
- Send a request to the service endpoint.
3023
- """
3024
- url, method = self._prepare_url(endpoint, path_params, query_params)
3025
- request_headers = self._prepare_headers(headers)
3026
- return self.http_session.request(method, url, json=payload, headers=request_headers)
3027
-
3028
- def _prepare_url(self, endpoint: str, path_params: Dict[str, str] | None = None, query_params: Dict[str, str] | None = None) -> Tuple[str, str]:
3029
- try:
3030
- ep = self.endpoints[endpoint]
3031
- except KeyError:
3032
- raise ValueError(f"Invalid endpoint: {endpoint}. Available endpoints: {list(self.endpoints.keys())}")
3033
- url = f"{self.service_endpoint}{ep.endpoint}"
3034
- if path_params:
3035
- escaped_path_params = {k: quote(v, safe='') for k, v in path_params.items()}
3036
- url = url.format(**escaped_path_params)
3037
- if query_params:
3038
- url += '?' + urlencode(query_params)
3039
- return url, ep.method
3040
-
3041
- def _prepare_headers(self, headers: Dict[str, str] | None) -> Dict[str, str]:
3042
- request_headers = {}
3043
- if headers:
3044
- request_headers.update(headers)
3045
- # Authorization tokens are not needed in a snowflake notebook environment
3046
- if not self._is_snowflake_notebook:
3047
- request_headers["Authorization"] = f'Snowflake Token="{self._token_handler.get_ingress_token(self.service_endpoint)}"'
3048
- # needed for oauth, does no harm for other authentication methods
3049
- request_headers["X-SF-SPCS-Authentication-Method"] = 'OAUTH'
3050
- request_headers["Content-Type"] = 'application/x-www-form-urlencoded'
3051
- request_headers["Accept"] = "application/json"
3052
-
3053
- request_headers["user-agent"] = get_pyrel_version(self.generation)
3054
- request_headers["pyrel_program_id"] = debugging.get_program_span_id() or ""
3055
-
3056
- return debugging.add_current_propagation_headers(request_headers)
3057
-
3058
2969
  class DirectAccessResources(Resources):
3059
2970
  """
3060
2971
  Resources class for Direct Service Access avoiding Snowflake service functions.
@@ -3068,7 +2979,14 @@ class DirectAccessResources(Resources):
3068
2979
  reset_session: bool = False,
3069
2980
  generation: Optional[Generation] = None,
3070
2981
  ):
3071
- super().__init__(generation=generation, profile=profile, config=config, connection=connection, dry_run=dry_run)
2982
+ super().__init__(
2983
+ generation=generation,
2984
+ profile=profile,
2985
+ config=config,
2986
+ connection=connection,
2987
+ reset_session=reset_session,
2988
+ dry_run=dry_run,
2989
+ )
3072
2990
  self._endpoint_info = ConfigStore(ENDPOINT_FILE)
3073
2991
  self._service_endpoint = ""
3074
2992
  self._direct_access_client = None
@@ -3140,7 +3058,12 @@ class DirectAccessResources(Resources):
3140
3058
  try:
3141
3059
  response = _send_request()
3142
3060
  if response.status_code != 200:
3143
- message = response.json().get("message", "")
3061
+ try:
3062
+ message = response.json().get("message", "")
3063
+ except requests.exceptions.JSONDecodeError:
3064
+ raise ResponseStatusException(
3065
+ f"Failed to parse error response from endpoint {endpoint}.", response
3066
+ )
3144
3067
 
3145
3068
  # fix engine on engine error and retry
3146
3069
  if _is_engine_issue(message):
@@ -3305,13 +3228,24 @@ class DirectAccessResources(Resources):
3305
3228
  assert response, f"No results from get_transaction('{txn_id}')"
3306
3229
 
3307
3230
  response_content = response.json()
3308
- status: str = response_content["transaction"]['state']
3231
+ transaction = response_content["transaction"]
3232
+ status: str = transaction['state']
3309
3233
 
3310
3234
  # remove the transaction from the pending list if it's completed or aborted
3311
3235
  if status in ["COMPLETED", "ABORTED"]:
3312
3236
  if txn_id in self._pending_transactions:
3313
3237
  self._pending_transactions.remove(txn_id)
3314
3238
 
3239
+ if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
3240
+ config_file_path = getattr(self.config, 'file_path', None)
3241
+ timeout_ms = int(transaction.get("timeout_ms", 0))
3242
+ timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
3243
+ raise QueryTimeoutExceededException(
3244
+ timeout_mins=timeout_mins,
3245
+ query_id=txn_id,
3246
+ config_file_path=config_file_path,
3247
+ )
3248
+
3315
3249
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
3316
3250
  return status == "COMPLETED" or status == "ABORTED"
3317
3251