apache-airflow-providers-snowflake 6.3.0__py3-none-any.whl → 6.8.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,11 +29,11 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "6.3.0"
32
+ __version__ = "6.8.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
- "2.10.0"
35
+ "2.11.0"
36
36
  ):
37
37
  raise RuntimeError(
38
- f"The package `apache-airflow-providers-snowflake:{__version__}` needs Apache Airflow 2.10.0+"
38
+ f"The package `apache-airflow-providers-snowflake:{__version__}` needs Apache Airflow 2.11.0+"
39
39
  )
@@ -17,22 +17,12 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from collections.abc import Sequence
21
- from typing import TYPE_CHECKING, Callable
22
-
23
- from airflow.providers.snowflake.version_compat import AIRFLOW_V_3_0_PLUS
24
-
25
- if AIRFLOW_V_3_0_PLUS:
26
- from airflow.sdk.bases.decorator import DecoratedOperator, task_decorator_factory
27
- else:
28
- from airflow.decorators.base import DecoratedOperator, task_decorator_factory # type: ignore[no-redef]
20
+ from collections.abc import Callable, Sequence
29
21
 
22
+ from airflow.providers.common.compat.sdk import DecoratedOperator, TaskDecorator, task_decorator_factory
30
23
  from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
31
24
  from airflow.providers.snowflake.utils.snowpark import inject_session_into_op_kwargs
32
25
 
33
- if TYPE_CHECKING:
34
- from airflow.sdk.bases.decorator import TaskDecorator
35
-
36
26
 
37
27
  class _SnowparkDecoratedOperator(DecoratedOperator, SnowparkOperator):
38
28
  """
@@ -73,11 +73,13 @@ def get_provider_info():
73
73
  "source-integration-name": "Google Cloud Storage (GCS)",
74
74
  "target-integration-name": "Snowflake",
75
75
  "python-module": "airflow.providers.snowflake.transfers.copy_into_snowflake",
76
+ "how-to-guide": "/docs/apache-airflow-providers-snowflake/operators/copy_into_snowflake.rst",
76
77
  },
77
78
  {
78
79
  "source-integration-name": "Microsoft Azure Blob Storage",
79
80
  "target-integration-name": "Snowflake",
80
81
  "python-module": "airflow.providers.snowflake.transfers.copy_into_snowflake",
82
+ "how-to-guide": "/docs/apache-airflow-providers-snowflake/operators/copy_into_snowflake.rst",
81
83
  },
82
84
  ],
83
85
  "connection-types": [
@@ -92,4 +94,18 @@ def get_provider_info():
92
94
  "python-modules": ["airflow.providers.snowflake.triggers.snowflake_trigger"],
93
95
  }
94
96
  ],
97
+ "config": {
98
+ "snowflake": {
99
+ "description": "Configuration for Snowflake hooks and operators.\n",
100
+ "options": {
101
+ "azure_oauth_scope": {
102
+ "description": "The scope to use while retrieving OAuth token for Snowflake from Azure Entra authentication.\n",
103
+ "version_added": "6.6.0",
104
+ "type": "string",
105
+ "example": None,
106
+ "default": "api://snowflake_oauth_server/.default",
107
+ }
108
+ },
109
+ }
110
+ },
95
111
  }
@@ -19,12 +19,12 @@ from __future__ import annotations
19
19
 
20
20
  import base64
21
21
  import os
22
- from collections.abc import Iterable, Mapping
22
+ from collections.abc import Callable, Iterable, Mapping
23
23
  from contextlib import closing, contextmanager
24
24
  from functools import cached_property
25
25
  from io import StringIO
26
26
  from pathlib import Path
27
- from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
27
+ from typing import TYPE_CHECKING, Any, TypeVar, overload
28
28
  from urllib.parse import urlparse
29
29
 
30
30
  import requests
@@ -36,7 +36,9 @@ from snowflake.connector import DictCursor, SnowflakeConnection, util_text
36
36
  from snowflake.sqlalchemy import URL
37
37
  from sqlalchemy import create_engine
38
38
 
39
- from airflow.exceptions import AirflowException
39
+ from airflow.configuration import conf
40
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
41
+ from airflow.providers.common.compat.sdk import AirflowException, Connection
40
42
  from airflow.providers.common.sql.hooks.handlers import return_single_query_results
41
43
  from airflow.providers.common.sql.hooks.sql import DbApiHook
42
44
  from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri
@@ -94,6 +96,7 @@ class SnowflakeHook(DbApiHook):
94
96
  hook_name = "Snowflake"
95
97
  supports_autocommit = True
96
98
  _test_connection_sql = "select 1"
99
+ default_azure_oauth_scope = "api://snowflake_oauth_server/.default"
97
100
 
98
101
  @classmethod
99
102
  def get_connection_form_widgets(cls) -> dict[str, Any]:
@@ -136,6 +139,10 @@ class SnowflakeHook(DbApiHook):
136
139
  "session_parameters": "session parameters",
137
140
  "client_request_mfa_token": "client request mfa token",
138
141
  "client_store_temporary_credential": "client store temporary credential (externalbrowser mode)",
142
+ "grant_type": "refresh_token client_credentials",
143
+ "token_endpoint": "token endpoint",
144
+ "refresh_token": "refresh token",
145
+ "scope": "scope",
139
146
  },
140
147
  indent=1,
141
148
  ),
@@ -200,18 +207,37 @@ class SnowflakeHook(DbApiHook):
200
207
 
201
208
  return account_identifier
202
209
 
203
- def get_oauth_token(self, conn_config: dict | None = None) -> str:
210
+ def get_oauth_token(
211
+ self,
212
+ conn_config: dict | None = None,
213
+ token_endpoint: str | None = None,
214
+ grant_type: str = "refresh_token",
215
+ ) -> str:
204
216
  """Generate temporary OAuth access token using refresh token in connection details."""
205
217
  if conn_config is None:
206
218
  conn_config = self._get_conn_params
207
219
 
208
- url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
220
+ url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
209
221
 
210
222
  data = {
211
- "grant_type": "refresh_token",
212
- "refresh_token": conn_config["refresh_token"],
223
+ "grant_type": grant_type,
213
224
  "redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
214
225
  }
226
+
227
+ scope = conn_config.get("scope")
228
+
229
+ if scope:
230
+ data["scope"] = scope
231
+
232
+ if grant_type == "refresh_token":
233
+ data |= {
234
+ "refresh_token": conn_config["refresh_token"],
235
+ }
236
+ elif grant_type == "client_credentials":
237
+ pass # no setup necessary for client credentials grant.
238
+ else:
239
+ raise ValueError(f"Unknown grant_type: {grant_type}")
240
+
215
241
  response = requests.post(
216
242
  url,
217
243
  data=data,
@@ -226,7 +252,40 @@ class SnowflakeHook(DbApiHook):
226
252
  except requests.exceptions.HTTPError as e: # pragma: no cover
227
253
  msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
228
254
  raise AirflowException(msg)
229
- return response.json()["access_token"]
255
+ token = response.json()["access_token"]
256
+ return token
257
+
258
+ def get_azure_oauth_token(self, azure_conn_id: str) -> str:
259
+ """
260
+ Generate OAuth access token using Azure connection id.
261
+
262
+ This uses AzureBaseHook on the connection id to retrieve the token. Scope for the OAuth token can be
263
+ set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``.
264
+
265
+ :param azure_conn_id: The connection id for the Azure connection that will be used to fetch the token.
266
+ :raises AttributeError: If AzureBaseHook does not have a get_token method which happens when
267
+ package apache-airflow-providers-microsoft-azure<12.8.0.
268
+ :returns: The OAuth access token string.
269
+ """
270
+ if TYPE_CHECKING:
271
+ from airflow.providers.microsoft.azure.hooks.azure_base import AzureBaseHook
272
+
273
+ try:
274
+ azure_conn = Connection.get(azure_conn_id)
275
+ except AttributeError:
276
+ azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined]
277
+ try:
278
+ azure_base_hook: AzureBaseHook = azure_conn.get_hook()
279
+ except TypeError as e:
280
+ if "required positional argument: 'sdk_client'" in str(e):
281
+ raise AirflowOptionalProviderFeatureException(
282
+ "Getting azure token is not supported by current version of 'AzureBaseHook'. "
283
+ "Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0"
284
+ ) from e
285
+ raise
286
+ scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope)
287
+ token = azure_base_hook.get_token(scope).token
288
+ return token
230
289
 
231
290
  @cached_property
232
291
  def _get_conn_params(self) -> dict[str, str | None]:
@@ -235,7 +294,7 @@ class SnowflakeHook(DbApiHook):
235
294
 
236
295
  This is used in ``get_uri()`` and ``get_connection()``.
237
296
  """
238
- conn = self.get_connection(self.snowflake_conn_id) # type: ignore[attr-defined]
297
+ conn = self.get_connection(self.get_conn_id())
239
298
  extra_dict = conn.extra_dejson
240
299
  account = self._get_field(extra_dict, "account") or ""
241
300
  warehouse = self._get_field(extra_dict, "warehouse") or ""
@@ -329,14 +388,26 @@ class SnowflakeHook(DbApiHook):
329
388
  if refresh_token:
330
389
  conn_config["refresh_token"] = refresh_token
331
390
  conn_config["authenticator"] = "oauth"
332
- conn_config["client_id"] = conn.login
333
- conn_config["client_secret"] = conn.password
391
+
392
+ if conn_config.get("authenticator") == "oauth":
393
+ if extra_dict.get("azure_conn_id"):
394
+ conn_config["token"] = self.get_azure_oauth_token(extra_dict["azure_conn_id"])
395
+ else:
396
+ token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
397
+ conn_config["scope"] = self._get_field(extra_dict, "scope")
398
+ conn_config["client_id"] = conn.login
399
+ conn_config["client_secret"] = conn.password
400
+
401
+ conn_config["token"] = self.get_oauth_token(
402
+ conn_config=conn_config,
403
+ token_endpoint=token_endpoint,
404
+ grant_type=extra_dict.get("grant_type", "refresh_token"),
405
+ )
406
+
334
407
  conn_config.pop("login", None)
335
408
  conn_config.pop("user", None)
336
409
  conn_config.pop("password", None)
337
410
 
338
- conn_config["token"] = self.get_oauth_token(conn_config=conn_config)
339
-
340
411
  # configure custom target hostname and port, if specified
341
412
  snowflake_host = extra_dict.get("host")
342
413
  snowflake_port = extra_dict.get("port")
@@ -436,7 +507,7 @@ class SnowflakeHook(DbApiHook):
436
507
  def get_autocommit(self, conn):
437
508
  return getattr(conn, "autocommit_mode", False)
438
509
 
439
- @overload # type: ignore[override]
510
+ @overload
440
511
  def run(
441
512
  self,
442
513
  sql: str | Iterable[str],
@@ -519,16 +590,16 @@ class SnowflakeHook(DbApiHook):
519
590
  results = []
520
591
  for sql_statement in sql_list:
521
592
  self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
522
- self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined]
593
+ self._run_command(cur, sql_statement, parameters)
523
594
 
524
595
  if handler is not None:
525
- result = self._make_common_data_structure(handler(cur)) # type: ignore[attr-defined]
596
+ result = self._make_common_data_structure(handler(cur))
526
597
  if return_single_query_results(sql, return_last, split_statements):
527
598
  _last_result = result
528
599
  _last_description = cur.description
529
600
  else:
530
601
  results.append(result)
531
- self.descriptions.append(cur.description) # type: ignore[has-type]
602
+ self.descriptions.append(cur.description)
532
603
 
533
604
  query_id = cur.sfqid
534
605
  self.log.info("Rows affected: %s", cur.rowcount)
@@ -592,10 +663,9 @@ class SnowflakeHook(DbApiHook):
592
663
 
593
664
  def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None:
594
665
  """
595
- Generate OpenLineage metadata for a Snowflake task instance based on executed query IDs.
666
+ Emit separate OpenLineage events for each Snowflake query, based on executed query IDs.
596
667
 
597
- If a single query ID is present, attach an `ExternalQueryRunFacet` to the lineage metadata.
598
- If multiple query IDs are present, emits separate OpenLineage events for each query.
668
+ If a single query ID is present, also add an `ExternalQueryRunFacet` to the returned lineage metadata.
599
669
 
600
670
  Note that `get_openlineage_database_specific_lineage` is usually called after task's execution,
601
671
  so if multiple query IDs are present, both START and COMPLETE event for each query will be emitted
@@ -616,13 +686,22 @@ class SnowflakeHook(DbApiHook):
616
686
  )
617
687
 
618
688
  if not self.query_ids:
619
- self.log.debug("openlineage: no snowflake query ids found.")
689
+ self.log.info("OpenLineage could not find snowflake query ids.")
620
690
  return None
621
691
 
622
692
  self.log.debug("openlineage: getting connection to get database info")
623
693
  connection = self.get_connection(self.get_conn_id())
624
694
  namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection))
625
695
 
696
+ self.log.info("Separate OpenLineage events will be emitted for each query_id.")
697
+ emit_openlineage_events_for_snowflake_queries(
698
+ task_instance=task_instance,
699
+ hook=self,
700
+ query_ids=self.query_ids,
701
+ query_for_extra_metadata=True,
702
+ query_source_namespace=namespace,
703
+ )
704
+
626
705
  if len(self.query_ids) == 1:
627
706
  self.log.debug("Attaching ExternalQueryRunFacet with single query_id to OpenLineage event.")
628
707
  return OperatorLineage(
@@ -633,20 +712,4 @@ class SnowflakeHook(DbApiHook):
633
712
  }
634
713
  )
635
714
 
636
- self.log.info("Multiple query_ids found. Separate OpenLineage event will be emitted for each query.")
637
- try:
638
- from airflow.providers.openlineage.utils.utils import should_use_external_connection
639
-
640
- use_external_connection = should_use_external_connection(self)
641
- except ImportError:
642
- # OpenLineage provider release < 1.8.0 - we always use connection
643
- use_external_connection = True
644
-
645
- emit_openlineage_events_for_snowflake_queries(
646
- query_ids=self.query_ids,
647
- query_source_namespace=namespace,
648
- task_instance=task_instance,
649
- hook=self if use_external_connection else None,
650
- )
651
-
652
715
  return None