apache-airflow-providers-databricks 6.6.0__tar.gz → 6.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-databricks might be problematic. Click here for more details.

Files changed (22) hide show
  1. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/PKG-INFO +8 -6
  2. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/README.rst +3 -3
  3. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/__init__.py +1 -1
  4. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/get_provider_info.py +8 -2
  5. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/hooks/databricks.py +6 -3
  6. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/hooks/databricks_base.py +34 -70
  7. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/hooks/databricks_sql.py +2 -1
  8. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/operators/databricks.py +95 -92
  9. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/utils/databricks.py +2 -2
  10. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/pyproject.toml +6 -3
  11. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/LICENSE +0 -0
  12. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/hooks/__init__.py +0 -0
  13. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/operators/__init__.py +0 -0
  14. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/operators/databricks_repos.py +0 -0
  15. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/operators/databricks_sql.py +0 -0
  16. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/operators/databricks_workflow.py +0 -0
  17. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/sensors/__init__.py +0 -0
  18. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/sensors/databricks_partition.py +0 -0
  19. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/sensors/databricks_sql.py +0 -0
  20. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/triggers/__init__.py +0 -0
  21. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/triggers/databricks.py +0 -0
  22. {apache_airflow_providers_databricks-6.6.0 → apache_airflow_providers_databricks-6.7.0}/airflow/providers/databricks/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: apache-airflow-providers-databricks
3
- Version: 6.6.0
3
+ Version: 6.7.0
4
4
  Summary: Provider package apache-airflow-providers-databricks for Apache Airflow
5
5
  Keywords: airflow-provider,databricks,airflow,integration
6
6
  Author-email: Apache Software Foundation <dev@airflow.apache.org>
@@ -30,15 +30,17 @@ Requires-Dist: pandas>=1.5.3,<2.2;python_version<"3.9"
30
30
  Requires-Dist: pandas>=2.1.2,<2.2;python_version>="3.9"
31
31
  Requires-Dist: pyarrow>=14.0.1
32
32
  Requires-Dist: requests>=2.27.0,<3
33
+ Requires-Dist: azure-identity>=1.3.1 ; extra == "azure-identity"
33
34
  Requires-Dist: apache-airflow-providers-common-sql ; extra == "common.sql"
34
35
  Requires-Dist: databricks-sdk==0.10.0 ; extra == "sdk"
35
36
  Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
36
- Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html
37
- Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0
37
+ Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.7.0/changelog.html
38
+ Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.7.0
38
39
  Project-URL: Slack Chat, https://s.apache.org/airflow-slack
39
40
  Project-URL: Source Code, https://github.com/apache/airflow
40
41
  Project-URL: Twitter, https://twitter.com/ApacheAirflow
41
42
  Project-URL: YouTube, https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/
43
+ Provides-Extra: azure-identity
42
44
  Provides-Extra: common.sql
43
45
  Provides-Extra: sdk
44
46
 
@@ -86,7 +88,7 @@ Provides-Extra: sdk
86
88
 
87
89
  Package ``apache-airflow-providers-databricks``
88
90
 
89
- Release: ``6.6.0``
91
+ Release: ``6.7.0``
90
92
 
91
93
 
92
94
  `Databricks <https://databricks.com/>`__
@@ -99,7 +101,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
99
101
  are in ``airflow.providers.databricks`` python package.
100
102
 
101
103
  You can find package information and changelog for the provider
102
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/>`_.
104
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.7.0/>`_.
103
105
 
104
106
  Installation
105
107
  ------------
@@ -147,4 +149,4 @@ Dependent package
147
149
  ============================================================================================================ ==============
148
150
 
149
151
  The changelog for the provider package can be found in the
150
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html>`_.
152
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.7.0/changelog.html>`_.
@@ -42,7 +42,7 @@
42
42
 
43
43
  Package ``apache-airflow-providers-databricks``
44
44
 
45
- Release: ``6.6.0``
45
+ Release: ``6.7.0``
46
46
 
47
47
 
48
48
  `Databricks <https://databricks.com/>`__
@@ -55,7 +55,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
55
55
  are in ``airflow.providers.databricks`` python package.
56
56
 
57
57
  You can find package information and changelog for the provider
58
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/>`_.
58
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.7.0/>`_.
59
59
 
60
60
  Installation
61
61
  ------------
@@ -103,4 +103,4 @@ Dependent package
103
103
  ============================================================================================================ ==============
104
104
 
105
105
  The changelog for the provider package can be found in the
106
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html>`_.
106
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.7.0/changelog.html>`_.
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "6.6.0"
32
+ __version__ = "6.7.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.7.0"
@@ -28,8 +28,9 @@ def get_provider_info():
28
28
  "name": "Databricks",
29
29
  "description": "`Databricks <https://databricks.com/>`__\n",
30
30
  "state": "ready",
31
- "source-date-epoch": 1718604145,
31
+ "source-date-epoch": 1720422668,
32
32
  "versions": [
33
+ "6.7.0",
33
34
  "6.6.0",
34
35
  "6.5.0",
35
36
  "6.4.0",
@@ -85,7 +86,12 @@ def get_provider_info():
85
86
  "name": "sdk",
86
87
  "description": "Install Databricks SDK",
87
88
  "dependencies": ["databricks-sdk==0.10.0"],
88
- }
89
+ },
90
+ {
91
+ "name": "azure-identity",
92
+ "description": "Install Azure Identity client library",
93
+ "dependencies": ["azure-identity>=1.3.1"],
94
+ },
89
95
  ],
90
96
  "devel-dependencies": ["deltalake>=0.12.0"],
91
97
  "integrations": [
@@ -65,7 +65,8 @@ SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
65
65
 
66
66
 
67
67
  class RunLifeCycleState(Enum):
68
- """Enum for the run life cycle state concept of Databricks runs.
68
+ """
69
+ Enum for the run life cycle state concept of Databricks runs.
69
70
 
70
71
  See more information at: https://docs.databricks.com/api/azure/workspace/jobs/listruns#runs-state-life_cycle_state
71
72
  """
@@ -215,7 +216,8 @@ class DatabricksHook(BaseDatabricksHook):
215
216
  super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller)
216
217
 
217
218
  def create_job(self, json: dict) -> int:
218
- """Call the ``api/2.1/jobs/create`` endpoint.
219
+ """
220
+ Call the ``api/2.1/jobs/create`` endpoint.
219
221
 
220
222
  :param json: The data used in the body of the request to the ``create`` endpoint.
221
223
  :return: the job_id as an int
@@ -224,7 +226,8 @@ class DatabricksHook(BaseDatabricksHook):
224
226
  return response["job_id"]
225
227
 
226
228
  def reset_job(self, job_id: str, json: dict) -> None:
227
- """Call the ``api/2.1/jobs/reset`` endpoint.
229
+ """
230
+ Call the ``api/2.1/jobs/reset`` endpoint.
228
231
 
229
232
  :param json: The data used in the new_settings of the request to the ``reset`` endpoint.
230
233
  """
@@ -47,17 +47,13 @@ from tenacity import (
47
47
  )
48
48
 
49
49
  from airflow import __version__
50
- from airflow.exceptions import AirflowException
50
+ from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException
51
51
  from airflow.hooks.base import BaseHook
52
52
  from airflow.providers_manager import ProvidersManager
53
53
 
54
54
  if TYPE_CHECKING:
55
55
  from airflow.models import Connection
56
56
 
57
- # https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
58
- # https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
59
- AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
60
- AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token"
61
57
  # https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token
62
58
  AZURE_METADATA_SERVICE_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token"
63
59
  AZURE_METADATA_SERVICE_INSTANCE_URL = "http://169.254.169.254/metadata/instance"
@@ -301,46 +297,29 @@ class BaseDatabricksHook(BaseHook):
301
297
 
302
298
  self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...")
303
299
  try:
300
+ from azure.identity import ClientSecretCredential, ManagedIdentityCredential
301
+
304
302
  for attempt in self._get_retry_object():
305
303
  with attempt:
306
304
  if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
307
- params = {
308
- "api-version": "2018-02-01",
309
- "resource": resource,
310
- }
311
- resp = requests.get(
312
- AZURE_METADATA_SERVICE_TOKEN_URL,
313
- params=params,
314
- headers={**self.user_agent_header, "Metadata": "true"},
315
- timeout=self.token_timeout_seconds,
316
- )
305
+ token = ManagedIdentityCredential().get_token(f"{resource}/.default")
317
306
  else:
318
- tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"]
319
- data = {
320
- "grant_type": "client_credentials",
321
- "client_id": self.databricks_conn.login,
322
- "resource": resource,
323
- "client_secret": self.databricks_conn.password,
324
- }
325
- azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
326
- "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
327
- )
328
- resp = requests.post(
329
- AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
330
- data=data,
331
- headers={
332
- **self.user_agent_header,
333
- "Content-Type": "application/x-www-form-urlencoded",
334
- },
335
- timeout=self.token_timeout_seconds,
307
+ credential = ClientSecretCredential(
308
+ client_id=self.databricks_conn.login,
309
+ client_secret=self.databricks_conn.password,
310
+ tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
336
311
  )
337
-
338
- resp.raise_for_status()
339
- jsn = resp.json()
340
-
312
+ token = credential.get_token(f"{resource}/.default")
313
+ jsn = {
314
+ "access_token": token.token,
315
+ "token_type": "Bearer",
316
+ "expires_on": token.expires_on,
317
+ }
341
318
  self._is_oauth_token_valid(jsn)
342
319
  self.oauth_tokens[resource] = jsn
343
320
  break
321
+ except ImportError as e:
322
+ raise AirflowOptionalProviderFeatureException(e)
344
323
  except RetryError:
345
324
  raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.")
346
325
  except requests_exceptions.HTTPError as e:
@@ -362,47 +341,32 @@ class BaseDatabricksHook(BaseHook):
362
341
 
363
342
  self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...")
364
343
  try:
344
+ from azure.identity.aio import (
345
+ ClientSecretCredential as AsyncClientSecretCredential,
346
+ ManagedIdentityCredential as AsyncManagedIdentityCredential,
347
+ )
348
+
365
349
  async for attempt in self._a_get_retry_object():
366
350
  with attempt:
367
351
  if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
368
- params = {
369
- "api-version": "2018-02-01",
370
- "resource": resource,
371
- }
372
- async with self._session.get(
373
- url=AZURE_METADATA_SERVICE_TOKEN_URL,
374
- params=params,
375
- headers={**self.user_agent_header, "Metadata": "true"},
376
- timeout=self.token_timeout_seconds,
377
- ) as resp:
378
- resp.raise_for_status()
379
- jsn = await resp.json()
352
+ token = await AsyncManagedIdentityCredential().get_token(f"{resource}/.default")
380
353
  else:
381
- tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"]
382
- data = {
383
- "grant_type": "client_credentials",
384
- "client_id": self.databricks_conn.login,
385
- "resource": resource,
386
- "client_secret": self.databricks_conn.password,
387
- }
388
- azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
389
- "azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
354
+ credential = AsyncClientSecretCredential(
355
+ client_id=self.databricks_conn.login,
356
+ client_secret=self.databricks_conn.password,
357
+ tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
390
358
  )
391
- async with self._session.post(
392
- url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
393
- data=data,
394
- headers={
395
- **self.user_agent_header,
396
- "Content-Type": "application/x-www-form-urlencoded",
397
- },
398
- timeout=self.token_timeout_seconds,
399
- ) as resp:
400
- resp.raise_for_status()
401
- jsn = await resp.json()
402
-
359
+ token = await credential.get_token(f"{resource}/.default")
360
+ jsn = {
361
+ "access_token": token.token,
362
+ "token_type": "Bearer",
363
+ "expires_on": token.expires_on,
364
+ }
403
365
  self._is_oauth_token_valid(jsn)
404
366
  self.oauth_tokens[resource] = jsn
405
367
  break
368
+ except ImportError as e:
369
+ raise AirflowOptionalProviderFeatureException(e)
406
370
  except RetryError:
407
371
  raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.")
408
372
  except aiohttp.ClientResponseError as err:
@@ -50,7 +50,8 @@ T = TypeVar("T")
50
50
 
51
51
 
52
52
  class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
53
- """Hook to interact with Databricks SQL.
53
+ """
54
+ Hook to interact with Databricks SQL.
54
55
 
55
56
  :param databricks_conn_id: Reference to the
56
57
  :ref:`Databricks connection <howto/connection:databricks>`.
@@ -36,7 +36,7 @@ from airflow.providers.databricks.operators.databricks_workflow import (
36
36
  WorkflowRunMetadata,
37
37
  )
38
38
  from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
39
- from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
39
+ from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event
40
40
 
41
41
  if TYPE_CHECKING:
42
42
  from airflow.models.taskinstancekey import TaskInstanceKey
@@ -182,6 +182,17 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
182
182
  raise AirflowException(error_message)
183
183
 
184
184
 
185
+ def _handle_overridden_json_params(operator):
186
+ for key, value in operator.overridden_json_params.items():
187
+ if value is not None:
188
+ operator.json[key] = value
189
+
190
+
191
+ def normalise_json_content(operator):
192
+ if operator.json:
193
+ operator.json = _normalise_json_content(operator.json)
194
+
195
+
185
196
  class DatabricksJobRunLink(BaseOperatorLink):
186
197
  """Constructs a link to monitor a Databricks Job Run."""
187
198
 
@@ -197,7 +208,8 @@ class DatabricksJobRunLink(BaseOperatorLink):
197
208
 
198
209
 
199
210
  class DatabricksCreateJobsOperator(BaseOperator):
200
- """Creates (or resets) a Databricks job using the API endpoint.
211
+ """
212
+ Creates (or resets) a Databricks job using the API endpoint.
201
213
 
202
214
  .. seealso::
203
215
  https://docs.databricks.com/api/workspace/jobs/create
@@ -284,34 +296,21 @@ class DatabricksCreateJobsOperator(BaseOperator):
284
296
  self.databricks_retry_limit = databricks_retry_limit
285
297
  self.databricks_retry_delay = databricks_retry_delay
286
298
  self.databricks_retry_args = databricks_retry_args
287
- if name is not None:
288
- self.json["name"] = name
289
- if description is not None:
290
- self.json["description"] = description
291
- if tags is not None:
292
- self.json["tags"] = tags
293
- if tasks is not None:
294
- self.json["tasks"] = tasks
295
- if job_clusters is not None:
296
- self.json["job_clusters"] = job_clusters
297
- if email_notifications is not None:
298
- self.json["email_notifications"] = email_notifications
299
- if webhook_notifications is not None:
300
- self.json["webhook_notifications"] = webhook_notifications
301
- if notification_settings is not None:
302
- self.json["notification_settings"] = notification_settings
303
- if timeout_seconds is not None:
304
- self.json["timeout_seconds"] = timeout_seconds
305
- if schedule is not None:
306
- self.json["schedule"] = schedule
307
- if max_concurrent_runs is not None:
308
- self.json["max_concurrent_runs"] = max_concurrent_runs
309
- if git_source is not None:
310
- self.json["git_source"] = git_source
311
- if access_control_list is not None:
312
- self.json["access_control_list"] = access_control_list
313
- if self.json:
314
- self.json = normalise_json_content(self.json)
299
+ self.overridden_json_params = {
300
+ "name": name,
301
+ "description": description,
302
+ "tags": tags,
303
+ "tasks": tasks,
304
+ "job_clusters": job_clusters,
305
+ "email_notifications": email_notifications,
306
+ "webhook_notifications": webhook_notifications,
307
+ "notification_settings": notification_settings,
308
+ "timeout_seconds": timeout_seconds,
309
+ "schedule": schedule,
310
+ "max_concurrent_runs": max_concurrent_runs,
311
+ "git_source": git_source,
312
+ "access_control_list": access_control_list,
313
+ }
315
314
 
316
315
  @cached_property
317
316
  def _hook(self):
@@ -323,16 +322,24 @@ class DatabricksCreateJobsOperator(BaseOperator):
323
322
  caller="DatabricksCreateJobsOperator",
324
323
  )
325
324
 
326
- def execute(self, context: Context) -> int:
325
+ def _setup_and_validate_json(self):
326
+ _handle_overridden_json_params(self)
327
+
327
328
  if "name" not in self.json:
328
329
  raise AirflowException("Missing required parameter: name")
330
+
331
+ normalise_json_content(self)
332
+
333
+ def execute(self, context: Context) -> int:
334
+ self._setup_and_validate_json()
335
+
329
336
  job_id = self._hook.find_job_id_by_name(self.json["name"])
330
337
  if job_id is None:
331
338
  return self._hook.create_job(self.json)
332
339
  self._hook.reset_job(str(job_id), self.json)
333
340
  if (access_control_list := self.json.get("access_control_list")) is not None:
334
341
  acl_json = {"access_control_list": access_control_list}
335
- self._hook.update_job_permission(job_id, normalise_json_content(acl_json))
342
+ self._hook.update_job_permission(job_id, _normalise_json_content(acl_json))
336
343
 
337
344
  return job_id
338
345
 
@@ -505,43 +512,23 @@ class DatabricksSubmitRunOperator(BaseOperator):
505
512
  self.databricks_retry_args = databricks_retry_args
506
513
  self.wait_for_termination = wait_for_termination
507
514
  self.deferrable = deferrable
508
- if tasks is not None:
509
- self.json["tasks"] = tasks
510
- if spark_jar_task is not None:
511
- self.json["spark_jar_task"] = spark_jar_task
512
- if notebook_task is not None:
513
- self.json["notebook_task"] = notebook_task
514
- if spark_python_task is not None:
515
- self.json["spark_python_task"] = spark_python_task
516
- if spark_submit_task is not None:
517
- self.json["spark_submit_task"] = spark_submit_task
518
- if pipeline_task is not None:
519
- self.json["pipeline_task"] = pipeline_task
520
- if dbt_task is not None:
521
- self.json["dbt_task"] = dbt_task
522
- if new_cluster is not None:
523
- self.json["new_cluster"] = new_cluster
524
- if existing_cluster_id is not None:
525
- self.json["existing_cluster_id"] = existing_cluster_id
526
- if libraries is not None:
527
- self.json["libraries"] = libraries
528
- if run_name is not None:
529
- self.json["run_name"] = run_name
530
- if timeout_seconds is not None:
531
- self.json["timeout_seconds"] = timeout_seconds
532
- if "run_name" not in self.json:
533
- self.json["run_name"] = run_name or kwargs["task_id"]
534
- if idempotency_token is not None:
535
- self.json["idempotency_token"] = idempotency_token
536
- if access_control_list is not None:
537
- self.json["access_control_list"] = access_control_list
538
- if git_source is not None:
539
- self.json["git_source"] = git_source
540
-
541
- if "dbt_task" in self.json and "git_source" not in self.json:
542
- raise AirflowException("git_source is required for dbt_task")
543
- if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task:
544
- raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")
515
+ self.overridden_json_params = {
516
+ "tasks": tasks,
517
+ "spark_jar_task": spark_jar_task,
518
+ "notebook_task": notebook_task,
519
+ "spark_python_task": spark_python_task,
520
+ "spark_submit_task": spark_submit_task,
521
+ "pipeline_task": pipeline_task,
522
+ "dbt_task": dbt_task,
523
+ "new_cluster": new_cluster,
524
+ "existing_cluster_id": existing_cluster_id,
525
+ "libraries": libraries,
526
+ "run_name": run_name,
527
+ "timeout_seconds": timeout_seconds,
528
+ "idempotency_token": idempotency_token,
529
+ "access_control_list": access_control_list,
530
+ "git_source": git_source,
531
+ }
545
532
 
546
533
  # This variable will be used in case our task gets killed.
547
534
  self.run_id: int | None = None
@@ -560,7 +547,25 @@ class DatabricksSubmitRunOperator(BaseOperator):
560
547
  caller=caller,
561
548
  )
562
549
 
550
+ def _setup_and_validate_json(self):
551
+ _handle_overridden_json_params(self)
552
+
553
+ if "run_name" not in self.json or self.json["run_name"] is None:
554
+ self.json["run_name"] = self.task_id
555
+
556
+ if "dbt_task" in self.json and "git_source" not in self.json:
557
+ raise AirflowException("git_source is required for dbt_task")
558
+ if (
559
+ "pipeline_task" in self.json
560
+ and "pipeline_id" in self.json["pipeline_task"]
561
+ and "pipeline_name" in self.json["pipeline_task"]
562
+ ):
563
+ raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")
564
+
565
+ normalise_json_content(self)
566
+
563
567
  def execute(self, context: Context):
568
+ self._setup_and_validate_json()
564
569
  if (
565
570
  "pipeline_task" in self.json
566
571
  and self.json["pipeline_task"].get("pipeline_id") is None
@@ -570,7 +575,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
570
575
  pipeline_name = self.json["pipeline_task"]["pipeline_name"]
571
576
  self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name)
572
577
  del self.json["pipeline_task"]["pipeline_name"]
573
- json_normalised = normalise_json_content(self.json)
578
+ json_normalised = _normalise_json_content(self.json)
574
579
  self.run_id = self._hook.submit_run(json_normalised)
575
580
  if self.deferrable:
576
581
  _handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
@@ -606,7 +611,7 @@ class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
606
611
 
607
612
  def execute(self, context):
608
613
  hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
609
- json_normalised = normalise_json_content(self.json)
614
+ json_normalised = _normalise_json_content(self.json)
610
615
  self.run_id = hook.submit_run(json_normalised)
611
616
  _handle_deferrable_databricks_operator_execution(self, hook, self.log, context)
612
617
 
@@ -806,27 +811,16 @@ class DatabricksRunNowOperator(BaseOperator):
806
811
  self.deferrable = deferrable
807
812
  self.repair_run = repair_run
808
813
  self.cancel_previous_runs = cancel_previous_runs
809
-
810
- if job_id is not None:
811
- self.json["job_id"] = job_id
812
- if job_name is not None:
813
- self.json["job_name"] = job_name
814
- if "job_id" in self.json and "job_name" in self.json:
815
- raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
816
- if notebook_params is not None:
817
- self.json["notebook_params"] = notebook_params
818
- if python_params is not None:
819
- self.json["python_params"] = python_params
820
- if python_named_params is not None:
821
- self.json["python_named_params"] = python_named_params
822
- if jar_params is not None:
823
- self.json["jar_params"] = jar_params
824
- if spark_submit_params is not None:
825
- self.json["spark_submit_params"] = spark_submit_params
826
- if idempotency_token is not None:
827
- self.json["idempotency_token"] = idempotency_token
828
- if self.json:
829
- self.json = normalise_json_content(self.json)
814
+ self.overridden_json_params = {
815
+ "job_id": job_id,
816
+ "job_name": job_name,
817
+ "notebook_params": notebook_params,
818
+ "python_params": python_params,
819
+ "python_named_params": python_named_params,
820
+ "jar_params": jar_params,
821
+ "spark_submit_params": spark_submit_params,
822
+ "idempotency_token": idempotency_token,
823
+ }
830
824
  # This variable will be used in case our task gets killed.
831
825
  self.run_id: int | None = None
832
826
  self.do_xcom_push = do_xcom_push
@@ -844,7 +838,16 @@ class DatabricksRunNowOperator(BaseOperator):
844
838
  caller=caller,
845
839
  )
846
840
 
841
+ def _setup_and_validate_json(self):
842
+ _handle_overridden_json_params(self)
843
+
844
+ if "job_id" in self.json and "job_name" in self.json:
845
+ raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
846
+
847
+ normalise_json_content(self)
848
+
847
849
  def execute(self, context: Context):
850
+ self._setup_and_validate_json()
848
851
  hook = self._hook
849
852
  if "job_name" in self.json:
850
853
  job_id = hook.find_job_id_by_name(self.json["job_name"])
@@ -21,7 +21,7 @@ from airflow.exceptions import AirflowException
21
21
  from airflow.providers.databricks.hooks.databricks import RunState
22
22
 
23
23
 
24
- def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
24
+ def _normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
25
25
  """
26
26
  Normalize content or all values of content if it is a dict to a string.
27
27
 
@@ -33,7 +33,7 @@ def normalise_json_content(content, json_path: str = "json") -> str | bool | lis
33
33
  The only one exception is when we have boolean values, they can not be converted
34
34
  to string type because databricks does not understand 'True' or 'False' values.
35
35
  """
36
- normalise = normalise_json_content
36
+ normalise = _normalise_json_content
37
37
  if isinstance(content, (str, bool)):
38
38
  return content
39
39
  elif isinstance(content, (int, float)):
@@ -28,7 +28,7 @@ build-backend = "flit_core.buildapi"
28
28
 
29
29
  [project]
30
30
  name = "apache-airflow-providers-databricks"
31
- version = "6.6.0"
31
+ version = "6.7.0"
32
32
  description = "Provider package apache-airflow-providers-databricks for Apache Airflow"
33
33
  readme = "README.rst"
34
34
  authors = [
@@ -68,8 +68,8 @@ dependencies = [
68
68
  ]
69
69
 
70
70
  [project.urls]
71
- "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0"
72
- "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html"
71
+ "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.7.0"
72
+ "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.7.0/changelog.html"
73
73
  "Bug Tracker" = "https://github.com/apache/airflow/issues"
74
74
  "Source Code" = "https://github.com/apache/airflow"
75
75
  "Slack Chat" = "https://s.apache.org/airflow-slack"
@@ -85,6 +85,9 @@ provider_info = "airflow.providers.databricks.get_provider_info:get_provider_inf
85
85
  "sdk" = [
86
86
  "databricks-sdk==0.10.0",
87
87
  ]
88
+ "azure-identity" = [
89
+ "azure-identity>=1.3.1",
90
+ ]
88
91
 
89
92
  [tool.flit.module]
90
93
  name = "airflow.providers.databricks"