apache-airflow-providers-microsoft-azure 10.0.0rc1__py3-none-any.whl → 10.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,7 +27,7 @@ import packaging.version
27
27
 
28
28
  __all__ = ["__version__"]
29
29
 
30
- __version__ = "10.0.0"
30
+ __version__ = "10.1.0"
31
31
 
32
32
  try:
33
33
  from airflow import __version__ as airflow_version
@@ -35,8 +35,8 @@ except ImportError:
35
35
  from airflow.version import version as airflow_version
36
36
 
37
37
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
38
- "2.6.0"
38
+ "2.7.0"
39
39
  ):
40
40
  raise RuntimeError(
41
- f"The package `apache-airflow-providers-microsoft-azure:{__version__}` needs Apache Airflow 2.6.0+"
41
+ f"The package `apache-airflow-providers-microsoft-azure:{__version__}` needs Apache Airflow 2.7.0+"
42
42
  )
@@ -18,8 +18,10 @@ from __future__ import annotations
18
18
 
19
19
  from typing import TYPE_CHECKING, Any
20
20
 
21
+ from azure.identity import ClientSecretCredential
22
+
21
23
  from airflow.hooks.base import BaseHook
22
- from airflow.providers.microsoft.azure.utils import get_field
24
+ from airflow.providers.microsoft.azure.utils import get_field, parse_blob_account_url
23
25
 
24
26
  if TYPE_CHECKING:
25
27
  from fsspec import AbstractFileSystem
@@ -35,13 +37,61 @@ def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) -
35
37
 
36
38
  conn = BaseHook.get_connection(conn_id)
37
39
  extras = conn.extra_dejson
40
+ conn_type = conn.conn_type or "azure_data_lake"
38
41
 
39
- options = {}
40
- fields = ["connection_string", "account_name", "account_key", "sas_token", "tenant"]
41
- for field in fields:
42
- options[field] = get_field(
43
- conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name=field
42
+ # connection string always overrides everything else
43
+ connection_string = get_field(
44
+ conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="connection_string"
45
+ )
46
+
47
+ if connection_string:
48
+ return AzureBlobFileSystem(connection_string=connection_string)
49
+
50
+ options: dict[str, Any] = {
51
+ "account_url": parse_blob_account_url(conn.host, conn.login),
52
+ }
53
+
54
+ # mirror handling of custom field "client_secret_auth_config" from extras. Ignore if missing as AzureBlobFileSystem can handle.
55
+ tenant_id = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="tenant_id")
56
+ login = conn.login or ""
57
+ password = conn.password or ""
58
+ # assumption (from WasbHook) that if tenant_id is set, we want service principal connection
59
+ if tenant_id:
60
+ client_secret_auth_config = get_field(
61
+ conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_secret_auth_config"
44
62
  )
63
+ if login:
64
+ options["client_id"] = login
65
+ if password:
66
+ options["client_secret"] = password
67
+ if client_secret_auth_config and login and password:
68
+ options["credential"] = ClientSecretCredential(
69
+ tenant_id=tenant_id, client_id=login, client_secret=password, **client_secret_auth_config
70
+ )
71
+
72
+ # if not service principal, then password is taken to be account admin key
73
+ if tenant_id is None and password:
74
+ options["account_key"] = password
75
+
76
+ # now take any fields from extras and overlay on these
77
+ # add empty field to remove defaults
78
+ fields = [
79
+ "account_name",
80
+ "account_key",
81
+ "sas_token",
82
+ "tenant_id",
83
+ "managed_identity_client_id",
84
+ "workload_identity_client_id",
85
+ "workload_identity_tenant_id",
86
+ "anon",
87
+ ]
88
+ for field in fields:
89
+ value = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name=field)
90
+ if value is not None:
91
+ if value == "":
92
+ options.pop(field, "")
93
+ else:
94
+ options[field] = value
45
95
 
46
96
  options.update(storage_options or {})
47
97
 
@@ -28,8 +28,9 @@ def get_provider_info():
28
28
  "name": "Microsoft Azure",
29
29
  "description": "`Microsoft Azure <https://azure.microsoft.com/>`__\n",
30
30
  "state": "ready",
31
- "source-date-epoch": 1709555852,
31
+ "source-date-epoch": 1714476738,
32
32
  "versions": [
33
+ "10.1.0",
33
34
  "10.0.0",
34
35
  "9.0.1",
35
36
  "9.0.0",
@@ -83,7 +84,7 @@ def get_provider_info():
83
84
  "1.0.0",
84
85
  ],
85
86
  "dependencies": [
86
- "apache-airflow>=2.6.0",
87
+ "apache-airflow>=2.7.0",
87
88
  "adlfs>=2023.10.0",
88
89
  "azure-batch>=8.0.0",
89
90
  "azure-cosmos>=4.6.0",
@@ -105,6 +106,7 @@ def get_provider_info():
105
106
  "azure-mgmt-datafactory>=2.0.0",
106
107
  "azure-mgmt-containerregistry>=8.0.0",
107
108
  "azure-mgmt-containerinstance>=9.0.0",
109
+ "msgraph-core>=1.0.0",
108
110
  ],
109
111
  "devel-dependencies": ["pywinrm"],
110
112
  "integrations": [
@@ -194,6 +196,13 @@ def get_provider_info():
194
196
  "logo": "/integration-logos/azure/Data Lake Storage.svg",
195
197
  "tags": ["azure"],
196
198
  },
199
+ {
200
+ "integration-name": "Microsoft Graph API",
201
+ "external-doc-url": "https://learn.microsoft.com/en-us/graph/use-the-api/",
202
+ "logo": "/integration-logos/azure/Microsoft-Graph-API.png",
203
+ "how-to-guide": ["/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst"],
204
+ "tags": ["azure"],
205
+ },
197
206
  ],
198
207
  "operators": [
199
208
  {
@@ -232,6 +241,10 @@ def get_provider_info():
232
241
  "integration-name": "Microsoft Azure Synapse",
233
242
  "python-modules": ["airflow.providers.microsoft.azure.operators.synapse"],
234
243
  },
244
+ {
245
+ "integration-name": "Microsoft Graph API",
246
+ "python-modules": ["airflow.providers.microsoft.azure.operators.msgraph"],
247
+ },
235
248
  ],
236
249
  "sensors": [
237
250
  {
@@ -246,6 +259,10 @@ def get_provider_info():
246
259
  "integration-name": "Microsoft Azure Data Factory",
247
260
  "python-modules": ["airflow.providers.microsoft.azure.sensors.data_factory"],
248
261
  },
262
+ {
263
+ "integration-name": "Microsoft Graph API",
264
+ "python-modules": ["airflow.providers.microsoft.azure.sensors.msgraph"],
265
+ },
249
266
  ],
250
267
  "filesystems": ["airflow.providers.microsoft.azure.fs.adls"],
251
268
  "hooks": [
@@ -301,6 +318,10 @@ def get_provider_info():
301
318
  "integration-name": "Microsoft Azure Synapse",
302
319
  "python-modules": ["airflow.providers.microsoft.azure.hooks.synapse"],
303
320
  },
321
+ {
322
+ "integration-name": "Microsoft Graph API",
323
+ "python-modules": ["airflow.providers.microsoft.azure.hooks.msgraph"],
324
+ },
304
325
  ],
305
326
  "triggers": [
306
327
  {
@@ -311,6 +332,10 @@ def get_provider_info():
311
332
  "integration-name": "Microsoft Azure Blob Storage",
312
333
  "python-modules": ["airflow.providers.microsoft.azure.triggers.wasb"],
313
334
  },
335
+ {
336
+ "integration-name": "Microsoft Graph API",
337
+ "python-modules": ["airflow.providers.microsoft.azure.triggers.msgraph"],
338
+ },
314
339
  ],
315
340
  "transfers": [
316
341
  {
@@ -0,0 +1,358 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ from http import HTTPStatus
22
+ from io import BytesIO
23
+ from typing import TYPE_CHECKING, Any, Callable
24
+ from urllib.parse import quote, urljoin, urlparse
25
+
26
+ import httpx
27
+ from azure.identity import ClientSecretCredential
28
+ from httpx import Timeout
29
+ from kiota_abstractions.api_error import APIError
30
+ from kiota_abstractions.method import Method
31
+ from kiota_abstractions.request_information import RequestInformation
32
+ from kiota_abstractions.response_handler import ResponseHandler
33
+ from kiota_authentication_azure.azure_identity_authentication_provider import (
34
+ AzureIdentityAuthenticationProvider,
35
+ )
36
+ from kiota_http.httpx_request_adapter import HttpxRequestAdapter
37
+ from kiota_http.middleware.options import ResponseHandlerOption
38
+ from msgraph_core import APIVersion, GraphClientFactory
39
+ from msgraph_core._enums import NationalClouds
40
+
41
+ from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException
42
+ from airflow.hooks.base import BaseHook
43
+
44
+ if TYPE_CHECKING:
45
+ from kiota_abstractions.request_adapter import RequestAdapter
46
+ from kiota_abstractions.request_information import QueryParams
47
+ from kiota_abstractions.response_handler import NativeResponseType
48
+ from kiota_abstractions.serialization import ParsableFactory
49
+ from kiota_http.httpx_request_adapter import ResponseType
50
+
51
+ from airflow.models import Connection
52
+
53
+
54
+ class CallableResponseHandler(ResponseHandler):
55
+ """
56
+ CallableResponseHandler executes the passed callable_function with response as parameter.
57
+
58
+ param callable_function: Function that is applied to the response.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any],
64
+ ):
65
+ self.callable_function = callable_function
66
+
67
+ async def handle_response_async(
68
+ self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
69
+ ) -> Any:
70
+ """
71
+ Invoke this callback method when a response is received.
72
+
73
+ param response: The type of the native response object.
74
+ param error_map: The error dict to use in case of a failed request.
75
+ """
76
+ value = self.callable_function(response, error_map)
77
+ if response.status_code not in {200, 201, 202, 204, 302}:
78
+ message = value or response.reason_phrase
79
+ status_code = HTTPStatus(response.status_code)
80
+ if status_code == HTTPStatus.BAD_REQUEST:
81
+ raise AirflowBadRequest(message)
82
+ elif status_code == HTTPStatus.NOT_FOUND:
83
+ raise AirflowNotFoundException(message)
84
+ raise AirflowException(message)
85
+ return value
86
+
87
+
88
+ class KiotaRequestAdapterHook(BaseHook):
89
+ """
90
+ A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter.
91
+
92
+ https://github.com/microsoftgraph/msgraph-sdk-python-core
93
+
94
+ :param conn_id: The HTTP Connection ID to run the trigger against.
95
+ :param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None).
96
+ When no timeout is specified or set to None then no HTTP timeout is applied on each request.
97
+ :param proxies: A Dict defining the HTTP proxies to be used (default is None).
98
+ :param api_version: The API version of the Microsoft Graph API to be used (default is v1).
99
+ You can pass an enum named APIVersion which has 2 possible members v1 and beta,
100
+ or you can pass a string as "v1.0" or "beta".
101
+ """
102
+
103
+ DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
104
+ cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
105
+ default_conn_name: str = "msgraph_default"
106
+
107
+ def __init__(
108
+ self,
109
+ conn_id: str = default_conn_name,
110
+ timeout: float | None = None,
111
+ proxies: dict | None = None,
112
+ api_version: APIVersion | str | None = None,
113
+ ):
114
+ super().__init__()
115
+ self.conn_id = conn_id
116
+ self.timeout = timeout
117
+ self.proxies = proxies
118
+ self._api_version = self.resolve_api_version_from_value(api_version)
119
+
120
+ @property
121
+ def api_version(self) -> APIVersion:
122
+ self.get_conn() # Make sure config has been loaded through get_conn to have correct api version!
123
+ return self._api_version
124
+
125
+ @staticmethod
126
+ def resolve_api_version_from_value(
127
+ api_version: APIVersion | str, default: APIVersion | None = None
128
+ ) -> APIVersion:
129
+ if isinstance(api_version, APIVersion):
130
+ return api_version
131
+ return next(
132
+ filter(lambda version: version.value == api_version, APIVersion),
133
+ default,
134
+ )
135
+
136
+ def get_api_version(self, config: dict) -> APIVersion:
137
+ if self._api_version is None:
138
+ return self.resolve_api_version_from_value(
139
+ api_version=config.get("api_version"), default=APIVersion.v1
140
+ )
141
+ return self._api_version
142
+
143
+ @staticmethod
144
+ def get_host(connection: Connection) -> str:
145
+ if connection.schema and connection.host:
146
+ return f"{connection.schema}://{connection.host}"
147
+ return NationalClouds.Global.value
148
+
149
+ @staticmethod
150
+ def format_no_proxy_url(url: str) -> str:
151
+ if "://" not in url:
152
+ url = f"all://{url}"
153
+ return url
154
+
155
+ @classmethod
156
+ def to_httpx_proxies(cls, proxies: dict) -> dict:
157
+ proxies = proxies.copy()
158
+ if proxies.get("http"):
159
+ proxies["http://"] = proxies.pop("http")
160
+ if proxies.get("https"):
161
+ proxies["https://"] = proxies.pop("https")
162
+ if proxies.get("no"):
163
+ for url in proxies.pop("no", "").split(","):
164
+ proxies[cls.format_no_proxy_url(url.strip())] = None
165
+ return proxies
166
+
167
+ def to_msal_proxies(self, authority: str | None, proxies: dict):
168
+ self.log.info("authority: %s", authority)
169
+ if authority:
170
+ no_proxies = proxies.get("no")
171
+ self.log.info("no_proxies: %s", no_proxies)
172
+ if no_proxies:
173
+ for url in no_proxies.split(","):
174
+ self.log.info("url: %s", url)
175
+ domain_name = urlparse(url).path.replace("*", "")
176
+ self.log.info("domain_name: %s", domain_name)
177
+ if authority.endswith(domain_name):
178
+ return None
179
+ return proxies
180
+
181
+ def get_conn(self) -> RequestAdapter:
182
+ if not self.conn_id:
183
+ raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")
184
+
185
+ api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))
186
+
187
+ if not request_adapter:
188
+ connection = self.get_connection(conn_id=self.conn_id)
189
+ client_id = connection.login
190
+ client_secret = connection.password
191
+ config = connection.extra_dejson if connection.extra else {}
192
+ tenant_id = config.get("tenant_id")
193
+ api_version = self.get_api_version(config)
194
+ host = self.get_host(connection)
195
+ base_url = config.get("base_url", urljoin(host, api_version.value))
196
+ authority = config.get("authority")
197
+ proxies = self.proxies or config.get("proxies", {})
198
+ msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
199
+ httpx_proxies = self.to_httpx_proxies(proxies=proxies)
200
+ scopes = config.get("scopes", ["https://graph.microsoft.com/.default"])
201
+ verify = config.get("verify", True)
202
+ trust_env = config.get("trust_env", False)
203
+ disable_instance_discovery = config.get("disable_instance_discovery", False)
204
+ allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")
205
+
206
+ self.log.info(
207
+ "Creating Microsoft Graph SDK client %s for conn_id: %s",
208
+ api_version.value,
209
+ self.conn_id,
210
+ )
211
+ self.log.info("Host: %s", host)
212
+ self.log.info("Base URL: %s", base_url)
213
+ self.log.info("Tenant id: %s", tenant_id)
214
+ self.log.info("Client id: %s", client_id)
215
+ self.log.info("Client secret: %s", client_secret)
216
+ self.log.info("API version: %s", api_version.value)
217
+ self.log.info("Scope: %s", scopes)
218
+ self.log.info("Verify: %s", verify)
219
+ self.log.info("Timeout: %s", self.timeout)
220
+ self.log.info("Trust env: %s", trust_env)
221
+ self.log.info("Authority: %s", authority)
222
+ self.log.info("Disable instance discovery: %s", disable_instance_discovery)
223
+ self.log.info("Allowed hosts: %s", allowed_hosts)
224
+ self.log.info("Proxies: %s", proxies)
225
+ self.log.info("MSAL Proxies: %s", msal_proxies)
226
+ self.log.info("HTTPX Proxies: %s", httpx_proxies)
227
+ credentials = ClientSecretCredential(
228
+ tenant_id=tenant_id, # type: ignore
229
+ client_id=connection.login,
230
+ client_secret=connection.password,
231
+ authority=authority,
232
+ proxies=msal_proxies,
233
+ disable_instance_discovery=disable_instance_discovery,
234
+ connection_verify=verify,
235
+ )
236
+ http_client = GraphClientFactory.create_with_default_middleware(
237
+ api_version=api_version,
238
+ client=httpx.AsyncClient(
239
+ proxies=httpx_proxies,
240
+ timeout=Timeout(timeout=self.timeout),
241
+ verify=verify,
242
+ trust_env=trust_env,
243
+ ),
244
+ host=host,
245
+ )
246
+ auth_provider = AzureIdentityAuthenticationProvider(
247
+ credentials=credentials,
248
+ scopes=scopes,
249
+ allowed_hosts=allowed_hosts,
250
+ )
251
+ request_adapter = HttpxRequestAdapter(
252
+ authentication_provider=auth_provider,
253
+ http_client=http_client,
254
+ base_url=base_url,
255
+ )
256
+ self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
257
+ self._api_version = api_version
258
+ return request_adapter
259
+
260
+ def test_connection(self):
261
+ """Test HTTP Connection."""
262
+ try:
263
+ self.run()
264
+ return True, "Connection successfully tested"
265
+ except Exception as e:
266
+ return False, str(e)
267
+
268
+ async def run(
269
+ self,
270
+ url: str = "",
271
+ response_type: ResponseType | None = None,
272
+ response_handler: Callable[
273
+ [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
274
+ ] = lambda response, error_map: response.json(),
275
+ path_parameters: dict[str, Any] | None = None,
276
+ method: str = "GET",
277
+ query_parameters: dict[str, QueryParams] | None = None,
278
+ headers: dict[str, str] | None = None,
279
+ data: dict[str, Any] | str | BytesIO | None = None,
280
+ ):
281
+ response = await self.get_conn().send_primitive_async(
282
+ request_info=self.request_information(
283
+ url=url,
284
+ response_type=response_type,
285
+ response_handler=response_handler,
286
+ path_parameters=path_parameters,
287
+ method=method,
288
+ query_parameters=query_parameters,
289
+ headers=headers,
290
+ data=data,
291
+ ),
292
+ response_type=response_type,
293
+ error_map=self.error_mapping(),
294
+ )
295
+
296
+ self.log.debug("response: %s", response)
297
+
298
+ return response
299
+
300
+ def request_information(
301
+ self,
302
+ url: str,
303
+ response_type: ResponseType | None = None,
304
+ response_handler: Callable[
305
+ [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
306
+ ] = lambda response, error_map: response.json(),
307
+ path_parameters: dict[str, Any] | None = None,
308
+ method: str = "GET",
309
+ query_parameters: dict[str, QueryParams] | None = None,
310
+ headers: dict[str, str] | None = None,
311
+ data: dict[str, Any] | str | BytesIO | None = None,
312
+ ) -> RequestInformation:
313
+ request_information = RequestInformation()
314
+ request_information.path_parameters = path_parameters or {}
315
+ request_information.http_method = Method(method.strip().upper())
316
+ request_information.query_parameters = self.encoded_query_parameters(query_parameters)
317
+ if url.startswith("http"):
318
+ request_information.url = url
319
+ elif request_information.query_parameters.keys():
320
+ query = ",".join(request_information.query_parameters.keys())
321
+ request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}{{?{query}}}"
322
+ else:
323
+ request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}"
324
+ if not response_type:
325
+ request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption(
326
+ response_handler=CallableResponseHandler(response_handler)
327
+ )
328
+ headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS
329
+ for header_name, header_value in headers.items():
330
+ request_information.headers.try_add(header_name=header_name, header_value=header_value)
331
+ self.log.info("data: %s", data)
332
+ if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str):
333
+ request_information.content = data
334
+ elif data:
335
+ request_information.headers.try_add(
336
+ header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json"
337
+ )
338
+ request_information.content = json.dumps(data).encode("utf-8")
339
+ return request_information
340
+
341
+ @staticmethod
342
+ def normalize_url(url: str) -> str | None:
343
+ if url.startswith("/"):
344
+ return url.replace("/", "", 1)
345
+ return url
346
+
347
+ @staticmethod
348
+ def encoded_query_parameters(query_parameters) -> dict:
349
+ if query_parameters:
350
+ return {quote(key): value for key, value in query_parameters.items()}
351
+ return {}
352
+
353
+ @staticmethod
354
+ def error_mapping() -> dict[str, ParsableFactory | None]:
355
+ return {
356
+ "4XX": APIError,
357
+ "5XX": APIError,
358
+ }
@@ -312,6 +312,7 @@ class AzureSynapsePipelineHook(BaseAzureSynapseHook):
312
312
  warnings.warn(
313
313
  "The usage of `default_conn_name=azure_synapse_connection` is deprecated and will be removed in future. Please update your code to use the new default connection name: `default_conn_name=azure_synapse_default`. ",
314
314
  AirflowProviderDeprecationWarning,
315
+ stacklevel=2,
315
316
  )
316
317
  self._conn: ArtifactsClient | None = None
317
318
  self.azure_synapse_workspace_dev_endpoint = azure_synapse_workspace_dev_endpoint
@@ -30,7 +30,6 @@ import logging
30
30
  import os
31
31
  from functools import cached_property
32
32
  from typing import TYPE_CHECKING, Any, Union
33
- from urllib.parse import urlparse
34
33
 
35
34
  from asgiref.sync import sync_to_async
36
35
  from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
@@ -52,6 +51,7 @@ from airflow.providers.microsoft.azure.utils import (
52
51
  add_managed_identity_connection_widgets,
53
52
  get_async_default_azure_credential,
54
53
  get_sync_default_azure_credential,
54
+ parse_blob_account_url,
55
55
  )
56
56
 
57
57
  if TYPE_CHECKING:
@@ -167,21 +167,7 @@ class WasbHook(BaseHook):
167
167
  # connection_string auth takes priority
168
168
  return BlobServiceClient.from_connection_string(connection_string, **extra)
169
169
 
170
- account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
171
- parsed_url = urlparse(account_url)
172
-
173
- if not parsed_url.netloc:
174
- if "." not in parsed_url.path:
175
- # if there's no netloc and no dots in the path, then user only
176
- # provided the Active Directory ID, not the full URL or DNS name
177
- account_url = f"https://{conn.login}.blob.core.windows.net/"
178
- else:
179
- # if there's no netloc but there are dots in the path, then user
180
- # provided the DNS name without the https:// prefix.
181
- # Azure storage account name can only be 3 to 24 characters in length
182
- # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name
183
- acc_name = account_url.split(".")[0][:24]
184
- account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:])
170
+ account_url = parse_blob_account_url(conn.host, conn.login)
185
171
 
186
172
  tenant = self._get_field(extra, "tenant_id")
187
173
  if tenant:
@@ -587,21 +573,7 @@ class WasbAsyncHook(WasbHook):
587
573
  )
588
574
  return self.blob_service_client
589
575
 
590
- account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
591
- parsed_url = urlparse(account_url)
592
-
593
- if not parsed_url.netloc:
594
- if "." not in parsed_url.path:
595
- # if there's no netloc and no dots in the path, then user only
596
- # provided the Active Directory ID, not the full URL or DNS name
597
- account_url = f"https://{conn.login}.blob.core.windows.net/"
598
- else:
599
- # if there's no netloc but there are dots in the path, then user
600
- # provided the DNS name without the https:// prefix.
601
- # Azure storage account name can only be 3 to 24 characters in length
602
- # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name
603
- acc_name = account_url.split(".")[0][:24]
604
- account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:])
576
+ account_url = parse_blob_account_url(conn.host, conn.login)
605
577
 
606
578
  tenant = self._get_field(extra, "tenant_id")
607
579
  if tenant:
@@ -25,8 +25,10 @@ from typing import TYPE_CHECKING, Any, Sequence
25
25
  from azure.mgmt.containerinstance.models import (
26
26
  Container,
27
27
  ContainerGroup,
28
+ ContainerGroupDiagnostics,
28
29
  ContainerGroupSubnetId,
29
30
  ContainerPort,
31
+ DnsConfiguration,
30
32
  EnvironmentVariable,
31
33
  IpAddress,
32
34
  ResourceRequests,
@@ -90,6 +92,8 @@ class AzureContainerInstancesOperator(BaseOperator):
90
92
  Possible values include: 'Always', 'OnFailure', 'Never'
91
93
  :param ip_address: The IP address type of the container group.
92
94
  :param subnet_ids: The subnet resource IDs for a container group
95
+ :param dns_config: The DNS configuration for a container group.
96
+ :param diagnostics: Container group diagnostic information (Log Analytics).
93
97
 
94
98
  **Example**::
95
99
 
@@ -113,6 +117,13 @@ class AzureContainerInstancesOperator(BaseOperator):
113
117
  memory_in_gb=14.0,
114
118
  cpu=4.0,
115
119
  gpu=GpuResource(count=1, sku="K80"),
120
+ dns_config=["10.0.0.10", "10.0.0.11"],
121
+ diagnostics={
122
+ "log_analytics": {
123
+ "workspaceId": "workspaceid",
124
+ "workspaceKey": "workspaceKey",
125
+ }
126
+ },
116
127
  command=["/bin/echo", "world"],
117
128
  task_id="start_container",
118
129
  )
@@ -145,6 +156,8 @@ class AzureContainerInstancesOperator(BaseOperator):
145
156
  ip_address: IpAddress | None = None,
146
157
  ports: list[ContainerPort] | None = None,
147
158
  subnet_ids: list[ContainerGroupSubnetId] | None = None,
159
+ dns_config: DnsConfiguration | None = None,
160
+ diagnostics: ContainerGroupDiagnostics | None = None,
148
161
  **kwargs,
149
162
  ) -> None:
150
163
  super().__init__(**kwargs)
@@ -183,6 +196,8 @@ class AzureContainerInstancesOperator(BaseOperator):
183
196
  self.ip_address = ip_address
184
197
  self.ports = ports
185
198
  self.subnet_ids = subnet_ids
199
+ self.dns_config = dns_config
200
+ self.diagnostics = diagnostics
186
201
 
187
202
  def execute(self, context: Context) -> int:
188
203
  # Check name again in case it was templated.
@@ -256,6 +271,8 @@ class AzureContainerInstancesOperator(BaseOperator):
256
271
  tags=self.tags,
257
272
  ip_address=self.ip_address,
258
273
  subnet_ids=self.subnet_ids,
274
+ dns_config=self.dns_config,
275
+ diagnostics=self.diagnostics,
259
276
  )
260
277
 
261
278
  self._ci_hook.create_or_update(self.resource_group, self.name, container_group)