apache-airflow-providers-snowflake 6.3.1rc1__py3-none-any.whl → 6.4.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-snowflake might be problematic. Click here for more details.

@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "6.3.1"
32
+ __version__ = "6.4.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.10.0"
@@ -136,6 +136,9 @@ class SnowflakeHook(DbApiHook):
136
136
  "session_parameters": "session parameters",
137
137
  "client_request_mfa_token": "client request mfa token",
138
138
  "client_store_temporary_credential": "client store temporary credential (externalbrowser mode)",
139
+ "grant_type": "refresh_token client_credentials",
140
+ "token_endpoint": "token endpoint",
141
+ "refresh_token": "refresh token",
139
142
  },
140
143
  indent=1,
141
144
  ),
@@ -200,18 +203,32 @@ class SnowflakeHook(DbApiHook):
200
203
 
201
204
  return account_identifier
202
205
 
203
- def get_oauth_token(self, conn_config: dict | None = None) -> str:
206
+ def get_oauth_token(
207
+ self,
208
+ conn_config: dict | None = None,
209
+ token_endpoint: str | None = None,
210
+ grant_type: str = "refresh_token",
211
+ ) -> str:
204
212
  """Generate temporary OAuth access token using refresh token in connection details."""
205
213
  if conn_config is None:
206
214
  conn_config = self._get_conn_params
207
215
 
208
- url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
216
+ url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
209
217
 
210
218
  data = {
211
- "grant_type": "refresh_token",
212
- "refresh_token": conn_config["refresh_token"],
219
+ "grant_type": grant_type,
213
220
  "redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
214
221
  }
222
+
223
+ if grant_type == "refresh_token":
224
+ data |= {
225
+ "refresh_token": conn_config["refresh_token"],
226
+ }
227
+ elif grant_type == "client_credentials":
228
+ pass # no setup necessary for client credentials grant.
229
+ else:
230
+ raise ValueError(f"Unknown grant_type: {grant_type}")
231
+
215
232
  response = requests.post(
216
233
  url,
217
234
  data=data,
@@ -226,7 +243,8 @@ class SnowflakeHook(DbApiHook):
226
243
  except requests.exceptions.HTTPError as e: # pragma: no cover
227
244
  msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
228
245
  raise AirflowException(msg)
229
- return response.json()["access_token"]
246
+ token = response.json()["access_token"]
247
+ return token
230
248
 
231
249
  @cached_property
232
250
  def _get_conn_params(self) -> dict[str, str | None]:
@@ -329,14 +347,21 @@ class SnowflakeHook(DbApiHook):
329
347
  if refresh_token:
330
348
  conn_config["refresh_token"] = refresh_token
331
349
  conn_config["authenticator"] = "oauth"
350
+
351
+ if conn_config.get("authenticator") == "oauth":
352
+ token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
332
353
  conn_config["client_id"] = conn.login
333
354
  conn_config["client_secret"] = conn.password
355
+ conn_config["token"] = self.get_oauth_token(
356
+ conn_config=conn_config,
357
+ token_endpoint=token_endpoint,
358
+ grant_type=extra_dict.get("grant_type", "refresh_token"),
359
+ )
360
+
334
361
  conn_config.pop("login", None)
335
362
  conn_config.pop("user", None)
336
363
  conn_config.pop("password", None)
337
364
 
338
- conn_config["token"] = self.get_oauth_token(conn_config=conn_config)
339
-
340
365
  # configure custom target hostname and port, if specified
341
366
  snowflake_host = extra_dict.get("host")
342
367
  snowflake_port = extra_dict.get("port")
@@ -137,6 +137,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
137
137
  When executing the statement, Snowflake replaces placeholders (? and :name) in
138
138
  the statement with these specified values.
139
139
  """
140
+ self.query_ids = []
140
141
  conn_config = self._get_conn_params
141
142
 
142
143
  req_id = uuid.uuid4()
@@ -222,14 +223,21 @@ class SnowflakeSqlApiHook(SnowflakeHook):
222
223
  }
223
224
  return headers
224
225
 
225
- def get_oauth_token(self, conn_config: dict[str, Any] | None = None) -> str:
226
+ def get_oauth_token(
227
+ self,
228
+ conn_config: dict[str, Any] | None = None,
229
+ token_endpoint: str | None = None,
230
+ grant_type: str = "refresh_token",
231
+ ) -> str:
226
232
  """Generate temporary OAuth access token using refresh token in connection details."""
227
233
  warnings.warn(
228
234
  "This method is deprecated. Please use `get_oauth_token` method from `SnowflakeHook` instead. ",
229
235
  AirflowProviderDeprecationWarning,
230
236
  stacklevel=2,
231
237
  )
232
- return super().get_oauth_token(conn_config=conn_config)
238
+ return super().get_oauth_token(
239
+ conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type
240
+ )
233
241
 
234
242
  def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
235
243
  """
@@ -20,6 +20,7 @@ from __future__ import annotations
20
20
  import time
21
21
  from collections.abc import Iterable, Mapping, Sequence
22
22
  from datetime import timedelta
23
+ from functools import cached_property
23
24
  from typing import TYPE_CHECKING, Any, SupportsAbs, cast
24
25
 
25
26
  from airflow.configuration import conf
@@ -390,6 +391,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
390
391
  self.bindings = bindings
391
392
  self.execute_async = False
392
393
  self.deferrable = deferrable
394
+ self.query_ids: list[str] = []
393
395
  if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover
394
396
  hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
395
397
  kwargs["hook_params"] = {
@@ -403,6 +405,16 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
403
405
  }
404
406
  super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover
405
407
 
408
+ @cached_property
409
+ def _hook(self):
410
+ return SnowflakeSqlApiHook(
411
+ snowflake_conn_id=self.snowflake_conn_id,
412
+ token_life_time=self.token_life_time,
413
+ token_renewal_delta=self.token_renewal_delta,
414
+ deferrable=self.deferrable,
415
+ **self.hook_params,
416
+ )
417
+
406
418
  def execute(self, context: Context) -> None:
407
419
  """
408
420
  Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids.
@@ -410,13 +422,6 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
410
422
  By deferring the SnowflakeSqlApiTrigger class passed along with query ids.
411
423
  """
412
424
  self.log.info("Executing: %s", self.sql)
413
- self._hook = SnowflakeSqlApiHook(
414
- snowflake_conn_id=self.snowflake_conn_id,
415
- token_life_time=self.token_life_time,
416
- token_renewal_delta=self.token_renewal_delta,
417
- deferrable=self.deferrable,
418
- **self.hook_params,
419
- )
420
425
  self.query_ids = self._hook.execute_query(
421
426
  self.sql, # type: ignore[arg-type]
422
427
  statement_count=self.statement_count,
@@ -504,9 +509,11 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
504
509
  msg = f"{event['status']}: {event['message']}"
505
510
  raise AirflowException(msg)
506
511
  if "status" in event and event["status"] == "success":
507
- hook = SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id)
508
- query_ids = cast("list[str]", event["statement_query_ids"])
509
- hook.check_query_output(query_ids)
512
+ self.query_ids = cast("list[str]", event["statement_query_ids"])
513
+ self._hook.check_query_output(self.query_ids)
510
514
  self.log.info("%s completed successfully.", self.task_id)
515
+ # Re-assign query_ids to hook after coming back from deferral to be consistent for listeners.
516
+ if not self._hook.query_ids:
517
+ self._hook.query_ids = self.query_ids
511
518
  else:
512
519
  self.log.info("%s completed successfully.", self.task_id)
@@ -27,6 +27,15 @@ from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
27
27
  from airflow.providers.snowflake.utils.common import enclose_param
28
28
 
29
29
 
30
+ def _validate_parameter(param_name: str, value: str | None) -> str | None:
31
+ """Validate that the parameter doesn't contain any invalid pattern."""
32
+ if value is None:
33
+ return None
34
+ if ";" in value:
35
+ raise ValueError(f"Invalid {param_name}: semicolons (;) not allowed.")
36
+ return value
37
+
38
+
30
39
  class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
31
40
  """
32
41
  Executes a COPY INTO command to load files from an external stage from clouds to Snowflake.
@@ -91,8 +100,8 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
91
100
  ):
92
101
  super().__init__(**kwargs)
93
102
  self.files = files
94
- self.table = table
95
- self.stage = stage
103
+ self.table = _validate_parameter("table", table)
104
+ self.stage = _validate_parameter("stage", stage)
96
105
  self.prefix = prefix
97
106
  self.file_format = file_format
98
107
  self.schema = schema
@@ -126,7 +135,7 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator):
126
135
  if self.schema:
127
136
  into = f"{self.schema}.{self.table}"
128
137
  else:
129
- into = self.table
138
+ into = self.table # type: ignore[assignment]
130
139
 
131
140
  if self.columns_array:
132
141
  into = f"{into}({', '.join(self.columns_array)})"
@@ -74,7 +74,6 @@ class SnowflakeSqlApiTrigger(BaseTrigger):
74
74
  self.token_renewal_delta,
75
75
  )
76
76
  try:
77
- statement_query_ids: list[str] = []
78
77
  for query_id in self.query_ids:
79
78
  while True:
80
79
  statement_status = await self.get_query_status(query_id)
@@ -84,12 +83,10 @@ class SnowflakeSqlApiTrigger(BaseTrigger):
84
83
  if statement_status["status"] == "error":
85
84
  yield TriggerEvent(statement_status)
86
85
  return
87
- if statement_status["status"] == "success":
88
- statement_query_ids.extend(statement_status["statement_handles"])
89
86
  yield TriggerEvent(
90
87
  {
91
88
  "status": "success",
92
- "statement_query_ids": statement_query_ids,
89
+ "statement_query_ids": self.query_ids,
93
90
  }
94
91
  )
95
92
  except Exception as e:
@@ -52,7 +52,15 @@ def fix_account_name(name: str) -> str:
52
52
  account, region = spl
53
53
  cloud = "aws"
54
54
  else:
55
- account, region, cloud = spl
55
+ # region can easily get duplicated without crashing snowflake, so we need to handle that as well
56
+ # eg. account_locator.europe-west3.gcp.europe-west3.gcp will be ok for snowflake
57
+ account, region, cloud, *rest = spl
58
+ rest = [x for x in rest if x not in (region, cloud)]
59
+ if rest: # Not sure what could be left here, but leaving this just in case
60
+ log.warning(
61
+ "Unexpected parts found in Snowflake uri hostname and will be ignored by OpenLineage: %s",
62
+ rest,
63
+ )
56
64
  return f"{account}.{region}.{cloud}"
57
65
 
58
66
  # Check for existing accounts with cloud names
@@ -72,13 +80,16 @@ def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
72
80
  """
73
81
  Fix snowflake sqlalchemy connection URI to OpenLineage structure.
74
82
 
75
- Snowflake sqlalchemy connection URI has following structure:
83
+ Snowflake sqlalchemy connection URI has the following structure:
76
84
  'snowflake://<user_login_name>:<password>@<account_identifier>/<database_name>/<schema_name>?warehouse=<warehouse_name>&role=<role_name>'
77
85
  We want account identifier normalized. It can have two forms:
78
- - newer, in form of <organization>-<id>. In this case we want to do nothing.
79
- - older, composed of <id>-<region>-<cloud> where region and cloud can be
86
+ - newer, in form of <organization_id>-<account_id>. In this case we want to do nothing.
87
+ - older, composed of <account_locator>.<region>.<cloud> where region and cloud can be
80
88
  optional in some cases. If <cloud> is omitted, it's AWS.
81
89
  If region and cloud are omitted, it's AWS us-west-1
90
+
91
+ Current doc on Snowflake account identifiers:
92
+ https://docs.snowflake.com/en/user-guide/admin-account-identifier
82
93
  """
83
94
  try:
84
95
  parts = urlparse(uri)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: apache-airflow-providers-snowflake
3
- Version: 6.3.1rc1
3
+ Version: 6.4.0rc1
4
4
  Summary: Provider package apache-airflow-providers-snowflake for Apache Airflow
5
5
  Keywords: airflow-provider,snowflake,airflow,integration
6
6
  Author-email: Apache Software Foundation <dev@airflow.apache.org>
@@ -30,8 +30,8 @@ Requires-Dist: snowflake-sqlalchemy>=1.4.0
30
30
  Requires-Dist: snowflake-snowpark-python>=1.17.0;python_version<'3.12'
31
31
  Requires-Dist: apache-airflow-providers-openlineage>=2.3.0rc1 ; extra == "openlineage"
32
32
  Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
33
- Project-URL: Changelog, https://airflow.staged.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/changelog.html
34
- Project-URL: Documentation, https://airflow.staged.apache.org/docs/apache-airflow-providers-snowflake/6.3.1
33
+ Project-URL: Changelog, https://airflow.staged.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/changelog.html
34
+ Project-URL: Documentation, https://airflow.staged.apache.org/docs/apache-airflow-providers-snowflake/6.4.0
35
35
  Project-URL: Mastodon, https://fosstodon.org/@airflow
36
36
  Project-URL: Slack Chat, https://s.apache.org/airflow-slack
37
37
  Project-URL: Source Code, https://github.com/apache/airflow
@@ -63,7 +63,7 @@ Provides-Extra: openlineage
63
63
 
64
64
  Package ``apache-airflow-providers-snowflake``
65
65
 
66
- Release: ``6.3.1``
66
+ Release: ``6.4.0``
67
67
 
68
68
 
69
69
  `Snowflake <https://www.snowflake.com/>`__
@@ -76,7 +76,7 @@ This is a provider package for ``snowflake`` provider. All classes for this prov
76
76
  are in ``airflow.providers.snowflake`` python package.
77
77
 
78
78
  You can find package information and changelog for the provider
79
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/>`_.
79
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/>`_.
80
80
 
81
81
  Installation
82
82
  ------------
@@ -125,5 +125,5 @@ Dependent package
125
125
  ================================================================================================================== =================
126
126
 
127
127
  The changelog for the provider package can be found in the
128
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.3.1/changelog.html>`_.
128
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-snowflake/6.4.0/changelog.html>`_.
129
129
 
@@ -1,25 +1,25 @@
1
1
  airflow/providers/snowflake/LICENSE,sha256=gXPVwptPlW1TJ4HSuG5OMPg-a3h43OGMkZRR1rpwfJA,10850
2
- airflow/providers/snowflake/__init__.py,sha256=QsK4vg_MkprBynxoW0TPaDLqVXgU_0-k0VW6Zabo7gk,1498
2
+ airflow/providers/snowflake/__init__.py,sha256=hfjXA59cpm_yb00p5Y7jNTIBCe3BPNoYYXkAVAcF_c4,1498
3
3
  airflow/providers/snowflake/get_provider_info.py,sha256=NdNRMfulBbpD-I4yFRr8U533m9djD18ijEMvuxOp4_g,3875
4
4
  airflow/providers/snowflake/version_compat.py,sha256=j5PCtXvZ71aBjixu-EFTNtVDPsngzzs7os0ZQDgFVDk,1536
5
5
  airflow/providers/snowflake/decorators/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
6
6
  airflow/providers/snowflake/decorators/snowpark.py,sha256=tKXOjP8m8SEIu0jx2KSrd0n3jGMaIKDOwG2lMkvk3cI,5523
7
7
  airflow/providers/snowflake/hooks/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
8
- airflow/providers/snowflake/hooks/snowflake.py,sha256=9OH16CYnnJ0-ayAg1D7OdZusEf5lSGjQurWifptp97k,28025
9
- airflow/providers/snowflake/hooks/snowflake_sql_api.py,sha256=-J0mPcdDc9wbB7DcnZfnXJN7H62nbR_NK5WQJxeKZjE,14532
8
+ airflow/providers/snowflake/hooks/snowflake.py,sha256=t-sukzbQ1OYMeyZBDrJ9s9DuJbnDZFJKhBMZn_mQLgY,28933
9
+ airflow/providers/snowflake/hooks/snowflake_sql_api.py,sha256=mREZ0nHc6L-9YSHZARYpgqrETqzVJ3Q6EfbWtEy5TV4,14745
10
10
  airflow/providers/snowflake/operators/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
11
- airflow/providers/snowflake/operators/snowflake.py,sha256=5MisB-bKqUFM9t5Ky913UqewoHlq3k3mCv4bnc-VY7g,22657
11
+ airflow/providers/snowflake/operators/snowflake.py,sha256=Abu0MVsUPwVxfDVNYn5OtwVoUOhQanUp-YSFtLdcn6c,22915
12
12
  airflow/providers/snowflake/operators/snowpark.py,sha256=Wt3wzcsja0ed4q2KE9WyL74XH6mUVSPNZvcCHWEHQtc,5815
13
13
  airflow/providers/snowflake/transfers/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
14
- airflow/providers/snowflake/transfers/copy_into_snowflake.py,sha256=UjbznjbK-QWN071ZFMvBHZXoFddMo0vQFK-7VLv3amo,13191
14
+ airflow/providers/snowflake/transfers/copy_into_snowflake.py,sha256=bXmqkNwthJqUo65DsI_pC3mwk_V_Iikwi646oRCyWus,13590
15
15
  airflow/providers/snowflake/triggers/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
16
- airflow/providers/snowflake/triggers/snowflake_trigger.py,sha256=38tkByMyjbVbSt-69YL8EzRBQT4rhwuOKHgbwHfULL0,4250
16
+ airflow/providers/snowflake/triggers/snowflake_trigger.py,sha256=QXNLijmtZI7NIdPtOwbvS-4ohgrm8RV_jaBKvekosHQ,4051
17
17
  airflow/providers/snowflake/utils/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
18
18
  airflow/providers/snowflake/utils/common.py,sha256=DG-KLy2KpZWAqZqm_XIECm8lmdoUlzwkXv9onmkQThc,1644
19
- airflow/providers/snowflake/utils/openlineage.py,sha256=QjbN76qjboTvpQZtoi0g7s3R9LwutRtD7HZ3DEVLbyY,14372
19
+ airflow/providers/snowflake/utils/openlineage.py,sha256=dr57b0fidPo7A451UE6s6d3PnOf4dOv3iM5aoyE4oBI,15067
20
20
  airflow/providers/snowflake/utils/snowpark.py,sha256=9kzWRkdgoNQ8f3Wnr92LdZylMpcpRasxefpOXrM30Cw,1602
21
21
  airflow/providers/snowflake/utils/sql_api_generate_jwt.py,sha256=9mR-vHIquv60tfAni87f6FAjKsiRHUDDrsVhzw4M9vM,6762
22
- apache_airflow_providers_snowflake-6.3.1rc1.dist-info/entry_points.txt,sha256=bCrl5J1PXUMzbgnrKYho61rkbL2gHRT4I6f_1jlxAX4,105
23
- apache_airflow_providers_snowflake-6.3.1rc1.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
24
- apache_airflow_providers_snowflake-6.3.1rc1.dist-info/METADATA,sha256=0pkDl6pLL8SvxNML46vEzwwQ1IP7TzRKYlbJ2MtTkzI,6242
25
- apache_airflow_providers_snowflake-6.3.1rc1.dist-info/RECORD,,
22
+ apache_airflow_providers_snowflake-6.4.0rc1.dist-info/entry_points.txt,sha256=bCrl5J1PXUMzbgnrKYho61rkbL2gHRT4I6f_1jlxAX4,105
23
+ apache_airflow_providers_snowflake-6.4.0rc1.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
24
+ apache_airflow_providers_snowflake-6.4.0rc1.dist-info/METADATA,sha256=mDqBewdnQi0Pa9w3DNrTvh2i0csPs95pMlAsZDp1kX8,6242
25
+ apache_airflow_providers_snowflake-6.4.0rc1.dist-info/RECORD,,