apache-airflow-providers-snowflake 6.3.0__py3-none-any.whl → 6.8.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,6 +17,7 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import base64
20
+ import time
20
21
  import uuid
21
22
  import warnings
22
23
  from datetime import timedelta
@@ -25,10 +26,21 @@ from typing import Any
25
26
 
26
27
  import aiohttp
27
28
  import requests
29
+ from aiohttp import ClientConnectionError, ClientResponseError
28
30
  from cryptography.hazmat.backends import default_backend
29
31
  from cryptography.hazmat.primitives import serialization
30
-
31
- from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
32
+ from requests.exceptions import ConnectionError, HTTPError, Timeout
33
+ from tenacity import (
34
+ AsyncRetrying,
35
+ Retrying,
36
+ before_sleep_log,
37
+ retry_if_exception,
38
+ stop_after_attempt,
39
+ wait_exponential,
40
+ )
41
+
42
+ from airflow.exceptions import AirflowProviderDeprecationWarning
43
+ from airflow.providers.common.compat.sdk import AirflowException
32
44
  from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
33
45
  from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator
34
46
 
@@ -65,6 +77,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
65
77
  :param token_life_time: lifetime of the JWT Token in timedelta
66
78
  :param token_renewal_delta: Renewal time of the JWT Token in timedelta
67
79
  :param deferrable: Run operator in the deferrable mode.
80
+ :param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
68
81
  """
69
82
 
70
83
  LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime
@@ -75,15 +88,27 @@ class SnowflakeSqlApiHook(SnowflakeHook):
75
88
  snowflake_conn_id: str,
76
89
  token_life_time: timedelta = LIFETIME,
77
90
  token_renewal_delta: timedelta = RENEWAL_DELTA,
91
+ api_retry_args: dict[Any, Any] | None = None, # Optional retry arguments passed to tenacity.retry
78
92
  *args: Any,
79
93
  **kwargs: Any,
80
94
  ):
81
95
  self.snowflake_conn_id = snowflake_conn_id
82
96
  self.token_life_time = token_life_time
83
97
  self.token_renewal_delta = token_renewal_delta
98
+
84
99
  super().__init__(snowflake_conn_id, *args, **kwargs)
85
100
  self.private_key: Any = None
86
101
 
102
+ self.retry_config = {
103
+ "retry": retry_if_exception(self._should_retry_on_error),
104
+ "wait": wait_exponential(multiplier=1, min=1, max=60),
105
+ "stop": stop_after_attempt(5),
106
+ "before_sleep": before_sleep_log(self.log, log_level=20), # type: ignore[arg-type]
107
+ "reraise": True,
108
+ }
109
+ if api_retry_args:
110
+ self.retry_config.update(api_retry_args)
111
+
87
112
  def get_private_key(self) -> None:
88
113
  """Get the private key from snowflake connection."""
89
114
  conn = self.get_connection(self.snowflake_conn_id)
@@ -137,6 +162,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
137
162
  When executing the statement, Snowflake replaces placeholders (? and :name) in
138
163
  the statement with these specified values.
139
164
  """
165
+ self.query_ids = []
140
166
  conn_config = self._get_conn_params
141
167
 
142
168
  req_id = uuid.uuid4()
@@ -167,13 +193,8 @@ class SnowflakeSqlApiHook(SnowflakeHook):
167
193
  "query_tag": query_tag,
168
194
  },
169
195
  }
170
- response = requests.post(url, json=data, headers=headers, params=params)
171
- try:
172
- response.raise_for_status()
173
- except requests.exceptions.HTTPError as e: # pragma: no cover
174
- msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
175
- raise AirflowException(msg)
176
- json_response = response.json()
196
+
197
+ _, json_response = self._make_api_call_with_retries("POST", url, headers, params, data)
177
198
  self.log.info("Snowflake SQL POST API response: %s", json_response)
178
199
  if "statementHandles" in json_response:
179
200
  self.query_ids = json_response["statementHandles"]
@@ -222,25 +243,37 @@ class SnowflakeSqlApiHook(SnowflakeHook):
222
243
  }
223
244
  return headers
224
245
 
225
- def get_oauth_token(self, conn_config: dict[str, Any] | None = None) -> str:
246
+ def get_oauth_token(
247
+ self,
248
+ conn_config: dict[str, Any] | None = None,
249
+ token_endpoint: str | None = None,
250
+ grant_type: str = "refresh_token",
251
+ ) -> str:
226
252
  """Generate temporary OAuth access token using refresh token in connection details."""
227
253
  warnings.warn(
228
254
  "This method is deprecated. Please use `get_oauth_token` method from `SnowflakeHook` instead. ",
229
255
  AirflowProviderDeprecationWarning,
230
256
  stacklevel=2,
231
257
  )
232
- return super().get_oauth_token(conn_config=conn_config)
258
+ return super().get_oauth_token(
259
+ conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type
260
+ )
233
261
 
234
- def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
262
+ def get_request_url_header_params(
263
+ self, query_id: str, url_suffix: str | None = None
264
+ ) -> tuple[dict[str, Any], dict[str, Any], str]:
235
265
  """
236
266
  Build the request header Url with account name identifier and query id from the connection params.
237
267
 
238
268
  :param query_id: statement handles query ids for the individual statements.
269
+ :param url_suffix: Optional path suffix to append to the URL. Must start with '/', e.g. '/cancel' or '/result'.
239
270
  """
240
271
  req_id = uuid.uuid4()
241
272
  header = self.get_headers()
242
273
  params = {"requestId": str(req_id)}
243
274
  url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements/{query_id}"
275
+ if url_suffix:
276
+ url += url_suffix
244
277
  return header, params, url
245
278
 
246
279
  def check_query_output(self, query_ids: list[str]) -> None:
@@ -251,20 +284,31 @@ class SnowflakeSqlApiHook(SnowflakeHook):
251
284
  """
252
285
  for query_id in query_ids:
253
286
  header, params, url = self.get_request_url_header_params(query_id)
254
- try:
255
- response = requests.get(url, headers=header, params=params)
256
- response.raise_for_status()
257
- self.log.info(response.json())
258
- except requests.exceptions.HTTPError as e:
259
- msg = f"Response: {e.response.content.decode()}, Status Code: {e.response.status_code}"
260
- raise AirflowException(msg)
287
+ _, response_json = self._make_api_call_with_retries(
288
+ method="GET", url=url, headers=header, params=params
289
+ )
290
+ self.log.info(response_json)
261
291
 
262
292
  def _process_response(self, status_code, resp):
263
293
  self.log.info("Snowflake SQL GET statements status API response: %s", resp)
264
294
  if status_code == 202:
265
295
  return {"status": "running", "message": "Query statements are still running"}
266
296
  if status_code == 422:
267
- return {"status": "error", "message": resp["message"]}
297
+ error_message = resp.get("message", "Unknown error occurred")
298
+ error_details = []
299
+ if code := resp.get("code"):
300
+ error_details.append(f"Code: {code}")
301
+ if sql_state := resp.get("sqlState"):
302
+ error_details.append(f"SQL State: {sql_state}")
303
+ if statement_handle := resp.get("statementHandle"):
304
+ error_details.append(f"Statement Handle: {statement_handle}")
305
+
306
+ if error_details:
307
+ enhanced_message = f"{error_message} ({', '.join(error_details)})"
308
+ else:
309
+ enhanced_message = error_message
310
+
311
+ return {"status": "error", "message": enhanced_message}
268
312
  if status_code == 200:
269
313
  if resp_statement_handles := resp.get("statementHandles"):
270
314
  statement_handles = resp_statement_handles
@@ -287,11 +331,83 @@ class SnowflakeSqlApiHook(SnowflakeHook):
287
331
  """
288
332
  self.log.info("Retrieving status for query id %s", query_id)
289
333
  header, params, url = self.get_request_url_header_params(query_id)
290
- response = requests.get(url, params=params, headers=header)
291
- status_code = response.status_code
292
- resp = response.json()
334
+ status_code, resp = self._make_api_call_with_retries("GET", url, header, params)
293
335
  return self._process_response(status_code, resp)
294
336
 
337
+ def wait_for_query(
338
+ self, query_id: str, raise_error: bool = False, poll_interval: int = 5, timeout: int = 60
339
+ ) -> dict[str, str | list[str]]:
340
+ """
341
+ Wait for query to finish either successfully or with error.
342
+
343
+ :param query_id: statement handle id for the individual statement.
344
+ :param raise_error: whether to raise an error if the query failed.
345
+ :param poll_interval: time (in seconds) between checking the query status.
346
+ :param timeout: max time (in seconds) to wait for the query to finish before raising a TimeoutError.
347
+
348
+ :raises RuntimeError: If the query status is 'error' and `raise_error` is True.
349
+ :raises TimeoutError: If the query doesn't finish within the specified timeout.
350
+ """
351
+ start_time = time.time()
352
+
353
+ while True:
354
+ response = self.get_sql_api_query_status(query_id=query_id)
355
+ self.log.debug("Query status `%s`", response["status"])
356
+
357
+ if time.time() - start_time > timeout:
358
+ raise TimeoutError(
359
+ f"Query `{query_id}` did not finish within the timeout period of {timeout} seconds."
360
+ )
361
+
362
+ if response["status"] != "running":
363
+ self.log.info("Query status `%s`", response["status"])
364
+ break
365
+
366
+ time.sleep(poll_interval)
367
+
368
+ if response["status"] == "error" and raise_error:
369
+ raise RuntimeError(response["message"])
370
+
371
+ return response
372
+
373
+ def get_result_from_successful_sql_api_query(self, query_id: str) -> list[dict[str, Any]]:
374
+ """
375
+ Based on the query id HTTP requests are made to snowflake SQL API and return result data.
376
+
377
+ :param query_id: statement handle id for the individual statement.
378
+
379
+ :raises RuntimeError: If the query status is not 'success'.
380
+ """
381
+ self.log.info("Retrieving data for query id %s", query_id)
382
+ header, params, url = self.get_request_url_header_params(query_id)
383
+ status_code, response = self._make_api_call_with_retries("GET", url, header, params)
384
+
385
+ if (query_status := self._process_response(status_code, response)["status"]) != "success":
386
+ msg = f"Query must have status `success` to retrieve data; got `{query_status}`."
387
+ raise RuntimeError(msg)
388
+
389
+ # Below fields should always be present in response, but added some safety checks
390
+ data = response.get("data", [])
391
+ if not data:
392
+ self.log.warning("No data found in the API response.")
393
+ return []
394
+ metadata = response.get("resultSetMetaData", {})
395
+ col_names = [row["name"] for row in metadata.get("rowType", [])]
396
+ if not col_names:
397
+ self.log.warning("No column metadata found in the API response.")
398
+ return []
399
+
400
+ num_partitions = len(metadata.get("partitionInfo", []))
401
+ if num_partitions > 1:
402
+ self.log.debug("Result data is returned as multiple partitions. Will perform additional queries.")
403
+ url += "?partition="
404
+ for partition_no in range(1, num_partitions): # First partition was already returned
405
+ self.log.debug("Querying for partition no. %s", partition_no)
406
+ _, response = self._make_api_call_with_retries("GET", url + str(partition_no), header, params)
407
+ data.extend(response.get("data", []))
408
+
409
+ return [dict(zip(col_names, row)) for row in data] # Merged column names with data
410
+
295
411
  async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
296
412
  """
297
413
  Based on the query id async HTTP request is made to snowflake SQL API and return response.
@@ -300,10 +416,91 @@ class SnowflakeSqlApiHook(SnowflakeHook):
300
416
  """
301
417
  self.log.info("Retrieving status for query id %s", query_id)
302
418
  header, params, url = self.get_request_url_header_params(query_id)
303
- async with (
304
- aiohttp.ClientSession(headers=header) as session,
305
- session.get(url, params=params) as response,
419
+ status_code, resp = await self._make_api_call_with_retries_async("GET", url, header, params)
420
+ return self._process_response(status_code, resp)
421
+
422
+ def _cancel_sql_api_query_execution(self, query_id: str) -> dict[str, str | list[str]]:
423
+ self.log.info("Cancelling query id %s", query_id)
424
+ header, params, url = self.get_request_url_header_params(query_id, "/cancel")
425
+ status_code, resp = self._make_api_call_with_retries("POST", url, header, params)
426
+ return self._process_response(status_code, resp)
427
+
428
+ def cancel_queries(self, query_ids: list[str]) -> None:
429
+ for query_id in query_ids:
430
+ self._cancel_sql_api_query_execution(query_id)
431
+
432
+ @staticmethod
433
+ def _should_retry_on_error(exception) -> bool:
434
+ """
435
+ Determine if the exception should trigger a retry based on error type and status code.
436
+
437
+ Retries on HTTP errors 429 (Too Many Requests), 503 (Service Unavailable),
438
+ and 504 (Gateway Timeout) as recommended by Snowflake error handling docs.
439
+ Retries on connection errors and timeouts.
440
+
441
+ :param exception: The exception to check
442
+ :return: True if the request should be retried, False otherwise
443
+ """
444
+ if isinstance(exception, HTTPError):
445
+ return exception.response.status_code in [429, 503, 504]
446
+ if isinstance(exception, ClientResponseError):
447
+ return exception.status in [429, 503, 504]
448
+ if isinstance(
449
+ exception,
450
+ ConnectionError | Timeout | ClientConnectionError,
306
451
  ):
307
- status_code = response.status
308
- resp = await response.json()
309
- return self._process_response(status_code, resp)
452
+ return True
453
+ return False
454
+
455
+ def _make_api_call_with_retries(
456
+ self, method: str, url: str, headers: dict, params: dict | None = None, json: dict | None = None
457
+ ):
458
+ """
459
+ Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors.
460
+
461
+ Error handling implemented based on Snowflake error handling docs:
462
+ https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors
463
+
464
+ :param method: The HTTP method to use for the API call.
465
+ :param url: The URL for the API endpoint.
466
+ :param headers: The headers to include in the API call.
467
+ :param params: (Optional) The query parameters to include in the API call.
468
+ :param json: (Optional) The data to include in the API call.
469
+ :return: The response object from the API call.
470
+ """
471
+ with requests.Session() as session:
472
+ for attempt in Retrying(**self.retry_config): # type: ignore
473
+ with attempt:
474
+ if method.upper() in ("GET", "POST"):
475
+ response = session.request(
476
+ method=method.lower(), url=url, headers=headers, params=params, json=json
477
+ )
478
+ else:
479
+ raise ValueError(f"Unsupported HTTP method: {method}")
480
+ response.raise_for_status()
481
+ return response.status_code, response.json()
482
+
483
+ async def _make_api_call_with_retries_async(self, method, url, headers, params=None):
484
+ """
485
+ Make an API call to the Snowflake SQL API asynchronously with retry logic for specific HTTP errors.
486
+
487
+ Error handling implemented based on Snowflake error handling docs:
488
+ https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors
489
+
490
+ :param method: The HTTP method to use for the API call. Only GET is supported as is synchronous.
491
+ :param url: The URL for the API endpoint.
492
+ :param headers: The headers to include in the API call.
493
+ :param params: (Optional) The query parameters to include in the API call.
494
+ :return: The response object from the API call.
495
+ """
496
+ async with aiohttp.ClientSession(headers=headers) as session:
497
+ async for attempt in AsyncRetrying(**self.retry_config):
498
+ with attempt:
499
+ if method.upper() == "GET":
500
+ async with session.request(method=method.lower(), url=url, params=params) as response:
501
+ response.raise_for_status()
502
+ # Return status and json content for async processing
503
+ content = await response.json()
504
+ return response.status, content
505
+ else:
506
+ raise ValueError(f"Unsupported HTTP method: {method}")
@@ -20,10 +20,11 @@ from __future__ import annotations
20
20
  import time
21
21
  from collections.abc import Iterable, Mapping, Sequence
22
22
  from datetime import timedelta
23
+ from functools import cached_property
23
24
  from typing import TYPE_CHECKING, Any, SupportsAbs, cast
24
25
 
25
26
  from airflow.configuration import conf
26
- from airflow.exceptions import AirflowException
27
+ from airflow.providers.common.compat.sdk import AirflowException
27
28
  from airflow.providers.common.sql.operators.sql import (
28
29
  SQLCheckOperator,
29
30
  SQLExecuteQueryOperator,
@@ -34,11 +35,7 @@ from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiH
34
35
  from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger
35
36
 
36
37
  if TYPE_CHECKING:
37
- try:
38
- from airflow.sdk.definitions.context import Context
39
- except ImportError:
40
- # TODO: Remove once provider drops support for Airflow 2
41
- from airflow.utils.context import Context
38
+ from airflow.providers.common.compat.sdk import Context
42
39
 
43
40
 
44
41
  class SnowflakeCheckOperator(SQLCheckOperator):
@@ -75,8 +72,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
75
72
  Template references are recognized by str ending in '.sql'
76
73
  :param snowflake_conn_id: Reference to
77
74
  :ref:`Snowflake connection id<howto/connection:snowflake>`
78
- :param autocommit: if True, each command is automatically committed.
79
- (default value: True)
80
75
  :param parameters: (optional) the parameters to render the SQL query with.
81
76
  :param warehouse: name of warehouse (will overwrite any warehouse
82
77
  defined in the connection's extra JSON)
@@ -108,8 +103,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
108
103
  sql: str,
109
104
  snowflake_conn_id: str = "snowflake_default",
110
105
  parameters: Iterable | Mapping[str, Any] | None = None,
111
- autocommit: bool = True,
112
- do_xcom_push: bool = True,
113
106
  warehouse: str | None = None,
114
107
  database: str | None = None,
115
108
  role: str | None = None,
@@ -178,8 +171,6 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
178
171
  tolerance: Any = None,
179
172
  snowflake_conn_id: str = "snowflake_default",
180
173
  parameters: Iterable | Mapping[str, Any] | None = None,
181
- autocommit: bool = True,
182
- do_xcom_push: bool = True,
183
174
  warehouse: str | None = None,
184
175
  database: str | None = None,
185
176
  role: str | None = None,
@@ -201,7 +192,12 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
201
192
  **hook_params,
202
193
  }
203
194
  super().__init__(
204
- sql=sql, pass_value=pass_value, tolerance=tolerance, conn_id=snowflake_conn_id, **kwargs
195
+ sql=sql,
196
+ pass_value=pass_value,
197
+ tolerance=tolerance,
198
+ conn_id=snowflake_conn_id,
199
+ parameters=parameters,
200
+ **kwargs,
205
201
  )
206
202
  self.query_ids: list[str] = []
207
203
 
@@ -258,9 +254,6 @@ class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
258
254
  date_filter_column: str = "ds",
259
255
  days_back: SupportsAbs[int] = -7,
260
256
  snowflake_conn_id: str = "snowflake_default",
261
- parameters: Iterable | Mapping[str, Any] | None = None,
262
- autocommit: bool = True,
263
- do_xcom_push: bool = True,
264
257
  warehouse: str | None = None,
265
258
  database: str | None = None,
266
259
  role: str | None = None,
@@ -354,6 +347,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
354
347
  When executing the statement, Snowflake replaces placeholders (? and :name) in
355
348
  the statement with these specified values.
356
349
  :param deferrable: Run operator in the deferrable mode.
350
+ :param snowflake_api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
357
351
  """
358
352
 
359
353
  LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime
@@ -380,6 +374,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
380
374
  token_renewal_delta: timedelta = RENEWAL_DELTA,
381
375
  bindings: dict[str, Any] | None = None,
382
376
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
377
+ snowflake_api_retry_args: dict[str, Any] | None = None,
383
378
  **kwargs: Any,
384
379
  ) -> None:
385
380
  self.snowflake_conn_id = snowflake_conn_id
@@ -389,7 +384,9 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
389
384
  self.token_renewal_delta = token_renewal_delta
390
385
  self.bindings = bindings
391
386
  self.execute_async = False
387
+ self.snowflake_api_retry_args = snowflake_api_retry_args or {}
392
388
  self.deferrable = deferrable
389
+ self.query_ids: list[str] = []
393
390
  if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover
394
391
  hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
395
392
  kwargs["hook_params"] = {
@@ -403,6 +400,17 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
403
400
  }
404
401
  super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover
405
402
 
403
+ @cached_property
404
+ def _hook(self):
405
+ return SnowflakeSqlApiHook(
406
+ snowflake_conn_id=self.snowflake_conn_id,
407
+ token_life_time=self.token_life_time,
408
+ token_renewal_delta=self.token_renewal_delta,
409
+ deferrable=self.deferrable,
410
+ api_retry_args=self.snowflake_api_retry_args,
411
+ **self.hook_params,
412
+ )
413
+
406
414
  def execute(self, context: Context) -> None:
407
415
  """
408
416
  Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids.
@@ -410,15 +418,8 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
410
418
  By deferring the SnowflakeSqlApiTrigger class passed along with query ids.
411
419
  """
412
420
  self.log.info("Executing: %s", self.sql)
413
- self._hook = SnowflakeSqlApiHook(
414
- snowflake_conn_id=self.snowflake_conn_id,
415
- token_life_time=self.token_life_time,
416
- token_renewal_delta=self.token_renewal_delta,
417
- deferrable=self.deferrable,
418
- **self.hook_params,
419
- )
420
421
  self.query_ids = self._hook.execute_query(
421
- self.sql, # type: ignore[arg-type]
422
+ self.sql,
422
423
  statement_count=self.statement_count,
423
424
  bindings=self.bindings,
424
425
  )
@@ -504,9 +505,18 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
504
505
  msg = f"{event['status']}: {event['message']}"
505
506
  raise AirflowException(msg)
506
507
  if "status" in event and event["status"] == "success":
507
- hook = SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id)
508
- query_ids = cast("list[str]", event["statement_query_ids"])
509
- hook.check_query_output(query_ids)
508
+ self.query_ids = cast("list[str]", event["statement_query_ids"])
509
+ self._hook.check_query_output(self.query_ids)
510
510
  self.log.info("%s completed successfully.", self.task_id)
511
+ # Re-assign query_ids to hook after coming back from deferral to be consistent for listeners.
512
+ if not self._hook.query_ids:
513
+ self._hook.query_ids = self.query_ids
511
514
  else:
512
515
  self.log.info("%s completed successfully.", self.task_id)
516
+
517
+ def on_kill(self) -> None:
518
+ """Cancel the running query."""
519
+ if self.query_ids:
520
+ self.log.info("Cancelling the query ids %s", self.query_ids)
521
+ self._hook.cancel_queries(self.query_ids)
522
+ self.log.info("Query ids %s cancelled successfully", self.query_ids)
@@ -17,8 +17,8 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from collections.abc import Collection, Mapping, Sequence
21
- from typing import Any, Callable
20
+ from collections.abc import Callable, Collection, Mapping, Sequence
21
+ from typing import Any
22
22
 
23
23
  from airflow.providers.common.compat.standard.operators import PythonOperator, get_current_context
24
24
  from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
@@ -22,11 +22,20 @@ from __future__ import annotations
22
22
  from collections.abc import Sequence
23
23
  from typing import Any
24
24
 
25
- from airflow.models import BaseOperator
25
+ from airflow.providers.common.compat.sdk import BaseOperator
26
26
  from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
27
27
  from airflow.providers.snowflake.utils.common import enclose_param
28
28
 
29
29
 
30
+ def _validate_parameter(param_name: str, value: str | None) -> str | None:
31
+ """Validate that the parameter doesn't contain any invalid pattern."""
32
+ if value is None:
33
+ return None
34
+ if ";" in value:
35
+ raise ValueError(f"Invalid {param_name}: semicolons (;) not allowed.")
36
+ return value
37
+
38
+
30
39
  class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
31
40
  """
32
41
  Executes a COPY INTO command to load files from an external stage from clouds to Snowflake.
@@ -91,8 +100,8 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
91
100
  ):
92
101
  super().__init__(**kwargs)
93
102
  self.files = files
94
- self.table = table
95
- self.stage = stage
103
+ self.table = _validate_parameter("table", table)
104
+ self.stage = _validate_parameter("stage", stage)
96
105
  self.prefix = prefix
97
106
  self.file_format = file_format
98
107
  self.schema = schema
@@ -126,7 +135,7 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
126
135
  if self.schema:
127
136
  into = f"{self.schema}.{self.table}"
128
137
  else:
129
- into = self.table
138
+ into = self.table # type: ignore[assignment]
130
139
 
131
140
  if self.columns_array:
132
141
  into = f"{into}({', '.join(self.columns_array)})"
@@ -74,7 +74,6 @@ class SnowflakeSqlApiTrigger(BaseTrigger):
74
74
  self.token_renewal_delta,
75
75
  )
76
76
  try:
77
- statement_query_ids: list[str] = []
78
77
  for query_id in self.query_ids:
79
78
  while True:
80
79
  statement_status = await self.get_query_status(query_id)
@@ -84,12 +83,10 @@ class SnowflakeSqlApiTrigger(BaseTrigger):
84
83
  if statement_status["status"] == "error":
85
84
  yield TriggerEvent(statement_status)
86
85
  return
87
- if statement_status["status"] == "success":
88
- statement_query_ids.extend(statement_status["statement_handles"])
89
86
  yield TriggerEvent(
90
87
  {
91
88
  "status": "success",
92
- "statement_query_ids": statement_query_ids,
89
+ "statement_query_ids": self.query_ids,
93
90
  }
94
91
  )
95
92
  except Exception as e: