apache-airflow-providers-databricks 6.5.0__tar.gz → 6.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-databricks might be problematic. Click here for more details.

Files changed (22) hide show
  1. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/PKG-INFO +17 -9
  2. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/README.rst +10 -6
  3. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/LICENSE +4 -4
  4. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/__init__.py +1 -1
  5. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/get_provider_info.py +18 -1
  6. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/hooks/databricks.py +20 -0
  7. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/hooks/databricks_base.py +8 -8
  8. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/operators/databricks.py +360 -109
  9. apache_airflow_providers_databricks-6.6.0/airflow/providers/databricks/operators/databricks_workflow.py +312 -0
  10. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/pyproject.toml +7 -3
  11. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/hooks/__init__.py +0 -0
  12. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/hooks/databricks_sql.py +0 -0
  13. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/operators/__init__.py +0 -0
  14. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/operators/databricks_repos.py +0 -0
  15. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/operators/databricks_sql.py +0 -0
  16. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/sensors/__init__.py +0 -0
  17. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/sensors/databricks_partition.py +0 -0
  18. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/sensors/databricks_sql.py +0 -0
  19. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/triggers/__init__.py +0 -0
  20. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/triggers/databricks.py +0 -0
  21. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/utils/__init__.py +0 -0
  22. {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/utils/databricks.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: apache-airflow-providers-databricks
3
- Version: 6.5.0
3
+ Version: 6.6.0
4
4
  Summary: Provider package apache-airflow-providers-databricks for Apache Airflow
5
5
  Keywords: airflow-provider,databricks,airflow,integration
6
6
  Author-email: Apache Software Foundation <dev@airflow.apache.org>
@@ -25,12 +25,16 @@ Requires-Dist: aiohttp>=3.9.2, <4
25
25
  Requires-Dist: apache-airflow-providers-common-sql>=1.10.0
26
26
  Requires-Dist: apache-airflow>=2.7.0
27
27
  Requires-Dist: databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0
28
+ Requires-Dist: mergedeep>=1.3.4
29
+ Requires-Dist: pandas>=1.5.3,<2.2;python_version<"3.9"
30
+ Requires-Dist: pandas>=2.1.2,<2.2;python_version>="3.9"
31
+ Requires-Dist: pyarrow>=14.0.1
28
32
  Requires-Dist: requests>=2.27.0,<3
29
33
  Requires-Dist: apache-airflow-providers-common-sql ; extra == "common.sql"
30
34
  Requires-Dist: databricks-sdk==0.10.0 ; extra == "sdk"
31
35
  Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
32
- Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.5.0/changelog.html
33
- Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.5.0
36
+ Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html
37
+ Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0
34
38
  Project-URL: Slack Chat, https://s.apache.org/airflow-slack
35
39
  Project-URL: Source Code, https://github.com/apache/airflow
36
40
  Project-URL: Twitter, https://twitter.com/ApacheAirflow
@@ -82,7 +86,7 @@ Provides-Extra: sdk
82
86
 
83
87
  Package ``apache-airflow-providers-databricks``
84
88
 
85
- Release: ``6.5.0``
89
+ Release: ``6.6.0``
86
90
 
87
91
 
88
92
  `Databricks <https://databricks.com/>`__
@@ -95,7 +99,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
95
99
  are in ``airflow.providers.databricks`` python package.
96
100
 
97
101
  You can find package information and changelog for the provider
98
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.5.0/>`_.
102
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/>`_.
99
103
 
100
104
  Installation
101
105
  ------------
@@ -109,15 +113,19 @@ The package supports the following python versions: 3.8,3.9,3.10,3.11,3.12
109
113
  Requirements
110
114
  ------------
111
115
 
112
- ======================================= ==========================
116
+ ======================================= =========================================
113
117
  PIP package Version required
114
- ======================================= ==========================
118
+ ======================================= =========================================
115
119
  ``apache-airflow`` ``>=2.7.0``
116
120
  ``apache-airflow-providers-common-sql`` ``>=1.10.0``
117
121
  ``requests`` ``>=2.27.0,<3``
118
122
  ``databricks-sql-connector`` ``>=2.0.0,!=2.9.0,<3.0.0``
119
123
  ``aiohttp`` ``>=3.9.2,<4``
120
- ======================================= ==========================
124
+ ``mergedeep`` ``>=1.3.4``
125
+ ``pandas`` ``>=2.1.2,<2.2; python_version >= "3.9"``
126
+ ``pandas`` ``>=1.5.3,<2.2; python_version < "3.9"``
127
+ ``pyarrow`` ``>=14.0.1``
128
+ ======================================= =========================================
121
129
 
122
130
  Cross provider package dependencies
123
131
  -----------------------------------
@@ -139,4 +147,4 @@ Dependent package
139
147
  ============================================================================================================ ==============
140
148
 
141
149
  The changelog for the provider package can be found in the
142
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.5.0/changelog.html>`_.
150
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html>`_.
@@ -42,7 +42,7 @@
42
42
 
43
43
  Package ``apache-airflow-providers-databricks``
44
44
 
45
- Release: ``6.5.0``
45
+ Release: ``6.6.0``
46
46
 
47
47
 
48
48
  `Databricks <https://databricks.com/>`__
@@ -55,7 +55,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
55
55
  are in ``airflow.providers.databricks`` python package.
56
56
 
57
57
  You can find package information and changelog for the provider
58
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.5.0/>`_.
58
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/>`_.
59
59
 
60
60
  Installation
61
61
  ------------
@@ -69,15 +69,19 @@ The package supports the following python versions: 3.8,3.9,3.10,3.11,3.12
69
69
  Requirements
70
70
  ------------
71
71
 
72
- ======================================= ==========================
72
+ ======================================= =========================================
73
73
  PIP package Version required
74
- ======================================= ==========================
74
+ ======================================= =========================================
75
75
  ``apache-airflow`` ``>=2.7.0``
76
76
  ``apache-airflow-providers-common-sql`` ``>=1.10.0``
77
77
  ``requests`` ``>=2.27.0,<3``
78
78
  ``databricks-sql-connector`` ``>=2.0.0,!=2.9.0,<3.0.0``
79
79
  ``aiohttp`` ``>=3.9.2,<4``
80
- ======================================= ==========================
80
+ ``mergedeep`` ``>=1.3.4``
81
+ ``pandas`` ``>=2.1.2,<2.2; python_version >= "3.9"``
82
+ ``pandas`` ``>=1.5.3,<2.2; python_version < "3.9"``
83
+ ``pyarrow`` ``>=14.0.1``
84
+ ======================================= =========================================
81
85
 
82
86
  Cross provider package dependencies
83
87
  -----------------------------------
@@ -99,4 +103,4 @@ Dependent package
99
103
  ============================================================================================================ ==============
100
104
 
101
105
  The changelog for the provider package can be found in the
102
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.5.0/changelog.html>`_.
106
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html>`_.
@@ -215,7 +215,7 @@ Third party Apache 2.0 licenses
215
215
 
216
216
  The following components are provided under the Apache 2.0 License.
217
217
  See project link for details. The text of each license is also included
218
- at licenses/LICENSE-[project].txt.
218
+ at 3rd-party-licenses/LICENSE-[project].txt.
219
219
 
220
220
  (ALv2 License) hue v4.3.0 (https://github.com/cloudera/hue/)
221
221
  (ALv2 License) jqclock v2.3.0 (https://github.com/JohnRDOrazio/jQuery-Clock-Plugin)
@@ -227,7 +227,7 @@ MIT licenses
227
227
  ========================================================================
228
228
 
229
229
  The following components are provided under the MIT License. See project link for details.
230
- The text of each license is also included at licenses/LICENSE-[project].txt.
230
+ The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
231
231
 
232
232
  (MIT License) jquery v3.5.1 (https://jquery.org/license/)
233
233
  (MIT License) dagre-d3 v0.6.4 (https://github.com/cpettitt/dagre-d3)
@@ -243,11 +243,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
243
243
  BSD 3-Clause licenses
244
244
  ========================================================================
245
245
  The following components are provided under the BSD 3-Clause license. See project links for details.
246
- The text of each license is also included at licenses/LICENSE-[project].txt.
246
+ The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
247
247
 
248
248
  (BSD 3 License) d3 v5.16.0 (https://d3js.org)
249
249
  (BSD 3 License) d3-shape v2.1.0 (https://github.com/d3/d3-shape)
250
250
  (BSD 3 License) cgroupspy 0.2.1 (https://github.com/cloudsigma/cgroupspy)
251
251
 
252
252
  ========================================================================
253
- See licenses/LICENSES-ui.txt for packages used in `/airflow/www`
253
+ See 3rd-party-licenses/LICENSES-ui.txt for packages used in `/airflow/www`
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "6.5.0"
32
+ __version__ = "6.6.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.7.0"
@@ -28,8 +28,9 @@ def get_provider_info():
28
28
  "name": "Databricks",
29
29
  "description": "`Databricks <https://databricks.com/>`__\n",
30
30
  "state": "ready",
31
- "source-date-epoch": 1716287262,
31
+ "source-date-epoch": 1718604145,
32
32
  "versions": [
33
+ "6.6.0",
33
34
  "6.5.0",
34
35
  "6.4.0",
35
36
  "6.3.0",
@@ -74,6 +75,10 @@ def get_provider_info():
74
75
  "requests>=2.27.0,<3",
75
76
  "databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
76
77
  "aiohttp>=3.9.2, <4",
78
+ "mergedeep>=1.3.4",
79
+ 'pandas>=2.1.2,<2.2;python_version>="3.9"',
80
+ 'pandas>=1.5.3,<2.2;python_version<"3.9"',
81
+ "pyarrow>=14.0.1",
77
82
  ],
78
83
  "additional-extras": [
79
84
  {
@@ -92,6 +97,7 @@ def get_provider_info():
92
97
  "/docs/apache-airflow-providers-databricks/operators/notebook.rst",
93
98
  "/docs/apache-airflow-providers-databricks/operators/submit_run.rst",
94
99
  "/docs/apache-airflow-providers-databricks/operators/run_now.rst",
100
+ "/docs/apache-airflow-providers-databricks/operators/task.rst",
95
101
  ],
96
102
  "logo": "/integration-logos/databricks/Databricks.png",
97
103
  "tags": ["service"],
@@ -117,6 +123,13 @@ def get_provider_info():
117
123
  "logo": "/integration-logos/databricks/Databricks.png",
118
124
  "tags": ["service"],
119
125
  },
126
+ {
127
+ "integration-name": "Databricks Workflow",
128
+ "external-doc-url": "https://docs.databricks.com/en/workflows/index.html",
129
+ "how-to-guide": ["/docs/apache-airflow-providers-databricks/operators/workflow.rst"],
130
+ "logo": "/integration-logos/databricks/Databricks.png",
131
+ "tags": ["service"],
132
+ },
120
133
  ],
121
134
  "operators": [
122
135
  {
@@ -131,6 +144,10 @@ def get_provider_info():
131
144
  "integration-name": "Databricks Repos",
132
145
  "python-modules": ["airflow.providers.databricks.operators.databricks_repos"],
133
146
  },
147
+ {
148
+ "integration-name": "Databricks Workflow",
149
+ "python-modules": ["airflow.providers.databricks.operators.databricks_workflow"],
150
+ },
134
151
  ],
135
152
  "hooks": [
136
153
  {
@@ -29,6 +29,7 @@ or the ``api/2.1/jobs/runs/submit``
29
29
  from __future__ import annotations
30
30
 
31
31
  import json
32
+ from enum import Enum
32
33
  from typing import Any
33
34
 
34
35
  from requests import exceptions as requests_exceptions
@@ -63,6 +64,23 @@ WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
63
64
  SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
64
65
 
65
66
 
67
+ class RunLifeCycleState(Enum):
68
+ """Enum for the run life cycle state concept of Databricks runs.
69
+
70
+ See more information at: https://docs.databricks.com/api/azure/workspace/jobs/listruns#runs-state-life_cycle_state
71
+ """
72
+
73
+ BLOCKED = "BLOCKED"
74
+ INTERNAL_ERROR = "INTERNAL_ERROR"
75
+ PENDING = "PENDING"
76
+ QUEUED = "QUEUED"
77
+ RUNNING = "RUNNING"
78
+ SKIPPED = "SKIPPED"
79
+ TERMINATED = "TERMINATED"
80
+ TERMINATING = "TERMINATING"
81
+ WAITING_FOR_RETRY = "WAITING_FOR_RETRY"
82
+
83
+
66
84
  class RunState:
67
85
  """Utility class for the run state concept of Databricks runs."""
68
86
 
@@ -238,6 +256,7 @@ class DatabricksHook(BaseDatabricksHook):
238
256
  expand_tasks: bool = False,
239
257
  job_name: str | None = None,
240
258
  page_token: str | None = None,
259
+ include_user_names: bool = False,
241
260
  ) -> list[dict[str, Any]]:
242
261
  """
243
262
  List the jobs in the Databricks Job Service.
@@ -257,6 +276,7 @@ class DatabricksHook(BaseDatabricksHook):
257
276
  payload: dict[str, Any] = {
258
277
  "limit": limit,
259
278
  "expand_tasks": expand_tasks,
279
+ "include_user_names": include_user_names,
260
280
  }
261
281
  payload["page_token"] = page_token
262
282
  if job_name:
@@ -499,21 +499,21 @@ class BaseDatabricksHook(BaseHook):
499
499
  )
500
500
  return self.databricks_conn.extra_dejson["token"]
501
501
  elif not self.databricks_conn.login and self.databricks_conn.password:
502
- self.log.info("Using token auth.")
502
+ self.log.debug("Using token auth.")
503
503
  return self.databricks_conn.password
504
504
  elif "azure_tenant_id" in self.databricks_conn.extra_dejson:
505
505
  if self.databricks_conn.login == "" or self.databricks_conn.password == "":
506
506
  raise AirflowException("Azure SPN credentials aren't provided")
507
- self.log.info("Using AAD Token for SPN.")
507
+ self.log.debug("Using AAD Token for SPN.")
508
508
  return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
509
509
  elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
510
- self.log.info("Using AAD Token for managed identity.")
510
+ self.log.debug("Using AAD Token for managed identity.")
511
511
  self._check_azure_metadata_service()
512
512
  return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
513
513
  elif self.databricks_conn.extra_dejson.get("service_principal_oauth", False):
514
514
  if self.databricks_conn.login == "" or self.databricks_conn.password == "":
515
515
  raise AirflowException("Service Principal credentials aren't provided")
516
- self.log.info("Using Service Principal Token.")
516
+ self.log.debug("Using Service Principal Token.")
517
517
  return self._get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host))
518
518
  elif raise_error:
519
519
  raise AirflowException("Token authentication isn't configured")
@@ -527,21 +527,21 @@ class BaseDatabricksHook(BaseHook):
527
527
  )
528
528
  return self.databricks_conn.extra_dejson["token"]
529
529
  elif not self.databricks_conn.login and self.databricks_conn.password:
530
- self.log.info("Using token auth.")
530
+ self.log.debug("Using token auth.")
531
531
  return self.databricks_conn.password
532
532
  elif "azure_tenant_id" in self.databricks_conn.extra_dejson:
533
533
  if self.databricks_conn.login == "" or self.databricks_conn.password == "":
534
534
  raise AirflowException("Azure SPN credentials aren't provided")
535
- self.log.info("Using AAD Token for SPN.")
535
+ self.log.debug("Using AAD Token for SPN.")
536
536
  return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
537
537
  elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
538
- self.log.info("Using AAD Token for managed identity.")
538
+ self.log.debug("Using AAD Token for managed identity.")
539
539
  await self._a_check_azure_metadata_service()
540
540
  return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
541
541
  elif self.databricks_conn.extra_dejson.get("service_principal_oauth", False):
542
542
  if self.databricks_conn.login == "" or self.databricks_conn.password == "":
543
543
  raise AirflowException("Service Principal credentials aren't provided")
544
- self.log.info("Using Service Principal Token.")
544
+ self.log.debug("Using Service Principal Token.")
545
545
  return await self._a_get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host))
546
546
  elif raise_error:
547
547
  raise AirflowException("Token authentication isn't configured")
@@ -20,6 +20,7 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  import time
23
+ from abc import ABC, abstractmethod
23
24
  from functools import cached_property
24
25
  from logging import Logger
25
26
  from typing import TYPE_CHECKING, Any, Sequence
@@ -29,13 +30,18 @@ from deprecated import deprecated
29
30
  from airflow.configuration import conf
30
31
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
31
32
  from airflow.models import BaseOperator, BaseOperatorLink, XCom
32
- from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState
33
+ from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState
34
+ from airflow.providers.databricks.operators.databricks_workflow import (
35
+ DatabricksWorkflowTaskGroup,
36
+ WorkflowRunMetadata,
37
+ )
33
38
  from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
34
39
  from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
35
40
 
36
41
  if TYPE_CHECKING:
37
42
  from airflow.models.taskinstancekey import TaskInstanceKey
38
43
  from airflow.utils.context import Context
44
+ from airflow.utils.task_group import TaskGroup
39
45
 
40
46
  DEFER_METHOD_NAME = "execute_complete"
41
47
  XCOM_RUN_ID_KEY = "run_id"
@@ -894,79 +900,64 @@ class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
894
900
  super().__init__(deferrable=True, *args, **kwargs)
895
901
 
896
902
 
897
- class DatabricksNotebookOperator(BaseOperator):
903
+ class DatabricksTaskBaseOperator(BaseOperator, ABC):
898
904
  """
899
- Runs a notebook on Databricks using an Airflow operator.
900
-
901
- The DatabricksNotebookOperator allows users to launch and monitor notebook
902
- job runs on Databricks as Airflow tasks.
905
+ Base class for operators that are run as Databricks job tasks or tasks within a Databricks workflow.
903
906
 
904
- .. seealso::
905
- For more information on how to use this operator, take a look at the guide:
906
- :ref:`howto/operator:DatabricksNotebookOperator`
907
-
908
- :param notebook_path: The path to the notebook in Databricks.
909
- :param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved
910
- from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository
911
- defined in git_source. If the value is empty, the task will use GIT if git_source is defined
912
- and WORKSPACE otherwise. For more information please visit
913
- https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
914
- :param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task.
915
- :param notebook_packages: A list of the Python libraries to be installed on the cluster running the
916
- notebook.
917
- :param new_cluster: Specs for a new cluster on which this task will be run.
907
+ :param caller: The name of the caller operator to be used in the logs.
908
+ :param databricks_conn_id: The name of the Airflow connection to use.
909
+ :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
910
+ :param databricks_retry_delay: Number of seconds to wait between retries.
911
+ :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
912
+ :param deferrable: Whether to run the operator in the deferrable mode.
918
913
  :param existing_cluster_id: ID for existing cluster on which to run this task.
919
914
  :param job_cluster_key: The key for the job cluster.
915
+ :param new_cluster: Specs for a new cluster on which this task will be run.
916
+ :param notebook_packages: A list of the Python libraries to be installed on the cluster running the
917
+ notebook.
918
+ :param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task.
920
919
  :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
921
- :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
922
- :param databricks_retry_delay: Number of seconds to wait between retries.
923
- :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
924
920
  :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
925
- :param databricks_conn_id: The name of the Airflow connection to use.
926
- :param deferrable: Run operator in the deferrable mode.
921
+ :param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within
922
+ a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow.
927
923
  """
928
924
 
929
- template_fields = ("notebook_params",)
930
- CALLER = "DatabricksNotebookOperator"
931
-
932
925
  def __init__(
933
926
  self,
934
- notebook_path: str,
935
- source: str,
936
- notebook_params: dict | None = None,
937
- notebook_packages: list[dict[str, Any]] | None = None,
938
- new_cluster: dict[str, Any] | None = None,
927
+ caller: str = "DatabricksTaskBaseOperator",
928
+ databricks_conn_id: str = "databricks_default",
929
+ databricks_retry_args: dict[Any, Any] | None = None,
930
+ databricks_retry_delay: int = 1,
931
+ databricks_retry_limit: int = 3,
932
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
939
933
  existing_cluster_id: str = "",
940
934
  job_cluster_key: str = "",
935
+ new_cluster: dict[str, Any] | None = None,
941
936
  polling_period_seconds: int = 5,
942
- databricks_retry_limit: int = 3,
943
- databricks_retry_delay: int = 1,
944
- databricks_retry_args: dict[Any, Any] | None = None,
945
937
  wait_for_termination: bool = True,
946
- databricks_conn_id: str = "databricks_default",
947
- deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
938
+ workflow_run_metadata: dict[str, Any] | None = None,
948
939
  **kwargs: Any,
949
940
  ):
950
- self.notebook_path = notebook_path
951
- self.source = source
952
- self.notebook_params = notebook_params or {}
953
- self.notebook_packages = notebook_packages or []
954
- self.new_cluster = new_cluster or {}
941
+ self.caller = caller
942
+ self.databricks_conn_id = databricks_conn_id
943
+ self.databricks_retry_args = databricks_retry_args
944
+ self.databricks_retry_delay = databricks_retry_delay
945
+ self.databricks_retry_limit = databricks_retry_limit
946
+ self.deferrable = deferrable
955
947
  self.existing_cluster_id = existing_cluster_id
956
948
  self.job_cluster_key = job_cluster_key
949
+ self.new_cluster = new_cluster or {}
957
950
  self.polling_period_seconds = polling_period_seconds
958
- self.databricks_retry_limit = databricks_retry_limit
959
- self.databricks_retry_delay = databricks_retry_delay
960
- self.databricks_retry_args = databricks_retry_args
961
951
  self.wait_for_termination = wait_for_termination
962
- self.databricks_conn_id = databricks_conn_id
952
+ self.workflow_run_metadata = workflow_run_metadata
953
+
963
954
  self.databricks_run_id: int | None = None
964
- self.deferrable = deferrable
955
+
965
956
  super().__init__(**kwargs)
966
957
 
967
958
  @cached_property
968
959
  def _hook(self) -> DatabricksHook:
969
- return self._get_hook(caller=self.CALLER)
960
+ return self._get_hook(caller=self.caller)
970
961
 
971
962
  def _get_hook(self, caller: str) -> DatabricksHook:
972
963
  return DatabricksHook(
@@ -974,47 +965,37 @@ class DatabricksNotebookOperator(BaseOperator):
974
965
  retry_limit=self.databricks_retry_limit,
975
966
  retry_delay=self.databricks_retry_delay,
976
967
  retry_args=self.databricks_retry_args,
977
- caller=self.CALLER,
968
+ caller=caller,
978
969
  )
979
970
 
980
- def _get_task_timeout_seconds(self) -> int:
971
+ def _get_databricks_task_id(self, task_id: str) -> str:
972
+ """Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
973
+ return f"{self.dag_id}__{task_id.replace('.', '__')}"
974
+
975
+ @property
976
+ def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
981
977
  """
982
- Get the timeout seconds value for the Databricks job based on the execution timeout value provided for the Airflow task.
978
+ Traverse up parent TaskGroups until the `is_databricks` flag associated with the root DatabricksWorkflowTaskGroup is found.
983
979
 
984
- By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when
985
- execution_timeout is not defined, the task continues to run indefinitely. Therefore,
986
- to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating
987
- that the job should run indefinitely. This aligns with the default behavior of Databricks jobs,
988
- where a timeout seconds value of 0 signifies an indefinite run duration.
989
- More details can be found in the Databricks documentation:
990
- See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds
980
+ If found, returns the task group. Otherwise, return None.
991
981
  """
992
- if self.execution_timeout is None:
993
- return 0
994
- execution_timeout_seconds = int(self.execution_timeout.total_seconds())
995
- if execution_timeout_seconds == 0:
996
- raise ValueError(
997
- "If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to "
998
- "`None` if you desire the task to run indefinitely."
999
- )
1000
- return execution_timeout_seconds
982
+ parent_tg: TaskGroup | DatabricksWorkflowTaskGroup | None = self.task_group
1001
983
 
1002
- def _get_task_base_json(self) -> dict[str, Any]:
1003
- """Get task base json to be used for task submissions."""
1004
- return {
1005
- "timeout_seconds": self._get_task_timeout_seconds(),
1006
- "email_notifications": {},
1007
- "notebook_task": {
1008
- "notebook_path": self.notebook_path,
1009
- "source": self.source,
1010
- "base_parameters": self.notebook_params,
1011
- },
1012
- "libraries": self.notebook_packages,
1013
- }
984
+ while parent_tg:
985
+ if getattr(parent_tg, "is_databricks", False):
986
+ return parent_tg # type: ignore[return-value]
1014
987
 
1015
- def _get_databricks_task_id(self, task_id: str) -> str:
1016
- """Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
1017
- return f"{self.dag_id}__{task_id.replace('.', '__')}"
988
+ if getattr(parent_tg, "task_group", None):
989
+ parent_tg = parent_tg.task_group
990
+ else:
991
+ return None
992
+
993
+ return None
994
+
995
+ @abstractmethod
996
+ def _get_task_base_json(self) -> dict[str, Any]:
997
+ """Get the base json for the task."""
998
+ raise NotImplementedError()
1018
999
 
1019
1000
  def _get_run_json(self) -> dict[str, Any]:
1020
1001
  """Get run json to be used for task submissions."""
@@ -1032,65 +1013,335 @@ class DatabricksNotebookOperator(BaseOperator):
1032
1013
  raise ValueError("Must specify either existing_cluster_id or new_cluster.")
1033
1014
  return run_json
1034
1015
 
1035
- def launch_notebook_job(self) -> int:
1016
+ def _launch_job(self) -> int:
1017
+ """Launch the job on Databricks."""
1036
1018
  run_json = self._get_run_json()
1037
1019
  self.databricks_run_id = self._hook.submit_run(run_json)
1038
1020
  url = self._hook.get_run_page_url(self.databricks_run_id)
1039
1021
  self.log.info("Check the job run in Databricks: %s", url)
1040
1022
  return self.databricks_run_id
1041
1023
 
1024
+ def _handle_terminal_run_state(self, run_state: RunState) -> None:
1025
+ """Handle the terminal state of the run."""
1026
+ if run_state.life_cycle_state != RunLifeCycleState.TERMINATED.value:
1027
+ raise AirflowException(
1028
+ f"Databricks job failed with state {run_state.life_cycle_state}. Message: {run_state.state_message}"
1029
+ )
1030
+ if not run_state.is_successful:
1031
+ raise AirflowException(
1032
+ f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
1033
+ )
1034
+ self.log.info("Task succeeded. Final state %s.", run_state.result_state)
1035
+
1036
+ def _get_current_databricks_task(self) -> dict[str, Any]:
1037
+ """Retrieve the Databricks task corresponding to the current Airflow task."""
1038
+ if self.databricks_run_id is None:
1039
+ raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
1040
+ return {task["task_key"]: task for task in self._hook.get_run(self.databricks_run_id)["tasks"]}[
1041
+ self._get_databricks_task_id(self.task_id)
1042
+ ]
1043
+
1044
+ def _convert_to_databricks_workflow_task(
1045
+ self, relevant_upstreams: list[BaseOperator], context: Context | None = None
1046
+ ) -> dict[str, object]:
1047
+ """Convert the operator to a Databricks workflow task that can be a task in a workflow."""
1048
+ base_task_json = self._get_task_base_json()
1049
+ result = {
1050
+ "task_key": self._get_databricks_task_id(self.task_id),
1051
+ "depends_on": [
1052
+ {"task_key": self._get_databricks_task_id(task_id)}
1053
+ for task_id in self.upstream_task_ids
1054
+ if task_id in relevant_upstreams
1055
+ ],
1056
+ **base_task_json,
1057
+ }
1058
+
1059
+ if self.existing_cluster_id and self.job_cluster_key:
1060
+ raise ValueError(
1061
+ "Both existing_cluster_id and job_cluster_key are set. Only one can be set per task."
1062
+ )
1063
+ if self.existing_cluster_id:
1064
+ result["existing_cluster_id"] = self.existing_cluster_id
1065
+ elif self.job_cluster_key:
1066
+ result["job_cluster_key"] = self.job_cluster_key
1067
+
1068
+ return result
1069
+
1042
1070
  def monitor_databricks_job(self) -> None:
1071
+ """
1072
+ Monitor the Databricks job.
1073
+
1074
+ Wait for the job to terminate. If deferrable, defer the task.
1075
+ """
1043
1076
  if self.databricks_run_id is None:
1044
1077
  raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
1045
- run = self._hook.get_run(self.databricks_run_id)
1078
+ current_task_run_id = self._get_current_databricks_task()["run_id"]
1079
+ run = self._hook.get_run(current_task_run_id)
1080
+ run_page_url = run["run_page_url"]
1081
+ self.log.info("Check the task run in Databricks: %s", run_page_url)
1046
1082
  run_state = RunState(**run["state"])
1047
- self.log.info("Current state of the job: %s", run_state.life_cycle_state)
1083
+ self.log.info(
1084
+ "Current state of the the databricks task %s is %s",
1085
+ self._get_databricks_task_id(self.task_id),
1086
+ run_state.life_cycle_state,
1087
+ )
1048
1088
  if self.deferrable and not run_state.is_terminal:
1049
1089
  self.defer(
1050
1090
  trigger=DatabricksExecutionTrigger(
1051
- run_id=self.databricks_run_id,
1091
+ run_id=current_task_run_id,
1052
1092
  databricks_conn_id=self.databricks_conn_id,
1053
1093
  polling_period_seconds=self.polling_period_seconds,
1054
1094
  retry_limit=self.databricks_retry_limit,
1055
1095
  retry_delay=self.databricks_retry_delay,
1056
1096
  retry_args=self.databricks_retry_args,
1057
- caller=self.CALLER,
1097
+ caller=self.caller,
1058
1098
  ),
1059
1099
  method_name=DEFER_METHOD_NAME,
1060
1100
  )
1061
1101
  while not run_state.is_terminal:
1062
1102
  time.sleep(self.polling_period_seconds)
1063
- run = self._hook.get_run(self.databricks_run_id)
1103
+ run = self._hook.get_run(current_task_run_id)
1064
1104
  run_state = RunState(**run["state"])
1065
1105
  self.log.info(
1066
- "task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state
1067
- )
1068
- self.log.info("Current state of the job: %s", run_state.life_cycle_state)
1069
- if run_state.life_cycle_state != "TERMINATED":
1070
- raise AirflowException(
1071
- f"Databricks job failed with state {run_state.life_cycle_state}. "
1072
- f"Message: {run_state.state_message}"
1106
+ "Current state of the databricks task %s is %s",
1107
+ self._get_databricks_task_id(self.task_id),
1108
+ run_state.life_cycle_state,
1073
1109
  )
1074
- if not run_state.is_successful:
1075
- raise AirflowException(
1076
- f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
1077
- )
1078
- self.log.info("Task succeeded. Final state %s.", run_state.result_state)
1110
+ self._handle_terminal_run_state(run_state)
1079
1111
 
1080
1112
  def execute(self, context: Context) -> None:
1081
- self.launch_notebook_job()
1113
+ """Execute the operator. Launch the job and monitor it if wait_for_termination is set to True."""
1114
+ if self._databricks_workflow_task_group:
1115
+ # If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched.
1116
+ if not self.workflow_run_metadata:
1117
+ launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch"))
1118
+ self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id)
1119
+ workflow_run_metadata = WorkflowRunMetadata( # type: ignore[arg-type]
1120
+ **self.workflow_run_metadata
1121
+ )
1122
+ self.databricks_run_id = workflow_run_metadata.run_id
1123
+ self.databricks_conn_id = workflow_run_metadata.conn_id
1124
+ else:
1125
+ self._launch_job()
1082
1126
  if self.wait_for_termination:
1083
1127
  self.monitor_databricks_job()
1084
1128
 
1085
1129
  def execute_complete(self, context: dict | None, event: dict) -> None:
1086
1130
  run_state = RunState.from_json(event["run_state"])
1087
- if run_state.life_cycle_state != "TERMINATED":
1088
- raise AirflowException(
1089
- f"Databricks job failed with state {run_state.life_cycle_state}. "
1090
- f"Message: {run_state.state_message}"
1131
+ self._handle_terminal_run_state(run_state)
1132
+
1133
+
1134
+ class DatabricksNotebookOperator(DatabricksTaskBaseOperator):
1135
+ """
1136
+ Runs a notebook on Databricks using an Airflow operator.
1137
+
1138
+ The DatabricksNotebookOperator allows users to launch and monitor notebook job runs on Databricks as
1139
+ Airflow tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job
1140
+ clusters, which allows users to run their tasks on cheaper clusters that can be shared between tasks.
1141
+
1142
+ .. seealso::
1143
+ For more information on how to use this operator, take a look at the guide:
1144
+ :ref:`howto/operator:DatabricksNotebookOperator`
1145
+
1146
+ :param notebook_path: The path to the notebook in Databricks.
1147
+ :param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved
1148
+ from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository
1149
+ defined in git_source. If the value is empty, the task will use GIT if git_source is defined
1150
+ and WORKSPACE otherwise. For more information please visit
1151
+ https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
1152
+ :param databricks_conn_id: The name of the Airflow connection to use.
1153
+ :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
1154
+ :param databricks_retry_delay: Number of seconds to wait between retries.
1155
+ :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
1156
+ :param deferrable: Whether to run the operator in the deferrable mode.
1157
+ :param existing_cluster_id: ID for existing cluster on which to run this task.
1158
+ :param job_cluster_key: The key for the job cluster.
1159
+ :param new_cluster: Specs for a new cluster on which this task will be run.
1160
+ :param notebook_packages: A list of the Python libraries to be installed on the cluster running the
1161
+ notebook.
1162
+ :param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task.
1163
+ :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
1164
+ :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
1165
+ :param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within
1166
+ a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow.
1167
+ """
1168
+
1169
+ template_fields = (
1170
+ "notebook_params",
1171
+ "workflow_run_metadata",
1172
+ )
1173
+ CALLER = "DatabricksNotebookOperator"
1174
+
1175
+ def __init__(
1176
+ self,
1177
+ notebook_path: str,
1178
+ source: str,
1179
+ databricks_conn_id: str = "databricks_default",
1180
+ databricks_retry_args: dict[Any, Any] | None = None,
1181
+ databricks_retry_delay: int = 1,
1182
+ databricks_retry_limit: int = 3,
1183
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1184
+ existing_cluster_id: str = "",
1185
+ job_cluster_key: str = "",
1186
+ new_cluster: dict[str, Any] | None = None,
1187
+ notebook_packages: list[dict[str, Any]] | None = None,
1188
+ notebook_params: dict | None = None,
1189
+ polling_period_seconds: int = 5,
1190
+ wait_for_termination: bool = True,
1191
+ workflow_run_metadata: dict | None = None,
1192
+ **kwargs: Any,
1193
+ ):
1194
+ self.notebook_path = notebook_path
1195
+ self.source = source
1196
+ self.notebook_packages = notebook_packages or []
1197
+ self.notebook_params = notebook_params or {}
1198
+
1199
+ super().__init__(
1200
+ caller=self.CALLER,
1201
+ databricks_conn_id=databricks_conn_id,
1202
+ databricks_retry_args=databricks_retry_args,
1203
+ databricks_retry_delay=databricks_retry_delay,
1204
+ databricks_retry_limit=databricks_retry_limit,
1205
+ deferrable=deferrable,
1206
+ existing_cluster_id=existing_cluster_id,
1207
+ job_cluster_key=job_cluster_key,
1208
+ new_cluster=new_cluster,
1209
+ polling_period_seconds=polling_period_seconds,
1210
+ wait_for_termination=wait_for_termination,
1211
+ workflow_run_metadata=workflow_run_metadata,
1212
+ **kwargs,
1213
+ )
1214
+
1215
+ def _get_task_timeout_seconds(self) -> int:
1216
+ """
1217
+ Get the timeout seconds value for the Databricks job based on the execution timeout value provided for the Airflow task.
1218
+
1219
+ By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when
1220
+ execution_timeout is not defined, the task continues to run indefinitely. Therefore,
1221
+ to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating
1222
+ that the job should run indefinitely. This aligns with the default behavior of Databricks jobs,
1223
+ where a timeout seconds value of 0 signifies an indefinite run duration.
1224
+ More details can be found in the Databricks documentation:
1225
+ See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds
1226
+ """
1227
+ if self.execution_timeout is None:
1228
+ return 0
1229
+ execution_timeout_seconds = int(self.execution_timeout.total_seconds())
1230
+ if execution_timeout_seconds == 0:
1231
+ raise ValueError(
1232
+ "If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to "
1233
+ "`None` if you desire the task to run indefinitely."
1091
1234
  )
1092
- if not run_state.is_successful:
1235
+ return execution_timeout_seconds
1236
+
1237
+ def _get_task_base_json(self) -> dict[str, Any]:
1238
+ """Get task base json to be used for task submissions."""
1239
+ return {
1240
+ "timeout_seconds": self._get_task_timeout_seconds(),
1241
+ "email_notifications": {},
1242
+ "notebook_task": {
1243
+ "notebook_path": self.notebook_path,
1244
+ "source": self.source,
1245
+ "base_parameters": self.notebook_params,
1246
+ },
1247
+ "libraries": self.notebook_packages,
1248
+ }
1249
+
1250
+ def _extend_workflow_notebook_packages(
1251
+ self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup
1252
+ ) -> None:
1253
+ """Extend the task group packages into the notebook's packages, without adding any duplicates."""
1254
+ for task_group_package in databricks_workflow_task_group.notebook_packages:
1255
+ exists = any(
1256
+ task_group_package == existing_package for existing_package in self.notebook_packages
1257
+ )
1258
+ if not exists:
1259
+ self.notebook_packages.append(task_group_package)
1260
+
1261
+ def _convert_to_databricks_workflow_task(
1262
+ self, relevant_upstreams: list[BaseOperator], context: Context | None = None
1263
+ ) -> dict[str, object]:
1264
+ """Convert the operator to a Databricks workflow task that can be a task in a workflow."""
1265
+ databricks_workflow_task_group = self._databricks_workflow_task_group
1266
+ if not databricks_workflow_task_group:
1093
1267
  raise AirflowException(
1094
- f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
1268
+ "Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup."
1095
1269
  )
1096
- self.log.info("Task succeeded. Final state %s.", run_state.result_state)
1270
+
1271
+ if hasattr(databricks_workflow_task_group, "notebook_packages"):
1272
+ self._extend_workflow_notebook_packages(databricks_workflow_task_group)
1273
+
1274
+ if hasattr(databricks_workflow_task_group, "notebook_params"):
1275
+ self.notebook_params = {
1276
+ **self.notebook_params,
1277
+ **databricks_workflow_task_group.notebook_params,
1278
+ }
1279
+
1280
+ return super()._convert_to_databricks_workflow_task(relevant_upstreams, context=context)
1281
+
1282
+
1283
+ class DatabricksTaskOperator(DatabricksTaskBaseOperator):
1284
+ """
1285
+ Runs a task on Databricks using an Airflow operator.
1286
+
1287
+ The DatabricksTaskOperator allows users to launch and monitor task job runs on Databricks as Airflow
1288
+ tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, which
1289
+ allows users to run their tasks on cheaper clusters that can be shared between tasks.
1290
+
1291
+ .. seealso::
1292
+ For more information on how to use this operator, take a look at the guide:
1293
+ :ref:`howto/operator:DatabricksTaskOperator`
1294
+
1295
+ :param task_config: The configuration of the task to be run on Databricks.
1296
+ :param databricks_conn_id: The name of the Airflow connection to use.
1297
+ :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
1298
+ :param databricks_retry_delay: Number of seconds to wait between retries.
1299
+ :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
1300
+ :param deferrable: Whether to run the operator in the deferrable mode.
1301
+ :param existing_cluster_id: ID for existing cluster on which to run this task.
1302
+ :param job_cluster_key: The key for the job cluster.
1303
+ :param new_cluster: Specs for a new cluster on which this task will be run.
1304
+ :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
1305
+ :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
1306
+ """
1307
+
1308
+ CALLER = "DatabricksTaskOperator"
1309
+ template_fields = ("workflow_run_metadata",)
1310
+
1311
+ def __init__(
1312
+ self,
1313
+ task_config: dict,
1314
+ databricks_conn_id: str = "databricks_default",
1315
+ databricks_retry_args: dict[Any, Any] | None = None,
1316
+ databricks_retry_delay: int = 1,
1317
+ databricks_retry_limit: int = 3,
1318
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1319
+ existing_cluster_id: str = "",
1320
+ job_cluster_key: str = "",
1321
+ new_cluster: dict[str, Any] | None = None,
1322
+ polling_period_seconds: int = 5,
1323
+ wait_for_termination: bool = True,
1324
+ workflow_run_metadata: dict | None = None,
1325
+ **kwargs,
1326
+ ):
1327
+ self.task_config = task_config
1328
+
1329
+ super().__init__(
1330
+ caller=self.CALLER,
1331
+ databricks_conn_id=databricks_conn_id,
1332
+ databricks_retry_args=databricks_retry_args,
1333
+ databricks_retry_delay=databricks_retry_delay,
1334
+ databricks_retry_limit=databricks_retry_limit,
1335
+ deferrable=deferrable,
1336
+ existing_cluster_id=existing_cluster_id,
1337
+ job_cluster_key=job_cluster_key,
1338
+ new_cluster=new_cluster,
1339
+ polling_period_seconds=polling_period_seconds,
1340
+ wait_for_termination=wait_for_termination,
1341
+ workflow_run_metadata=workflow_run_metadata,
1342
+ **kwargs,
1343
+ )
1344
+
1345
+ def _get_task_base_json(self) -> dict[str, Any]:
1346
+ """Get task base json to be used for task submissions."""
1347
+ return self.task_config
@@ -0,0 +1,312 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ import time
22
+ from dataclasses import dataclass
23
+ from functools import cached_property
24
+ from typing import TYPE_CHECKING, Any
25
+
26
+ from mergedeep import merge
27
+
28
+ from airflow.exceptions import AirflowException
29
+ from airflow.models import BaseOperator
30
+ from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState
31
+ from airflow.utils.task_group import TaskGroup
32
+
33
+ if TYPE_CHECKING:
34
+ from types import TracebackType
35
+
36
+ from airflow.models.taskmixin import DAGNode
37
+ from airflow.utils.context import Context
38
+
39
+
40
+ @dataclass
41
+ class WorkflowRunMetadata:
42
+ """
43
+ Metadata for a Databricks workflow run.
44
+
45
+ :param run_id: The ID of the Databricks workflow run.
46
+ :param job_id: The ID of the Databricks workflow job.
47
+ :param conn_id: The connection ID used to connect to Databricks.
48
+ """
49
+
50
+ conn_id: str
51
+ job_id: str
52
+ run_id: int
53
+
54
+
55
+ def _flatten_node(
56
+ node: TaskGroup | BaseOperator | DAGNode, tasks: list[BaseOperator] | None = None
57
+ ) -> list[BaseOperator]:
58
+ """Flatten a node (either a TaskGroup or Operator) to a list of nodes."""
59
+ if tasks is None:
60
+ tasks = []
61
+ if isinstance(node, BaseOperator):
62
+ return [node]
63
+
64
+ if isinstance(node, TaskGroup):
65
+ new_tasks = []
66
+ for _, child in node.children.items():
67
+ new_tasks += _flatten_node(child, tasks)
68
+
69
+ return tasks + new_tasks
70
+
71
+ return tasks
72
+
73
+
74
+ class _CreateDatabricksWorkflowOperator(BaseOperator):
75
+ """
76
+ Creates a Databricks workflow from a DatabricksWorkflowTaskGroup specified in a DAG.
77
+
78
+ :param task_id: The task_id of the operator
79
+ :param databricks_conn_id: The connection ID to use when connecting to Databricks.
80
+ :param existing_clusters: A list of existing clusters to use for the workflow.
81
+ :param extra_job_params: A dictionary of extra properties which will override the default Databricks
82
+ Workflow Job definitions.
83
+ :param job_clusters: A list of job clusters to use for the workflow.
84
+ :param max_concurrent_runs: The maximum number of concurrent runs for the workflow.
85
+ :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters
86
+ will be passed to all notebooks in the workflow.
87
+ :param tasks_to_convert: A list of tasks to convert to a Databricks workflow. This list can also be
88
+ populated after instantiation using the `add_task` method.
89
+ """
90
+
91
+ template_fields = ("notebook_params",)
92
+ caller = "_CreateDatabricksWorkflowOperator"
93
+
94
+ def __init__(
95
+ self,
96
+ task_id: str,
97
+ databricks_conn_id: str,
98
+ existing_clusters: list[str] | None = None,
99
+ extra_job_params: dict[str, Any] | None = None,
100
+ job_clusters: list[dict[str, object]] | None = None,
101
+ max_concurrent_runs: int = 1,
102
+ notebook_params: dict | None = None,
103
+ tasks_to_convert: list[BaseOperator] | None = None,
104
+ **kwargs,
105
+ ):
106
+ self.databricks_conn_id = databricks_conn_id
107
+ self.existing_clusters = existing_clusters or []
108
+ self.extra_job_params = extra_job_params or {}
109
+ self.job_clusters = job_clusters or []
110
+ self.max_concurrent_runs = max_concurrent_runs
111
+ self.notebook_params = notebook_params or {}
112
+ self.tasks_to_convert = tasks_to_convert or []
113
+ self.relevant_upstreams = [task_id]
114
+ super().__init__(task_id=task_id, **kwargs)
115
+
116
+ def _get_hook(self, caller: str) -> DatabricksHook:
117
+ return DatabricksHook(
118
+ self.databricks_conn_id,
119
+ caller=caller,
120
+ )
121
+
122
+ @cached_property
123
+ def _hook(self) -> DatabricksHook:
124
+ return self._get_hook(caller=self.caller)
125
+
126
+ def add_task(self, task: BaseOperator) -> None:
127
+ """Add a task to the list of tasks to convert to a Databricks workflow."""
128
+ self.tasks_to_convert.append(task)
129
+
130
+ @property
131
+ def job_name(self) -> str:
132
+ if not self.task_group:
133
+ raise AirflowException("Task group must be set before accessing job_name")
134
+ return f"{self.dag_id}.{self.task_group.group_id}"
135
+
136
+ def create_workflow_json(self, context: Context | None = None) -> dict[str, object]:
137
+ """Create a workflow json to be used in the Databricks API."""
138
+ task_json = [
139
+ task._convert_to_databricks_workflow_task( # type: ignore[attr-defined]
140
+ relevant_upstreams=self.relevant_upstreams, context=context
141
+ )
142
+ for task in self.tasks_to_convert
143
+ ]
144
+
145
+ default_json = {
146
+ "name": self.job_name,
147
+ "email_notifications": {"no_alert_for_skipped_runs": False},
148
+ "timeout_seconds": 0,
149
+ "tasks": task_json,
150
+ "format": "MULTI_TASK",
151
+ "job_clusters": self.job_clusters,
152
+ "max_concurrent_runs": self.max_concurrent_runs,
153
+ }
154
+ return merge(default_json, self.extra_job_params)
155
+
156
+ def _create_or_reset_job(self, context: Context) -> int:
157
+ job_spec = self.create_workflow_json(context=context)
158
+ existing_jobs = self._hook.list_jobs(job_name=self.job_name)
159
+ job_id = existing_jobs[0]["job_id"] if existing_jobs else None
160
+ if job_id:
161
+ self.log.info(
162
+ "Updating existing Databricks workflow job %s with spec %s",
163
+ self.job_name,
164
+ json.dumps(job_spec, indent=2),
165
+ )
166
+ self._hook.reset_job(job_id, job_spec)
167
+ else:
168
+ self.log.info(
169
+ "Creating new Databricks workflow job %s with spec %s",
170
+ self.job_name,
171
+ json.dumps(job_spec, indent=2),
172
+ )
173
+ job_id = self._hook.create_job(job_spec)
174
+ return job_id
175
+
176
+ def _wait_for_job_to_start(self, run_id: int) -> None:
177
+ run_url = self._hook.get_run_page_url(run_id)
178
+ self.log.info("Check the progress of the Databricks job at %s", run_url)
179
+ life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state
180
+ if life_cycle_state not in (
181
+ RunLifeCycleState.PENDING.value,
182
+ RunLifeCycleState.RUNNING.value,
183
+ RunLifeCycleState.BLOCKED.value,
184
+ ):
185
+ raise AirflowException(f"Could not start the workflow job. State: {life_cycle_state}")
186
+ while life_cycle_state in (RunLifeCycleState.PENDING.value, RunLifeCycleState.BLOCKED.value):
187
+ self.log.info("Waiting for the Databricks job to start running")
188
+ time.sleep(5)
189
+ life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state
190
+ self.log.info("Databricks job started. State: %s", life_cycle_state)
191
+
192
+ def execute(self, context: Context) -> Any:
193
+ if not isinstance(self.task_group, DatabricksWorkflowTaskGroup):
194
+ raise AirflowException("Task group must be a DatabricksWorkflowTaskGroup")
195
+
196
+ job_id = self._create_or_reset_job(context)
197
+
198
+ run_id = self._hook.run_now(
199
+ {
200
+ "job_id": job_id,
201
+ "jar_params": self.task_group.jar_params,
202
+ "notebook_params": self.notebook_params,
203
+ "python_params": self.task_group.python_params,
204
+ "spark_submit_params": self.task_group.spark_submit_params,
205
+ }
206
+ )
207
+
208
+ self._wait_for_job_to_start(run_id)
209
+
210
+ return {
211
+ "conn_id": self.databricks_conn_id,
212
+ "job_id": job_id,
213
+ "run_id": run_id,
214
+ }
215
+
216
+
217
+ class DatabricksWorkflowTaskGroup(TaskGroup):
218
+ """
219
+ A task group that takes a list of tasks and creates a databricks workflow.
220
+
221
+ The DatabricksWorkflowTaskGroup takes a list of tasks and creates a databricks workflow
222
+ based on the metadata produced by those tasks. For a task to be eligible for this
223
+ TaskGroup, it must contain the ``_convert_to_databricks_workflow_task`` method. If any tasks
224
+ do not contain this method then the Taskgroup will raise an error at parse time.
225
+
226
+ .. seealso::
227
+ For more information on how to use this operator, take a look at the guide:
228
+ :ref:`howto/operator:DatabricksWorkflowTaskGroup`
229
+
230
+ :param databricks_conn_id: The name of the databricks connection to use.
231
+ :param existing_clusters: A list of existing clusters to use for this workflow.
232
+ :param extra_job_params: A dictionary containing properties which will override the default
233
+ Databricks Workflow Job definitions.
234
+ :param jar_params: A list of jar parameters to pass to the workflow. These parameters will be passed to all jar
235
+ tasks in the workflow.
236
+ :param job_clusters: A list of job clusters to use for this workflow.
237
+ :param max_concurrent_runs: The maximum number of concurrent runs for this workflow.
238
+ :param notebook_packages: A list of dictionary of Python packages to be installed. Packages defined
239
+ at the workflow task group level are installed for each of the notebook tasks under it. And
240
+ packages defined at the notebook task level are installed specific for the notebook task.
241
+ :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters
242
+ will be passed to all notebook tasks in the workflow.
243
+ :param python_params: A list of python parameters to pass to the workflow. These parameters will be passed to
244
+ all python tasks in the workflow.
245
+ :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters
246
+ will be passed to all spark submit tasks.
247
+ """
248
+
249
+ is_databricks = True
250
+
251
+ def __init__(
252
+ self,
253
+ databricks_conn_id: str,
254
+ existing_clusters: list[str] | None = None,
255
+ extra_job_params: dict[str, Any] | None = None,
256
+ jar_params: list[str] | None = None,
257
+ job_clusters: list[dict] | None = None,
258
+ max_concurrent_runs: int = 1,
259
+ notebook_packages: list[dict[str, Any]] | None = None,
260
+ notebook_params: dict | None = None,
261
+ python_params: list | None = None,
262
+ spark_submit_params: list | None = None,
263
+ **kwargs,
264
+ ):
265
+ self.databricks_conn_id = databricks_conn_id
266
+ self.existing_clusters = existing_clusters or []
267
+ self.extra_job_params = extra_job_params or {}
268
+ self.jar_params = jar_params or []
269
+ self.job_clusters = job_clusters or []
270
+ self.max_concurrent_runs = max_concurrent_runs
271
+ self.notebook_packages = notebook_packages or []
272
+ self.notebook_params = notebook_params or {}
273
+ self.python_params = python_params or []
274
+ self.spark_submit_params = spark_submit_params or []
275
+ super().__init__(**kwargs)
276
+
277
+ def __exit__(
278
+ self, _type: type[BaseException] | None, _value: BaseException | None, _tb: TracebackType | None
279
+ ) -> None:
280
+ """Exit the context manager and add tasks to a single ``_CreateDatabricksWorkflowOperator``."""
281
+ roots = list(self.get_roots())
282
+ tasks = _flatten_node(self)
283
+
284
+ create_databricks_workflow_task = _CreateDatabricksWorkflowOperator(
285
+ dag=self.dag,
286
+ task_group=self,
287
+ task_id="launch",
288
+ databricks_conn_id=self.databricks_conn_id,
289
+ existing_clusters=self.existing_clusters,
290
+ extra_job_params=self.extra_job_params,
291
+ job_clusters=self.job_clusters,
292
+ max_concurrent_runs=self.max_concurrent_runs,
293
+ notebook_params=self.notebook_params,
294
+ )
295
+
296
+ for task in tasks:
297
+ if not (
298
+ hasattr(task, "_convert_to_databricks_workflow_task")
299
+ and callable(task._convert_to_databricks_workflow_task)
300
+ ):
301
+ raise AirflowException(
302
+ f"Task {task.task_id} does not support conversion to databricks workflow task."
303
+ )
304
+
305
+ task.workflow_run_metadata = create_databricks_workflow_task.output
306
+ create_databricks_workflow_task.relevant_upstreams.append(task.task_id)
307
+ create_databricks_workflow_task.add_task(task)
308
+
309
+ for root_task in roots:
310
+ root_task.set_upstream(create_databricks_workflow_task)
311
+
312
+ super().__exit__(_type, _value, _tb)
@@ -28,7 +28,7 @@ build-backend = "flit_core.buildapi"
28
28
 
29
29
  [project]
30
30
  name = "apache-airflow-providers-databricks"
31
- version = "6.5.0"
31
+ version = "6.6.0"
32
32
  description = "Provider package apache-airflow-providers-databricks for Apache Airflow"
33
33
  readme = "README.rst"
34
34
  authors = [
@@ -60,12 +60,16 @@ dependencies = [
60
60
  "apache-airflow-providers-common-sql>=1.10.0",
61
61
  "apache-airflow>=2.7.0",
62
62
  "databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
63
+ "mergedeep>=1.3.4",
64
+ "pandas>=1.5.3,<2.2;python_version<\"3.9\"",
65
+ "pandas>=2.1.2,<2.2;python_version>=\"3.9\"",
66
+ "pyarrow>=14.0.1",
63
67
  "requests>=2.27.0,<3",
64
68
  ]
65
69
 
66
70
  [project.urls]
67
- "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.5.0"
68
- "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.5.0/changelog.html"
71
+ "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0"
72
+ "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html"
69
73
  "Bug Tracker" = "https://github.com/apache/airflow/issues"
70
74
  "Source Code" = "https://github.com/apache/airflow"
71
75
  "Slack Chat" = "https://s.apache.org/airflow-slack"