apache-airflow-providers-microsoft-azure 10.0.0__py3-none-any.whl → 10.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/microsoft/azure/__init__.py +5 -8
- airflow/providers/microsoft/azure/fs/adls.py +56 -6
- airflow/providers/microsoft/azure/get_provider_info.py +27 -2
- airflow/providers/microsoft/azure/hooks/msgraph.py +353 -0
- airflow/providers/microsoft/azure/hooks/synapse.py +1 -0
- airflow/providers/microsoft/azure/hooks/wasb.py +3 -31
- airflow/providers/microsoft/azure/operators/container_instances.py +17 -0
- airflow/providers/microsoft/azure/operators/msgraph.py +281 -0
- airflow/providers/microsoft/azure/sensors/msgraph.py +177 -0
- airflow/providers/microsoft/azure/triggers/msgraph.py +231 -0
- airflow/providers/microsoft/azure/utils.py +34 -0
- {apache_airflow_providers_microsoft_azure-10.0.0.dist-info → apache_airflow_providers_microsoft_azure-10.1.0.dist-info}/METADATA +10 -8
- {apache_airflow_providers_microsoft_azure-10.0.0.dist-info → apache_airflow_providers_microsoft_azure-10.1.0.dist-info}/RECORD +15 -12
- airflow/providers/microsoft/azure/serialization/__init__.py +0 -16
- {apache_airflow_providers_microsoft_azure-10.0.0.dist-info → apache_airflow_providers_microsoft_azure-10.1.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_microsoft_azure-10.0.0.dist-info → apache_airflow_providers_microsoft_azure-10.1.0.dist-info}/entry_points.txt +0 -0
@@ -25,18 +25,15 @@ from __future__ import annotations
|
|
25
25
|
|
26
26
|
import packaging.version
|
27
27
|
|
28
|
-
|
28
|
+
from airflow import __version__ as airflow_version
|
29
29
|
|
30
|
-
|
30
|
+
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
|
33
|
-
from airflow import __version__ as airflow_version
|
34
|
-
except ImportError:
|
35
|
-
from airflow.version import version as airflow_version
|
32
|
+
__version__ = "10.1.0"
|
36
33
|
|
37
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
38
|
-
"2.
|
35
|
+
"2.7.0"
|
39
36
|
):
|
40
37
|
raise RuntimeError(
|
41
|
-
f"The package `apache-airflow-providers-microsoft-azure:{__version__}` needs Apache Airflow 2.
|
38
|
+
f"The package `apache-airflow-providers-microsoft-azure:{__version__}` needs Apache Airflow 2.7.0+"
|
42
39
|
)
|
@@ -18,8 +18,10 @@ from __future__ import annotations
|
|
18
18
|
|
19
19
|
from typing import TYPE_CHECKING, Any
|
20
20
|
|
21
|
+
from azure.identity import ClientSecretCredential
|
22
|
+
|
21
23
|
from airflow.hooks.base import BaseHook
|
22
|
-
from airflow.providers.microsoft.azure.utils import get_field
|
24
|
+
from airflow.providers.microsoft.azure.utils import get_field, parse_blob_account_url
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
27
|
from fsspec import AbstractFileSystem
|
@@ -35,13 +37,61 @@ def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) -
|
|
35
37
|
|
36
38
|
conn = BaseHook.get_connection(conn_id)
|
37
39
|
extras = conn.extra_dejson
|
40
|
+
conn_type = conn.conn_type or "azure_data_lake"
|
38
41
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
42
|
+
# connection string always overrides everything else
|
43
|
+
connection_string = get_field(
|
44
|
+
conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="connection_string"
|
45
|
+
)
|
46
|
+
|
47
|
+
if connection_string:
|
48
|
+
return AzureBlobFileSystem(connection_string=connection_string)
|
49
|
+
|
50
|
+
options: dict[str, Any] = {
|
51
|
+
"account_url": parse_blob_account_url(conn.host, conn.login),
|
52
|
+
}
|
53
|
+
|
54
|
+
# mirror handling of custom field "client_secret_auth_config" from extras. Ignore if missing as AzureBlobFileSystem can handle.
|
55
|
+
tenant_id = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="tenant_id")
|
56
|
+
login = conn.login or ""
|
57
|
+
password = conn.password or ""
|
58
|
+
# assumption (from WasbHook) that if tenant_id is set, we want service principal connection
|
59
|
+
if tenant_id:
|
60
|
+
client_secret_auth_config = get_field(
|
61
|
+
conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_secret_auth_config"
|
44
62
|
)
|
63
|
+
if login:
|
64
|
+
options["client_id"] = login
|
65
|
+
if password:
|
66
|
+
options["client_secret"] = password
|
67
|
+
if client_secret_auth_config and login and password:
|
68
|
+
options["credential"] = ClientSecretCredential(
|
69
|
+
tenant_id=tenant_id, client_id=login, client_secret=password, **client_secret_auth_config
|
70
|
+
)
|
71
|
+
|
72
|
+
# if not service principal, then password is taken to be account admin key
|
73
|
+
if tenant_id is None and password:
|
74
|
+
options["account_key"] = password
|
75
|
+
|
76
|
+
# now take any fields from extras and overlay on these
|
77
|
+
# add empty field to remove defaults
|
78
|
+
fields = [
|
79
|
+
"account_name",
|
80
|
+
"account_key",
|
81
|
+
"sas_token",
|
82
|
+
"tenant_id",
|
83
|
+
"managed_identity_client_id",
|
84
|
+
"workload_identity_client_id",
|
85
|
+
"workload_identity_tenant_id",
|
86
|
+
"anon",
|
87
|
+
]
|
88
|
+
for field in fields:
|
89
|
+
value = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name=field)
|
90
|
+
if value is not None:
|
91
|
+
if value == "":
|
92
|
+
options.pop(field, "")
|
93
|
+
else:
|
94
|
+
options[field] = value
|
45
95
|
|
46
96
|
options.update(storage_options or {})
|
47
97
|
|
@@ -28,8 +28,9 @@ def get_provider_info():
|
|
28
28
|
"name": "Microsoft Azure",
|
29
29
|
"description": "`Microsoft Azure <https://azure.microsoft.com/>`__\n",
|
30
30
|
"state": "ready",
|
31
|
-
"source-date-epoch":
|
31
|
+
"source-date-epoch": 1715384449,
|
32
32
|
"versions": [
|
33
|
+
"10.1.0",
|
33
34
|
"10.0.0",
|
34
35
|
"9.0.1",
|
35
36
|
"9.0.0",
|
@@ -83,7 +84,7 @@ def get_provider_info():
|
|
83
84
|
"1.0.0",
|
84
85
|
],
|
85
86
|
"dependencies": [
|
86
|
-
"apache-airflow>=2.
|
87
|
+
"apache-airflow>=2.7.0",
|
87
88
|
"adlfs>=2023.10.0",
|
88
89
|
"azure-batch>=8.0.0",
|
89
90
|
"azure-cosmos>=4.6.0",
|
@@ -105,6 +106,7 @@ def get_provider_info():
|
|
105
106
|
"azure-mgmt-datafactory>=2.0.0",
|
106
107
|
"azure-mgmt-containerregistry>=8.0.0",
|
107
108
|
"azure-mgmt-containerinstance>=9.0.0",
|
109
|
+
"msgraph-core>=1.0.0",
|
108
110
|
],
|
109
111
|
"devel-dependencies": ["pywinrm"],
|
110
112
|
"integrations": [
|
@@ -194,6 +196,13 @@ def get_provider_info():
|
|
194
196
|
"logo": "/integration-logos/azure/Data Lake Storage.svg",
|
195
197
|
"tags": ["azure"],
|
196
198
|
},
|
199
|
+
{
|
200
|
+
"integration-name": "Microsoft Graph API",
|
201
|
+
"external-doc-url": "https://learn.microsoft.com/en-us/graph/use-the-api/",
|
202
|
+
"logo": "/integration-logos/azure/Microsoft-Graph-API.png",
|
203
|
+
"how-to-guide": ["/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst"],
|
204
|
+
"tags": ["azure"],
|
205
|
+
},
|
197
206
|
],
|
198
207
|
"operators": [
|
199
208
|
{
|
@@ -232,6 +241,10 @@ def get_provider_info():
|
|
232
241
|
"integration-name": "Microsoft Azure Synapse",
|
233
242
|
"python-modules": ["airflow.providers.microsoft.azure.operators.synapse"],
|
234
243
|
},
|
244
|
+
{
|
245
|
+
"integration-name": "Microsoft Graph API",
|
246
|
+
"python-modules": ["airflow.providers.microsoft.azure.operators.msgraph"],
|
247
|
+
},
|
235
248
|
],
|
236
249
|
"sensors": [
|
237
250
|
{
|
@@ -246,6 +259,10 @@ def get_provider_info():
|
|
246
259
|
"integration-name": "Microsoft Azure Data Factory",
|
247
260
|
"python-modules": ["airflow.providers.microsoft.azure.sensors.data_factory"],
|
248
261
|
},
|
262
|
+
{
|
263
|
+
"integration-name": "Microsoft Graph API",
|
264
|
+
"python-modules": ["airflow.providers.microsoft.azure.sensors.msgraph"],
|
265
|
+
},
|
249
266
|
],
|
250
267
|
"filesystems": ["airflow.providers.microsoft.azure.fs.adls"],
|
251
268
|
"hooks": [
|
@@ -301,6 +318,10 @@ def get_provider_info():
|
|
301
318
|
"integration-name": "Microsoft Azure Synapse",
|
302
319
|
"python-modules": ["airflow.providers.microsoft.azure.hooks.synapse"],
|
303
320
|
},
|
321
|
+
{
|
322
|
+
"integration-name": "Microsoft Graph API",
|
323
|
+
"python-modules": ["airflow.providers.microsoft.azure.hooks.msgraph"],
|
324
|
+
},
|
304
325
|
],
|
305
326
|
"triggers": [
|
306
327
|
{
|
@@ -311,6 +332,10 @@ def get_provider_info():
|
|
311
332
|
"integration-name": "Microsoft Azure Blob Storage",
|
312
333
|
"python-modules": ["airflow.providers.microsoft.azure.triggers.wasb"],
|
313
334
|
},
|
335
|
+
{
|
336
|
+
"integration-name": "Microsoft Graph API",
|
337
|
+
"python-modules": ["airflow.providers.microsoft.azure.triggers.msgraph"],
|
338
|
+
},
|
314
339
|
],
|
315
340
|
"transfers": [
|
316
341
|
{
|
@@ -0,0 +1,353 @@
|
|
1
|
+
#
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
4
|
+
# distributed with this work for additional information
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
7
|
+
# "License"); you may not use this file except in compliance
|
8
|
+
# with the License. You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
13
|
+
# software distributed under the License is distributed on an
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
15
|
+
# KIND, either express or implied. See the License for the
|
16
|
+
# specific language governing permissions and limitations
|
17
|
+
# under the License.
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import json
|
21
|
+
from contextlib import suppress
|
22
|
+
from http import HTTPStatus
|
23
|
+
from io import BytesIO
|
24
|
+
from json import JSONDecodeError
|
25
|
+
from typing import TYPE_CHECKING, Any
|
26
|
+
from urllib.parse import quote, urljoin, urlparse
|
27
|
+
|
28
|
+
import httpx
|
29
|
+
from azure.identity import ClientSecretCredential
|
30
|
+
from httpx import Timeout
|
31
|
+
from kiota_abstractions.api_error import APIError
|
32
|
+
from kiota_abstractions.method import Method
|
33
|
+
from kiota_abstractions.request_information import RequestInformation
|
34
|
+
from kiota_abstractions.response_handler import ResponseHandler
|
35
|
+
from kiota_authentication_azure.azure_identity_authentication_provider import (
|
36
|
+
AzureIdentityAuthenticationProvider,
|
37
|
+
)
|
38
|
+
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
|
39
|
+
from kiota_http.middleware.options import ResponseHandlerOption
|
40
|
+
from msgraph_core import APIVersion, GraphClientFactory
|
41
|
+
from msgraph_core._enums import NationalClouds
|
42
|
+
|
43
|
+
from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException
|
44
|
+
from airflow.hooks.base import BaseHook
|
45
|
+
|
46
|
+
if TYPE_CHECKING:
|
47
|
+
from kiota_abstractions.request_adapter import RequestAdapter
|
48
|
+
from kiota_abstractions.request_information import QueryParams
|
49
|
+
from kiota_abstractions.response_handler import NativeResponseType
|
50
|
+
from kiota_abstractions.serialization import ParsableFactory
|
51
|
+
from kiota_http.httpx_request_adapter import ResponseType
|
52
|
+
|
53
|
+
from airflow.models import Connection
|
54
|
+
|
55
|
+
|
56
|
+
class DefaultResponseHandler(ResponseHandler):
|
57
|
+
"""DefaultResponseHandler returns JSON payload or content in bytes or response headers."""
|
58
|
+
|
59
|
+
@staticmethod
|
60
|
+
def get_value(response: NativeResponseType) -> Any:
|
61
|
+
with suppress(JSONDecodeError):
|
62
|
+
return response.json()
|
63
|
+
content = response.content
|
64
|
+
if not content:
|
65
|
+
return {key: value for key, value in response.headers.items()}
|
66
|
+
return content
|
67
|
+
|
68
|
+
async def handle_response_async(
|
69
|
+
self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
|
70
|
+
) -> Any:
|
71
|
+
"""
|
72
|
+
Invoke this callback method when a response is received.
|
73
|
+
|
74
|
+
param response: The type of the native response object.
|
75
|
+
param error_map: The error dict to use in case of a failed request.
|
76
|
+
"""
|
77
|
+
value = self.get_value(response)
|
78
|
+
if response.status_code not in {200, 201, 202, 204, 302}:
|
79
|
+
message = value or response.reason_phrase
|
80
|
+
status_code = HTTPStatus(response.status_code)
|
81
|
+
if status_code == HTTPStatus.BAD_REQUEST:
|
82
|
+
raise AirflowBadRequest(message)
|
83
|
+
elif status_code == HTTPStatus.NOT_FOUND:
|
84
|
+
raise AirflowNotFoundException(message)
|
85
|
+
raise AirflowException(message)
|
86
|
+
return value
|
87
|
+
|
88
|
+
|
89
|
+
class KiotaRequestAdapterHook(BaseHook):
|
90
|
+
"""
|
91
|
+
A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter.
|
92
|
+
|
93
|
+
https://github.com/microsoftgraph/msgraph-sdk-python-core
|
94
|
+
|
95
|
+
:param conn_id: The HTTP Connection ID to run the trigger against.
|
96
|
+
:param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None).
|
97
|
+
When no timeout is specified or set to None then no HTTP timeout is applied on each request.
|
98
|
+
:param proxies: A Dict defining the HTTP proxies to be used (default is None).
|
99
|
+
:param api_version: The API version of the Microsoft Graph API to be used (default is v1).
|
100
|
+
You can pass an enum named APIVersion which has 2 possible members v1 and beta,
|
101
|
+
or you can pass a string as "v1.0" or "beta".
|
102
|
+
"""
|
103
|
+
|
104
|
+
DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
|
105
|
+
cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
|
106
|
+
default_conn_name: str = "msgraph_default"
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
conn_id: str = default_conn_name,
|
111
|
+
timeout: float | None = None,
|
112
|
+
proxies: dict | None = None,
|
113
|
+
api_version: APIVersion | str | None = None,
|
114
|
+
):
|
115
|
+
super().__init__()
|
116
|
+
self.conn_id = conn_id
|
117
|
+
self.timeout = timeout
|
118
|
+
self.proxies = proxies
|
119
|
+
self._api_version = self.resolve_api_version_from_value(api_version)
|
120
|
+
|
121
|
+
@property
|
122
|
+
def api_version(self) -> APIVersion:
|
123
|
+
self.get_conn() # Make sure config has been loaded through get_conn to have correct api version!
|
124
|
+
return self._api_version
|
125
|
+
|
126
|
+
@staticmethod
|
127
|
+
def resolve_api_version_from_value(
|
128
|
+
api_version: APIVersion | str, default: APIVersion | None = None
|
129
|
+
) -> APIVersion:
|
130
|
+
if isinstance(api_version, APIVersion):
|
131
|
+
return api_version
|
132
|
+
return next(
|
133
|
+
filter(lambda version: version.value == api_version, APIVersion),
|
134
|
+
default,
|
135
|
+
)
|
136
|
+
|
137
|
+
def get_api_version(self, config: dict) -> APIVersion:
|
138
|
+
if self._api_version is None:
|
139
|
+
return self.resolve_api_version_from_value(
|
140
|
+
api_version=config.get("api_version"), default=APIVersion.v1
|
141
|
+
)
|
142
|
+
return self._api_version
|
143
|
+
|
144
|
+
@staticmethod
|
145
|
+
def get_host(connection: Connection) -> str:
|
146
|
+
if connection.schema and connection.host:
|
147
|
+
return f"{connection.schema}://{connection.host}"
|
148
|
+
return NationalClouds.Global.value
|
149
|
+
|
150
|
+
@staticmethod
|
151
|
+
def format_no_proxy_url(url: str) -> str:
|
152
|
+
if "://" not in url:
|
153
|
+
url = f"all://{url}"
|
154
|
+
return url
|
155
|
+
|
156
|
+
@classmethod
|
157
|
+
def to_httpx_proxies(cls, proxies: dict) -> dict:
|
158
|
+
proxies = proxies.copy()
|
159
|
+
if proxies.get("http"):
|
160
|
+
proxies["http://"] = proxies.pop("http")
|
161
|
+
if proxies.get("https"):
|
162
|
+
proxies["https://"] = proxies.pop("https")
|
163
|
+
if proxies.get("no"):
|
164
|
+
for url in proxies.pop("no", "").split(","):
|
165
|
+
proxies[cls.format_no_proxy_url(url.strip())] = None
|
166
|
+
return proxies
|
167
|
+
|
168
|
+
def to_msal_proxies(self, authority: str | None, proxies: dict):
|
169
|
+
self.log.info("authority: %s", authority)
|
170
|
+
if authority:
|
171
|
+
no_proxies = proxies.get("no")
|
172
|
+
self.log.info("no_proxies: %s", no_proxies)
|
173
|
+
if no_proxies:
|
174
|
+
for url in no_proxies.split(","):
|
175
|
+
self.log.info("url: %s", url)
|
176
|
+
domain_name = urlparse(url).path.replace("*", "")
|
177
|
+
self.log.info("domain_name: %s", domain_name)
|
178
|
+
if authority.endswith(domain_name):
|
179
|
+
return None
|
180
|
+
return proxies
|
181
|
+
|
182
|
+
def get_conn(self) -> RequestAdapter:
|
183
|
+
if not self.conn_id:
|
184
|
+
raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")
|
185
|
+
|
186
|
+
api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))
|
187
|
+
|
188
|
+
if not request_adapter:
|
189
|
+
connection = self.get_connection(conn_id=self.conn_id)
|
190
|
+
client_id = connection.login
|
191
|
+
client_secret = connection.password
|
192
|
+
config = connection.extra_dejson if connection.extra else {}
|
193
|
+
tenant_id = config.get("tenant_id")
|
194
|
+
api_version = self.get_api_version(config)
|
195
|
+
host = self.get_host(connection)
|
196
|
+
base_url = config.get("base_url", urljoin(host, api_version.value))
|
197
|
+
authority = config.get("authority")
|
198
|
+
proxies = self.proxies or config.get("proxies", {})
|
199
|
+
msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
|
200
|
+
httpx_proxies = self.to_httpx_proxies(proxies=proxies)
|
201
|
+
scopes = config.get("scopes", ["https://graph.microsoft.com/.default"])
|
202
|
+
verify = config.get("verify", True)
|
203
|
+
trust_env = config.get("trust_env", False)
|
204
|
+
disable_instance_discovery = config.get("disable_instance_discovery", False)
|
205
|
+
allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")
|
206
|
+
|
207
|
+
self.log.info(
|
208
|
+
"Creating Microsoft Graph SDK client %s for conn_id: %s",
|
209
|
+
api_version.value,
|
210
|
+
self.conn_id,
|
211
|
+
)
|
212
|
+
self.log.info("Host: %s", host)
|
213
|
+
self.log.info("Base URL: %s", base_url)
|
214
|
+
self.log.info("Tenant id: %s", tenant_id)
|
215
|
+
self.log.info("Client id: %s", client_id)
|
216
|
+
self.log.info("Client secret: %s", client_secret)
|
217
|
+
self.log.info("API version: %s", api_version.value)
|
218
|
+
self.log.info("Scope: %s", scopes)
|
219
|
+
self.log.info("Verify: %s", verify)
|
220
|
+
self.log.info("Timeout: %s", self.timeout)
|
221
|
+
self.log.info("Trust env: %s", trust_env)
|
222
|
+
self.log.info("Authority: %s", authority)
|
223
|
+
self.log.info("Disable instance discovery: %s", disable_instance_discovery)
|
224
|
+
self.log.info("Allowed hosts: %s", allowed_hosts)
|
225
|
+
self.log.info("Proxies: %s", proxies)
|
226
|
+
self.log.info("MSAL Proxies: %s", msal_proxies)
|
227
|
+
self.log.info("HTTPX Proxies: %s", httpx_proxies)
|
228
|
+
credentials = ClientSecretCredential(
|
229
|
+
tenant_id=tenant_id, # type: ignore
|
230
|
+
client_id=connection.login,
|
231
|
+
client_secret=connection.password,
|
232
|
+
authority=authority,
|
233
|
+
proxies=msal_proxies,
|
234
|
+
disable_instance_discovery=disable_instance_discovery,
|
235
|
+
connection_verify=verify,
|
236
|
+
)
|
237
|
+
http_client = GraphClientFactory.create_with_default_middleware(
|
238
|
+
api_version=api_version,
|
239
|
+
client=httpx.AsyncClient(
|
240
|
+
proxies=httpx_proxies,
|
241
|
+
timeout=Timeout(timeout=self.timeout),
|
242
|
+
verify=verify,
|
243
|
+
trust_env=trust_env,
|
244
|
+
),
|
245
|
+
host=host,
|
246
|
+
)
|
247
|
+
auth_provider = AzureIdentityAuthenticationProvider(
|
248
|
+
credentials=credentials,
|
249
|
+
scopes=scopes,
|
250
|
+
allowed_hosts=allowed_hosts,
|
251
|
+
)
|
252
|
+
request_adapter = HttpxRequestAdapter(
|
253
|
+
authentication_provider=auth_provider,
|
254
|
+
http_client=http_client,
|
255
|
+
base_url=base_url,
|
256
|
+
)
|
257
|
+
self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
|
258
|
+
self._api_version = api_version
|
259
|
+
return request_adapter
|
260
|
+
|
261
|
+
def test_connection(self):
|
262
|
+
"""Test HTTP Connection."""
|
263
|
+
try:
|
264
|
+
self.run()
|
265
|
+
return True, "Connection successfully tested"
|
266
|
+
except Exception as e:
|
267
|
+
return False, str(e)
|
268
|
+
|
269
|
+
async def run(
|
270
|
+
self,
|
271
|
+
url: str = "",
|
272
|
+
response_type: ResponseType | None = None,
|
273
|
+
path_parameters: dict[str, Any] | None = None,
|
274
|
+
method: str = "GET",
|
275
|
+
query_parameters: dict[str, QueryParams] | None = None,
|
276
|
+
headers: dict[str, str] | None = None,
|
277
|
+
data: dict[str, Any] | str | BytesIO | None = None,
|
278
|
+
):
|
279
|
+
self.log.info("Executing url '%s' as '%s'", url, method)
|
280
|
+
|
281
|
+
response = await self.get_conn().send_primitive_async(
|
282
|
+
request_info=self.request_information(
|
283
|
+
url=url,
|
284
|
+
response_type=response_type,
|
285
|
+
path_parameters=path_parameters,
|
286
|
+
method=method,
|
287
|
+
query_parameters=query_parameters,
|
288
|
+
headers=headers,
|
289
|
+
data=data,
|
290
|
+
),
|
291
|
+
response_type=response_type,
|
292
|
+
error_map=self.error_mapping(),
|
293
|
+
)
|
294
|
+
|
295
|
+
self.log.info("response: %s", response)
|
296
|
+
|
297
|
+
return response
|
298
|
+
|
299
|
+
def request_information(
|
300
|
+
self,
|
301
|
+
url: str,
|
302
|
+
response_type: ResponseType | None = None,
|
303
|
+
path_parameters: dict[str, Any] | None = None,
|
304
|
+
method: str = "GET",
|
305
|
+
query_parameters: dict[str, QueryParams] | None = None,
|
306
|
+
headers: dict[str, str] | None = None,
|
307
|
+
data: dict[str, Any] | str | BytesIO | None = None,
|
308
|
+
) -> RequestInformation:
|
309
|
+
request_information = RequestInformation()
|
310
|
+
request_information.path_parameters = path_parameters or {}
|
311
|
+
request_information.http_method = Method(method.strip().upper())
|
312
|
+
request_information.query_parameters = self.encoded_query_parameters(query_parameters)
|
313
|
+
if url.startswith("http"):
|
314
|
+
request_information.url = url
|
315
|
+
elif request_information.query_parameters.keys():
|
316
|
+
query = ",".join(request_information.query_parameters.keys())
|
317
|
+
request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}{{?{query}}}"
|
318
|
+
else:
|
319
|
+
request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}"
|
320
|
+
if not response_type:
|
321
|
+
request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption(
|
322
|
+
response_handler=DefaultResponseHandler()
|
323
|
+
)
|
324
|
+
headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS
|
325
|
+
for header_name, header_value in headers.items():
|
326
|
+
request_information.headers.try_add(header_name=header_name, header_value=header_value)
|
327
|
+
if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str):
|
328
|
+
request_information.content = data
|
329
|
+
elif data:
|
330
|
+
request_information.headers.try_add(
|
331
|
+
header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json"
|
332
|
+
)
|
333
|
+
request_information.content = json.dumps(data).encode("utf-8")
|
334
|
+
return request_information
|
335
|
+
|
336
|
+
@staticmethod
|
337
|
+
def normalize_url(url: str) -> str | None:
|
338
|
+
if url.startswith("/"):
|
339
|
+
return url.replace("/", "", 1)
|
340
|
+
return url
|
341
|
+
|
342
|
+
@staticmethod
|
343
|
+
def encoded_query_parameters(query_parameters) -> dict:
|
344
|
+
if query_parameters:
|
345
|
+
return {quote(key): value for key, value in query_parameters.items()}
|
346
|
+
return {}
|
347
|
+
|
348
|
+
@staticmethod
|
349
|
+
def error_mapping() -> dict[str, ParsableFactory | None]:
|
350
|
+
return {
|
351
|
+
"4XX": APIError,
|
352
|
+
"5XX": APIError,
|
353
|
+
}
|
@@ -312,6 +312,7 @@ class AzureSynapsePipelineHook(BaseAzureSynapseHook):
|
|
312
312
|
warnings.warn(
|
313
313
|
"The usage of `default_conn_name=azure_synapse_connection` is deprecated and will be removed in future. Please update your code to use the new default connection name: `default_conn_name=azure_synapse_default`. ",
|
314
314
|
AirflowProviderDeprecationWarning,
|
315
|
+
stacklevel=2,
|
315
316
|
)
|
316
317
|
self._conn: ArtifactsClient | None = None
|
317
318
|
self.azure_synapse_workspace_dev_endpoint = azure_synapse_workspace_dev_endpoint
|
@@ -30,7 +30,6 @@ import logging
|
|
30
30
|
import os
|
31
31
|
from functools import cached_property
|
32
32
|
from typing import TYPE_CHECKING, Any, Union
|
33
|
-
from urllib.parse import urlparse
|
34
33
|
|
35
34
|
from asgiref.sync import sync_to_async
|
36
35
|
from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
|
@@ -52,6 +51,7 @@ from airflow.providers.microsoft.azure.utils import (
|
|
52
51
|
add_managed_identity_connection_widgets,
|
53
52
|
get_async_default_azure_credential,
|
54
53
|
get_sync_default_azure_credential,
|
54
|
+
parse_blob_account_url,
|
55
55
|
)
|
56
56
|
|
57
57
|
if TYPE_CHECKING:
|
@@ -167,21 +167,7 @@ class WasbHook(BaseHook):
|
|
167
167
|
# connection_string auth takes priority
|
168
168
|
return BlobServiceClient.from_connection_string(connection_string, **extra)
|
169
169
|
|
170
|
-
account_url = conn.host
|
171
|
-
parsed_url = urlparse(account_url)
|
172
|
-
|
173
|
-
if not parsed_url.netloc:
|
174
|
-
if "." not in parsed_url.path:
|
175
|
-
# if there's no netloc and no dots in the path, then user only
|
176
|
-
# provided the Active Directory ID, not the full URL or DNS name
|
177
|
-
account_url = f"https://{conn.login}.blob.core.windows.net/"
|
178
|
-
else:
|
179
|
-
# if there's no netloc but there are dots in the path, then user
|
180
|
-
# provided the DNS name without the https:// prefix.
|
181
|
-
# Azure storage account name can only be 3 to 24 characters in length
|
182
|
-
# https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name
|
183
|
-
acc_name = account_url.split(".")[0][:24]
|
184
|
-
account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:])
|
170
|
+
account_url = parse_blob_account_url(conn.host, conn.login)
|
185
171
|
|
186
172
|
tenant = self._get_field(extra, "tenant_id")
|
187
173
|
if tenant:
|
@@ -587,21 +573,7 @@ class WasbAsyncHook(WasbHook):
|
|
587
573
|
)
|
588
574
|
return self.blob_service_client
|
589
575
|
|
590
|
-
account_url = conn.host
|
591
|
-
parsed_url = urlparse(account_url)
|
592
|
-
|
593
|
-
if not parsed_url.netloc:
|
594
|
-
if "." not in parsed_url.path:
|
595
|
-
# if there's no netloc and no dots in the path, then user only
|
596
|
-
# provided the Active Directory ID, not the full URL or DNS name
|
597
|
-
account_url = f"https://{conn.login}.blob.core.windows.net/"
|
598
|
-
else:
|
599
|
-
# if there's no netloc but there are dots in the path, then user
|
600
|
-
# provided the DNS name without the https:// prefix.
|
601
|
-
# Azure storage account name can only be 3 to 24 characters in length
|
602
|
-
# https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name
|
603
|
-
acc_name = account_url.split(".")[0][:24]
|
604
|
-
account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:])
|
576
|
+
account_url = parse_blob_account_url(conn.host, conn.login)
|
605
577
|
|
606
578
|
tenant = self._get_field(extra, "tenant_id")
|
607
579
|
if tenant:
|
@@ -25,8 +25,10 @@ from typing import TYPE_CHECKING, Any, Sequence
|
|
25
25
|
from azure.mgmt.containerinstance.models import (
|
26
26
|
Container,
|
27
27
|
ContainerGroup,
|
28
|
+
ContainerGroupDiagnostics,
|
28
29
|
ContainerGroupSubnetId,
|
29
30
|
ContainerPort,
|
31
|
+
DnsConfiguration,
|
30
32
|
EnvironmentVariable,
|
31
33
|
IpAddress,
|
32
34
|
ResourceRequests,
|
@@ -90,6 +92,8 @@ class AzureContainerInstancesOperator(BaseOperator):
|
|
90
92
|
Possible values include: 'Always', 'OnFailure', 'Never'
|
91
93
|
:param ip_address: The IP address type of the container group.
|
92
94
|
:param subnet_ids: The subnet resource IDs for a container group
|
95
|
+
:param dns_config: The DNS configuration for a container group.
|
96
|
+
:param diagnostics: Container group diagnostic information (Log Analytics).
|
93
97
|
|
94
98
|
**Example**::
|
95
99
|
|
@@ -113,6 +117,13 @@ class AzureContainerInstancesOperator(BaseOperator):
|
|
113
117
|
memory_in_gb=14.0,
|
114
118
|
cpu=4.0,
|
115
119
|
gpu=GpuResource(count=1, sku="K80"),
|
120
|
+
dns_config=["10.0.0.10", "10.0.0.11"],
|
121
|
+
diagnostics={
|
122
|
+
"log_analytics": {
|
123
|
+
"workspaceId": "workspaceid",
|
124
|
+
"workspaceKey": "workspaceKey",
|
125
|
+
}
|
126
|
+
},
|
116
127
|
command=["/bin/echo", "world"],
|
117
128
|
task_id="start_container",
|
118
129
|
)
|
@@ -145,6 +156,8 @@ class AzureContainerInstancesOperator(BaseOperator):
|
|
145
156
|
ip_address: IpAddress | None = None,
|
146
157
|
ports: list[ContainerPort] | None = None,
|
147
158
|
subnet_ids: list[ContainerGroupSubnetId] | None = None,
|
159
|
+
dns_config: DnsConfiguration | None = None,
|
160
|
+
diagnostics: ContainerGroupDiagnostics | None = None,
|
148
161
|
**kwargs,
|
149
162
|
) -> None:
|
150
163
|
super().__init__(**kwargs)
|
@@ -183,6 +196,8 @@ class AzureContainerInstancesOperator(BaseOperator):
|
|
183
196
|
self.ip_address = ip_address
|
184
197
|
self.ports = ports
|
185
198
|
self.subnet_ids = subnet_ids
|
199
|
+
self.dns_config = dns_config
|
200
|
+
self.diagnostics = diagnostics
|
186
201
|
|
187
202
|
def execute(self, context: Context) -> int:
|
188
203
|
# Check name again in case it was templated.
|
@@ -256,6 +271,8 @@ class AzureContainerInstancesOperator(BaseOperator):
|
|
256
271
|
tags=self.tags,
|
257
272
|
ip_address=self.ip_address,
|
258
273
|
subnet_ids=self.subnet_ids,
|
274
|
+
dns_config=self.dns_config,
|
275
|
+
diagnostics=self.diagnostics,
|
259
276
|
)
|
260
277
|
|
261
278
|
self._ci_hook.create_or_update(self.resource_group, self.name, container_group)
|