apache-airflow-providers-microsoft-azure 12.0.0rc2__py3-none-any.whl → 12.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. airflow/providers/microsoft/azure/LICENSE +0 -52
  2. airflow/providers/microsoft/azure/__init__.py +1 -1
  3. airflow/providers/microsoft/azure/fs/adls.py +1 -2
  4. airflow/providers/microsoft/azure/get_provider_info.py +52 -46
  5. airflow/providers/microsoft/azure/hooks/adx.py +6 -7
  6. airflow/providers/microsoft/azure/hooks/asb.py +237 -8
  7. airflow/providers/microsoft/azure/hooks/base_azure.py +2 -3
  8. airflow/providers/microsoft/azure/hooks/batch.py +1 -2
  9. airflow/providers/microsoft/azure/hooks/container_instance.py +3 -4
  10. airflow/providers/microsoft/azure/hooks/container_registry.py +2 -3
  11. airflow/providers/microsoft/azure/hooks/container_volume.py +2 -3
  12. airflow/providers/microsoft/azure/hooks/cosmos.py +4 -5
  13. airflow/providers/microsoft/azure/hooks/data_factory.py +7 -7
  14. airflow/providers/microsoft/azure/hooks/data_lake.py +8 -9
  15. airflow/providers/microsoft/azure/hooks/fileshare.py +1 -2
  16. airflow/providers/microsoft/azure/hooks/msgraph.py +102 -35
  17. airflow/providers/microsoft/azure/hooks/synapse.py +4 -5
  18. airflow/providers/microsoft/azure/hooks/wasb.py +9 -9
  19. airflow/providers/microsoft/azure/log/wasb_task_handler.py +1 -2
  20. airflow/providers/microsoft/azure/operators/adx.py +1 -2
  21. airflow/providers/microsoft/azure/operators/asb.py +50 -62
  22. airflow/providers/microsoft/azure/operators/batch.py +1 -2
  23. airflow/providers/microsoft/azure/operators/container_instances.py +7 -7
  24. airflow/providers/microsoft/azure/operators/msgraph.py +44 -12
  25. airflow/providers/microsoft/azure/operators/powerbi.py +34 -5
  26. airflow/providers/microsoft/azure/operators/synapse.py +1 -2
  27. airflow/providers/microsoft/azure/secrets/key_vault.py +3 -4
  28. airflow/providers/microsoft/azure/sensors/msgraph.py +21 -5
  29. airflow/providers/microsoft/azure/triggers/data_factory.py +1 -2
  30. airflow/providers/microsoft/azure/triggers/msgraph.py +4 -0
  31. airflow/providers/microsoft/azure/triggers/powerbi.py +55 -11
  32. airflow/providers/microsoft/azure/utils.py +2 -1
  33. {apache_airflow_providers_microsoft_azure-12.0.0rc2.dist-info → apache_airflow_providers_microsoft_azure-12.2.0.dist-info}/METADATA +21 -38
  34. apache_airflow_providers_microsoft_azure-12.2.0.dist-info/RECORD +58 -0
  35. apache_airflow_providers_microsoft_azure-12.0.0rc2.dist-info/RECORD +0 -58
  36. {apache_airflow_providers_microsoft_azure-12.0.0rc2.dist-info → apache_airflow_providers_microsoft_azure-12.2.0.dist-info}/WHEEL +0 -0
  37. {apache_airflow_providers_microsoft_azure-12.0.0rc2.dist-info → apache_airflow_providers_microsoft_azure-12.2.0.dist-info}/entry_points.txt +0 -0
@@ -22,15 +22,14 @@ from __future__ import annotations
22
22
  from functools import cached_property
23
23
  from typing import Any
24
24
 
25
- from azure.mgmt.containerinstance.models import ImageRegistryCredential
26
- from azure.mgmt.containerregistry import ContainerRegistryManagementClient
27
-
28
25
  from airflow.hooks.base import BaseHook
29
26
  from airflow.providers.microsoft.azure.utils import (
30
27
  add_managed_identity_connection_widgets,
31
28
  get_field,
32
29
  get_sync_default_azure_credential,
33
30
  )
31
+ from azure.mgmt.containerinstance.models import ImageRegistryCredential
32
+ from azure.mgmt.containerregistry import ContainerRegistryManagementClient
34
33
 
35
34
 
36
35
  class AzureContainerRegistryHook(BaseHook):
@@ -18,15 +18,14 @@ from __future__ import annotations
18
18
 
19
19
  from typing import Any
20
20
 
21
- from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
22
- from azure.mgmt.storage import StorageManagementClient
23
-
24
21
  from airflow.hooks.base import BaseHook
25
22
  from airflow.providers.microsoft.azure.utils import (
26
23
  add_managed_identity_connection_widgets,
27
24
  get_field,
28
25
  get_sync_default_azure_credential,
29
26
  )
27
+ from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
28
+ from azure.mgmt.storage import StorageManagementClient
30
29
 
31
30
 
32
31
  class AzureContainerVolumeHook(BaseHook):
@@ -30,11 +30,6 @@ import uuid
30
30
  from typing import TYPE_CHECKING, Any, Union
31
31
  from urllib.parse import urlparse
32
32
 
33
- from azure.cosmos import PartitionKey
34
- from azure.cosmos.cosmos_client import CosmosClient
35
- from azure.cosmos.exceptions import CosmosHttpResponseError
36
- from azure.mgmt.cosmosdb import CosmosDBManagementClient
37
-
38
33
  from airflow.exceptions import AirflowBadRequest, AirflowException
39
34
  from airflow.hooks.base import BaseHook
40
35
  from airflow.providers.microsoft.azure.utils import (
@@ -42,6 +37,10 @@ from airflow.providers.microsoft.azure.utils import (
42
37
  get_field,
43
38
  get_sync_default_azure_credential,
44
39
  )
40
+ from azure.cosmos import PartitionKey
41
+ from azure.cosmos.cosmos_client import CosmosClient
42
+ from azure.cosmos.exceptions import CosmosHttpResponseError
43
+ from azure.mgmt.cosmosdb import CosmosDBManagementClient
45
44
 
46
45
  if TYPE_CHECKING:
47
46
  PartitionKeyType = Union[str, list[str]]
@@ -39,13 +39,6 @@ from functools import wraps
39
39
  from typing import IO, TYPE_CHECKING, Any, Callable, TypeVar, Union, cast
40
40
 
41
41
  from asgiref.sync import sync_to_async
42
- from azure.identity import ClientSecretCredential, DefaultAzureCredential
43
- from azure.identity.aio import (
44
- ClientSecretCredential as AsyncClientSecretCredential,
45
- DefaultAzureCredential as AsyncDefaultAzureCredential,
46
- )
47
- from azure.mgmt.datafactory import DataFactoryManagementClient
48
- from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient
49
42
 
50
43
  from airflow.exceptions import AirflowException
51
44
  from airflow.hooks.base import BaseHook
@@ -54,6 +47,13 @@ from airflow.providers.microsoft.azure.utils import (
54
47
  get_async_default_azure_credential,
55
48
  get_sync_default_azure_credential,
56
49
  )
50
+ from azure.identity import ClientSecretCredential, DefaultAzureCredential
51
+ from azure.identity.aio import (
52
+ ClientSecretCredential as AsyncClientSecretCredential,
53
+ DefaultAzureCredential as AsyncDefaultAzureCredential,
54
+ )
55
+ from azure.mgmt.datafactory import DataFactoryManagementClient
56
+ from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient
57
57
 
58
58
  if TYPE_CHECKING:
59
59
  from azure.core.polling import LROPoller
@@ -20,6 +20,14 @@ from __future__ import annotations
20
20
  from functools import cached_property
21
21
  from typing import Any, Union
22
22
 
23
+ from airflow.exceptions import AirflowException
24
+ from airflow.hooks.base import BaseHook
25
+ from airflow.providers.microsoft.azure.utils import (
26
+ AzureIdentityCredentialAdapter,
27
+ add_managed_identity_connection_widgets,
28
+ get_field,
29
+ get_sync_default_azure_credential,
30
+ )
23
31
  from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
24
32
  from azure.datalake.store import core, lib, multithread
25
33
  from azure.identity import ClientSecretCredential, DefaultAzureCredential
@@ -32,15 +40,6 @@ from azure.storage.filedatalake import (
32
40
  FileSystemProperties,
33
41
  )
34
42
 
35
- from airflow.exceptions import AirflowException
36
- from airflow.hooks.base import BaseHook
37
- from airflow.providers.microsoft.azure.utils import (
38
- AzureIdentityCredentialAdapter,
39
- add_managed_identity_connection_widgets,
40
- get_field,
41
- get_sync_default_azure_credential,
42
- )
43
-
44
43
  Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter, DefaultAzureCredential]
45
44
 
46
45
 
@@ -19,13 +19,12 @@ from __future__ import annotations
19
19
 
20
20
  from typing import IO, Any
21
21
 
22
- from azure.storage.fileshare import FileProperties, ShareDirectoryClient, ShareFileClient, ShareServiceClient
23
-
24
22
  from airflow.hooks.base import BaseHook
25
23
  from airflow.providers.microsoft.azure.utils import (
26
24
  add_managed_identity_connection_widgets,
27
25
  get_sync_default_azure_credential,
28
26
  )
27
+ from azure.storage.fileshare import FileProperties, ShareDirectoryClient, ShareFileClient, ShareServiceClient
29
28
 
30
29
 
31
30
  class AzureFileShareHook(BaseHook):
@@ -18,6 +18,7 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import json
21
+ from ast import literal_eval
21
22
  from contextlib import suppress
22
23
  from http import HTTPStatus
23
24
  from io import BytesIO
@@ -26,8 +27,7 @@ from typing import TYPE_CHECKING, Any
26
27
  from urllib.parse import quote, urljoin, urlparse
27
28
 
28
29
  import httpx
29
- from azure.identity import ClientSecretCredential
30
- from httpx import Timeout
30
+ from httpx import AsyncHTTPTransport, Timeout
31
31
  from kiota_abstractions.api_error import APIError
32
32
  from kiota_abstractions.method import Method
33
33
  from kiota_abstractions.request_information import RequestInformation
@@ -43,8 +43,14 @@ from kiota_serialization_text.text_parse_node_factory import TextParseNodeFactor
43
43
  from msgraph_core import APIVersion, GraphClientFactory
44
44
  from msgraph_core._enums import NationalClouds
45
45
 
46
- from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException
46
+ from airflow.exceptions import (
47
+ AirflowBadRequest,
48
+ AirflowConfigException,
49
+ AirflowException,
50
+ AirflowNotFoundException,
51
+ )
47
52
  from airflow.hooks.base import BaseHook
53
+ from azure.identity import CertificateCredential, ClientSecretCredential
48
54
 
49
55
  if TYPE_CHECKING:
50
56
  from kiota_abstractions.request_adapter import RequestAdapter
@@ -54,6 +60,7 @@ if TYPE_CHECKING:
54
60
  from kiota_http.httpx_request_adapter import ResponseType
55
61
 
56
62
  from airflow.models import Connection
63
+ from azure.identity._internal.client_credential_base import ClientCredentialBase
57
64
 
58
65
 
59
66
  class DefaultResponseHandler(ResponseHandler):
@@ -107,6 +114,7 @@ class KiotaRequestAdapterHook(BaseHook):
107
114
  """
108
115
 
109
116
  DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
117
+ DEFAULT_SCOPE = "https://graph.microsoft.com/.default"
110
118
  cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
111
119
  conn_type: str = "msgraph"
112
120
  conn_name_attr: str = "conn_id"
@@ -119,7 +127,7 @@ class KiotaRequestAdapterHook(BaseHook):
119
127
  timeout: float | None = None,
120
128
  proxies: dict | None = None,
121
129
  host: str = NationalClouds.Global.value,
122
- scopes: list[str] | None = None,
130
+ scopes: str | list[str] | None = None,
123
131
  api_version: APIVersion | str | None = None,
124
132
  ):
125
133
  super().__init__()
@@ -127,7 +135,10 @@ class KiotaRequestAdapterHook(BaseHook):
127
135
  self.timeout = timeout
128
136
  self.proxies = proxies
129
137
  self.host = host
130
- self.scopes = scopes or ["https://graph.microsoft.com/.default"]
138
+ if isinstance(scopes, str):
139
+ self.scopes = [scopes]
140
+ else:
141
+ self.scopes = scopes or [self.DEFAULT_SCOPE]
131
142
  self._api_version = self.resolve_api_version_from_value(api_version)
132
143
 
133
144
  @classmethod
@@ -140,20 +151,21 @@ class KiotaRequestAdapterHook(BaseHook):
140
151
  return {
141
152
  "tenant_id": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()),
142
153
  "api_version": StringField(
143
- lazy_gettext("API Version"), widget=BS3TextFieldWidget(), default="v1.0"
154
+ lazy_gettext("API Version"), widget=BS3TextFieldWidget(), default=APIVersion.v1.value
144
155
  ),
145
156
  "authority": StringField(lazy_gettext("Authority"), widget=BS3TextFieldWidget()),
157
+ "certificate_path": StringField(lazy_gettext("Certificate path"), widget=BS3TextFieldWidget()),
158
+ "certificate_data": StringField(lazy_gettext("Certificate data"), widget=BS3TextFieldWidget()),
146
159
  "scopes": StringField(
147
160
  lazy_gettext("Scopes"),
148
161
  widget=BS3TextFieldWidget(),
149
- default="https://graph.microsoft.com/.default",
162
+ default=cls.DEFAULT_SCOPE,
150
163
  ),
151
164
  "disable_instance_discovery": BooleanField(
152
165
  lazy_gettext("Disable instance discovery"), default=False
153
166
  ),
154
- "allowed_hosts": StringField(lazy_gettext("Allowed"), widget=BS3TextFieldWidget()),
167
+ "allowed_hosts": StringField(lazy_gettext("Allowed hosts"), widget=BS3TextFieldWidget()),
155
168
  "proxies": StringField(lazy_gettext("Proxies"), widget=BS3TextAreaFieldWidget()),
156
- "stream": BooleanField(lazy_gettext("Stream"), default=False),
157
169
  "verify": BooleanField(lazy_gettext("Verify"), default=True),
158
170
  "trust_env": BooleanField(lazy_gettext("Trust environment"), default=True),
159
171
  "base_url": StringField(lazy_gettext("Base URL"), widget=BS3TextFieldWidget()),
@@ -201,24 +213,25 @@ class KiotaRequestAdapterHook(BaseHook):
201
213
  @staticmethod
202
214
  def format_no_proxy_url(url: str) -> str:
203
215
  if "://" not in url:
204
- url = f"all://{url}"
216
+ return f"all://{url}"
205
217
  return url
206
218
 
207
219
  @classmethod
208
220
  def to_httpx_proxies(cls, proxies: dict) -> dict:
209
- proxies = proxies.copy()
210
- if proxies.get("http"):
211
- proxies["http://"] = proxies.pop("http")
212
- if proxies.get("https"):
213
- proxies["https://"] = proxies.pop("https")
214
- if proxies.get("no"):
215
- for url in proxies.pop("no", "").split(","):
216
- proxies[cls.format_no_proxy_url(url.strip())] = None
221
+ if proxies:
222
+ proxies = proxies.copy()
223
+ if proxies.get("http"):
224
+ proxies["http://"] = AsyncHTTPTransport(proxy=proxies.pop("http"))
225
+ if proxies.get("https"):
226
+ proxies["https://"] = AsyncHTTPTransport(proxy=proxies.pop("https"))
227
+ if proxies.get("no"):
228
+ for url in proxies.pop("no", "").split(","):
229
+ proxies[cls.format_no_proxy_url(url.strip())] = None
217
230
  return proxies
218
231
 
219
- def to_msal_proxies(self, authority: str | None, proxies: dict):
232
+ def to_msal_proxies(self, authority: str | None, proxies: dict) -> dict | None:
220
233
  self.log.debug("authority: %s", authority)
221
- if authority:
234
+ if authority and proxies:
222
235
  no_proxies = proxies.get("no")
223
236
  self.log.debug("no_proxies: %s", no_proxies)
224
237
  if no_proxies:
@@ -241,18 +254,17 @@ class KiotaRequestAdapterHook(BaseHook):
241
254
  client_id = connection.login
242
255
  client_secret = connection.password
243
256
  config = connection.extra_dejson if connection.extra else {}
244
- tenant_id = config.get("tenant_id") or config.get("tenantId")
245
257
  api_version = self.get_api_version(config)
246
258
  host = self.get_host(connection)
247
259
  base_url = config.get("base_url", urljoin(host, api_version))
248
260
  authority = config.get("authority")
249
- proxies = self.proxies or config.get("proxies", {})
250
- msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
261
+ proxies = self.get_proxies(config)
251
262
  httpx_proxies = self.to_httpx_proxies(proxies=proxies)
252
263
  scopes = config.get("scopes", self.scopes)
264
+ if isinstance(scopes, str):
265
+ scopes = scopes.split(",")
253
266
  verify = config.get("verify", True)
254
267
  trust_env = config.get("trust_env", False)
255
- disable_instance_discovery = config.get("disable_instance_discovery", False)
256
268
  allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")
257
269
 
258
270
  self.log.info(
@@ -262,7 +274,6 @@ class KiotaRequestAdapterHook(BaseHook):
262
274
  )
263
275
  self.log.info("Host: %s", host)
264
276
  self.log.info("Base URL: %s", base_url)
265
- self.log.info("Tenant id: %s", tenant_id)
266
277
  self.log.info("Client id: %s", client_id)
267
278
  self.log.info("Client secret: %s", client_secret)
268
279
  self.log.info("API version: %s", api_version)
@@ -271,24 +282,21 @@ class KiotaRequestAdapterHook(BaseHook):
271
282
  self.log.info("Timeout: %s", self.timeout)
272
283
  self.log.info("Trust env: %s", trust_env)
273
284
  self.log.info("Authority: %s", authority)
274
- self.log.info("Disable instance discovery: %s", disable_instance_discovery)
275
285
  self.log.info("Allowed hosts: %s", allowed_hosts)
276
286
  self.log.info("Proxies: %s", proxies)
277
- self.log.info("MSAL Proxies: %s", msal_proxies)
278
287
  self.log.info("HTTPX Proxies: %s", httpx_proxies)
279
- credentials = ClientSecretCredential(
280
- tenant_id=tenant_id, # type: ignore
281
- client_id=connection.login,
282
- client_secret=connection.password,
288
+ credentials = self.get_credentials(
289
+ login=connection.login,
290
+ password=connection.password,
291
+ config=config,
283
292
  authority=authority,
284
- proxies=msal_proxies,
285
- disable_instance_discovery=disable_instance_discovery,
286
- connection_verify=verify,
293
+ verify=verify,
294
+ proxies=proxies,
287
295
  )
288
296
  http_client = GraphClientFactory.create_with_default_middleware(
289
297
  api_version=api_version, # type: ignore
290
298
  client=httpx.AsyncClient(
291
- proxies=httpx_proxies,
299
+ mounts=httpx_proxies,
292
300
  timeout=Timeout(timeout=self.timeout),
293
301
  verify=verify,
294
302
  trust_env=trust_env,
@@ -313,6 +321,65 @@ class KiotaRequestAdapterHook(BaseHook):
313
321
  self._api_version = api_version
314
322
  return request_adapter
315
323
 
324
+ def get_proxies(self, config: dict) -> dict:
325
+ proxies = self.proxies or config.get("proxies", {})
326
+ if isinstance(proxies, str):
327
+ # TODO: Once provider depends on Airflow 2.10 or higher code below won't be needed anymore as
328
+ # we could then use the get_extra_dejson method on the connection which deserializes
329
+ # nested json. Make sure to use connection.get_extra_dejson(nested=True) instead of
330
+ # connection.extra_dejson.
331
+ with suppress(JSONDecodeError):
332
+ proxies = json.loads(proxies)
333
+ with suppress(Exception):
334
+ proxies = literal_eval(proxies)
335
+ if not isinstance(proxies, dict):
336
+ raise AirflowConfigException(
337
+ f"Proxies must be of type dict, got {type(proxies).__name__} instead!"
338
+ )
339
+ return proxies
340
+
341
+ def get_credentials(
342
+ self,
343
+ login: str | None,
344
+ password: str | None,
345
+ config,
346
+ authority: str | None,
347
+ verify: bool,
348
+ proxies: dict,
349
+ ) -> ClientCredentialBase:
350
+ tenant_id = config.get("tenant_id") or config.get("tenantId")
351
+ certificate_path = config.get("certificate_path")
352
+ certificate_data = config.get("certificate_data")
353
+ disable_instance_discovery = config.get("disable_instance_discovery", False)
354
+ msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
355
+ self.log.info("Tenant id: %s", tenant_id)
356
+ self.log.info("Certificate path: %s", certificate_path)
357
+ self.log.info("Certificate data: %s", certificate_data is not None)
358
+ self.log.info("Authority: %s", authority)
359
+ self.log.info("Disable instance discovery: %s", disable_instance_discovery)
360
+ self.log.info("MSAL Proxies: %s", msal_proxies)
361
+ if certificate_path or certificate_data:
362
+ return CertificateCredential(
363
+ tenant_id=tenant_id, # type: ignore
364
+ client_id=login, # type: ignore
365
+ password=password,
366
+ certificate_path=certificate_path,
367
+ certificate_data=certificate_data.encode() if certificate_data else None,
368
+ authority=authority,
369
+ proxies=msal_proxies,
370
+ disable_instance_discovery=disable_instance_discovery,
371
+ connection_verify=verify,
372
+ )
373
+ return ClientSecretCredential(
374
+ tenant_id=tenant_id, # type: ignore
375
+ client_id=login, # type: ignore
376
+ client_secret=password, # type: ignore
377
+ authority=authority,
378
+ proxies=msal_proxies,
379
+ disable_instance_discovery=disable_instance_discovery,
380
+ connection_verify=verify,
381
+ )
382
+
316
383
  def test_connection(self):
317
384
  """Test HTTP Connection."""
318
385
  try:
@@ -19,11 +19,6 @@ from __future__ import annotations
19
19
  import time
20
20
  from typing import TYPE_CHECKING, Any, Union
21
21
 
22
- from azure.core.exceptions import ServiceRequestError
23
- from azure.identity import ClientSecretCredential, DefaultAzureCredential
24
- from azure.synapse.artifacts import ArtifactsClient
25
- from azure.synapse.spark import SparkClient
26
-
27
22
  from airflow.exceptions import AirflowException, AirflowTaskTimeout
28
23
  from airflow.hooks.base import BaseHook
29
24
  from airflow.providers.microsoft.azure.utils import (
@@ -31,6 +26,10 @@ from airflow.providers.microsoft.azure.utils import (
31
26
  get_field,
32
27
  get_sync_default_azure_credential,
33
28
  )
29
+ from azure.core.exceptions import ServiceRequestError
30
+ from azure.identity import ClientSecretCredential, DefaultAzureCredential
31
+ from azure.synapse.artifacts import ArtifactsClient
32
+ from azure.synapse.spark import SparkClient
34
33
 
35
34
  if TYPE_CHECKING:
36
35
  from azure.synapse.artifacts.models import CreateRunResponse, PipelineRun
@@ -32,6 +32,15 @@ from functools import cached_property
32
32
  from typing import TYPE_CHECKING, Any, Union
33
33
 
34
34
  from asgiref.sync import sync_to_async
35
+
36
+ from airflow.exceptions import AirflowException
37
+ from airflow.hooks.base import BaseHook
38
+ from airflow.providers.microsoft.azure.utils import (
39
+ add_managed_identity_connection_widgets,
40
+ get_async_default_azure_credential,
41
+ get_sync_default_azure_credential,
42
+ parse_blob_account_url,
43
+ )
35
44
  from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
36
45
  from azure.identity import ClientSecretCredential
37
46
  from azure.identity.aio import (
@@ -45,15 +54,6 @@ from azure.storage.blob.aio import (
45
54
  ContainerClient as AsyncContainerClient,
46
55
  )
47
56
 
48
- from airflow.exceptions import AirflowException
49
- from airflow.hooks.base import BaseHook
50
- from airflow.providers.microsoft.azure.utils import (
51
- add_managed_identity_connection_widgets,
52
- get_async_default_azure_credential,
53
- get_sync_default_azure_credential,
54
- parse_blob_account_url,
55
- )
56
-
57
57
  if TYPE_CHECKING:
58
58
  from azure.storage.blob._models import BlobProperties
59
59
 
@@ -23,11 +23,10 @@ from functools import cached_property
23
23
  from pathlib import Path
24
24
  from typing import TYPE_CHECKING
25
25
 
26
- from azure.core.exceptions import HttpResponseError
27
-
28
26
  from airflow.configuration import conf
29
27
  from airflow.utils.log.file_task_handler import FileTaskHandler
30
28
  from airflow.utils.log.logging_mixin import LoggingMixin
29
+ from azure.core.exceptions import HttpResponseError
31
30
 
32
31
  if TYPE_CHECKING:
33
32
  import logging
@@ -28,9 +28,8 @@ from airflow.models import BaseOperator
28
28
  from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
29
29
 
30
30
  if TYPE_CHECKING:
31
- from azure.kusto.data._models import KustoResultTable
32
-
33
31
  from airflow.utils.context import Context
32
+ from azure.kusto.data._models import KustoResultTable
34
33
 
35
34
 
36
35
  class AzureDataExplorerQueryOperator(BaseOperator):
@@ -19,18 +19,15 @@ from __future__ import annotations
19
19
  from collections.abc import Sequence
20
20
  from typing import TYPE_CHECKING, Any, Callable
21
21
 
22
- from azure.core.exceptions import ResourceNotFoundError
23
-
24
22
  from airflow.models import BaseOperator
25
23
  from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook
26
24
 
27
25
  if TYPE_CHECKING:
28
26
  import datetime
29
27
 
30
- from azure.servicebus import ServiceBusMessage
31
- from azure.servicebus.management._models import AuthorizationRule
32
-
33
28
  from airflow.utils.context import Context
29
+ from azure.servicebus import ServiceBusMessage
30
+ from azure.servicebus.management import AuthorizationRule, CorrelationRuleFilter, SqlRuleFilter
34
31
 
35
32
  MessageCallback = Callable[[ServiceBusMessage, Context], None]
36
33
 
@@ -313,33 +310,23 @@ class AzureServiceBusTopicCreateOperator(BaseOperator):
313
310
  # Create the hook
314
311
  hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)
315
312
 
316
- with hook.get_conn() as service_mgmt_conn:
317
- try:
318
- topic_properties = service_mgmt_conn.get_topic(self.topic_name)
319
- except ResourceNotFoundError:
320
- topic_properties = None
321
- if topic_properties and topic_properties.name == self.topic_name:
322
- self.log.info("Topic name already exists")
323
- return topic_properties.name
324
- topic = service_mgmt_conn.create_topic(
325
- topic_name=self.topic_name,
326
- default_message_time_to_live=self.default_message_time_to_live,
327
- max_size_in_megabytes=self.max_size_in_megabytes,
328
- requires_duplicate_detection=self.requires_duplicate_detection,
329
- duplicate_detection_history_time_window=self.duplicate_detection_history_time_window,
330
- enable_batched_operations=self.enable_batched_operations,
331
- size_in_bytes=self.size_in_bytes,
332
- filtering_messages_before_publishing=self.filtering_messages_before_publishing,
333
- authorization_rules=self.authorization_rules,
334
- support_ordering=self.support_ordering,
335
- auto_delete_on_idle=self.auto_delete_on_idle,
336
- enable_partitioning=self.enable_partitioning,
337
- enable_express=self.enable_express,
338
- user_metadata=self.user_metadata,
339
- max_message_size_in_kilobytes=self.max_message_size_in_kilobytes,
340
- )
341
- self.log.info("Created Topic %s", topic.name)
342
- return topic.name
313
+ return hook.create_topic(
314
+ topic_name=self.topic_name,
315
+ default_message_time_to_live=self.default_message_time_to_live,
316
+ max_size_in_megabytes=self.max_size_in_megabytes,
317
+ requires_duplicate_detection=self.requires_duplicate_detection,
318
+ duplicate_detection_history_time_window=self.duplicate_detection_history_time_window,
319
+ enable_batched_operations=self.enable_batched_operations,
320
+ size_in_bytes=self.size_in_bytes,
321
+ filtering_messages_before_publishing=self.filtering_messages_before_publishing,
322
+ authorization_rules=self.authorization_rules,
323
+ support_ordering=self.support_ordering,
324
+ auto_delete_on_idle=self.auto_delete_on_idle,
325
+ enable_partitioning=self.enable_partitioning,
326
+ enable_express=self.enable_express,
327
+ user_metadata=self.user_metadata,
328
+ max_message_size_in_kilobytes=self.max_message_size_in_kilobytes,
329
+ )
343
330
 
344
331
 
345
332
  class AzureServiceBusSubscriptionCreateOperator(BaseOperator):
@@ -378,6 +365,8 @@ class AzureServiceBusSubscriptionCreateOperator(BaseOperator):
378
365
  :param auto_delete_on_idle: ISO 8601 time Span idle interval after which the subscription is
379
366
  automatically deleted. The minimum duration is 5 minutes. Input value of either
380
367
  type ~datetime.timedelta or string in ISO 8601 duration format like "PT300S" is accepted.
368
+ :param filter_rule: Optional correlation or SQL rule filter to apply on the messages.
369
+ :param filter_rule_name: Optional rule name to use applying the rule filter to the subscription
381
370
  :param azure_service_bus_conn_id: Reference to the
382
371
  :ref:`Azure Service Bus connection<howto/connection:azure_service_bus>`.
383
372
  """
@@ -402,6 +391,8 @@ class AzureServiceBusSubscriptionCreateOperator(BaseOperator):
402
391
  user_metadata: str | None = None,
403
392
  forward_dead_lettered_messages_to: str | None = None,
404
393
  auto_delete_on_idle: datetime.timedelta | str | None = None,
394
+ filter_rule: CorrelationRuleFilter | SqlRuleFilter | None = None,
395
+ filter_rule_name: str | None = None,
405
396
  **kwargs,
406
397
  ) -> None:
407
398
  super().__init__(**kwargs)
@@ -419,6 +410,8 @@ class AzureServiceBusSubscriptionCreateOperator(BaseOperator):
419
410
  self.forward_dead_lettered_messages_to = forward_dead_lettered_messages_to
420
411
  self.auto_delete_on_idle = auto_delete_on_idle
421
412
  self.azure_service_bus_conn_id = azure_service_bus_conn_id
413
+ self.filter_rule = filter_rule
414
+ self.filter_rule_name = filter_rule_name
422
415
 
423
416
  def execute(self, context: Context) -> None:
424
417
  """Create Subscription in Service Bus namespace, by connecting to Service Bus Admin client."""
@@ -429,24 +422,24 @@ class AzureServiceBusSubscriptionCreateOperator(BaseOperator):
429
422
  # Create the hook
430
423
  hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)
431
424
 
432
- with hook.get_conn() as service_mgmt_conn:
433
- # create subscription with name
434
- subscription = service_mgmt_conn.create_subscription(
435
- topic_name=self.topic_name,
436
- subscription_name=self.subscription_name,
437
- lock_duration=self.lock_duration,
438
- requires_session=self.requires_session,
439
- default_message_time_to_live=self.default_message_time_to_live,
440
- dead_lettering_on_message_expiration=self.dl_on_message_expiration,
441
- dead_lettering_on_filter_evaluation_exceptions=self.dl_on_filter_evaluation_exceptions,
442
- max_delivery_count=self.max_delivery_count,
443
- enable_batched_operations=self.enable_batched_operations,
444
- forward_to=self.forward_to,
445
- user_metadata=self.user_metadata,
446
- forward_dead_lettered_messages_to=self.forward_dead_lettered_messages_to,
447
- auto_delete_on_idle=self.auto_delete_on_idle,
448
- )
449
- self.log.info("Created subscription %s", subscription.name)
425
+ subscription = hook.create_subscription(
426
+ topic_name=self.topic_name,
427
+ subscription_name=self.subscription_name,
428
+ lock_duration=self.lock_duration,
429
+ requires_session=self.requires_session,
430
+ default_message_time_to_live=self.default_message_time_to_live,
431
+ dead_lettering_on_message_expiration=self.dl_on_message_expiration,
432
+ dead_lettering_on_filter_evaluation_exceptions=self.dl_on_filter_evaluation_exceptions,
433
+ max_delivery_count=self.max_delivery_count,
434
+ enable_batched_operations=self.enable_batched_operations,
435
+ forward_to=self.forward_to,
436
+ user_metadata=self.user_metadata,
437
+ forward_dead_lettered_messages_to=self.forward_dead_lettered_messages_to,
438
+ auto_delete_on_idle=self.auto_delete_on_idle,
439
+ filter_rule=self.filter_rule,
440
+ filter_rule_name=self.filter_rule_name,
441
+ )
442
+ self.log.info("Created subscription %s", subscription.name)
450
443
 
451
444
 
452
445
  class AzureServiceBusUpdateSubscriptionOperator(BaseOperator):
@@ -495,18 +488,13 @@ class AzureServiceBusUpdateSubscriptionOperator(BaseOperator):
495
488
  """Update Subscription properties, by connecting to Service Bus Admin client."""
496
489
  hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)
497
490
 
498
- with hook.get_conn() as service_mgmt_conn:
499
- subscription_prop = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name)
500
- if self.max_delivery_count:
501
- subscription_prop.max_delivery_count = self.max_delivery_count
502
- if self.dl_on_message_expiration is not None:
503
- subscription_prop.dead_lettering_on_message_expiration = self.dl_on_message_expiration
504
- if self.enable_batched_operations is not None:
505
- subscription_prop.enable_batched_operations = self.enable_batched_operations
506
- # update by updating the properties in the model
507
- service_mgmt_conn.update_subscription(self.topic_name, subscription_prop)
508
- updated_subscription = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name)
509
- self.log.info("Subscription Updated successfully %s", updated_subscription)
491
+ hook.update_subscription(
492
+ topic_name=self.topic_name,
493
+ subscription_name=self.subscription_name,
494
+ max_delivery_count=self.max_delivery_count,
495
+ dead_lettering_on_message_expiration=self.dl_on_message_expiration,
496
+ enable_batched_operations=self.enable_batched_operations,
497
+ )
510
498
 
511
499
 
512
500
  class ASBReceiveSubscriptionMessageOperator(BaseOperator):
@@ -21,11 +21,10 @@ from collections.abc import Sequence
21
21
  from functools import cached_property
22
22
  from typing import TYPE_CHECKING, Any
23
23
 
24
- from azure.batch import models as batch_models
25
-
26
24
  from airflow.exceptions import AirflowException
27
25
  from airflow.models import BaseOperator
28
26
  from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
27
+ from azure.batch import models as batch_models
29
28
 
30
29
  if TYPE_CHECKING:
31
30
  from airflow.utils.context import Context