apache-airflow-providers-microsoft-azure 12.7.0__py3-none-any.whl → 12.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. airflow/providers/microsoft/azure/__init__.py +1 -1
  2. airflow/providers/microsoft/azure/fs/adls.py +1 -1
  3. airflow/providers/microsoft/azure/fs/msgraph.py +111 -0
  4. airflow/providers/microsoft/azure/get_provider_info.py +12 -2
  5. airflow/providers/microsoft/azure/hooks/adx.py +1 -1
  6. airflow/providers/microsoft/azure/hooks/asb.py +1 -1
  7. airflow/providers/microsoft/azure/hooks/base_azure.py +93 -17
  8. airflow/providers/microsoft/azure/hooks/batch.py +1 -1
  9. airflow/providers/microsoft/azure/hooks/container_registry.py +1 -1
  10. airflow/providers/microsoft/azure/hooks/container_volume.py +1 -1
  11. airflow/providers/microsoft/azure/hooks/cosmos.py +1 -1
  12. airflow/providers/microsoft/azure/hooks/data_factory.py +2 -5
  13. airflow/providers/microsoft/azure/hooks/data_lake.py +1 -1
  14. airflow/providers/microsoft/azure/hooks/fileshare.py +1 -1
  15. airflow/providers/microsoft/azure/hooks/msgraph.py +18 -6
  16. airflow/providers/microsoft/azure/hooks/synapse.py +1 -1
  17. airflow/providers/microsoft/azure/hooks/wasb.py +1 -1
  18. airflow/providers/microsoft/azure/log/wasb_task_handler.py +7 -1
  19. airflow/providers/microsoft/azure/operators/adls.py +1 -1
  20. airflow/providers/microsoft/azure/operators/adx.py +1 -1
  21. airflow/providers/microsoft/azure/operators/asb.py +1 -1
  22. airflow/providers/microsoft/azure/operators/batch.py +1 -1
  23. airflow/providers/microsoft/azure/operators/container_instances.py +1 -1
  24. airflow/providers/microsoft/azure/operators/cosmos.py +1 -1
  25. airflow/providers/microsoft/azure/operators/data_factory.py +1 -13
  26. airflow/providers/microsoft/azure/operators/msgraph.py +2 -1
  27. airflow/providers/microsoft/azure/operators/powerbi.py +1 -8
  28. airflow/providers/microsoft/azure/operators/synapse.py +1 -13
  29. airflow/providers/microsoft/azure/operators/wasb_delete_blob.py +1 -1
  30. airflow/providers/microsoft/azure/sensors/cosmos.py +1 -6
  31. airflow/providers/microsoft/azure/sensors/data_factory.py +1 -6
  32. airflow/providers/microsoft/azure/sensors/msgraph.py +1 -6
  33. airflow/providers/microsoft/azure/sensors/wasb.py +1 -6
  34. airflow/providers/microsoft/azure/transfers/local_to_adls.py +1 -1
  35. airflow/providers/microsoft/azure/transfers/local_to_wasb.py +1 -1
  36. airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py +1 -1
  37. airflow/providers/microsoft/azure/transfers/s3_to_wasb.py +1 -1
  38. airflow/providers/microsoft/azure/transfers/sftp_to_wasb.py +1 -1
  39. airflow/providers/microsoft/azure/triggers/data_factory.py +4 -3
  40. airflow/providers/microsoft/azure/triggers/msgraph.py +1 -12
  41. airflow/providers/microsoft/azure/version_compat.py +0 -24
  42. {apache_airflow_providers_microsoft_azure-12.7.0.dist-info → apache_airflow_providers_microsoft_azure-12.8.1.dist-info}/METADATA +58 -42
  43. apache_airflow_providers_microsoft_azure-12.8.1.dist-info/RECORD +61 -0
  44. apache_airflow_providers_microsoft_azure-12.8.1.dist-info/licenses/NOTICE +5 -0
  45. apache_airflow_providers_microsoft_azure-12.7.0.dist-info/RECORD +0 -59
  46. {apache_airflow_providers_microsoft_azure-12.7.0.dist-info → apache_airflow_providers_microsoft_azure-12.8.1.dist-info}/WHEEL +0 -0
  47. {apache_airflow_providers_microsoft_azure-12.7.0.dist-info → apache_airflow_providers_microsoft_azure-12.8.1.dist-info}/entry_points.txt +0 -0
  48. {airflow/providers/microsoft/azure → apache_airflow_providers_microsoft_azure-12.8.1.dist-info/licenses}/LICENSE +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "12.7.0"
32
+ __version__ = "12.8.1"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.10.0"
@@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, Any
20
20
 
21
21
  from azure.identity import ClientSecretCredential
22
22
 
23
+ from airflow.providers.common.compat.sdk import BaseHook
23
24
  from airflow.providers.microsoft.azure.utils import get_field, parse_blob_account_url
24
- from airflow.providers.microsoft.azure.version_compat import BaseHook
25
25
 
26
26
  if TYPE_CHECKING:
27
27
  from fsspec import AbstractFileSystem
@@ -0,0 +1,111 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import TYPE_CHECKING, Any
21
+
22
+ from airflow.providers.common.compat.sdk import BaseHook
23
+ from airflow.providers.microsoft.azure.utils import get_field
24
+
25
+ if TYPE_CHECKING:
26
+ from fsspec import AbstractFileSystem
27
+
28
+ schemes = ["msgraph", "sharepoint", "onedrive", "msgd"]
29
+
30
+
31
+ def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) -> AbstractFileSystem:
32
+ from msgraphfs import MSGDriveFS
33
+
34
+ if conn_id is None:
35
+ return MSGDriveFS({})
36
+
37
+ conn = BaseHook.get_connection(conn_id)
38
+ extras = conn.extra_dejson
39
+ conn_type = conn.conn_type or "msgraph"
40
+
41
+ options: dict[str, Any] = {}
42
+
43
+ # Get authentication parameters with fallback handling
44
+ client_id = conn.login or get_field(
45
+ conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_id"
46
+ )
47
+ client_secret = conn.password or get_field(
48
+ conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_secret"
49
+ )
50
+ tenant_id = conn.host or get_field(
51
+ conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="tenant_id"
52
+ )
53
+
54
+ if client_id:
55
+ options["client_id"] = client_id
56
+ if client_secret:
57
+ options["client_secret"] = client_secret
58
+ if tenant_id:
59
+ options["tenant_id"] = tenant_id
60
+
61
+ # Process additional fields from extras
62
+ fields = [
63
+ "drive_id",
64
+ "scope",
65
+ "token_endpoint",
66
+ "redirect_uri",
67
+ "token_endpoint_auth_method",
68
+ "code_challenge_method",
69
+ "update_token",
70
+ "username",
71
+ "password",
72
+ ]
73
+ for field in fields:
74
+ value = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name=field)
75
+ if value is not None:
76
+ if value == "":
77
+ options.pop(field, "")
78
+ else:
79
+ options[field] = value
80
+
81
+ # Update with storage options
82
+ options.update(storage_options or {})
83
+
84
+ # Create oauth2 client parameters if authentication is provided
85
+ oauth2_client_params = {}
86
+ if options.get("client_id") and options.get("client_secret") and options.get("tenant_id"):
87
+ oauth2_client_params = {
88
+ "client_id": options["client_id"],
89
+ "client_secret": options["client_secret"],
90
+ "tenant_id": options["tenant_id"],
91
+ }
92
+
93
+ # Add additional oauth2 parameters supported by authlib
94
+ oauth2_params = [
95
+ "scope",
96
+ "token_endpoint",
97
+ "redirect_uri",
98
+ "token_endpoint_auth_method",
99
+ "code_challenge_method",
100
+ "update_token",
101
+ "username",
102
+ "password",
103
+ ]
104
+ for param in oauth2_params:
105
+ if param in options:
106
+ oauth2_client_params[param] = options[param]
107
+
108
+ # Determine which filesystem to return based on drive_id
109
+ drive_id = options.get("drive_id")
110
+
111
+ return MSGDriveFS(drive_id=drive_id, oauth2_client_params=oauth2_client_params)
@@ -37,6 +37,7 @@ def get_provider_info():
37
37
  {
38
38
  "integration-name": "Microsoft Azure Blob Storage",
39
39
  "external-doc-url": "https://azure.microsoft.com/en-us/services/storage/blobs/",
40
+ "how-to-guide": ["/docs/apache-airflow-providers-microsoft-azure/sensors/wasb_sensors.rst"],
40
41
  "logo": "/docs/integration-logos/Blob-Storage.svg",
41
42
  "tags": ["azure"],
42
43
  },
@@ -49,6 +50,9 @@ def get_provider_info():
49
50
  {
50
51
  "integration-name": "Microsoft Azure Cosmos DB",
51
52
  "external-doc-url": "https://azure.microsoft.com/en-us/services/cosmos-db/",
53
+ "how-to-guide": [
54
+ "/docs/apache-airflow-providers-microsoft-azure/sensors/cosmos_document_sensor.rst"
55
+ ],
52
56
  "logo": "/docs/integration-logos/Azure-Cosmos-DB.svg",
53
57
  "tags": ["azure"],
54
58
  },
@@ -117,7 +121,10 @@ def get_provider_info():
117
121
  "integration-name": "Microsoft Graph API",
118
122
  "external-doc-url": "https://learn.microsoft.com/en-us/graph/use-the-api/",
119
123
  "logo": "/docs/integration-logos/Microsoft-Graph-API.png",
120
- "how-to-guide": ["/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst"],
124
+ "how-to-guide": [
125
+ "/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst",
126
+ "/docs/apache-airflow-providers-microsoft-azure/sensors/msgraph.rst",
127
+ ],
121
128
  "tags": ["azure"],
122
129
  },
123
130
  {
@@ -191,7 +198,10 @@ def get_provider_info():
191
198
  "python-modules": ["airflow.providers.microsoft.azure.sensors.msgraph"],
192
199
  },
193
200
  ],
194
- "filesystems": ["airflow.providers.microsoft.azure.fs.adls"],
201
+ "filesystems": [
202
+ "airflow.providers.microsoft.azure.fs.adls",
203
+ "airflow.providers.microsoft.azure.fs.msgraphfs",
204
+ ],
195
205
  "hooks": [
196
206
  {
197
207
  "integration-name": "Microsoft Azure Container Instances",
@@ -34,11 +34,11 @@ from azure.kusto.data import ClientRequestProperties, KustoClient, KustoConnecti
34
34
  from azure.kusto.data.exceptions import KustoServiceError
35
35
 
36
36
  from airflow.exceptions import AirflowException
37
+ from airflow.providers.common.compat.sdk import BaseHook
37
38
  from airflow.providers.microsoft.azure.utils import (
38
39
  add_managed_identity_connection_widgets,
39
40
  get_sync_default_azure_credential,
40
41
  )
41
- from airflow.providers.microsoft.azure.version_compat import BaseHook
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from azure.kusto.data.response import KustoResponseDataSet
@@ -37,12 +37,12 @@ from azure.servicebus.management import (
37
37
  SubscriptionProperties,
38
38
  )
39
39
 
40
+ from airflow.providers.common.compat.sdk import BaseHook
40
41
  from airflow.providers.microsoft.azure.utils import (
41
42
  add_managed_identity_connection_widgets,
42
43
  get_field,
43
44
  get_sync_default_azure_credential,
44
45
  )
45
- from airflow.providers.microsoft.azure.version_compat import BaseHook
46
46
 
47
47
  if TYPE_CHECKING:
48
48
  import datetime
@@ -16,17 +16,24 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
- from typing import Any
19
+ from typing import TYPE_CHECKING, Any
20
20
 
21
21
  from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
22
22
  from azure.common.credentials import ServicePrincipalCredentials
23
+ from azure.identity import ClientSecretCredential, DefaultAzureCredential
23
24
 
24
25
  from airflow.exceptions import AirflowException
26
+ from airflow.providers.common.compat.sdk import BaseHook
25
27
  from airflow.providers.microsoft.azure.utils import (
26
28
  AzureIdentityCredentialAdapter,
27
29
  add_managed_identity_connection_widgets,
30
+ get_sync_default_azure_credential,
28
31
  )
29
- from airflow.providers.microsoft.azure.version_compat import BaseHook
32
+
33
+ if TYPE_CHECKING:
34
+ from azure.core.credentials import AccessToken
35
+
36
+ from airflow.sdk import Connection
30
37
 
31
38
 
32
39
  class AzureBaseHook(BaseHook):
@@ -85,7 +92,7 @@ class AzureBaseHook(BaseHook):
85
92
  },
86
93
  }
87
94
 
88
- def __init__(self, sdk_client: Any, conn_id: str = "azure_default"):
95
+ def __init__(self, sdk_client: Any = None, conn_id: str = "azure_default"):
89
96
  self.sdk_client = sdk_client
90
97
  self.conn_id = conn_id
91
98
  super().__init__()
@@ -96,8 +103,9 @@ class AzureBaseHook(BaseHook):
96
103
 
97
104
  :return: the authenticated client.
98
105
  """
106
+ if not self.sdk_client:
107
+ raise ValueError("`sdk_client` must be provided to AzureBaseHook to use `get_conn` method.")
99
108
  conn = self.get_connection(self.conn_id)
100
- tenant = conn.extra_dejson.get("tenantId")
101
109
  subscription_id = conn.extra_dejson.get("subscriptionId")
102
110
  key_path = conn.extra_dejson.get("key_path")
103
111
  if key_path:
@@ -111,22 +119,90 @@ class AzureBaseHook(BaseHook):
111
119
  self.log.info("Getting connection using a JSON config.")
112
120
  return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)
113
121
 
114
- credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter
122
+ credentials = self.get_credential(conn=conn)
123
+
124
+ return self.sdk_client(
125
+ credentials=credentials,
126
+ subscription_id=subscription_id,
127
+ )
128
+
129
+ def get_credential(
130
+ self, *, conn: Connection | None = None
131
+ ) -> (
132
+ ServicePrincipalCredentials
133
+ | AzureIdentityCredentialAdapter
134
+ | ClientSecretCredential
135
+ | DefaultAzureCredential
136
+ ):
137
+ """
138
+ Get Azure credential object for the connection.
139
+
140
+ Azure Identity based credential object (``ClientSecretCredential``, ``DefaultAzureCredential``) can be used to get OAuth token using ``get_token`` method.
141
+ Older Credential objects (``ServicePrincipalCredentials``, ``AzureIdentityCredentialAdapter``) are supported for backward compatibility.
142
+
143
+ :return: The Azure credential object
144
+ """
145
+ if not conn:
146
+ conn = self.get_connection(self.conn_id)
147
+ tenant = conn.extra_dejson.get("tenantId")
148
+ credential: (
149
+ ServicePrincipalCredentials
150
+ | AzureIdentityCredentialAdapter
151
+ | ClientSecretCredential
152
+ | DefaultAzureCredential
153
+ )
115
154
  if all([conn.login, conn.password, tenant]):
116
- self.log.info("Getting connection using specific credentials and subscription_id.")
117
- credentials = ServicePrincipalCredentials(
118
- client_id=conn.login, secret=conn.password, tenant=tenant
119
- )
155
+ credential = self._get_client_secret_credential(conn)
120
156
  else:
121
- self.log.info("Using DefaultAzureCredential as credential")
122
- managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
123
- workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
124
- credentials = AzureIdentityCredentialAdapter(
157
+ credential = self._get_default_azure_credential(conn)
158
+ return credential
159
+
160
+ def _get_client_secret_credential(
161
+ self, conn: Connection
162
+ ) -> ServicePrincipalCredentials | ClientSecretCredential:
163
+ self.log.info("Getting credentials using specific credentials and subscription_id.")
164
+ extra_dejson = conn.extra_dejson
165
+ tenant = extra_dejson.get("tenantId")
166
+ use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
167
+ if use_azure_identity_object:
168
+ return ClientSecretCredential(
169
+ client_id=conn.login, # type: ignore[arg-type]
170
+ client_secret=conn.password, # type: ignore[arg-type]
171
+ tenant_id=tenant, # type: ignore[arg-type]
172
+ )
173
+ return ServicePrincipalCredentials(client_id=conn.login, secret=conn.password, tenant=tenant)
174
+
175
+ def _get_default_azure_credential(
176
+ self, conn: Connection
177
+ ) -> DefaultAzureCredential | AzureIdentityCredentialAdapter:
178
+ self.log.info("Using DefaultAzureCredential as credential")
179
+ extra_dejson = conn.extra_dejson
180
+ managed_identity_client_id = extra_dejson.get("managed_identity_client_id")
181
+ workload_identity_tenant_id = extra_dejson.get("workload_identity_tenant_id")
182
+ use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
183
+ if use_azure_identity_object:
184
+ return get_sync_default_azure_credential(
125
185
  managed_identity_client_id=managed_identity_client_id,
126
186
  workload_identity_tenant_id=workload_identity_tenant_id,
127
187
  )
128
-
129
- return self.sdk_client(
130
- credentials=credentials,
131
- subscription_id=subscription_id,
188
+ return AzureIdentityCredentialAdapter(
189
+ managed_identity_client_id=managed_identity_client_id,
190
+ workload_identity_tenant_id=workload_identity_tenant_id,
132
191
  )
192
+
193
+ def get_token(self, *scopes, **kwargs) -> AccessToken:
194
+ """
195
+ Request an access token for `scopes`.
196
+
197
+ To use this method, set `use_azure_identity_object: True` in the connection extra field.
198
+ ServicePrincipalCredentials and AzureIdentityCredentialAdapter don't support `get_token` method.
199
+ """
200
+ credential = self.get_credential()
201
+ if isinstance(credential, ServicePrincipalCredentials) or isinstance(
202
+ credential, AzureIdentityCredentialAdapter
203
+ ):
204
+ raise AttributeError(
205
+ "ServicePrincipalCredentials and AzureIdentityCredentialAdapter don't support get_token method. "
206
+ "Please set `use_azure_identity_object: True` in the connection extra field to use credential that support get_token method."
207
+ )
208
+ return credential.get_token(*scopes, **kwargs)
@@ -25,12 +25,12 @@ from typing import TYPE_CHECKING, Any
25
25
  from azure.batch import BatchServiceClient, batch_auth, models as batch_models
26
26
 
27
27
  from airflow.exceptions import AirflowException
28
+ from airflow.providers.common.compat.sdk import BaseHook
28
29
  from airflow.providers.microsoft.azure.utils import (
29
30
  AzureIdentityCredentialAdapter,
30
31
  add_managed_identity_connection_widgets,
31
32
  get_field,
32
33
  )
33
- from airflow.providers.microsoft.azure.version_compat import BaseHook
34
34
  from airflow.utils import timezone
35
35
 
36
36
  if TYPE_CHECKING:
@@ -25,12 +25,12 @@ from typing import Any, cast
25
25
  from azure.mgmt.containerinstance.models import ImageRegistryCredential
26
26
  from azure.mgmt.containerregistry import ContainerRegistryManagementClient
27
27
 
28
+ from airflow.providers.common.compat.sdk import BaseHook
28
29
  from airflow.providers.microsoft.azure.utils import (
29
30
  add_managed_identity_connection_widgets,
30
31
  get_field,
31
32
  get_sync_default_azure_credential,
32
33
  )
33
- from airflow.providers.microsoft.azure.version_compat import BaseHook
34
34
 
35
35
 
36
36
  class AzureContainerRegistryHook(BaseHook):
@@ -21,12 +21,12 @@ from typing import Any, cast
21
21
  from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
22
22
  from azure.mgmt.storage import StorageManagementClient
23
23
 
24
+ from airflow.providers.common.compat.sdk import BaseHook
24
25
  from airflow.providers.microsoft.azure.utils import (
25
26
  add_managed_identity_connection_widgets,
26
27
  get_field,
27
28
  get_sync_default_azure_credential,
28
29
  )
29
- from airflow.providers.microsoft.azure.version_compat import BaseHook
30
30
 
31
31
 
32
32
  class AzureContainerVolumeHook(BaseHook):
@@ -36,12 +36,12 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
36
36
  from azure.mgmt.cosmosdb import CosmosDBManagementClient
37
37
 
38
38
  from airflow.exceptions import AirflowBadRequest, AirflowException
39
+ from airflow.providers.common.compat.sdk import BaseHook
39
40
  from airflow.providers.microsoft.azure.utils import (
40
41
  add_managed_identity_connection_widgets,
41
42
  get_field,
42
43
  get_sync_default_azure_credential,
43
44
  )
44
- from airflow.providers.microsoft.azure.version_compat import BaseHook
45
45
 
46
46
  if TYPE_CHECKING:
47
47
  PartitionKeyType = str | list[str]
@@ -49,12 +49,12 @@ from azure.mgmt.datafactory import DataFactoryManagementClient
49
49
  from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient
50
50
 
51
51
  from airflow.exceptions import AirflowException
52
+ from airflow.providers.common.compat.sdk import BaseHook
52
53
  from airflow.providers.microsoft.azure.utils import (
53
54
  add_managed_identity_connection_widgets,
54
55
  get_async_default_azure_credential,
55
56
  get_sync_default_azure_credential,
56
57
  )
57
- from airflow.providers.microsoft.azure.version_compat import BaseHook
58
58
 
59
59
  if TYPE_CHECKING:
60
60
  from azure.core.polling import LROPoller
@@ -1214,7 +1214,4 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
1214
1214
  :param config: Extra parameters for the ADF client.
1215
1215
  """
1216
1216
  client = await self.get_async_conn()
1217
- try:
1218
- await client.pipeline_runs.cancel(resource_group_name, factory_name, run_id)
1219
- except Exception as e:
1220
- raise AirflowException(e)
1217
+ await client.pipeline_runs.cancel(resource_group_name, factory_name, run_id)
@@ -33,13 +33,13 @@ from azure.storage.filedatalake import (
33
33
  )
34
34
 
35
35
  from airflow.exceptions import AirflowException
36
+ from airflow.providers.common.compat.sdk import BaseHook
36
37
  from airflow.providers.microsoft.azure.utils import (
37
38
  AzureIdentityCredentialAdapter,
38
39
  add_managed_identity_connection_widgets,
39
40
  get_field,
40
41
  get_sync_default_azure_credential,
41
42
  )
42
- from airflow.providers.microsoft.azure.version_compat import BaseHook
43
43
 
44
44
  Credentials = ClientSecretCredential | AzureIdentityCredentialAdapter | DefaultAzureCredential
45
45
 
@@ -21,11 +21,11 @@ from typing import IO, Any
21
21
 
22
22
  from azure.storage.fileshare import FileProperties, ShareDirectoryClient, ShareFileClient, ShareServiceClient
23
23
 
24
+ from airflow.providers.common.compat.sdk import BaseHook
24
25
  from airflow.providers.microsoft.azure.utils import (
25
26
  add_managed_identity_connection_widgets,
26
27
  get_sync_default_azure_credential,
27
28
  )
28
- from airflow.providers.microsoft.azure.version_compat import BaseHook
29
29
 
30
30
 
31
31
  class AzureFileShareHook(BaseHook):
@@ -17,6 +17,7 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
+ import asyncio
20
21
  import json
21
22
  import warnings
22
23
  from ast import literal_eval
@@ -28,7 +29,6 @@ from typing import TYPE_CHECKING, Any, cast
28
29
  from urllib.parse import quote, urljoin, urlparse
29
30
 
30
31
  import httpx
31
- from asgiref.sync import sync_to_async
32
32
  from azure.identity import CertificateCredential, ClientSecretCredential
33
33
  from httpx import AsyncHTTPTransport, Response, Timeout
34
34
  from kiota_abstractions.api_error import APIError
@@ -53,7 +53,7 @@ from airflow.exceptions import (
53
53
  AirflowNotFoundException,
54
54
  AirflowProviderDeprecationWarning,
55
55
  )
56
- from airflow.providers.microsoft.azure.version_compat import BaseHook
56
+ from airflow.providers.common.compat.sdk import BaseHook
57
57
 
58
58
  if TYPE_CHECKING:
59
59
  from azure.identity._internal.client_credential_base import ClientCredentialBase
@@ -61,7 +61,7 @@ if TYPE_CHECKING:
61
61
  from kiota_abstractions.response_handler import NativeResponseType
62
62
  from kiota_abstractions.serialization import ParsableFactory
63
63
 
64
- from airflow.models import Connection
64
+ from airflow.providers.common.compat.sdk import Connection
65
65
 
66
66
 
67
67
  class DefaultResponseHandler(ResponseHandler):
@@ -152,6 +152,7 @@ class KiotaRequestAdapterHook(BaseHook):
152
152
 
153
153
  return {
154
154
  "tenant_id": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()),
155
+ "drive_id": StringField(lazy_gettext("Drive ID"), widget=BS3TextFieldWidget()),
155
156
  "api_version": StringField(
156
157
  lazy_gettext("API Version"), widget=BS3TextFieldWidget(), default=APIVersion.v1.value
157
158
  ),
@@ -250,7 +251,9 @@ class KiotaRequestAdapterHook(BaseHook):
250
251
  def _build_request_adapter(self, connection) -> tuple[str, RequestAdapter]:
251
252
  client_id = connection.login
252
253
  client_secret = connection.password
253
- config = connection.extra_dejson if connection.extra else {}
254
+ # TODO (#54350): do not use connection.extra_dejson until it's fixed in Airflow otherwise expect:
255
+ # RuntimeError: You cannot use AsyncToSync in the same thread as an async event loop.
256
+ config = json.loads(connection.extra) if connection.extra else {}
254
257
  api_version = self.get_api_version(config)
255
258
  host = self.get_host(connection) # type: ignore[arg-type]
256
259
  base_url = self.get_base_url(host, api_version, config)
@@ -342,6 +345,15 @@ class KiotaRequestAdapterHook(BaseHook):
342
345
  self.api_version = api_version
343
346
  return request_adapter
344
347
 
348
+ @classmethod
349
+ async def get_async_connection(cls, conn_id: str) -> Connection:
350
+ if hasattr(BaseHook, "aget_connection"):
351
+ return await BaseHook.aget_connection(conn_id=conn_id)
352
+
353
+ from asgiref.sync import sync_to_async
354
+
355
+ return await sync_to_async(BaseHook.get_connection)(conn_id=conn_id)
356
+
345
357
  async def get_async_conn(self) -> RequestAdapter:
346
358
  """Initiate a new RequestAdapter connection asynchronously."""
347
359
  if not self.conn_id:
@@ -350,7 +362,7 @@ class KiotaRequestAdapterHook(BaseHook):
350
362
  api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))
351
363
 
352
364
  if not request_adapter:
353
- connection = await sync_to_async(self.get_connection)(conn_id=self.conn_id)
365
+ connection = await self.get_async_connection(conn_id=self.conn_id)
354
366
  api_version, request_adapter = self._build_request_adapter(connection)
355
367
  self.api_version = api_version
356
368
  return request_adapter
@@ -417,7 +429,7 @@ class KiotaRequestAdapterHook(BaseHook):
417
429
  def test_connection(self):
418
430
  """Test HTTP Connection."""
419
431
  try:
420
- self.run()
432
+ asyncio.run(self.run())
421
433
  return True, "Connection successfully tested"
422
434
  except Exception as e:
423
435
  return False, str(e)
@@ -25,12 +25,12 @@ from azure.synapse.artifacts import ArtifactsClient
25
25
  from azure.synapse.spark import SparkClient
26
26
 
27
27
  from airflow.exceptions import AirflowException, AirflowTaskTimeout
28
+ from airflow.providers.common.compat.sdk import BaseHook
28
29
  from airflow.providers.microsoft.azure.utils import (
29
30
  add_managed_identity_connection_widgets,
30
31
  get_field,
31
32
  get_sync_default_azure_credential,
32
33
  )
33
- from airflow.providers.microsoft.azure.version_compat import BaseHook
34
34
 
35
35
  if TYPE_CHECKING:
36
36
  from azure.synapse.artifacts.models import CreateRunResponse, PipelineRun
@@ -45,13 +45,13 @@ from azure.storage.blob.aio import (
45
45
  )
46
46
 
47
47
  from airflow.exceptions import AirflowException
48
+ from airflow.providers.common.compat.sdk import BaseHook
48
49
  from airflow.providers.microsoft.azure.utils import (
49
50
  add_managed_identity_connection_widgets,
50
51
  get_async_default_azure_credential,
51
52
  get_sync_default_azure_credential,
52
53
  parse_blob_account_url,
53
54
  )
54
- from airflow.providers.microsoft.azure.version_compat import BaseHook
55
55
 
56
56
  if TYPE_CHECKING:
57
57
  from azure.core.credentials import TokenCredential
@@ -188,9 +188,15 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
188
188
  base_log_folder: str,
189
189
  wasb_log_folder: str,
190
190
  wasb_container: str,
191
+ max_bytes: int = 0,
192
+ backup_count: int = 0,
193
+ delay: bool = False,
191
194
  **kwargs,
192
195
  ) -> None:
193
- super().__init__(base_log_folder)
196
+ # support log file size handling of FileTaskHandler
197
+ super().__init__(
198
+ base_log_folder=base_log_folder, max_bytes=max_bytes, backup_count=backup_count, delay=delay
199
+ )
194
200
  self.handler: logging.FileHandler | None = None
195
201
  self.log_relative_path = ""
196
202
  self.closed = False
@@ -19,8 +19,8 @@ from __future__ import annotations
19
19
  from collections.abc import Iterable, Sequence
20
20
  from typing import IO, TYPE_CHECKING, Any, AnyStr
21
21
 
22
+ from airflow.providers.common.compat.sdk import BaseOperator
22
23
  from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook, AzureDataLakeStorageV2Hook
23
- from airflow.providers.microsoft.azure.version_compat import BaseOperator
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from airflow.utils.context import Context
@@ -24,8 +24,8 @@ from functools import cached_property
24
24
  from typing import TYPE_CHECKING
25
25
 
26
26
  from airflow.configuration import conf
27
+ from airflow.providers.common.compat.sdk import BaseOperator
27
28
  from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
28
- from airflow.providers.microsoft.azure.version_compat import BaseOperator
29
29
 
30
30
  if TYPE_CHECKING:
31
31
  from azure.kusto.data._models import KustoResultTable
@@ -20,8 +20,8 @@ from collections.abc import Callable, Sequence
20
20
  from typing import TYPE_CHECKING, Any
21
21
  from uuid import UUID
22
22
 
23
+ from airflow.providers.common.compat.sdk import BaseOperator
23
24
  from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook
24
- from airflow.providers.microsoft.azure.version_compat import BaseOperator
25
25
 
26
26
  if TYPE_CHECKING:
27
27
  import datetime
@@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Any
24
24
  from azure.batch import models as batch_models
25
25
 
26
26
  from airflow.exceptions import AirflowException
27
+ from airflow.providers.common.compat.sdk import BaseOperator
27
28
  from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
28
- from airflow.providers.microsoft.azure.version_compat import BaseOperator
29
29
 
30
30
  if TYPE_CHECKING:
31
31
  from airflow.utils.context import Context
@@ -40,10 +40,10 @@ from azure.mgmt.containerinstance.models import (
40
40
  from msrestazure.azure_exceptions import CloudError
41
41
 
42
42
  from airflow.exceptions import AirflowException, AirflowTaskTimeout
43
+ from airflow.providers.common.compat.sdk import BaseOperator
43
44
  from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook
44
45
  from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook
45
46
  from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook
46
- from airflow.providers.microsoft.azure.version_compat import BaseOperator
47
47
 
48
48
  if TYPE_CHECKING:
49
49
  from airflow.utils.context import Context