apache-airflow-providers-snowflake 6.3.0__py3-none-any.whl → 6.8.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/snowflake/__init__.py +3 -3
- airflow/providers/snowflake/decorators/snowpark.py +2 -12
- airflow/providers/snowflake/get_provider_info.py +16 -0
- airflow/providers/snowflake/hooks/snowflake.py +100 -37
- airflow/providers/snowflake/hooks/snowflake_sql_api.py +226 -29
- airflow/providers/snowflake/operators/snowflake.py +37 -27
- airflow/providers/snowflake/operators/snowpark.py +2 -2
- airflow/providers/snowflake/transfers/copy_into_snowflake.py +13 -4
- airflow/providers/snowflake/triggers/snowflake_trigger.py +1 -4
- airflow/providers/snowflake/utils/openlineage.py +141 -93
- airflow/providers/snowflake/utils/snowpark.py +2 -1
- airflow/providers/snowflake/version_compat.py +4 -0
- {apache_airflow_providers_snowflake-6.3.0.dist-info → apache_airflow_providers_snowflake-6.8.0rc1.dist-info}/METADATA +61 -38
- apache_airflow_providers_snowflake-6.8.0rc1.dist-info/RECORD +26 -0
- apache_airflow_providers_snowflake-6.8.0rc1.dist-info/licenses/NOTICE +5 -0
- apache_airflow_providers_snowflake-6.3.0.dist-info/RECORD +0 -25
- {apache_airflow_providers_snowflake-6.3.0.dist-info → apache_airflow_providers_snowflake-6.8.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_snowflake-6.3.0.dist-info → apache_airflow_providers_snowflake-6.8.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/snowflake → apache_airflow_providers_snowflake-6.8.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
19
|
import base64
|
|
20
|
+
import time
|
|
20
21
|
import uuid
|
|
21
22
|
import warnings
|
|
22
23
|
from datetime import timedelta
|
|
@@ -25,10 +26,21 @@ from typing import Any
|
|
|
25
26
|
|
|
26
27
|
import aiohttp
|
|
27
28
|
import requests
|
|
29
|
+
from aiohttp import ClientConnectionError, ClientResponseError
|
|
28
30
|
from cryptography.hazmat.backends import default_backend
|
|
29
31
|
from cryptography.hazmat.primitives import serialization
|
|
30
|
-
|
|
31
|
-
from
|
|
32
|
+
from requests.exceptions import ConnectionError, HTTPError, Timeout
|
|
33
|
+
from tenacity import (
|
|
34
|
+
AsyncRetrying,
|
|
35
|
+
Retrying,
|
|
36
|
+
before_sleep_log,
|
|
37
|
+
retry_if_exception,
|
|
38
|
+
stop_after_attempt,
|
|
39
|
+
wait_exponential,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
43
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
32
44
|
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
|
|
33
45
|
from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator
|
|
34
46
|
|
|
@@ -65,6 +77,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
|
|
|
65
77
|
:param token_life_time: lifetime of the JWT Token in timedelta
|
|
66
78
|
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
|
|
67
79
|
:param deferrable: Run operator in the deferrable mode.
|
|
80
|
+
:param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
|
|
68
81
|
"""
|
|
69
82
|
|
|
70
83
|
LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime
|
|
@@ -75,15 +88,27 @@ class SnowflakeSqlApiHook(SnowflakeHook):
|
|
|
75
88
|
snowflake_conn_id: str,
|
|
76
89
|
token_life_time: timedelta = LIFETIME,
|
|
77
90
|
token_renewal_delta: timedelta = RENEWAL_DELTA,
|
|
91
|
+
api_retry_args: dict[Any, Any] | None = None, # Optional retry arguments passed to tenacity.retry
|
|
78
92
|
*args: Any,
|
|
79
93
|
**kwargs: Any,
|
|
80
94
|
):
|
|
81
95
|
self.snowflake_conn_id = snowflake_conn_id
|
|
82
96
|
self.token_life_time = token_life_time
|
|
83
97
|
self.token_renewal_delta = token_renewal_delta
|
|
98
|
+
|
|
84
99
|
super().__init__(snowflake_conn_id, *args, **kwargs)
|
|
85
100
|
self.private_key: Any = None
|
|
86
101
|
|
|
102
|
+
self.retry_config = {
|
|
103
|
+
"retry": retry_if_exception(self._should_retry_on_error),
|
|
104
|
+
"wait": wait_exponential(multiplier=1, min=1, max=60),
|
|
105
|
+
"stop": stop_after_attempt(5),
|
|
106
|
+
"before_sleep": before_sleep_log(self.log, log_level=20), # type: ignore[arg-type]
|
|
107
|
+
"reraise": True,
|
|
108
|
+
}
|
|
109
|
+
if api_retry_args:
|
|
110
|
+
self.retry_config.update(api_retry_args)
|
|
111
|
+
|
|
87
112
|
def get_private_key(self) -> None:
|
|
88
113
|
"""Get the private key from snowflake connection."""
|
|
89
114
|
conn = self.get_connection(self.snowflake_conn_id)
|
|
@@ -137,6 +162,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
|
|
|
137
162
|
When executing the statement, Snowflake replaces placeholders (? and :name) in
|
|
138
163
|
the statement with these specified values.
|
|
139
164
|
"""
|
|
165
|
+
self.query_ids = []
|
|
140
166
|
conn_config = self._get_conn_params
|
|
141
167
|
|
|
142
168
|
req_id = uuid.uuid4()
|
|
@@ -167,13 +193,8 @@ class SnowflakeSqlApiHook(SnowflakeHook):
|
|
|
167
193
|
"query_tag": query_tag,
|
|
168
194
|
},
|
|
169
195
|
}
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
response.raise_for_status()
|
|
173
|
-
except requests.exceptions.HTTPError as e: # pragma: no cover
|
|
174
|
-
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
|
|
175
|
-
raise AirflowException(msg)
|
|
176
|
-
json_response = response.json()
|
|
196
|
+
|
|
197
|
+
_, json_response = self._make_api_call_with_retries("POST", url, headers, params, data)
|
|
177
198
|
self.log.info("Snowflake SQL POST API response: %s", json_response)
|
|
178
199
|
if "statementHandles" in json_response:
|
|
179
200
|
self.query_ids = json_response["statementHandles"]
|
|
@@ -222,25 +243,37 @@ class SnowflakeSqlApiHook(SnowflakeHook):
|
|
|
222
243
|
}
|
|
223
244
|
return headers
|
|
224
245
|
|
|
225
|
-
def get_oauth_token(
|
|
246
|
+
def get_oauth_token(
|
|
247
|
+
self,
|
|
248
|
+
conn_config: dict[str, Any] | None = None,
|
|
249
|
+
token_endpoint: str | None = None,
|
|
250
|
+
grant_type: str = "refresh_token",
|
|
251
|
+
) -> str:
|
|
226
252
|
"""Generate temporary OAuth access token using refresh token in connection details."""
|
|
227
253
|
warnings.warn(
|
|
228
254
|
"This method is deprecated. Please use `get_oauth_token` method from `SnowflakeHook` instead. ",
|
|
229
255
|
AirflowProviderDeprecationWarning,
|
|
230
256
|
stacklevel=2,
|
|
231
257
|
)
|
|
232
|
-
return super().get_oauth_token(
|
|
258
|
+
return super().get_oauth_token(
|
|
259
|
+
conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type
|
|
260
|
+
)
|
|
233
261
|
|
|
234
|
-
def get_request_url_header_params(
|
|
262
|
+
def get_request_url_header_params(
|
|
263
|
+
self, query_id: str, url_suffix: str | None = None
|
|
264
|
+
) -> tuple[dict[str, Any], dict[str, Any], str]:
|
|
235
265
|
"""
|
|
236
266
|
Build the request header Url with account name identifier and query id from the connection params.
|
|
237
267
|
|
|
238
268
|
:param query_id: statement handles query ids for the individual statements.
|
|
269
|
+
:param url_suffix: Optional path suffix to append to the URL. Must start with '/', e.g. '/cancel' or '/result'.
|
|
239
270
|
"""
|
|
240
271
|
req_id = uuid.uuid4()
|
|
241
272
|
header = self.get_headers()
|
|
242
273
|
params = {"requestId": str(req_id)}
|
|
243
274
|
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements/{query_id}"
|
|
275
|
+
if url_suffix:
|
|
276
|
+
url += url_suffix
|
|
244
277
|
return header, params, url
|
|
245
278
|
|
|
246
279
|
def check_query_output(self, query_ids: list[str]) -> None:
|
|
@@ -251,20 +284,31 @@ class SnowflakeSqlApiHook(SnowflakeHook):
|
|
|
251
284
|
"""
|
|
252
285
|
for query_id in query_ids:
|
|
253
286
|
header, params, url = self.get_request_url_header_params(query_id)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
except requests.exceptions.HTTPError as e:
|
|
259
|
-
msg = f"Response: {e.response.content.decode()}, Status Code: {e.response.status_code}"
|
|
260
|
-
raise AirflowException(msg)
|
|
287
|
+
_, response_json = self._make_api_call_with_retries(
|
|
288
|
+
method="GET", url=url, headers=header, params=params
|
|
289
|
+
)
|
|
290
|
+
self.log.info(response_json)
|
|
261
291
|
|
|
262
292
|
def _process_response(self, status_code, resp):
|
|
263
293
|
self.log.info("Snowflake SQL GET statements status API response: %s", resp)
|
|
264
294
|
if status_code == 202:
|
|
265
295
|
return {"status": "running", "message": "Query statements are still running"}
|
|
266
296
|
if status_code == 422:
|
|
267
|
-
|
|
297
|
+
error_message = resp.get("message", "Unknown error occurred")
|
|
298
|
+
error_details = []
|
|
299
|
+
if code := resp.get("code"):
|
|
300
|
+
error_details.append(f"Code: {code}")
|
|
301
|
+
if sql_state := resp.get("sqlState"):
|
|
302
|
+
error_details.append(f"SQL State: {sql_state}")
|
|
303
|
+
if statement_handle := resp.get("statementHandle"):
|
|
304
|
+
error_details.append(f"Statement Handle: {statement_handle}")
|
|
305
|
+
|
|
306
|
+
if error_details:
|
|
307
|
+
enhanced_message = f"{error_message} ({', '.join(error_details)})"
|
|
308
|
+
else:
|
|
309
|
+
enhanced_message = error_message
|
|
310
|
+
|
|
311
|
+
return {"status": "error", "message": enhanced_message}
|
|
268
312
|
if status_code == 200:
|
|
269
313
|
if resp_statement_handles := resp.get("statementHandles"):
|
|
270
314
|
statement_handles = resp_statement_handles
|
|
@@ -287,11 +331,83 @@ class SnowflakeSqlApiHook(SnowflakeHook):
|
|
|
287
331
|
"""
|
|
288
332
|
self.log.info("Retrieving status for query id %s", query_id)
|
|
289
333
|
header, params, url = self.get_request_url_header_params(query_id)
|
|
290
|
-
|
|
291
|
-
status_code = response.status_code
|
|
292
|
-
resp = response.json()
|
|
334
|
+
status_code, resp = self._make_api_call_with_retries("GET", url, header, params)
|
|
293
335
|
return self._process_response(status_code, resp)
|
|
294
336
|
|
|
337
|
+
def wait_for_query(
|
|
338
|
+
self, query_id: str, raise_error: bool = False, poll_interval: int = 5, timeout: int = 60
|
|
339
|
+
) -> dict[str, str | list[str]]:
|
|
340
|
+
"""
|
|
341
|
+
Wait for query to finish either successfully or with error.
|
|
342
|
+
|
|
343
|
+
:param query_id: statement handle id for the individual statement.
|
|
344
|
+
:param raise_error: whether to raise an error if the query failed.
|
|
345
|
+
:param poll_interval: time (in seconds) between checking the query status.
|
|
346
|
+
:param timeout: max time (in seconds) to wait for the query to finish before raising a TimeoutError.
|
|
347
|
+
|
|
348
|
+
:raises RuntimeError: If the query status is 'error' and `raise_error` is True.
|
|
349
|
+
:raises TimeoutError: If the query doesn't finish within the specified timeout.
|
|
350
|
+
"""
|
|
351
|
+
start_time = time.time()
|
|
352
|
+
|
|
353
|
+
while True:
|
|
354
|
+
response = self.get_sql_api_query_status(query_id=query_id)
|
|
355
|
+
self.log.debug("Query status `%s`", response["status"])
|
|
356
|
+
|
|
357
|
+
if time.time() - start_time > timeout:
|
|
358
|
+
raise TimeoutError(
|
|
359
|
+
f"Query `{query_id}` did not finish within the timeout period of {timeout} seconds."
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
if response["status"] != "running":
|
|
363
|
+
self.log.info("Query status `%s`", response["status"])
|
|
364
|
+
break
|
|
365
|
+
|
|
366
|
+
time.sleep(poll_interval)
|
|
367
|
+
|
|
368
|
+
if response["status"] == "error" and raise_error:
|
|
369
|
+
raise RuntimeError(response["message"])
|
|
370
|
+
|
|
371
|
+
return response
|
|
372
|
+
|
|
373
|
+
def get_result_from_successful_sql_api_query(self, query_id: str) -> list[dict[str, Any]]:
|
|
374
|
+
"""
|
|
375
|
+
Based on the query id HTTP requests are made to snowflake SQL API and return result data.
|
|
376
|
+
|
|
377
|
+
:param query_id: statement handle id for the individual statement.
|
|
378
|
+
|
|
379
|
+
:raises RuntimeError: If the query status is not 'success'.
|
|
380
|
+
"""
|
|
381
|
+
self.log.info("Retrieving data for query id %s", query_id)
|
|
382
|
+
header, params, url = self.get_request_url_header_params(query_id)
|
|
383
|
+
status_code, response = self._make_api_call_with_retries("GET", url, header, params)
|
|
384
|
+
|
|
385
|
+
if (query_status := self._process_response(status_code, response)["status"]) != "success":
|
|
386
|
+
msg = f"Query must have status `success` to retrieve data; got `{query_status}`."
|
|
387
|
+
raise RuntimeError(msg)
|
|
388
|
+
|
|
389
|
+
# Below fields should always be present in response, but added some safety checks
|
|
390
|
+
data = response.get("data", [])
|
|
391
|
+
if not data:
|
|
392
|
+
self.log.warning("No data found in the API response.")
|
|
393
|
+
return []
|
|
394
|
+
metadata = response.get("resultSetMetaData", {})
|
|
395
|
+
col_names = [row["name"] for row in metadata.get("rowType", [])]
|
|
396
|
+
if not col_names:
|
|
397
|
+
self.log.warning("No column metadata found in the API response.")
|
|
398
|
+
return []
|
|
399
|
+
|
|
400
|
+
num_partitions = len(metadata.get("partitionInfo", []))
|
|
401
|
+
if num_partitions > 1:
|
|
402
|
+
self.log.debug("Result data is returned as multiple partitions. Will perform additional queries.")
|
|
403
|
+
url += "?partition="
|
|
404
|
+
for partition_no in range(1, num_partitions): # First partition was already returned
|
|
405
|
+
self.log.debug("Querying for partition no. %s", partition_no)
|
|
406
|
+
_, response = self._make_api_call_with_retries("GET", url + str(partition_no), header, params)
|
|
407
|
+
data.extend(response.get("data", []))
|
|
408
|
+
|
|
409
|
+
return [dict(zip(col_names, row)) for row in data] # Merged column names with data
|
|
410
|
+
|
|
295
411
|
async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
|
|
296
412
|
"""
|
|
297
413
|
Based on the query id async HTTP request is made to snowflake SQL API and return response.
|
|
@@ -300,10 +416,91 @@ class SnowflakeSqlApiHook(SnowflakeHook):
|
|
|
300
416
|
"""
|
|
301
417
|
self.log.info("Retrieving status for query id %s", query_id)
|
|
302
418
|
header, params, url = self.get_request_url_header_params(query_id)
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
419
|
+
status_code, resp = await self._make_api_call_with_retries_async("GET", url, header, params)
|
|
420
|
+
return self._process_response(status_code, resp)
|
|
421
|
+
|
|
422
|
+
def _cancel_sql_api_query_execution(self, query_id: str) -> dict[str, str | list[str]]:
|
|
423
|
+
self.log.info("Cancelling query id %s", query_id)
|
|
424
|
+
header, params, url = self.get_request_url_header_params(query_id, "/cancel")
|
|
425
|
+
status_code, resp = self._make_api_call_with_retries("POST", url, header, params)
|
|
426
|
+
return self._process_response(status_code, resp)
|
|
427
|
+
|
|
428
|
+
def cancel_queries(self, query_ids: list[str]) -> None:
|
|
429
|
+
for query_id in query_ids:
|
|
430
|
+
self._cancel_sql_api_query_execution(query_id)
|
|
431
|
+
|
|
432
|
+
@staticmethod
|
|
433
|
+
def _should_retry_on_error(exception) -> bool:
|
|
434
|
+
"""
|
|
435
|
+
Determine if the exception should trigger a retry based on error type and status code.
|
|
436
|
+
|
|
437
|
+
Retries on HTTP errors 429 (Too Many Requests), 503 (Service Unavailable),
|
|
438
|
+
and 504 (Gateway Timeout) as recommended by Snowflake error handling docs.
|
|
439
|
+
Retries on connection errors and timeouts.
|
|
440
|
+
|
|
441
|
+
:param exception: The exception to check
|
|
442
|
+
:return: True if the request should be retried, False otherwise
|
|
443
|
+
"""
|
|
444
|
+
if isinstance(exception, HTTPError):
|
|
445
|
+
return exception.response.status_code in [429, 503, 504]
|
|
446
|
+
if isinstance(exception, ClientResponseError):
|
|
447
|
+
return exception.status in [429, 503, 504]
|
|
448
|
+
if isinstance(
|
|
449
|
+
exception,
|
|
450
|
+
ConnectionError | Timeout | ClientConnectionError,
|
|
306
451
|
):
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
452
|
+
return True
|
|
453
|
+
return False
|
|
454
|
+
|
|
455
|
+
def _make_api_call_with_retries(
|
|
456
|
+
self, method: str, url: str, headers: dict, params: dict | None = None, json: dict | None = None
|
|
457
|
+
):
|
|
458
|
+
"""
|
|
459
|
+
Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors.
|
|
460
|
+
|
|
461
|
+
Error handling implemented based on Snowflake error handling docs:
|
|
462
|
+
https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors
|
|
463
|
+
|
|
464
|
+
:param method: The HTTP method to use for the API call.
|
|
465
|
+
:param url: The URL for the API endpoint.
|
|
466
|
+
:param headers: The headers to include in the API call.
|
|
467
|
+
:param params: (Optional) The query parameters to include in the API call.
|
|
468
|
+
:param json: (Optional) The data to include in the API call.
|
|
469
|
+
:return: The response object from the API call.
|
|
470
|
+
"""
|
|
471
|
+
with requests.Session() as session:
|
|
472
|
+
for attempt in Retrying(**self.retry_config): # type: ignore
|
|
473
|
+
with attempt:
|
|
474
|
+
if method.upper() in ("GET", "POST"):
|
|
475
|
+
response = session.request(
|
|
476
|
+
method=method.lower(), url=url, headers=headers, params=params, json=json
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
480
|
+
response.raise_for_status()
|
|
481
|
+
return response.status_code, response.json()
|
|
482
|
+
|
|
483
|
+
async def _make_api_call_with_retries_async(self, method, url, headers, params=None):
|
|
484
|
+
"""
|
|
485
|
+
Make an API call to the Snowflake SQL API asynchronously with retry logic for specific HTTP errors.
|
|
486
|
+
|
|
487
|
+
Error handling implemented based on Snowflake error handling docs:
|
|
488
|
+
https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors
|
|
489
|
+
|
|
490
|
+
:param method: The HTTP method to use for the API call. Only GET is supported as is synchronous.
|
|
491
|
+
:param url: The URL for the API endpoint.
|
|
492
|
+
:param headers: The headers to include in the API call.
|
|
493
|
+
:param params: (Optional) The query parameters to include in the API call.
|
|
494
|
+
:return: The response object from the API call.
|
|
495
|
+
"""
|
|
496
|
+
async with aiohttp.ClientSession(headers=headers) as session:
|
|
497
|
+
async for attempt in AsyncRetrying(**self.retry_config):
|
|
498
|
+
with attempt:
|
|
499
|
+
if method.upper() == "GET":
|
|
500
|
+
async with session.request(method=method.lower(), url=url, params=params) as response:
|
|
501
|
+
response.raise_for_status()
|
|
502
|
+
# Return status and json content for async processing
|
|
503
|
+
content = await response.json()
|
|
504
|
+
return response.status, content
|
|
505
|
+
else:
|
|
506
|
+
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
@@ -20,10 +20,11 @@ from __future__ import annotations
|
|
|
20
20
|
import time
|
|
21
21
|
from collections.abc import Iterable, Mapping, Sequence
|
|
22
22
|
from datetime import timedelta
|
|
23
|
+
from functools import cached_property
|
|
23
24
|
from typing import TYPE_CHECKING, Any, SupportsAbs, cast
|
|
24
25
|
|
|
25
26
|
from airflow.configuration import conf
|
|
26
|
-
from airflow.
|
|
27
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
27
28
|
from airflow.providers.common.sql.operators.sql import (
|
|
28
29
|
SQLCheckOperator,
|
|
29
30
|
SQLExecuteQueryOperator,
|
|
@@ -34,11 +35,7 @@ from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiH
|
|
|
34
35
|
from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger
|
|
35
36
|
|
|
36
37
|
if TYPE_CHECKING:
|
|
37
|
-
|
|
38
|
-
from airflow.sdk.definitions.context import Context
|
|
39
|
-
except ImportError:
|
|
40
|
-
# TODO: Remove once provider drops support for Airflow 2
|
|
41
|
-
from airflow.utils.context import Context
|
|
38
|
+
from airflow.providers.common.compat.sdk import Context
|
|
42
39
|
|
|
43
40
|
|
|
44
41
|
class SnowflakeCheckOperator(SQLCheckOperator):
|
|
@@ -75,8 +72,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
|
|
|
75
72
|
Template references are recognized by str ending in '.sql'
|
|
76
73
|
:param snowflake_conn_id: Reference to
|
|
77
74
|
:ref:`Snowflake connection id<howto/connection:snowflake>`
|
|
78
|
-
:param autocommit: if True, each command is automatically committed.
|
|
79
|
-
(default value: True)
|
|
80
75
|
:param parameters: (optional) the parameters to render the SQL query with.
|
|
81
76
|
:param warehouse: name of warehouse (will overwrite any warehouse
|
|
82
77
|
defined in the connection's extra JSON)
|
|
@@ -108,8 +103,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
|
|
|
108
103
|
sql: str,
|
|
109
104
|
snowflake_conn_id: str = "snowflake_default",
|
|
110
105
|
parameters: Iterable | Mapping[str, Any] | None = None,
|
|
111
|
-
autocommit: bool = True,
|
|
112
|
-
do_xcom_push: bool = True,
|
|
113
106
|
warehouse: str | None = None,
|
|
114
107
|
database: str | None = None,
|
|
115
108
|
role: str | None = None,
|
|
@@ -178,8 +171,6 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
|
|
|
178
171
|
tolerance: Any = None,
|
|
179
172
|
snowflake_conn_id: str = "snowflake_default",
|
|
180
173
|
parameters: Iterable | Mapping[str, Any] | None = None,
|
|
181
|
-
autocommit: bool = True,
|
|
182
|
-
do_xcom_push: bool = True,
|
|
183
174
|
warehouse: str | None = None,
|
|
184
175
|
database: str | None = None,
|
|
185
176
|
role: str | None = None,
|
|
@@ -201,7 +192,12 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
|
|
|
201
192
|
**hook_params,
|
|
202
193
|
}
|
|
203
194
|
super().__init__(
|
|
204
|
-
sql=sql,
|
|
195
|
+
sql=sql,
|
|
196
|
+
pass_value=pass_value,
|
|
197
|
+
tolerance=tolerance,
|
|
198
|
+
conn_id=snowflake_conn_id,
|
|
199
|
+
parameters=parameters,
|
|
200
|
+
**kwargs,
|
|
205
201
|
)
|
|
206
202
|
self.query_ids: list[str] = []
|
|
207
203
|
|
|
@@ -258,9 +254,6 @@ class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
|
|
|
258
254
|
date_filter_column: str = "ds",
|
|
259
255
|
days_back: SupportsAbs[int] = -7,
|
|
260
256
|
snowflake_conn_id: str = "snowflake_default",
|
|
261
|
-
parameters: Iterable | Mapping[str, Any] | None = None,
|
|
262
|
-
autocommit: bool = True,
|
|
263
|
-
do_xcom_push: bool = True,
|
|
264
257
|
warehouse: str | None = None,
|
|
265
258
|
database: str | None = None,
|
|
266
259
|
role: str | None = None,
|
|
@@ -354,6 +347,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
|
|
|
354
347
|
When executing the statement, Snowflake replaces placeholders (? and :name) in
|
|
355
348
|
the statement with these specified values.
|
|
356
349
|
:param deferrable: Run operator in the deferrable mode.
|
|
350
|
+
:param snowflake_api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
|
|
357
351
|
"""
|
|
358
352
|
|
|
359
353
|
LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime
|
|
@@ -380,6 +374,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
|
|
|
380
374
|
token_renewal_delta: timedelta = RENEWAL_DELTA,
|
|
381
375
|
bindings: dict[str, Any] | None = None,
|
|
382
376
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
|
377
|
+
snowflake_api_retry_args: dict[str, Any] | None = None,
|
|
383
378
|
**kwargs: Any,
|
|
384
379
|
) -> None:
|
|
385
380
|
self.snowflake_conn_id = snowflake_conn_id
|
|
@@ -389,7 +384,9 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
|
|
|
389
384
|
self.token_renewal_delta = token_renewal_delta
|
|
390
385
|
self.bindings = bindings
|
|
391
386
|
self.execute_async = False
|
|
387
|
+
self.snowflake_api_retry_args = snowflake_api_retry_args or {}
|
|
392
388
|
self.deferrable = deferrable
|
|
389
|
+
self.query_ids: list[str] = []
|
|
393
390
|
if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover
|
|
394
391
|
hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
|
|
395
392
|
kwargs["hook_params"] = {
|
|
@@ -403,6 +400,17 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
|
|
|
403
400
|
}
|
|
404
401
|
super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover
|
|
405
402
|
|
|
403
|
+
@cached_property
|
|
404
|
+
def _hook(self):
|
|
405
|
+
return SnowflakeSqlApiHook(
|
|
406
|
+
snowflake_conn_id=self.snowflake_conn_id,
|
|
407
|
+
token_life_time=self.token_life_time,
|
|
408
|
+
token_renewal_delta=self.token_renewal_delta,
|
|
409
|
+
deferrable=self.deferrable,
|
|
410
|
+
api_retry_args=self.snowflake_api_retry_args,
|
|
411
|
+
**self.hook_params,
|
|
412
|
+
)
|
|
413
|
+
|
|
406
414
|
def execute(self, context: Context) -> None:
|
|
407
415
|
"""
|
|
408
416
|
Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids.
|
|
@@ -410,15 +418,8 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
|
|
|
410
418
|
By deferring the SnowflakeSqlApiTrigger class passed along with query ids.
|
|
411
419
|
"""
|
|
412
420
|
self.log.info("Executing: %s", self.sql)
|
|
413
|
-
self._hook = SnowflakeSqlApiHook(
|
|
414
|
-
snowflake_conn_id=self.snowflake_conn_id,
|
|
415
|
-
token_life_time=self.token_life_time,
|
|
416
|
-
token_renewal_delta=self.token_renewal_delta,
|
|
417
|
-
deferrable=self.deferrable,
|
|
418
|
-
**self.hook_params,
|
|
419
|
-
)
|
|
420
421
|
self.query_ids = self._hook.execute_query(
|
|
421
|
-
self.sql,
|
|
422
|
+
self.sql,
|
|
422
423
|
statement_count=self.statement_count,
|
|
423
424
|
bindings=self.bindings,
|
|
424
425
|
)
|
|
@@ -504,9 +505,18 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
|
|
|
504
505
|
msg = f"{event['status']}: {event['message']}"
|
|
505
506
|
raise AirflowException(msg)
|
|
506
507
|
if "status" in event and event["status"] == "success":
|
|
507
|
-
|
|
508
|
-
query_ids
|
|
509
|
-
hook.check_query_output(query_ids)
|
|
508
|
+
self.query_ids = cast("list[str]", event["statement_query_ids"])
|
|
509
|
+
self._hook.check_query_output(self.query_ids)
|
|
510
510
|
self.log.info("%s completed successfully.", self.task_id)
|
|
511
|
+
# Re-assign query_ids to hook after coming back from deferral to be consistent for listeners.
|
|
512
|
+
if not self._hook.query_ids:
|
|
513
|
+
self._hook.query_ids = self.query_ids
|
|
511
514
|
else:
|
|
512
515
|
self.log.info("%s completed successfully.", self.task_id)
|
|
516
|
+
|
|
517
|
+
def on_kill(self) -> None:
|
|
518
|
+
"""Cancel the running query."""
|
|
519
|
+
if self.query_ids:
|
|
520
|
+
self.log.info("Cancelling the query ids %s", self.query_ids)
|
|
521
|
+
self._hook.cancel_queries(self.query_ids)
|
|
522
|
+
self.log.info("Query ids %s cancelled successfully", self.query_ids)
|
|
@@ -17,8 +17,8 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
-
from collections.abc import Collection, Mapping, Sequence
|
|
21
|
-
from typing import Any
|
|
20
|
+
from collections.abc import Callable, Collection, Mapping, Sequence
|
|
21
|
+
from typing import Any
|
|
22
22
|
|
|
23
23
|
from airflow.providers.common.compat.standard.operators import PythonOperator, get_current_context
|
|
24
24
|
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
|
|
@@ -22,11 +22,20 @@ from __future__ import annotations
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
23
|
from typing import Any
|
|
24
24
|
|
|
25
|
-
from airflow.
|
|
25
|
+
from airflow.providers.common.compat.sdk import BaseOperator
|
|
26
26
|
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
|
|
27
27
|
from airflow.providers.snowflake.utils.common import enclose_param
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def _validate_parameter(param_name: str, value: str | None) -> str | None:
|
|
31
|
+
"""Validate that the parameter doesn't contain any invalid pattern."""
|
|
32
|
+
if value is None:
|
|
33
|
+
return None
|
|
34
|
+
if ";" in value:
|
|
35
|
+
raise ValueError(f"Invalid {param_name}: semicolons (;) not allowed.")
|
|
36
|
+
return value
|
|
37
|
+
|
|
38
|
+
|
|
30
39
|
class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
|
|
31
40
|
"""
|
|
32
41
|
Executes a COPY INTO command to load files from an external stage from clouds to Snowflake.
|
|
@@ -91,8 +100,8 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
|
|
|
91
100
|
):
|
|
92
101
|
super().__init__(**kwargs)
|
|
93
102
|
self.files = files
|
|
94
|
-
self.table = table
|
|
95
|
-
self.stage = stage
|
|
103
|
+
self.table = _validate_parameter("table", table)
|
|
104
|
+
self.stage = _validate_parameter("stage", stage)
|
|
96
105
|
self.prefix = prefix
|
|
97
106
|
self.file_format = file_format
|
|
98
107
|
self.schema = schema
|
|
@@ -126,7 +135,7 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
|
|
|
126
135
|
if self.schema:
|
|
127
136
|
into = f"{self.schema}.{self.table}"
|
|
128
137
|
else:
|
|
129
|
-
into = self.table
|
|
138
|
+
into = self.table # type: ignore[assignment]
|
|
130
139
|
|
|
131
140
|
if self.columns_array:
|
|
132
141
|
into = f"{into}({', '.join(self.columns_array)})"
|
|
@@ -74,7 +74,6 @@ class SnowflakeSqlApiTrigger(BaseTrigger):
|
|
|
74
74
|
self.token_renewal_delta,
|
|
75
75
|
)
|
|
76
76
|
try:
|
|
77
|
-
statement_query_ids: list[str] = []
|
|
78
77
|
for query_id in self.query_ids:
|
|
79
78
|
while True:
|
|
80
79
|
statement_status = await self.get_query_status(query_id)
|
|
@@ -84,12 +83,10 @@ class SnowflakeSqlApiTrigger(BaseTrigger):
|
|
|
84
83
|
if statement_status["status"] == "error":
|
|
85
84
|
yield TriggerEvent(statement_status)
|
|
86
85
|
return
|
|
87
|
-
if statement_status["status"] == "success":
|
|
88
|
-
statement_query_ids.extend(statement_status["statement_handles"])
|
|
89
86
|
yield TriggerEvent(
|
|
90
87
|
{
|
|
91
88
|
"status": "success",
|
|
92
|
-
"statement_query_ids":
|
|
89
|
+
"statement_query_ids": self.query_ids,
|
|
93
90
|
}
|
|
94
91
|
)
|
|
95
92
|
except Exception as e:
|