apache-airflow-providers-snowflake 6.3.0__py3-none-any.whl → 6.8.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/snowflake/__init__.py +3 -3
- airflow/providers/snowflake/decorators/snowpark.py +2 -12
- airflow/providers/snowflake/get_provider_info.py +16 -0
- airflow/providers/snowflake/hooks/snowflake.py +100 -37
- airflow/providers/snowflake/hooks/snowflake_sql_api.py +226 -29
- airflow/providers/snowflake/operators/snowflake.py +37 -27
- airflow/providers/snowflake/operators/snowpark.py +2 -2
- airflow/providers/snowflake/transfers/copy_into_snowflake.py +13 -4
- airflow/providers/snowflake/triggers/snowflake_trigger.py +1 -4
- airflow/providers/snowflake/utils/openlineage.py +141 -93
- airflow/providers/snowflake/utils/snowpark.py +2 -1
- airflow/providers/snowflake/version_compat.py +4 -0
- {apache_airflow_providers_snowflake-6.3.0.dist-info → apache_airflow_providers_snowflake-6.8.0rc1.dist-info}/METADATA +61 -38
- apache_airflow_providers_snowflake-6.8.0rc1.dist-info/RECORD +26 -0
- apache_airflow_providers_snowflake-6.8.0rc1.dist-info/licenses/NOTICE +5 -0
- apache_airflow_providers_snowflake-6.3.0.dist-info/RECORD +0 -25
- {apache_airflow_providers_snowflake-6.3.0.dist-info → apache_airflow_providers_snowflake-6.8.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_snowflake-6.3.0.dist-info → apache_airflow_providers_snowflake-6.8.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/snowflake → apache_airflow_providers_snowflake-6.8.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -29,11 +29,11 @@ from airflow import __version__ as airflow_version
|
|
|
29
29
|
|
|
30
30
|
__all__ = ["__version__"]
|
|
31
31
|
|
|
32
|
-
__version__ = "6.
|
|
32
|
+
__version__ = "6.8.0"
|
|
33
33
|
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
|
35
|
-
"2.
|
|
35
|
+
"2.11.0"
|
|
36
36
|
):
|
|
37
37
|
raise RuntimeError(
|
|
38
|
-
f"The package `apache-airflow-providers-snowflake:{__version__}` needs Apache Airflow 2.
|
|
38
|
+
f"The package `apache-airflow-providers-snowflake:{__version__}` needs Apache Airflow 2.11.0+"
|
|
39
39
|
)
|
|
@@ -17,22 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
-
from collections.abc import Sequence
|
|
21
|
-
from typing import TYPE_CHECKING, Callable
|
|
22
|
-
|
|
23
|
-
from airflow.providers.snowflake.version_compat import AIRFLOW_V_3_0_PLUS
|
|
24
|
-
|
|
25
|
-
if AIRFLOW_V_3_0_PLUS:
|
|
26
|
-
from airflow.sdk.bases.decorator import DecoratedOperator, task_decorator_factory
|
|
27
|
-
else:
|
|
28
|
-
from airflow.decorators.base import DecoratedOperator, task_decorator_factory # type: ignore[no-redef]
|
|
20
|
+
from collections.abc import Callable, Sequence
|
|
29
21
|
|
|
22
|
+
from airflow.providers.common.compat.sdk import DecoratedOperator, TaskDecorator, task_decorator_factory
|
|
30
23
|
from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
|
|
31
24
|
from airflow.providers.snowflake.utils.snowpark import inject_session_into_op_kwargs
|
|
32
25
|
|
|
33
|
-
if TYPE_CHECKING:
|
|
34
|
-
from airflow.sdk.bases.decorator import TaskDecorator
|
|
35
|
-
|
|
36
26
|
|
|
37
27
|
class _SnowparkDecoratedOperator(DecoratedOperator, SnowparkOperator):
|
|
38
28
|
"""
|
|
@@ -73,11 +73,13 @@ def get_provider_info():
|
|
|
73
73
|
"source-integration-name": "Google Cloud Storage (GCS)",
|
|
74
74
|
"target-integration-name": "Snowflake",
|
|
75
75
|
"python-module": "airflow.providers.snowflake.transfers.copy_into_snowflake",
|
|
76
|
+
"how-to-guide": "/docs/apache-airflow-providers-snowflake/operators/copy_into_snowflake.rst",
|
|
76
77
|
},
|
|
77
78
|
{
|
|
78
79
|
"source-integration-name": "Microsoft Azure Blob Storage",
|
|
79
80
|
"target-integration-name": "Snowflake",
|
|
80
81
|
"python-module": "airflow.providers.snowflake.transfers.copy_into_snowflake",
|
|
82
|
+
"how-to-guide": "/docs/apache-airflow-providers-snowflake/operators/copy_into_snowflake.rst",
|
|
81
83
|
},
|
|
82
84
|
],
|
|
83
85
|
"connection-types": [
|
|
@@ -92,4 +94,18 @@ def get_provider_info():
|
|
|
92
94
|
"python-modules": ["airflow.providers.snowflake.triggers.snowflake_trigger"],
|
|
93
95
|
}
|
|
94
96
|
],
|
|
97
|
+
"config": {
|
|
98
|
+
"snowflake": {
|
|
99
|
+
"description": "Configuration for Snowflake hooks and operators.\n",
|
|
100
|
+
"options": {
|
|
101
|
+
"azure_oauth_scope": {
|
|
102
|
+
"description": "The scope to use while retrieving OAuth token for Snowflake from Azure Entra authentication.\n",
|
|
103
|
+
"version_added": "6.6.0",
|
|
104
|
+
"type": "string",
|
|
105
|
+
"example": None,
|
|
106
|
+
"default": "api://snowflake_oauth_server/.default",
|
|
107
|
+
}
|
|
108
|
+
},
|
|
109
|
+
}
|
|
110
|
+
},
|
|
95
111
|
}
|
|
@@ -19,12 +19,12 @@ from __future__ import annotations
|
|
|
19
19
|
|
|
20
20
|
import base64
|
|
21
21
|
import os
|
|
22
|
-
from collections.abc import Iterable, Mapping
|
|
22
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
23
23
|
from contextlib import closing, contextmanager
|
|
24
24
|
from functools import cached_property
|
|
25
25
|
from io import StringIO
|
|
26
26
|
from pathlib import Path
|
|
27
|
-
from typing import TYPE_CHECKING, Any,
|
|
27
|
+
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
|
28
28
|
from urllib.parse import urlparse
|
|
29
29
|
|
|
30
30
|
import requests
|
|
@@ -36,7 +36,9 @@ from snowflake.connector import DictCursor, SnowflakeConnection, util_text
|
|
|
36
36
|
from snowflake.sqlalchemy import URL
|
|
37
37
|
from sqlalchemy import create_engine
|
|
38
38
|
|
|
39
|
-
from airflow.
|
|
39
|
+
from airflow.configuration import conf
|
|
40
|
+
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
|
41
|
+
from airflow.providers.common.compat.sdk import AirflowException, Connection
|
|
40
42
|
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
|
|
41
43
|
from airflow.providers.common.sql.hooks.sql import DbApiHook
|
|
42
44
|
from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri
|
|
@@ -94,6 +96,7 @@ class SnowflakeHook(DbApiHook):
|
|
|
94
96
|
hook_name = "Snowflake"
|
|
95
97
|
supports_autocommit = True
|
|
96
98
|
_test_connection_sql = "select 1"
|
|
99
|
+
default_azure_oauth_scope = "api://snowflake_oauth_server/.default"
|
|
97
100
|
|
|
98
101
|
@classmethod
|
|
99
102
|
def get_connection_form_widgets(cls) -> dict[str, Any]:
|
|
@@ -136,6 +139,10 @@ class SnowflakeHook(DbApiHook):
|
|
|
136
139
|
"session_parameters": "session parameters",
|
|
137
140
|
"client_request_mfa_token": "client request mfa token",
|
|
138
141
|
"client_store_temporary_credential": "client store temporary credential (externalbrowser mode)",
|
|
142
|
+
"grant_type": "refresh_token client_credentials",
|
|
143
|
+
"token_endpoint": "token endpoint",
|
|
144
|
+
"refresh_token": "refresh token",
|
|
145
|
+
"scope": "scope",
|
|
139
146
|
},
|
|
140
147
|
indent=1,
|
|
141
148
|
),
|
|
@@ -200,18 +207,37 @@ class SnowflakeHook(DbApiHook):
|
|
|
200
207
|
|
|
201
208
|
return account_identifier
|
|
202
209
|
|
|
203
|
-
def get_oauth_token(
|
|
210
|
+
def get_oauth_token(
|
|
211
|
+
self,
|
|
212
|
+
conn_config: dict | None = None,
|
|
213
|
+
token_endpoint: str | None = None,
|
|
214
|
+
grant_type: str = "refresh_token",
|
|
215
|
+
) -> str:
|
|
204
216
|
"""Generate temporary OAuth access token using refresh token in connection details."""
|
|
205
217
|
if conn_config is None:
|
|
206
218
|
conn_config = self._get_conn_params
|
|
207
219
|
|
|
208
|
-
url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
|
|
220
|
+
url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
|
|
209
221
|
|
|
210
222
|
data = {
|
|
211
|
-
"grant_type":
|
|
212
|
-
"refresh_token": conn_config["refresh_token"],
|
|
223
|
+
"grant_type": grant_type,
|
|
213
224
|
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
|
|
214
225
|
}
|
|
226
|
+
|
|
227
|
+
scope = conn_config.get("scope")
|
|
228
|
+
|
|
229
|
+
if scope:
|
|
230
|
+
data["scope"] = scope
|
|
231
|
+
|
|
232
|
+
if grant_type == "refresh_token":
|
|
233
|
+
data |= {
|
|
234
|
+
"refresh_token": conn_config["refresh_token"],
|
|
235
|
+
}
|
|
236
|
+
elif grant_type == "client_credentials":
|
|
237
|
+
pass # no setup necessary for client credentials grant.
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError(f"Unknown grant_type: {grant_type}")
|
|
240
|
+
|
|
215
241
|
response = requests.post(
|
|
216
242
|
url,
|
|
217
243
|
data=data,
|
|
@@ -226,7 +252,40 @@ class SnowflakeHook(DbApiHook):
|
|
|
226
252
|
except requests.exceptions.HTTPError as e: # pragma: no cover
|
|
227
253
|
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
|
|
228
254
|
raise AirflowException(msg)
|
|
229
|
-
|
|
255
|
+
token = response.json()["access_token"]
|
|
256
|
+
return token
|
|
257
|
+
|
|
258
|
+
def get_azure_oauth_token(self, azure_conn_id: str) -> str:
|
|
259
|
+
"""
|
|
260
|
+
Generate OAuth access token using Azure connection id.
|
|
261
|
+
|
|
262
|
+
This uses AzureBaseHook on the connection id to retrieve the token. Scope for the OAuth token can be
|
|
263
|
+
set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``.
|
|
264
|
+
|
|
265
|
+
:param azure_conn_id: The connection id for the Azure connection that will be used to fetch the token.
|
|
266
|
+
:raises AttributeError: If AzureBaseHook does not have a get_token method which happens when
|
|
267
|
+
package apache-airflow-providers-microsoft-azure<12.8.0.
|
|
268
|
+
:returns: The OAuth access token string.
|
|
269
|
+
"""
|
|
270
|
+
if TYPE_CHECKING:
|
|
271
|
+
from airflow.providers.microsoft.azure.hooks.azure_base import AzureBaseHook
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
azure_conn = Connection.get(azure_conn_id)
|
|
275
|
+
except AttributeError:
|
|
276
|
+
azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined]
|
|
277
|
+
try:
|
|
278
|
+
azure_base_hook: AzureBaseHook = azure_conn.get_hook()
|
|
279
|
+
except TypeError as e:
|
|
280
|
+
if "required positional argument: 'sdk_client'" in str(e):
|
|
281
|
+
raise AirflowOptionalProviderFeatureException(
|
|
282
|
+
"Getting azure token is not supported by current version of 'AzureBaseHook'. "
|
|
283
|
+
"Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0"
|
|
284
|
+
) from e
|
|
285
|
+
raise
|
|
286
|
+
scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope)
|
|
287
|
+
token = azure_base_hook.get_token(scope).token
|
|
288
|
+
return token
|
|
230
289
|
|
|
231
290
|
@cached_property
|
|
232
291
|
def _get_conn_params(self) -> dict[str, str | None]:
|
|
@@ -235,7 +294,7 @@ class SnowflakeHook(DbApiHook):
|
|
|
235
294
|
|
|
236
295
|
This is used in ``get_uri()`` and ``get_connection()``.
|
|
237
296
|
"""
|
|
238
|
-
conn = self.get_connection(self.
|
|
297
|
+
conn = self.get_connection(self.get_conn_id())
|
|
239
298
|
extra_dict = conn.extra_dejson
|
|
240
299
|
account = self._get_field(extra_dict, "account") or ""
|
|
241
300
|
warehouse = self._get_field(extra_dict, "warehouse") or ""
|
|
@@ -329,14 +388,26 @@ class SnowflakeHook(DbApiHook):
|
|
|
329
388
|
if refresh_token:
|
|
330
389
|
conn_config["refresh_token"] = refresh_token
|
|
331
390
|
conn_config["authenticator"] = "oauth"
|
|
332
|
-
|
|
333
|
-
|
|
391
|
+
|
|
392
|
+
if conn_config.get("authenticator") == "oauth":
|
|
393
|
+
if extra_dict.get("azure_conn_id"):
|
|
394
|
+
conn_config["token"] = self.get_azure_oauth_token(extra_dict["azure_conn_id"])
|
|
395
|
+
else:
|
|
396
|
+
token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
|
|
397
|
+
conn_config["scope"] = self._get_field(extra_dict, "scope")
|
|
398
|
+
conn_config["client_id"] = conn.login
|
|
399
|
+
conn_config["client_secret"] = conn.password
|
|
400
|
+
|
|
401
|
+
conn_config["token"] = self.get_oauth_token(
|
|
402
|
+
conn_config=conn_config,
|
|
403
|
+
token_endpoint=token_endpoint,
|
|
404
|
+
grant_type=extra_dict.get("grant_type", "refresh_token"),
|
|
405
|
+
)
|
|
406
|
+
|
|
334
407
|
conn_config.pop("login", None)
|
|
335
408
|
conn_config.pop("user", None)
|
|
336
409
|
conn_config.pop("password", None)
|
|
337
410
|
|
|
338
|
-
conn_config["token"] = self.get_oauth_token(conn_config=conn_config)
|
|
339
|
-
|
|
340
411
|
# configure custom target hostname and port, if specified
|
|
341
412
|
snowflake_host = extra_dict.get("host")
|
|
342
413
|
snowflake_port = extra_dict.get("port")
|
|
@@ -436,7 +507,7 @@ class SnowflakeHook(DbApiHook):
|
|
|
436
507
|
def get_autocommit(self, conn):
|
|
437
508
|
return getattr(conn, "autocommit_mode", False)
|
|
438
509
|
|
|
439
|
-
@overload
|
|
510
|
+
@overload
|
|
440
511
|
def run(
|
|
441
512
|
self,
|
|
442
513
|
sql: str | Iterable[str],
|
|
@@ -519,16 +590,16 @@ class SnowflakeHook(DbApiHook):
|
|
|
519
590
|
results = []
|
|
520
591
|
for sql_statement in sql_list:
|
|
521
592
|
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
|
|
522
|
-
self._run_command(cur, sql_statement, parameters)
|
|
593
|
+
self._run_command(cur, sql_statement, parameters)
|
|
523
594
|
|
|
524
595
|
if handler is not None:
|
|
525
|
-
result = self._make_common_data_structure(handler(cur))
|
|
596
|
+
result = self._make_common_data_structure(handler(cur))
|
|
526
597
|
if return_single_query_results(sql, return_last, split_statements):
|
|
527
598
|
_last_result = result
|
|
528
599
|
_last_description = cur.description
|
|
529
600
|
else:
|
|
530
601
|
results.append(result)
|
|
531
|
-
self.descriptions.append(cur.description)
|
|
602
|
+
self.descriptions.append(cur.description)
|
|
532
603
|
|
|
533
604
|
query_id = cur.sfqid
|
|
534
605
|
self.log.info("Rows affected: %s", cur.rowcount)
|
|
@@ -592,10 +663,9 @@ class SnowflakeHook(DbApiHook):
|
|
|
592
663
|
|
|
593
664
|
def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None:
|
|
594
665
|
"""
|
|
595
|
-
|
|
666
|
+
Emit separate OpenLineage events for each Snowflake query, based on executed query IDs.
|
|
596
667
|
|
|
597
|
-
If a single query ID is present,
|
|
598
|
-
If multiple query IDs are present, emits separate OpenLineage events for each query.
|
|
668
|
+
If a single query ID is present, also add an `ExternalQueryRunFacet` to the returned lineage metadata.
|
|
599
669
|
|
|
600
670
|
Note that `get_openlineage_database_specific_lineage` is usually called after task's execution,
|
|
601
671
|
so if multiple query IDs are present, both START and COMPLETE event for each query will be emitted
|
|
@@ -616,13 +686,22 @@ class SnowflakeHook(DbApiHook):
|
|
|
616
686
|
)
|
|
617
687
|
|
|
618
688
|
if not self.query_ids:
|
|
619
|
-
self.log.
|
|
689
|
+
self.log.info("OpenLineage could not find snowflake query ids.")
|
|
620
690
|
return None
|
|
621
691
|
|
|
622
692
|
self.log.debug("openlineage: getting connection to get database info")
|
|
623
693
|
connection = self.get_connection(self.get_conn_id())
|
|
624
694
|
namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection))
|
|
625
695
|
|
|
696
|
+
self.log.info("Separate OpenLineage events will be emitted for each query_id.")
|
|
697
|
+
emit_openlineage_events_for_snowflake_queries(
|
|
698
|
+
task_instance=task_instance,
|
|
699
|
+
hook=self,
|
|
700
|
+
query_ids=self.query_ids,
|
|
701
|
+
query_for_extra_metadata=True,
|
|
702
|
+
query_source_namespace=namespace,
|
|
703
|
+
)
|
|
704
|
+
|
|
626
705
|
if len(self.query_ids) == 1:
|
|
627
706
|
self.log.debug("Attaching ExternalQueryRunFacet with single query_id to OpenLineage event.")
|
|
628
707
|
return OperatorLineage(
|
|
@@ -633,20 +712,4 @@ class SnowflakeHook(DbApiHook):
|
|
|
633
712
|
}
|
|
634
713
|
)
|
|
635
714
|
|
|
636
|
-
self.log.info("Multiple query_ids found. Separate OpenLineage event will be emitted for each query.")
|
|
637
|
-
try:
|
|
638
|
-
from airflow.providers.openlineage.utils.utils import should_use_external_connection
|
|
639
|
-
|
|
640
|
-
use_external_connection = should_use_external_connection(self)
|
|
641
|
-
except ImportError:
|
|
642
|
-
# OpenLineage provider release < 1.8.0 - we always use connection
|
|
643
|
-
use_external_connection = True
|
|
644
|
-
|
|
645
|
-
emit_openlineage_events_for_snowflake_queries(
|
|
646
|
-
query_ids=self.query_ids,
|
|
647
|
-
query_source_namespace=namespace,
|
|
648
|
-
task_instance=task_instance,
|
|
649
|
-
hook=self if use_external_connection else None,
|
|
650
|
-
)
|
|
651
|
-
|
|
652
715
|
return None
|