apache-airflow-providers-microsoft-azure 12.7.0__py3-none-any.whl → 12.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/microsoft/azure/__init__.py +1 -1
- airflow/providers/microsoft/azure/fs/adls.py +1 -1
- airflow/providers/microsoft/azure/fs/msgraph.py +111 -0
- airflow/providers/microsoft/azure/get_provider_info.py +12 -2
- airflow/providers/microsoft/azure/hooks/adx.py +1 -1
- airflow/providers/microsoft/azure/hooks/asb.py +1 -1
- airflow/providers/microsoft/azure/hooks/base_azure.py +93 -17
- airflow/providers/microsoft/azure/hooks/batch.py +1 -1
- airflow/providers/microsoft/azure/hooks/container_registry.py +1 -1
- airflow/providers/microsoft/azure/hooks/container_volume.py +1 -1
- airflow/providers/microsoft/azure/hooks/cosmos.py +1 -1
- airflow/providers/microsoft/azure/hooks/data_factory.py +2 -5
- airflow/providers/microsoft/azure/hooks/data_lake.py +1 -1
- airflow/providers/microsoft/azure/hooks/fileshare.py +1 -1
- airflow/providers/microsoft/azure/hooks/msgraph.py +18 -6
- airflow/providers/microsoft/azure/hooks/synapse.py +1 -1
- airflow/providers/microsoft/azure/hooks/wasb.py +1 -1
- airflow/providers/microsoft/azure/log/wasb_task_handler.py +7 -1
- airflow/providers/microsoft/azure/operators/adls.py +1 -1
- airflow/providers/microsoft/azure/operators/adx.py +1 -1
- airflow/providers/microsoft/azure/operators/asb.py +1 -1
- airflow/providers/microsoft/azure/operators/batch.py +1 -1
- airflow/providers/microsoft/azure/operators/container_instances.py +1 -1
- airflow/providers/microsoft/azure/operators/cosmos.py +1 -1
- airflow/providers/microsoft/azure/operators/data_factory.py +1 -13
- airflow/providers/microsoft/azure/operators/msgraph.py +2 -1
- airflow/providers/microsoft/azure/operators/powerbi.py +1 -8
- airflow/providers/microsoft/azure/operators/synapse.py +1 -13
- airflow/providers/microsoft/azure/operators/wasb_delete_blob.py +1 -1
- airflow/providers/microsoft/azure/sensors/cosmos.py +1 -6
- airflow/providers/microsoft/azure/sensors/data_factory.py +1 -6
- airflow/providers/microsoft/azure/sensors/msgraph.py +1 -6
- airflow/providers/microsoft/azure/sensors/wasb.py +1 -6
- airflow/providers/microsoft/azure/transfers/local_to_adls.py +1 -1
- airflow/providers/microsoft/azure/transfers/local_to_wasb.py +1 -1
- airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py +1 -1
- airflow/providers/microsoft/azure/transfers/s3_to_wasb.py +1 -1
- airflow/providers/microsoft/azure/transfers/sftp_to_wasb.py +1 -1
- airflow/providers/microsoft/azure/triggers/data_factory.py +4 -3
- airflow/providers/microsoft/azure/triggers/msgraph.py +1 -12
- airflow/providers/microsoft/azure/version_compat.py +0 -24
- {apache_airflow_providers_microsoft_azure-12.7.0.dist-info → apache_airflow_providers_microsoft_azure-12.8.1.dist-info}/METADATA +58 -42
- apache_airflow_providers_microsoft_azure-12.8.1.dist-info/RECORD +61 -0
- apache_airflow_providers_microsoft_azure-12.8.1.dist-info/licenses/NOTICE +5 -0
- apache_airflow_providers_microsoft_azure-12.7.0.dist-info/RECORD +0 -59
- {apache_airflow_providers_microsoft_azure-12.7.0.dist-info → apache_airflow_providers_microsoft_azure-12.8.1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_microsoft_azure-12.7.0.dist-info → apache_airflow_providers_microsoft_azure-12.8.1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/microsoft/azure → apache_airflow_providers_microsoft_azure-12.8.1.dist-info/licenses}/LICENSE +0 -0
|
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
|
29
29
|
|
|
30
30
|
__all__ = ["__version__"]
|
|
31
31
|
|
|
32
|
-
__version__ = "12.
|
|
32
|
+
__version__ = "12.8.1"
|
|
33
33
|
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
|
35
35
|
"2.10.0"
|
|
@@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, Any
|
|
|
20
20
|
|
|
21
21
|
from azure.identity import ClientSecretCredential
|
|
22
22
|
|
|
23
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
23
24
|
from airflow.providers.microsoft.azure.utils import get_field, parse_blob_account_url
|
|
24
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
27
|
from fsspec import AbstractFileSystem
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
|
3
|
+
# distributed with this work for additional information
|
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
|
6
|
+
# "License"); you may not use this file except in compliance
|
|
7
|
+
# with the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
|
12
|
+
# software distributed under the License is distributed on an
|
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
14
|
+
# KIND, either express or implied. See the License for the
|
|
15
|
+
# specific language governing permissions and limitations
|
|
16
|
+
# under the License.
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import TYPE_CHECKING, Any
|
|
21
|
+
|
|
22
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
23
|
+
from airflow.providers.microsoft.azure.utils import get_field
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from fsspec import AbstractFileSystem
|
|
27
|
+
|
|
28
|
+
schemes = ["msgraph", "sharepoint", "onedrive", "msgd"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) -> AbstractFileSystem:
|
|
32
|
+
from msgraphfs import MSGDriveFS
|
|
33
|
+
|
|
34
|
+
if conn_id is None:
|
|
35
|
+
return MSGDriveFS({})
|
|
36
|
+
|
|
37
|
+
conn = BaseHook.get_connection(conn_id)
|
|
38
|
+
extras = conn.extra_dejson
|
|
39
|
+
conn_type = conn.conn_type or "msgraph"
|
|
40
|
+
|
|
41
|
+
options: dict[str, Any] = {}
|
|
42
|
+
|
|
43
|
+
# Get authentication parameters with fallback handling
|
|
44
|
+
client_id = conn.login or get_field(
|
|
45
|
+
conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_id"
|
|
46
|
+
)
|
|
47
|
+
client_secret = conn.password or get_field(
|
|
48
|
+
conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_secret"
|
|
49
|
+
)
|
|
50
|
+
tenant_id = conn.host or get_field(
|
|
51
|
+
conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="tenant_id"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if client_id:
|
|
55
|
+
options["client_id"] = client_id
|
|
56
|
+
if client_secret:
|
|
57
|
+
options["client_secret"] = client_secret
|
|
58
|
+
if tenant_id:
|
|
59
|
+
options["tenant_id"] = tenant_id
|
|
60
|
+
|
|
61
|
+
# Process additional fields from extras
|
|
62
|
+
fields = [
|
|
63
|
+
"drive_id",
|
|
64
|
+
"scope",
|
|
65
|
+
"token_endpoint",
|
|
66
|
+
"redirect_uri",
|
|
67
|
+
"token_endpoint_auth_method",
|
|
68
|
+
"code_challenge_method",
|
|
69
|
+
"update_token",
|
|
70
|
+
"username",
|
|
71
|
+
"password",
|
|
72
|
+
]
|
|
73
|
+
for field in fields:
|
|
74
|
+
value = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name=field)
|
|
75
|
+
if value is not None:
|
|
76
|
+
if value == "":
|
|
77
|
+
options.pop(field, "")
|
|
78
|
+
else:
|
|
79
|
+
options[field] = value
|
|
80
|
+
|
|
81
|
+
# Update with storage options
|
|
82
|
+
options.update(storage_options or {})
|
|
83
|
+
|
|
84
|
+
# Create oauth2 client parameters if authentication is provided
|
|
85
|
+
oauth2_client_params = {}
|
|
86
|
+
if options.get("client_id") and options.get("client_secret") and options.get("tenant_id"):
|
|
87
|
+
oauth2_client_params = {
|
|
88
|
+
"client_id": options["client_id"],
|
|
89
|
+
"client_secret": options["client_secret"],
|
|
90
|
+
"tenant_id": options["tenant_id"],
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# Add additional oauth2 parameters supported by authlib
|
|
94
|
+
oauth2_params = [
|
|
95
|
+
"scope",
|
|
96
|
+
"token_endpoint",
|
|
97
|
+
"redirect_uri",
|
|
98
|
+
"token_endpoint_auth_method",
|
|
99
|
+
"code_challenge_method",
|
|
100
|
+
"update_token",
|
|
101
|
+
"username",
|
|
102
|
+
"password",
|
|
103
|
+
]
|
|
104
|
+
for param in oauth2_params:
|
|
105
|
+
if param in options:
|
|
106
|
+
oauth2_client_params[param] = options[param]
|
|
107
|
+
|
|
108
|
+
# Determine which filesystem to return based on drive_id
|
|
109
|
+
drive_id = options.get("drive_id")
|
|
110
|
+
|
|
111
|
+
return MSGDriveFS(drive_id=drive_id, oauth2_client_params=oauth2_client_params)
|
|
@@ -37,6 +37,7 @@ def get_provider_info():
|
|
|
37
37
|
{
|
|
38
38
|
"integration-name": "Microsoft Azure Blob Storage",
|
|
39
39
|
"external-doc-url": "https://azure.microsoft.com/en-us/services/storage/blobs/",
|
|
40
|
+
"how-to-guide": ["/docs/apache-airflow-providers-microsoft-azure/sensors/wasb_sensors.rst"],
|
|
40
41
|
"logo": "/docs/integration-logos/Blob-Storage.svg",
|
|
41
42
|
"tags": ["azure"],
|
|
42
43
|
},
|
|
@@ -49,6 +50,9 @@ def get_provider_info():
|
|
|
49
50
|
{
|
|
50
51
|
"integration-name": "Microsoft Azure Cosmos DB",
|
|
51
52
|
"external-doc-url": "https://azure.microsoft.com/en-us/services/cosmos-db/",
|
|
53
|
+
"how-to-guide": [
|
|
54
|
+
"/docs/apache-airflow-providers-microsoft-azure/sensors/cosmos_document_sensor.rst"
|
|
55
|
+
],
|
|
52
56
|
"logo": "/docs/integration-logos/Azure-Cosmos-DB.svg",
|
|
53
57
|
"tags": ["azure"],
|
|
54
58
|
},
|
|
@@ -117,7 +121,10 @@ def get_provider_info():
|
|
|
117
121
|
"integration-name": "Microsoft Graph API",
|
|
118
122
|
"external-doc-url": "https://learn.microsoft.com/en-us/graph/use-the-api/",
|
|
119
123
|
"logo": "/docs/integration-logos/Microsoft-Graph-API.png",
|
|
120
|
-
"how-to-guide": [
|
|
124
|
+
"how-to-guide": [
|
|
125
|
+
"/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst",
|
|
126
|
+
"/docs/apache-airflow-providers-microsoft-azure/sensors/msgraph.rst",
|
|
127
|
+
],
|
|
121
128
|
"tags": ["azure"],
|
|
122
129
|
},
|
|
123
130
|
{
|
|
@@ -191,7 +198,10 @@ def get_provider_info():
|
|
|
191
198
|
"python-modules": ["airflow.providers.microsoft.azure.sensors.msgraph"],
|
|
192
199
|
},
|
|
193
200
|
],
|
|
194
|
-
"filesystems": [
|
|
201
|
+
"filesystems": [
|
|
202
|
+
"airflow.providers.microsoft.azure.fs.adls",
|
|
203
|
+
"airflow.providers.microsoft.azure.fs.msgraphfs",
|
|
204
|
+
],
|
|
195
205
|
"hooks": [
|
|
196
206
|
{
|
|
197
207
|
"integration-name": "Microsoft Azure Container Instances",
|
|
@@ -34,11 +34,11 @@ from azure.kusto.data import ClientRequestProperties, KustoClient, KustoConnecti
|
|
|
34
34
|
from azure.kusto.data.exceptions import KustoServiceError
|
|
35
35
|
|
|
36
36
|
from airflow.exceptions import AirflowException
|
|
37
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
37
38
|
from airflow.providers.microsoft.azure.utils import (
|
|
38
39
|
add_managed_identity_connection_widgets,
|
|
39
40
|
get_sync_default_azure_credential,
|
|
40
41
|
)
|
|
41
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
42
42
|
|
|
43
43
|
if TYPE_CHECKING:
|
|
44
44
|
from azure.kusto.data.response import KustoResponseDataSet
|
|
@@ -37,12 +37,12 @@ from azure.servicebus.management import (
|
|
|
37
37
|
SubscriptionProperties,
|
|
38
38
|
)
|
|
39
39
|
|
|
40
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
40
41
|
from airflow.providers.microsoft.azure.utils import (
|
|
41
42
|
add_managed_identity_connection_widgets,
|
|
42
43
|
get_field,
|
|
43
44
|
get_sync_default_azure_credential,
|
|
44
45
|
)
|
|
45
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
46
46
|
|
|
47
47
|
if TYPE_CHECKING:
|
|
48
48
|
import datetime
|
|
@@ -16,17 +16,24 @@
|
|
|
16
16
|
# under the License.
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
-
from typing import Any
|
|
19
|
+
from typing import TYPE_CHECKING, Any
|
|
20
20
|
|
|
21
21
|
from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
|
|
22
22
|
from azure.common.credentials import ServicePrincipalCredentials
|
|
23
|
+
from azure.identity import ClientSecretCredential, DefaultAzureCredential
|
|
23
24
|
|
|
24
25
|
from airflow.exceptions import AirflowException
|
|
26
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
25
27
|
from airflow.providers.microsoft.azure.utils import (
|
|
26
28
|
AzureIdentityCredentialAdapter,
|
|
27
29
|
add_managed_identity_connection_widgets,
|
|
30
|
+
get_sync_default_azure_credential,
|
|
28
31
|
)
|
|
29
|
-
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from azure.core.credentials import AccessToken
|
|
35
|
+
|
|
36
|
+
from airflow.sdk import Connection
|
|
30
37
|
|
|
31
38
|
|
|
32
39
|
class AzureBaseHook(BaseHook):
|
|
@@ -85,7 +92,7 @@ class AzureBaseHook(BaseHook):
|
|
|
85
92
|
},
|
|
86
93
|
}
|
|
87
94
|
|
|
88
|
-
def __init__(self, sdk_client: Any, conn_id: str = "azure_default"):
|
|
95
|
+
def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"):
|
|
89
96
|
self.sdk_client = sdk_client
|
|
90
97
|
self.conn_id = conn_id
|
|
91
98
|
super().__init__()
|
|
@@ -96,8 +103,9 @@ class AzureBaseHook(BaseHook):
|
|
|
96
103
|
|
|
97
104
|
:return: the authenticated client.
|
|
98
105
|
"""
|
|
106
|
+
if not self.sdk_client:
|
|
107
|
+
raise ValueError("`sdk_client` must be provided to AzureBaseHook to use `get_conn` method.")
|
|
99
108
|
conn = self.get_connection(self.conn_id)
|
|
100
|
-
tenant = conn.extra_dejson.get("tenantId")
|
|
101
109
|
subscription_id = conn.extra_dejson.get("subscriptionId")
|
|
102
110
|
key_path = conn.extra_dejson.get("key_path")
|
|
103
111
|
if key_path:
|
|
@@ -111,22 +119,90 @@ class AzureBaseHook(BaseHook):
|
|
|
111
119
|
self.log.info("Getting connection using a JSON config.")
|
|
112
120
|
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)
|
|
113
121
|
|
|
114
|
-
credentials
|
|
122
|
+
credentials = self.get_credential(conn=conn)
|
|
123
|
+
|
|
124
|
+
return self.sdk_client(
|
|
125
|
+
credentials=credentials,
|
|
126
|
+
subscription_id=subscription_id,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def get_credential(
|
|
130
|
+
self, *, conn: Connection | None = None
|
|
131
|
+
) -> (
|
|
132
|
+
ServicePrincipalCredentials
|
|
133
|
+
| AzureIdentityCredentialAdapter
|
|
134
|
+
| ClientSecretCredential
|
|
135
|
+
| DefaultAzureCredential
|
|
136
|
+
):
|
|
137
|
+
"""
|
|
138
|
+
Get Azure credential object for the connection.
|
|
139
|
+
|
|
140
|
+
Azure Identity based credential object (``ClientSecretCredential``, ``DefaultAzureCredential``) can be used to get OAuth token using ``get_token`` method.
|
|
141
|
+
Older Credential objects (``ServicePrincipalCredentials``, ``AzureIdentityCredentialAdapter``) are supported for backward compatibility.
|
|
142
|
+
|
|
143
|
+
:return: The Azure credential object
|
|
144
|
+
"""
|
|
145
|
+
if not conn:
|
|
146
|
+
conn = self.get_connection(self.conn_id)
|
|
147
|
+
tenant = conn.extra_dejson.get("tenantId")
|
|
148
|
+
credential: (
|
|
149
|
+
ServicePrincipalCredentials
|
|
150
|
+
| AzureIdentityCredentialAdapter
|
|
151
|
+
| ClientSecretCredential
|
|
152
|
+
| DefaultAzureCredential
|
|
153
|
+
)
|
|
115
154
|
if all([conn.login, conn.password, tenant]):
|
|
116
|
-
self.
|
|
117
|
-
credentials = ServicePrincipalCredentials(
|
|
118
|
-
client_id=conn.login, secret=conn.password, tenant=tenant
|
|
119
|
-
)
|
|
155
|
+
credential = self._get_client_secret_credential(conn)
|
|
120
156
|
else:
|
|
121
|
-
self.
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
157
|
+
credential = self._get_default_azure_credential(conn)
|
|
158
|
+
return credential
|
|
159
|
+
|
|
160
|
+
def _get_client_secret_credential(
|
|
161
|
+
self, conn: Connection
|
|
162
|
+
) -> ServicePrincipalCredentials | ClientSecretCredential:
|
|
163
|
+
self.log.info("Getting credentials using specific credentials and subscription_id.")
|
|
164
|
+
extra_dejson = conn.extra_dejson
|
|
165
|
+
tenant = extra_dejson.get("tenantId")
|
|
166
|
+
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
|
|
167
|
+
if use_azure_identity_object:
|
|
168
|
+
return ClientSecretCredential(
|
|
169
|
+
client_id=conn.login, # type: ignore[arg-type]
|
|
170
|
+
client_secret=conn.password, # type: ignore[arg-type]
|
|
171
|
+
tenant_id=tenant, # type: ignore[arg-type]
|
|
172
|
+
)
|
|
173
|
+
return ServicePrincipalCredentials(client_id=conn.login, secret=conn.password, tenant=tenant)
|
|
174
|
+
|
|
175
|
+
def _get_default_azure_credential(
|
|
176
|
+
self, conn: Connection
|
|
177
|
+
) -> DefaultAzureCredential | AzureIdentityCredentialAdapter:
|
|
178
|
+
self.log.info("Using DefaultAzureCredential as credential")
|
|
179
|
+
extra_dejson = conn.extra_dejson
|
|
180
|
+
managed_identity_client_id = extra_dejson.get("managed_identity_client_id")
|
|
181
|
+
workload_identity_tenant_id = extra_dejson.get("workload_identity_tenant_id")
|
|
182
|
+
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
|
|
183
|
+
if use_azure_identity_object:
|
|
184
|
+
return get_sync_default_azure_credential(
|
|
125
185
|
managed_identity_client_id=managed_identity_client_id,
|
|
126
186
|
workload_identity_tenant_id=workload_identity_tenant_id,
|
|
127
187
|
)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
subscription_id=subscription_id,
|
|
188
|
+
return AzureIdentityCredentialAdapter(
|
|
189
|
+
managed_identity_client_id=managed_identity_client_id,
|
|
190
|
+
workload_identity_tenant_id=workload_identity_tenant_id,
|
|
132
191
|
)
|
|
192
|
+
|
|
193
|
+
def get_token(self, *scopes, **kwargs) -> AccessToken:
|
|
194
|
+
"""
|
|
195
|
+
Request an access token for `scopes`.
|
|
196
|
+
|
|
197
|
+
To use this method, set `use_azure_identity_object: True` in the connection extra field.
|
|
198
|
+
ServicePrincipalCredentials and AzureIdentityCredentialAdapter don't support `get_token` method.
|
|
199
|
+
"""
|
|
200
|
+
credential = self.get_credential()
|
|
201
|
+
if isinstance(credential, ServicePrincipalCredentials) or isinstance(
|
|
202
|
+
credential, AzureIdentityCredentialAdapter
|
|
203
|
+
):
|
|
204
|
+
raise AttributeError(
|
|
205
|
+
"ServicePrincipalCredentials and AzureIdentityCredentialAdapter don't support get_token method. "
|
|
206
|
+
"Please set `use_azure_identity_object: True` in the connection extra field to use credential that support get_token method."
|
|
207
|
+
)
|
|
208
|
+
return credential.get_token(*scopes, **kwargs)
|
|
@@ -25,12 +25,12 @@ from typing import TYPE_CHECKING, Any
|
|
|
25
25
|
from azure.batch import BatchServiceClient, batch_auth, models as batch_models
|
|
26
26
|
|
|
27
27
|
from airflow.exceptions import AirflowException
|
|
28
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
28
29
|
from airflow.providers.microsoft.azure.utils import (
|
|
29
30
|
AzureIdentityCredentialAdapter,
|
|
30
31
|
add_managed_identity_connection_widgets,
|
|
31
32
|
get_field,
|
|
32
33
|
)
|
|
33
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
34
34
|
from airflow.utils import timezone
|
|
35
35
|
|
|
36
36
|
if TYPE_CHECKING:
|
|
@@ -25,12 +25,12 @@ from typing import Any, cast
|
|
|
25
25
|
from azure.mgmt.containerinstance.models import ImageRegistryCredential
|
|
26
26
|
from azure.mgmt.containerregistry import ContainerRegistryManagementClient
|
|
27
27
|
|
|
28
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
28
29
|
from airflow.providers.microsoft.azure.utils import (
|
|
29
30
|
add_managed_identity_connection_widgets,
|
|
30
31
|
get_field,
|
|
31
32
|
get_sync_default_azure_credential,
|
|
32
33
|
)
|
|
33
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class AzureContainerRegistryHook(BaseHook):
|
|
@@ -21,12 +21,12 @@ from typing import Any, cast
|
|
|
21
21
|
from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
|
|
22
22
|
from azure.mgmt.storage import StorageManagementClient
|
|
23
23
|
|
|
24
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
24
25
|
from airflow.providers.microsoft.azure.utils import (
|
|
25
26
|
add_managed_identity_connection_widgets,
|
|
26
27
|
get_field,
|
|
27
28
|
get_sync_default_azure_credential,
|
|
28
29
|
)
|
|
29
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class AzureContainerVolumeHook(BaseHook):
|
|
@@ -36,12 +36,12 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
|
|
|
36
36
|
from azure.mgmt.cosmosdb import CosmosDBManagementClient
|
|
37
37
|
|
|
38
38
|
from airflow.exceptions import AirflowBadRequest, AirflowException
|
|
39
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
39
40
|
from airflow.providers.microsoft.azure.utils import (
|
|
40
41
|
add_managed_identity_connection_widgets,
|
|
41
42
|
get_field,
|
|
42
43
|
get_sync_default_azure_credential,
|
|
43
44
|
)
|
|
44
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
45
45
|
|
|
46
46
|
if TYPE_CHECKING:
|
|
47
47
|
PartitionKeyType = str | list[str]
|
|
@@ -49,12 +49,12 @@ from azure.mgmt.datafactory import DataFactoryManagementClient
|
|
|
49
49
|
from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient
|
|
50
50
|
|
|
51
51
|
from airflow.exceptions import AirflowException
|
|
52
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
52
53
|
from airflow.providers.microsoft.azure.utils import (
|
|
53
54
|
add_managed_identity_connection_widgets,
|
|
54
55
|
get_async_default_azure_credential,
|
|
55
56
|
get_sync_default_azure_credential,
|
|
56
57
|
)
|
|
57
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
58
58
|
|
|
59
59
|
if TYPE_CHECKING:
|
|
60
60
|
from azure.core.polling import LROPoller
|
|
@@ -1214,7 +1214,4 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
|
|
|
1214
1214
|
:param config: Extra parameters for the ADF client.
|
|
1215
1215
|
"""
|
|
1216
1216
|
client = await self.get_async_conn()
|
|
1217
|
-
|
|
1218
|
-
await client.pipeline_runs.cancel(resource_group_name, factory_name, run_id)
|
|
1219
|
-
except Exception as e:
|
|
1220
|
-
raise AirflowException(e)
|
|
1217
|
+
await client.pipeline_runs.cancel(resource_group_name, factory_name, run_id)
|
|
@@ -33,13 +33,13 @@ from azure.storage.filedatalake import (
|
|
|
33
33
|
)
|
|
34
34
|
|
|
35
35
|
from airflow.exceptions import AirflowException
|
|
36
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
36
37
|
from airflow.providers.microsoft.azure.utils import (
|
|
37
38
|
AzureIdentityCredentialAdapter,
|
|
38
39
|
add_managed_identity_connection_widgets,
|
|
39
40
|
get_field,
|
|
40
41
|
get_sync_default_azure_credential,
|
|
41
42
|
)
|
|
42
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
43
43
|
|
|
44
44
|
Credentials = ClientSecretCredential | AzureIdentityCredentialAdapter | DefaultAzureCredential
|
|
45
45
|
|
|
@@ -21,11 +21,11 @@ from typing import IO, Any
|
|
|
21
21
|
|
|
22
22
|
from azure.storage.fileshare import FileProperties, ShareDirectoryClient, ShareFileClient, ShareServiceClient
|
|
23
23
|
|
|
24
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
24
25
|
from airflow.providers.microsoft.azure.utils import (
|
|
25
26
|
add_managed_identity_connection_widgets,
|
|
26
27
|
get_sync_default_azure_credential,
|
|
27
28
|
)
|
|
28
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class AzureFileShareHook(BaseHook):
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
# under the License.
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
+
import asyncio
|
|
20
21
|
import json
|
|
21
22
|
import warnings
|
|
22
23
|
from ast import literal_eval
|
|
@@ -28,7 +29,6 @@ from typing import TYPE_CHECKING, Any, cast
|
|
|
28
29
|
from urllib.parse import quote, urljoin, urlparse
|
|
29
30
|
|
|
30
31
|
import httpx
|
|
31
|
-
from asgiref.sync import sync_to_async
|
|
32
32
|
from azure.identity import CertificateCredential, ClientSecretCredential
|
|
33
33
|
from httpx import AsyncHTTPTransport, Response, Timeout
|
|
34
34
|
from kiota_abstractions.api_error import APIError
|
|
@@ -53,7 +53,7 @@ from airflow.exceptions import (
|
|
|
53
53
|
AirflowNotFoundException,
|
|
54
54
|
AirflowProviderDeprecationWarning,
|
|
55
55
|
)
|
|
56
|
-
from airflow.providers.
|
|
56
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
57
57
|
|
|
58
58
|
if TYPE_CHECKING:
|
|
59
59
|
from azure.identity._internal.client_credential_base import ClientCredentialBase
|
|
@@ -61,7 +61,7 @@ if TYPE_CHECKING:
|
|
|
61
61
|
from kiota_abstractions.response_handler import NativeResponseType
|
|
62
62
|
from kiota_abstractions.serialization import ParsableFactory
|
|
63
63
|
|
|
64
|
-
from airflow.
|
|
64
|
+
from airflow.providers.common.compat.sdk import Connection
|
|
65
65
|
|
|
66
66
|
|
|
67
67
|
class DefaultResponseHandler(ResponseHandler):
|
|
@@ -152,6 +152,7 @@ class KiotaRequestAdapterHook(BaseHook):
|
|
|
152
152
|
|
|
153
153
|
return {
|
|
154
154
|
"tenant_id": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()),
|
|
155
|
+
"drive_id": StringField(lazy_gettext("Drive ID"), widget=BS3TextFieldWidget()),
|
|
155
156
|
"api_version": StringField(
|
|
156
157
|
lazy_gettext("API Version"), widget=BS3TextFieldWidget(), default=APIVersion.v1.value
|
|
157
158
|
),
|
|
@@ -250,7 +251,9 @@ class KiotaRequestAdapterHook(BaseHook):
|
|
|
250
251
|
def _build_request_adapter(self, connection) -> tuple[str, RequestAdapter]:
|
|
251
252
|
client_id = connection.login
|
|
252
253
|
client_secret = connection.password
|
|
253
|
-
|
|
254
|
+
# TODO (#54350): do not use connection.extra_dejson until it's fixed in Airflow otherwise expect:
|
|
255
|
+
# RuntimeError: You cannot use AsyncToSync in the same thread as an async event loop.
|
|
256
|
+
config = json.loads(connection.extra) if connection.extra else {}
|
|
254
257
|
api_version = self.get_api_version(config)
|
|
255
258
|
host = self.get_host(connection) # type: ignore[arg-type]
|
|
256
259
|
base_url = self.get_base_url(host, api_version, config)
|
|
@@ -342,6 +345,15 @@ class KiotaRequestAdapterHook(BaseHook):
|
|
|
342
345
|
self.api_version = api_version
|
|
343
346
|
return request_adapter
|
|
344
347
|
|
|
348
|
+
@classmethod
|
|
349
|
+
async def get_async_connection(cls, conn_id: str) -> Connection:
|
|
350
|
+
if hasattr(BaseHook, "aget_connection"):
|
|
351
|
+
return await BaseHook.aget_connection(conn_id=conn_id)
|
|
352
|
+
|
|
353
|
+
from asgiref.sync import sync_to_async
|
|
354
|
+
|
|
355
|
+
return await sync_to_async(BaseHook.get_connection)(conn_id=conn_id)
|
|
356
|
+
|
|
345
357
|
async def get_async_conn(self) -> RequestAdapter:
|
|
346
358
|
"""Initiate a new RequestAdapter connection asynchronously."""
|
|
347
359
|
if not self.conn_id:
|
|
@@ -350,7 +362,7 @@ class KiotaRequestAdapterHook(BaseHook):
|
|
|
350
362
|
api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))
|
|
351
363
|
|
|
352
364
|
if not request_adapter:
|
|
353
|
-
connection = await
|
|
365
|
+
connection = await self.get_async_connection(conn_id=self.conn_id)
|
|
354
366
|
api_version, request_adapter = self._build_request_adapter(connection)
|
|
355
367
|
self.api_version = api_version
|
|
356
368
|
return request_adapter
|
|
@@ -417,7 +429,7 @@ class KiotaRequestAdapterHook(BaseHook):
|
|
|
417
429
|
def test_connection(self):
|
|
418
430
|
"""Test HTTP Connection."""
|
|
419
431
|
try:
|
|
420
|
-
self.run()
|
|
432
|
+
asyncio.run(self.run())
|
|
421
433
|
return True, "Connection successfully tested"
|
|
422
434
|
except Exception as e:
|
|
423
435
|
return False, str(e)
|
|
@@ -25,12 +25,12 @@ from azure.synapse.artifacts import ArtifactsClient
|
|
|
25
25
|
from azure.synapse.spark import SparkClient
|
|
26
26
|
|
|
27
27
|
from airflow.exceptions import AirflowException, AirflowTaskTimeout
|
|
28
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
28
29
|
from airflow.providers.microsoft.azure.utils import (
|
|
29
30
|
add_managed_identity_connection_widgets,
|
|
30
31
|
get_field,
|
|
31
32
|
get_sync_default_azure_credential,
|
|
32
33
|
)
|
|
33
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
34
34
|
|
|
35
35
|
if TYPE_CHECKING:
|
|
36
36
|
from azure.synapse.artifacts.models import CreateRunResponse, PipelineRun
|
|
@@ -45,13 +45,13 @@ from azure.storage.blob.aio import (
|
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
from airflow.exceptions import AirflowException
|
|
48
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
48
49
|
from airflow.providers.microsoft.azure.utils import (
|
|
49
50
|
add_managed_identity_connection_widgets,
|
|
50
51
|
get_async_default_azure_credential,
|
|
51
52
|
get_sync_default_azure_credential,
|
|
52
53
|
parse_blob_account_url,
|
|
53
54
|
)
|
|
54
|
-
from airflow.providers.microsoft.azure.version_compat import BaseHook
|
|
55
55
|
|
|
56
56
|
if TYPE_CHECKING:
|
|
57
57
|
from azure.core.credentials import TokenCredential
|
|
@@ -188,9 +188,15 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
|
|
|
188
188
|
base_log_folder: str,
|
|
189
189
|
wasb_log_folder: str,
|
|
190
190
|
wasb_container: str,
|
|
191
|
+
max_bytes: int = 0,
|
|
192
|
+
backup_count: int = 0,
|
|
193
|
+
delay: bool = False,
|
|
191
194
|
**kwargs,
|
|
192
195
|
) -> None:
|
|
193
|
-
|
|
196
|
+
# support log file size handling of FileTaskHandler
|
|
197
|
+
super().__init__(
|
|
198
|
+
base_log_folder=base_log_folder, max_bytes=max_bytes, backup_count=backup_count, delay=delay
|
|
199
|
+
)
|
|
194
200
|
self.handler: logging.FileHandler | None = None
|
|
195
201
|
self.log_relative_path = ""
|
|
196
202
|
self.closed = False
|
|
@@ -19,8 +19,8 @@ from __future__ import annotations
|
|
|
19
19
|
from collections.abc import Iterable, Sequence
|
|
20
20
|
from typing import IO, TYPE_CHECKING, Any, AnyStr
|
|
21
21
|
|
|
22
|
+
from airflow.providers.common.compat.sdk import BaseOperator
|
|
22
23
|
from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook, AzureDataLakeStorageV2Hook
|
|
23
|
-
from airflow.providers.microsoft.azure.version_compat import BaseOperator
|
|
24
24
|
|
|
25
25
|
if TYPE_CHECKING:
|
|
26
26
|
from airflow.utils.context import Context
|
|
@@ -24,8 +24,8 @@ from functools import cached_property
|
|
|
24
24
|
from typing import TYPE_CHECKING
|
|
25
25
|
|
|
26
26
|
from airflow.configuration import conf
|
|
27
|
+
from airflow.providers.common.compat.sdk import BaseOperator
|
|
27
28
|
from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
|
|
28
|
-
from airflow.providers.microsoft.azure.version_compat import BaseOperator
|
|
29
29
|
|
|
30
30
|
if TYPE_CHECKING:
|
|
31
31
|
from azure.kusto.data._models import KustoResultTable
|
|
@@ -20,8 +20,8 @@ from collections.abc import Callable, Sequence
|
|
|
20
20
|
from typing import TYPE_CHECKING, Any
|
|
21
21
|
from uuid import UUID
|
|
22
22
|
|
|
23
|
+
from airflow.providers.common.compat.sdk import BaseOperator
|
|
23
24
|
from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook
|
|
24
|
-
from airflow.providers.microsoft.azure.version_compat import BaseOperator
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
27
|
import datetime
|
|
@@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Any
|
|
|
24
24
|
from azure.batch import models as batch_models
|
|
25
25
|
|
|
26
26
|
from airflow.exceptions import AirflowException
|
|
27
|
+
from airflow.providers.common.compat.sdk import BaseOperator
|
|
27
28
|
from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
|
|
28
|
-
from airflow.providers.microsoft.azure.version_compat import BaseOperator
|
|
29
29
|
|
|
30
30
|
if TYPE_CHECKING:
|
|
31
31
|
from airflow.utils.context import Context
|
|
@@ -40,10 +40,10 @@ from azure.mgmt.containerinstance.models import (
|
|
|
40
40
|
from msrestazure.azure_exceptions import CloudError
|
|
41
41
|
|
|
42
42
|
from airflow.exceptions import AirflowException, AirflowTaskTimeout
|
|
43
|
+
from airflow.providers.common.compat.sdk import BaseOperator
|
|
43
44
|
from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook
|
|
44
45
|
from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook
|
|
45
46
|
from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook
|
|
46
|
-
from airflow.providers.microsoft.azure.version_compat import BaseOperator
|
|
47
47
|
|
|
48
48
|
if TYPE_CHECKING:
|
|
49
49
|
from airflow.utils.context import Context
|