apache-airflow-providers-microsoft-azure 10.0.0__py3-none-any.whl → 10.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,18 +25,15 @@ from __future__ import annotations
25
25
 
26
26
  import packaging.version
27
27
 
28
- __all__ = ["__version__"]
28
+ from airflow import __version__ as airflow_version
29
29
 
30
- __version__ = "10.0.0"
30
+ __all__ = ["__version__"]
31
31
 
32
- try:
33
- from airflow import __version__ as airflow_version
34
- except ImportError:
35
- from airflow.version import version as airflow_version
32
+ __version__ = "10.1.0"
36
33
 
37
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
38
- "2.6.0"
35
+ "2.7.0"
39
36
  ):
40
37
  raise RuntimeError(
41
- f"The package `apache-airflow-providers-microsoft-azure:{__version__}` needs Apache Airflow 2.6.0+"
38
+ f"The package `apache-airflow-providers-microsoft-azure:{__version__}` needs Apache Airflow 2.7.0+"
42
39
  )
@@ -18,8 +18,10 @@ from __future__ import annotations
18
18
 
19
19
  from typing import TYPE_CHECKING, Any
20
20
 
21
+ from azure.identity import ClientSecretCredential
22
+
21
23
  from airflow.hooks.base import BaseHook
22
- from airflow.providers.microsoft.azure.utils import get_field
24
+ from airflow.providers.microsoft.azure.utils import get_field, parse_blob_account_url
23
25
 
24
26
  if TYPE_CHECKING:
25
27
  from fsspec import AbstractFileSystem
@@ -35,13 +37,61 @@ def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) -
35
37
 
36
38
  conn = BaseHook.get_connection(conn_id)
37
39
  extras = conn.extra_dejson
40
+ conn_type = conn.conn_type or "azure_data_lake"
38
41
 
39
- options = {}
40
- fields = ["connection_string", "account_name", "account_key", "sas_token", "tenant"]
41
- for field in fields:
42
- options[field] = get_field(
43
- conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name=field
42
+ # connection string always overrides everything else
43
+ connection_string = get_field(
44
+ conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="connection_string"
45
+ )
46
+
47
+ if connection_string:
48
+ return AzureBlobFileSystem(connection_string=connection_string)
49
+
50
+ options: dict[str, Any] = {
51
+ "account_url": parse_blob_account_url(conn.host, conn.login),
52
+ }
53
+
54
+ # mirror handling of custom field "client_secret_auth_config" from extras. Ignore if missing as AzureBlobFileSystem can handle.
55
+ tenant_id = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="tenant_id")
56
+ login = conn.login or ""
57
+ password = conn.password or ""
58
+ # assumption (from WasbHook) that if tenant_id is set, we want service principal connection
59
+ if tenant_id:
60
+ client_secret_auth_config = get_field(
61
+ conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_secret_auth_config"
44
62
  )
63
+ if login:
64
+ options["client_id"] = login
65
+ if password:
66
+ options["client_secret"] = password
67
+ if client_secret_auth_config and login and password:
68
+ options["credential"] = ClientSecretCredential(
69
+ tenant_id=tenant_id, client_id=login, client_secret=password, **client_secret_auth_config
70
+ )
71
+
72
+ # if not service principal, then password is taken to be account admin key
73
+ if tenant_id is None and password:
74
+ options["account_key"] = password
75
+
76
+ # now take any fields from extras and overlay on these
77
+ # add empty field to remove defaults
78
+ fields = [
79
+ "account_name",
80
+ "account_key",
81
+ "sas_token",
82
+ "tenant_id",
83
+ "managed_identity_client_id",
84
+ "workload_identity_client_id",
85
+ "workload_identity_tenant_id",
86
+ "anon",
87
+ ]
88
+ for field in fields:
89
+ value = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name=field)
90
+ if value is not None:
91
+ if value == "":
92
+ options.pop(field, "")
93
+ else:
94
+ options[field] = value
45
95
 
46
96
  options.update(storage_options or {})
47
97
 
@@ -28,8 +28,9 @@ def get_provider_info():
28
28
  "name": "Microsoft Azure",
29
29
  "description": "`Microsoft Azure <https://azure.microsoft.com/>`__\n",
30
30
  "state": "ready",
31
- "source-date-epoch": 1709555852,
31
+ "source-date-epoch": 1715384449,
32
32
  "versions": [
33
+ "10.1.0",
33
34
  "10.0.0",
34
35
  "9.0.1",
35
36
  "9.0.0",
@@ -83,7 +84,7 @@ def get_provider_info():
83
84
  "1.0.0",
84
85
  ],
85
86
  "dependencies": [
86
- "apache-airflow>=2.6.0",
87
+ "apache-airflow>=2.7.0",
87
88
  "adlfs>=2023.10.0",
88
89
  "azure-batch>=8.0.0",
89
90
  "azure-cosmos>=4.6.0",
@@ -105,6 +106,7 @@ def get_provider_info():
105
106
  "azure-mgmt-datafactory>=2.0.0",
106
107
  "azure-mgmt-containerregistry>=8.0.0",
107
108
  "azure-mgmt-containerinstance>=9.0.0",
109
+ "msgraph-core>=1.0.0",
108
110
  ],
109
111
  "devel-dependencies": ["pywinrm"],
110
112
  "integrations": [
@@ -194,6 +196,13 @@ def get_provider_info():
194
196
  "logo": "/integration-logos/azure/Data Lake Storage.svg",
195
197
  "tags": ["azure"],
196
198
  },
199
+ {
200
+ "integration-name": "Microsoft Graph API",
201
+ "external-doc-url": "https://learn.microsoft.com/en-us/graph/use-the-api/",
202
+ "logo": "/integration-logos/azure/Microsoft-Graph-API.png",
203
+ "how-to-guide": ["/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst"],
204
+ "tags": ["azure"],
205
+ },
197
206
  ],
198
207
  "operators": [
199
208
  {
@@ -232,6 +241,10 @@ def get_provider_info():
232
241
  "integration-name": "Microsoft Azure Synapse",
233
242
  "python-modules": ["airflow.providers.microsoft.azure.operators.synapse"],
234
243
  },
244
+ {
245
+ "integration-name": "Microsoft Graph API",
246
+ "python-modules": ["airflow.providers.microsoft.azure.operators.msgraph"],
247
+ },
235
248
  ],
236
249
  "sensors": [
237
250
  {
@@ -246,6 +259,10 @@ def get_provider_info():
246
259
  "integration-name": "Microsoft Azure Data Factory",
247
260
  "python-modules": ["airflow.providers.microsoft.azure.sensors.data_factory"],
248
261
  },
262
+ {
263
+ "integration-name": "Microsoft Graph API",
264
+ "python-modules": ["airflow.providers.microsoft.azure.sensors.msgraph"],
265
+ },
249
266
  ],
250
267
  "filesystems": ["airflow.providers.microsoft.azure.fs.adls"],
251
268
  "hooks": [
@@ -301,6 +318,10 @@ def get_provider_info():
301
318
  "integration-name": "Microsoft Azure Synapse",
302
319
  "python-modules": ["airflow.providers.microsoft.azure.hooks.synapse"],
303
320
  },
321
+ {
322
+ "integration-name": "Microsoft Graph API",
323
+ "python-modules": ["airflow.providers.microsoft.azure.hooks.msgraph"],
324
+ },
304
325
  ],
305
326
  "triggers": [
306
327
  {
@@ -311,6 +332,10 @@ def get_provider_info():
311
332
  "integration-name": "Microsoft Azure Blob Storage",
312
333
  "python-modules": ["airflow.providers.microsoft.azure.triggers.wasb"],
313
334
  },
335
+ {
336
+ "integration-name": "Microsoft Graph API",
337
+ "python-modules": ["airflow.providers.microsoft.azure.triggers.msgraph"],
338
+ },
314
339
  ],
315
340
  "transfers": [
316
341
  {
@@ -0,0 +1,353 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ from contextlib import suppress
22
+ from http import HTTPStatus
23
+ from io import BytesIO
24
+ from json import JSONDecodeError
25
+ from typing import TYPE_CHECKING, Any
26
+ from urllib.parse import quote, urljoin, urlparse
27
+
28
+ import httpx
29
+ from azure.identity import ClientSecretCredential
30
+ from httpx import Timeout
31
+ from kiota_abstractions.api_error import APIError
32
+ from kiota_abstractions.method import Method
33
+ from kiota_abstractions.request_information import RequestInformation
34
+ from kiota_abstractions.response_handler import ResponseHandler
35
+ from kiota_authentication_azure.azure_identity_authentication_provider import (
36
+ AzureIdentityAuthenticationProvider,
37
+ )
38
+ from kiota_http.httpx_request_adapter import HttpxRequestAdapter
39
+ from kiota_http.middleware.options import ResponseHandlerOption
40
+ from msgraph_core import APIVersion, GraphClientFactory
41
+ from msgraph_core._enums import NationalClouds
42
+
43
+ from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException
44
+ from airflow.hooks.base import BaseHook
45
+
46
+ if TYPE_CHECKING:
47
+ from kiota_abstractions.request_adapter import RequestAdapter
48
+ from kiota_abstractions.request_information import QueryParams
49
+ from kiota_abstractions.response_handler import NativeResponseType
50
+ from kiota_abstractions.serialization import ParsableFactory
51
+ from kiota_http.httpx_request_adapter import ResponseType
52
+
53
+ from airflow.models import Connection
54
+
55
+
56
+ class DefaultResponseHandler(ResponseHandler):
57
+ """DefaultResponseHandler returns JSON payload or content in bytes or response headers."""
58
+
59
+ @staticmethod
60
+ def get_value(response: NativeResponseType) -> Any:
61
+ with suppress(JSONDecodeError):
62
+ return response.json()
63
+ content = response.content
64
+ if not content:
65
+ return {key: value for key, value in response.headers.items()}
66
+ return content
67
+
68
+ async def handle_response_async(
69
+ self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
70
+ ) -> Any:
71
+ """
72
+ Invoke this callback method when a response is received.
73
+
74
+ param response: The type of the native response object.
75
+ param error_map: The error dict to use in case of a failed request.
76
+ """
77
+ value = self.get_value(response)
78
+ if response.status_code not in {200, 201, 202, 204, 302}:
79
+ message = value or response.reason_phrase
80
+ status_code = HTTPStatus(response.status_code)
81
+ if status_code == HTTPStatus.BAD_REQUEST:
82
+ raise AirflowBadRequest(message)
83
+ elif status_code == HTTPStatus.NOT_FOUND:
84
+ raise AirflowNotFoundException(message)
85
+ raise AirflowException(message)
86
+ return value
87
+
88
+
89
+ class KiotaRequestAdapterHook(BaseHook):
90
+ """
91
+ A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter.
92
+
93
+ https://github.com/microsoftgraph/msgraph-sdk-python-core
94
+
95
+ :param conn_id: The HTTP Connection ID to run the trigger against.
96
+ :param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None).
97
+ When no timeout is specified or set to None then no HTTP timeout is applied on each request.
98
+ :param proxies: A Dict defining the HTTP proxies to be used (default is None).
99
+ :param api_version: The API version of the Microsoft Graph API to be used (default is v1).
100
+ You can pass an enum named APIVersion which has 2 possible members v1 and beta,
101
+ or you can pass a string as "v1.0" or "beta".
102
+ """
103
+
104
+ DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
105
+ cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
106
+ default_conn_name: str = "msgraph_default"
107
+
108
+ def __init__(
109
+ self,
110
+ conn_id: str = default_conn_name,
111
+ timeout: float | None = None,
112
+ proxies: dict | None = None,
113
+ api_version: APIVersion | str | None = None,
114
+ ):
115
+ super().__init__()
116
+ self.conn_id = conn_id
117
+ self.timeout = timeout
118
+ self.proxies = proxies
119
+ self._api_version = self.resolve_api_version_from_value(api_version)
120
+
121
+ @property
122
+ def api_version(self) -> APIVersion:
123
+ self.get_conn() # Make sure config has been loaded through get_conn to have correct api version!
124
+ return self._api_version
125
+
126
+ @staticmethod
127
+ def resolve_api_version_from_value(
128
+ api_version: APIVersion | str, default: APIVersion | None = None
129
+ ) -> APIVersion:
130
+ if isinstance(api_version, APIVersion):
131
+ return api_version
132
+ return next(
133
+ filter(lambda version: version.value == api_version, APIVersion),
134
+ default,
135
+ )
136
+
137
+ def get_api_version(self, config: dict) -> APIVersion:
138
+ if self._api_version is None:
139
+ return self.resolve_api_version_from_value(
140
+ api_version=config.get("api_version"), default=APIVersion.v1
141
+ )
142
+ return self._api_version
143
+
144
+ @staticmethod
145
+ def get_host(connection: Connection) -> str:
146
+ if connection.schema and connection.host:
147
+ return f"{connection.schema}://{connection.host}"
148
+ return NationalClouds.Global.value
149
+
150
+ @staticmethod
151
+ def format_no_proxy_url(url: str) -> str:
152
+ if "://" not in url:
153
+ url = f"all://{url}"
154
+ return url
155
+
156
+ @classmethod
157
+ def to_httpx_proxies(cls, proxies: dict) -> dict:
158
+ proxies = proxies.copy()
159
+ if proxies.get("http"):
160
+ proxies["http://"] = proxies.pop("http")
161
+ if proxies.get("https"):
162
+ proxies["https://"] = proxies.pop("https")
163
+ if proxies.get("no"):
164
+ for url in proxies.pop("no", "").split(","):
165
+ proxies[cls.format_no_proxy_url(url.strip())] = None
166
+ return proxies
167
+
168
+ def to_msal_proxies(self, authority: str | None, proxies: dict):
169
+ self.log.info("authority: %s", authority)
170
+ if authority:
171
+ no_proxies = proxies.get("no")
172
+ self.log.info("no_proxies: %s", no_proxies)
173
+ if no_proxies:
174
+ for url in no_proxies.split(","):
175
+ self.log.info("url: %s", url)
176
+ domain_name = urlparse(url).path.replace("*", "")
177
+ self.log.info("domain_name: %s", domain_name)
178
+ if authority.endswith(domain_name):
179
+ return None
180
+ return proxies
181
+
182
+ def get_conn(self) -> RequestAdapter:
183
+ if not self.conn_id:
184
+ raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")
185
+
186
+ api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))
187
+
188
+ if not request_adapter:
189
+ connection = self.get_connection(conn_id=self.conn_id)
190
+ client_id = connection.login
191
+ client_secret = connection.password
192
+ config = connection.extra_dejson if connection.extra else {}
193
+ tenant_id = config.get("tenant_id")
194
+ api_version = self.get_api_version(config)
195
+ host = self.get_host(connection)
196
+ base_url = config.get("base_url", urljoin(host, api_version.value))
197
+ authority = config.get("authority")
198
+ proxies = self.proxies or config.get("proxies", {})
199
+ msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
200
+ httpx_proxies = self.to_httpx_proxies(proxies=proxies)
201
+ scopes = config.get("scopes", ["https://graph.microsoft.com/.default"])
202
+ verify = config.get("verify", True)
203
+ trust_env = config.get("trust_env", False)
204
+ disable_instance_discovery = config.get("disable_instance_discovery", False)
205
+ allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")
206
+
207
+ self.log.info(
208
+ "Creating Microsoft Graph SDK client %s for conn_id: %s",
209
+ api_version.value,
210
+ self.conn_id,
211
+ )
212
+ self.log.info("Host: %s", host)
213
+ self.log.info("Base URL: %s", base_url)
214
+ self.log.info("Tenant id: %s", tenant_id)
215
+ self.log.info("Client id: %s", client_id)
216
+ self.log.info("Client secret: %s", client_secret)
217
+ self.log.info("API version: %s", api_version.value)
218
+ self.log.info("Scope: %s", scopes)
219
+ self.log.info("Verify: %s", verify)
220
+ self.log.info("Timeout: %s", self.timeout)
221
+ self.log.info("Trust env: %s", trust_env)
222
+ self.log.info("Authority: %s", authority)
223
+ self.log.info("Disable instance discovery: %s", disable_instance_discovery)
224
+ self.log.info("Allowed hosts: %s", allowed_hosts)
225
+ self.log.info("Proxies: %s", proxies)
226
+ self.log.info("MSAL Proxies: %s", msal_proxies)
227
+ self.log.info("HTTPX Proxies: %s", httpx_proxies)
228
+ credentials = ClientSecretCredential(
229
+ tenant_id=tenant_id, # type: ignore
230
+ client_id=connection.login,
231
+ client_secret=connection.password,
232
+ authority=authority,
233
+ proxies=msal_proxies,
234
+ disable_instance_discovery=disable_instance_discovery,
235
+ connection_verify=verify,
236
+ )
237
+ http_client = GraphClientFactory.create_with_default_middleware(
238
+ api_version=api_version,
239
+ client=httpx.AsyncClient(
240
+ proxies=httpx_proxies,
241
+ timeout=Timeout(timeout=self.timeout),
242
+ verify=verify,
243
+ trust_env=trust_env,
244
+ ),
245
+ host=host,
246
+ )
247
+ auth_provider = AzureIdentityAuthenticationProvider(
248
+ credentials=credentials,
249
+ scopes=scopes,
250
+ allowed_hosts=allowed_hosts,
251
+ )
252
+ request_adapter = HttpxRequestAdapter(
253
+ authentication_provider=auth_provider,
254
+ http_client=http_client,
255
+ base_url=base_url,
256
+ )
257
+ self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
258
+ self._api_version = api_version
259
+ return request_adapter
260
+
261
+ def test_connection(self):
262
+ """Test HTTP Connection."""
263
+ try:
264
+ self.run()
265
+ return True, "Connection successfully tested"
266
+ except Exception as e:
267
+ return False, str(e)
268
+
269
+ async def run(
270
+ self,
271
+ url: str = "",
272
+ response_type: ResponseType | None = None,
273
+ path_parameters: dict[str, Any] | None = None,
274
+ method: str = "GET",
275
+ query_parameters: dict[str, QueryParams] | None = None,
276
+ headers: dict[str, str] | None = None,
277
+ data: dict[str, Any] | str | BytesIO | None = None,
278
+ ):
279
+ self.log.info("Executing url '%s' as '%s'", url, method)
280
+
281
+ response = await self.get_conn().send_primitive_async(
282
+ request_info=self.request_information(
283
+ url=url,
284
+ response_type=response_type,
285
+ path_parameters=path_parameters,
286
+ method=method,
287
+ query_parameters=query_parameters,
288
+ headers=headers,
289
+ data=data,
290
+ ),
291
+ response_type=response_type,
292
+ error_map=self.error_mapping(),
293
+ )
294
+
295
+ self.log.info("response: %s", response)
296
+
297
+ return response
298
+
299
+ def request_information(
300
+ self,
301
+ url: str,
302
+ response_type: ResponseType | None = None,
303
+ path_parameters: dict[str, Any] | None = None,
304
+ method: str = "GET",
305
+ query_parameters: dict[str, QueryParams] | None = None,
306
+ headers: dict[str, str] | None = None,
307
+ data: dict[str, Any] | str | BytesIO | None = None,
308
+ ) -> RequestInformation:
309
+ request_information = RequestInformation()
310
+ request_information.path_parameters = path_parameters or {}
311
+ request_information.http_method = Method(method.strip().upper())
312
+ request_information.query_parameters = self.encoded_query_parameters(query_parameters)
313
+ if url.startswith("http"):
314
+ request_information.url = url
315
+ elif request_information.query_parameters.keys():
316
+ query = ",".join(request_information.query_parameters.keys())
317
+ request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}{{?{query}}}"
318
+ else:
319
+ request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}"
320
+ if not response_type:
321
+ request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption(
322
+ response_handler=DefaultResponseHandler()
323
+ )
324
+ headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS
325
+ for header_name, header_value in headers.items():
326
+ request_information.headers.try_add(header_name=header_name, header_value=header_value)
327
+ if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str):
328
+ request_information.content = data
329
+ elif data:
330
+ request_information.headers.try_add(
331
+ header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json"
332
+ )
333
+ request_information.content = json.dumps(data).encode("utf-8")
334
+ return request_information
335
+
336
+ @staticmethod
337
+ def normalize_url(url: str) -> str | None:
338
+ if url.startswith("/"):
339
+ return url.replace("/", "", 1)
340
+ return url
341
+
342
+ @staticmethod
343
+ def encoded_query_parameters(query_parameters) -> dict:
344
+ if query_parameters:
345
+ return {quote(key): value for key, value in query_parameters.items()}
346
+ return {}
347
+
348
+ @staticmethod
349
+ def error_mapping() -> dict[str, ParsableFactory | None]:
350
+ return {
351
+ "4XX": APIError,
352
+ "5XX": APIError,
353
+ }
@@ -312,6 +312,7 @@ class AzureSynapsePipelineHook(BaseAzureSynapseHook):
312
312
  warnings.warn(
313
313
  "The usage of `default_conn_name=azure_synapse_connection` is deprecated and will be removed in future. Please update your code to use the new default connection name: `default_conn_name=azure_synapse_default`. ",
314
314
  AirflowProviderDeprecationWarning,
315
+ stacklevel=2,
315
316
  )
316
317
  self._conn: ArtifactsClient | None = None
317
318
  self.azure_synapse_workspace_dev_endpoint = azure_synapse_workspace_dev_endpoint
@@ -30,7 +30,6 @@ import logging
30
30
  import os
31
31
  from functools import cached_property
32
32
  from typing import TYPE_CHECKING, Any, Union
33
- from urllib.parse import urlparse
34
33
 
35
34
  from asgiref.sync import sync_to_async
36
35
  from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
@@ -52,6 +51,7 @@ from airflow.providers.microsoft.azure.utils import (
52
51
  add_managed_identity_connection_widgets,
53
52
  get_async_default_azure_credential,
54
53
  get_sync_default_azure_credential,
54
+ parse_blob_account_url,
55
55
  )
56
56
 
57
57
  if TYPE_CHECKING:
@@ -167,21 +167,7 @@ class WasbHook(BaseHook):
167
167
  # connection_string auth takes priority
168
168
  return BlobServiceClient.from_connection_string(connection_string, **extra)
169
169
 
170
- account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
171
- parsed_url = urlparse(account_url)
172
-
173
- if not parsed_url.netloc:
174
- if "." not in parsed_url.path:
175
- # if there's no netloc and no dots in the path, then user only
176
- # provided the Active Directory ID, not the full URL or DNS name
177
- account_url = f"https://{conn.login}.blob.core.windows.net/"
178
- else:
179
- # if there's no netloc but there are dots in the path, then user
180
- # provided the DNS name without the https:// prefix.
181
- # Azure storage account name can only be 3 to 24 characters in length
182
- # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name
183
- acc_name = account_url.split(".")[0][:24]
184
- account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:])
170
+ account_url = parse_blob_account_url(conn.host, conn.login)
185
171
 
186
172
  tenant = self._get_field(extra, "tenant_id")
187
173
  if tenant:
@@ -587,21 +573,7 @@ class WasbAsyncHook(WasbHook):
587
573
  )
588
574
  return self.blob_service_client
589
575
 
590
- account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
591
- parsed_url = urlparse(account_url)
592
-
593
- if not parsed_url.netloc:
594
- if "." not in parsed_url.path:
595
- # if there's no netloc and no dots in the path, then user only
596
- # provided the Active Directory ID, not the full URL or DNS name
597
- account_url = f"https://{conn.login}.blob.core.windows.net/"
598
- else:
599
- # if there's no netloc but there are dots in the path, then user
600
- # provided the DNS name without the https:// prefix.
601
- # Azure storage account name can only be 3 to 24 characters in length
602
- # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name
603
- acc_name = account_url.split(".")[0][:24]
604
- account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:])
576
+ account_url = parse_blob_account_url(conn.host, conn.login)
605
577
 
606
578
  tenant = self._get_field(extra, "tenant_id")
607
579
  if tenant:
@@ -25,8 +25,10 @@ from typing import TYPE_CHECKING, Any, Sequence
25
25
  from azure.mgmt.containerinstance.models import (
26
26
  Container,
27
27
  ContainerGroup,
28
+ ContainerGroupDiagnostics,
28
29
  ContainerGroupSubnetId,
29
30
  ContainerPort,
31
+ DnsConfiguration,
30
32
  EnvironmentVariable,
31
33
  IpAddress,
32
34
  ResourceRequests,
@@ -90,6 +92,8 @@ class AzureContainerInstancesOperator(BaseOperator):
90
92
  Possible values include: 'Always', 'OnFailure', 'Never'
91
93
  :param ip_address: The IP address type of the container group.
92
94
  :param subnet_ids: The subnet resource IDs for a container group
95
+ :param dns_config: The DNS configuration for a container group.
96
+ :param diagnostics: Container group diagnostic information (Log Analytics).
93
97
 
94
98
  **Example**::
95
99
 
@@ -113,6 +117,13 @@ class AzureContainerInstancesOperator(BaseOperator):
113
117
  memory_in_gb=14.0,
114
118
  cpu=4.0,
115
119
  gpu=GpuResource(count=1, sku="K80"),
120
+ dns_config=["10.0.0.10", "10.0.0.11"],
121
+ diagnostics={
122
+ "log_analytics": {
123
+ "workspaceId": "workspaceid",
124
+ "workspaceKey": "workspaceKey",
125
+ }
126
+ },
116
127
  command=["/bin/echo", "world"],
117
128
  task_id="start_container",
118
129
  )
@@ -145,6 +156,8 @@ class AzureContainerInstancesOperator(BaseOperator):
145
156
  ip_address: IpAddress | None = None,
146
157
  ports: list[ContainerPort] | None = None,
147
158
  subnet_ids: list[ContainerGroupSubnetId] | None = None,
159
+ dns_config: DnsConfiguration | None = None,
160
+ diagnostics: ContainerGroupDiagnostics | None = None,
148
161
  **kwargs,
149
162
  ) -> None:
150
163
  super().__init__(**kwargs)
@@ -183,6 +196,8 @@ class AzureContainerInstancesOperator(BaseOperator):
183
196
  self.ip_address = ip_address
184
197
  self.ports = ports
185
198
  self.subnet_ids = subnet_ids
199
+ self.dns_config = dns_config
200
+ self.diagnostics = diagnostics
186
201
 
187
202
  def execute(self, context: Context) -> int:
188
203
  # Check name again in case it was templated.
@@ -256,6 +271,8 @@ class AzureContainerInstancesOperator(BaseOperator):
256
271
  tags=self.tags,
257
272
  ip_address=self.ip_address,
258
273
  subnet_ids=self.subnet_ids,
274
+ dns_config=self.dns_config,
275
+ diagnostics=self.diagnostics,
259
276
  )
260
277
 
261
278
  self._ci_hook.create_or_update(self.resource_group, self.name, container_group)