apache-airflow-providers-databricks 6.5.0__tar.gz → 6.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of apache-airflow-providers-databricks might be problematic. Click here for more details.
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/PKG-INFO +17 -9
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/README.rst +10 -6
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/LICENSE +4 -4
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/__init__.py +1 -1
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/get_provider_info.py +18 -1
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/hooks/databricks.py +20 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/hooks/databricks_base.py +8 -8
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/operators/databricks.py +360 -109
- apache_airflow_providers_databricks-6.6.0/airflow/providers/databricks/operators/databricks_workflow.py +312 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/pyproject.toml +7 -3
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/hooks/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/hooks/databricks_sql.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/operators/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/operators/databricks_repos.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/operators/databricks_sql.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/sensors/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/sensors/databricks_partition.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/sensors/databricks_sql.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/triggers/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/triggers/databricks.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/utils/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/airflow/providers/databricks/utils/databricks.py +0 -0
{apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/PKG-INFO
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: apache-airflow-providers-databricks
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.6.0
|
|
4
4
|
Summary: Provider package apache-airflow-providers-databricks for Apache Airflow
|
|
5
5
|
Keywords: airflow-provider,databricks,airflow,integration
|
|
6
6
|
Author-email: Apache Software Foundation <dev@airflow.apache.org>
|
|
@@ -25,12 +25,16 @@ Requires-Dist: aiohttp>=3.9.2, <4
|
|
|
25
25
|
Requires-Dist: apache-airflow-providers-common-sql>=1.10.0
|
|
26
26
|
Requires-Dist: apache-airflow>=2.7.0
|
|
27
27
|
Requires-Dist: databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0
|
|
28
|
+
Requires-Dist: mergedeep>=1.3.4
|
|
29
|
+
Requires-Dist: pandas>=1.5.3,<2.2;python_version<"3.9"
|
|
30
|
+
Requires-Dist: pandas>=2.1.2,<2.2;python_version>="3.9"
|
|
31
|
+
Requires-Dist: pyarrow>=14.0.1
|
|
28
32
|
Requires-Dist: requests>=2.27.0,<3
|
|
29
33
|
Requires-Dist: apache-airflow-providers-common-sql ; extra == "common.sql"
|
|
30
34
|
Requires-Dist: databricks-sdk==0.10.0 ; extra == "sdk"
|
|
31
35
|
Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
|
|
32
|
-
Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
33
|
-
Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
36
|
+
Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html
|
|
37
|
+
Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0
|
|
34
38
|
Project-URL: Slack Chat, https://s.apache.org/airflow-slack
|
|
35
39
|
Project-URL: Source Code, https://github.com/apache/airflow
|
|
36
40
|
Project-URL: Twitter, https://twitter.com/ApacheAirflow
|
|
@@ -82,7 +86,7 @@ Provides-Extra: sdk
|
|
|
82
86
|
|
|
83
87
|
Package ``apache-airflow-providers-databricks``
|
|
84
88
|
|
|
85
|
-
Release: ``6.
|
|
89
|
+
Release: ``6.6.0``
|
|
86
90
|
|
|
87
91
|
|
|
88
92
|
`Databricks <https://databricks.com/>`__
|
|
@@ -95,7 +99,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
|
|
|
95
99
|
are in ``airflow.providers.databricks`` python package.
|
|
96
100
|
|
|
97
101
|
You can find package information and changelog for the provider
|
|
98
|
-
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
102
|
+
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/>`_.
|
|
99
103
|
|
|
100
104
|
Installation
|
|
101
105
|
------------
|
|
@@ -109,15 +113,19 @@ The package supports the following python versions: 3.8,3.9,3.10,3.11,3.12
|
|
|
109
113
|
Requirements
|
|
110
114
|
------------
|
|
111
115
|
|
|
112
|
-
=======================================
|
|
116
|
+
======================================= =========================================
|
|
113
117
|
PIP package Version required
|
|
114
|
-
=======================================
|
|
118
|
+
======================================= =========================================
|
|
115
119
|
``apache-airflow`` ``>=2.7.0``
|
|
116
120
|
``apache-airflow-providers-common-sql`` ``>=1.10.0``
|
|
117
121
|
``requests`` ``>=2.27.0,<3``
|
|
118
122
|
``databricks-sql-connector`` ``>=2.0.0,!=2.9.0,<3.0.0``
|
|
119
123
|
``aiohttp`` ``>=3.9.2,<4``
|
|
120
|
-
|
|
124
|
+
``mergedeep`` ``>=1.3.4``
|
|
125
|
+
``pandas`` ``>=2.1.2,<2.2; python_version >= "3.9"``
|
|
126
|
+
``pandas`` ``>=1.5.3,<2.2; python_version < "3.9"``
|
|
127
|
+
``pyarrow`` ``>=14.0.1``
|
|
128
|
+
======================================= =========================================
|
|
121
129
|
|
|
122
130
|
Cross provider package dependencies
|
|
123
131
|
-----------------------------------
|
|
@@ -139,4 +147,4 @@ Dependent package
|
|
|
139
147
|
============================================================================================================ ==============
|
|
140
148
|
|
|
141
149
|
The changelog for the provider package can be found in the
|
|
142
|
-
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
150
|
+
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html>`_.
|
{apache_airflow_providers_databricks-6.5.0 → apache_airflow_providers_databricks-6.6.0}/README.rst
RENAMED
|
@@ -42,7 +42,7 @@
|
|
|
42
42
|
|
|
43
43
|
Package ``apache-airflow-providers-databricks``
|
|
44
44
|
|
|
45
|
-
Release: ``6.
|
|
45
|
+
Release: ``6.6.0``
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
`Databricks <https://databricks.com/>`__
|
|
@@ -55,7 +55,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
|
|
|
55
55
|
are in ``airflow.providers.databricks`` python package.
|
|
56
56
|
|
|
57
57
|
You can find package information and changelog for the provider
|
|
58
|
-
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
58
|
+
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/>`_.
|
|
59
59
|
|
|
60
60
|
Installation
|
|
61
61
|
------------
|
|
@@ -69,15 +69,19 @@ The package supports the following python versions: 3.8,3.9,3.10,3.11,3.12
|
|
|
69
69
|
Requirements
|
|
70
70
|
------------
|
|
71
71
|
|
|
72
|
-
=======================================
|
|
72
|
+
======================================= =========================================
|
|
73
73
|
PIP package Version required
|
|
74
|
-
=======================================
|
|
74
|
+
======================================= =========================================
|
|
75
75
|
``apache-airflow`` ``>=2.7.0``
|
|
76
76
|
``apache-airflow-providers-common-sql`` ``>=1.10.0``
|
|
77
77
|
``requests`` ``>=2.27.0,<3``
|
|
78
78
|
``databricks-sql-connector`` ``>=2.0.0,!=2.9.0,<3.0.0``
|
|
79
79
|
``aiohttp`` ``>=3.9.2,<4``
|
|
80
|
-
|
|
80
|
+
``mergedeep`` ``>=1.3.4``
|
|
81
|
+
``pandas`` ``>=2.1.2,<2.2; python_version >= "3.9"``
|
|
82
|
+
``pandas`` ``>=1.5.3,<2.2; python_version < "3.9"``
|
|
83
|
+
``pyarrow`` ``>=14.0.1``
|
|
84
|
+
======================================= =========================================
|
|
81
85
|
|
|
82
86
|
Cross provider package dependencies
|
|
83
87
|
-----------------------------------
|
|
@@ -99,4 +103,4 @@ Dependent package
|
|
|
99
103
|
============================================================================================================ ==============
|
|
100
104
|
|
|
101
105
|
The changelog for the provider package can be found in the
|
|
102
|
-
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
106
|
+
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html>`_.
|
|
@@ -215,7 +215,7 @@ Third party Apache 2.0 licenses
|
|
|
215
215
|
|
|
216
216
|
The following components are provided under the Apache 2.0 License.
|
|
217
217
|
See project link for details. The text of each license is also included
|
|
218
|
-
at licenses/LICENSE-[project].txt.
|
|
218
|
+
at 3rd-party-licenses/LICENSE-[project].txt.
|
|
219
219
|
|
|
220
220
|
(ALv2 License) hue v4.3.0 (https://github.com/cloudera/hue/)
|
|
221
221
|
(ALv2 License) jqclock v2.3.0 (https://github.com/JohnRDOrazio/jQuery-Clock-Plugin)
|
|
@@ -227,7 +227,7 @@ MIT licenses
|
|
|
227
227
|
========================================================================
|
|
228
228
|
|
|
229
229
|
The following components are provided under the MIT License. See project link for details.
|
|
230
|
-
The text of each license is also included at licenses/LICENSE-[project].txt.
|
|
230
|
+
The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
|
|
231
231
|
|
|
232
232
|
(MIT License) jquery v3.5.1 (https://jquery.org/license/)
|
|
233
233
|
(MIT License) dagre-d3 v0.6.4 (https://github.com/cpettitt/dagre-d3)
|
|
@@ -243,11 +243,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
|
|
|
243
243
|
BSD 3-Clause licenses
|
|
244
244
|
========================================================================
|
|
245
245
|
The following components are provided under the BSD 3-Clause license. See project links for details.
|
|
246
|
-
The text of each license is also included at licenses/LICENSE-[project].txt.
|
|
246
|
+
The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
|
|
247
247
|
|
|
248
248
|
(BSD 3 License) d3 v5.16.0 (https://d3js.org)
|
|
249
249
|
(BSD 3 License) d3-shape v2.1.0 (https://github.com/d3/d3-shape)
|
|
250
250
|
(BSD 3 License) cgroupspy 0.2.1 (https://github.com/cloudsigma/cgroupspy)
|
|
251
251
|
|
|
252
252
|
========================================================================
|
|
253
|
-
See licenses/LICENSES-ui.txt for packages used in `/airflow/www`
|
|
253
|
+
See 3rd-party-licenses/LICENSES-ui.txt for packages used in `/airflow/www`
|
|
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
|
29
29
|
|
|
30
30
|
__all__ = ["__version__"]
|
|
31
31
|
|
|
32
|
-
__version__ = "6.
|
|
32
|
+
__version__ = "6.6.0"
|
|
33
33
|
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
|
35
35
|
"2.7.0"
|
|
@@ -28,8 +28,9 @@ def get_provider_info():
|
|
|
28
28
|
"name": "Databricks",
|
|
29
29
|
"description": "`Databricks <https://databricks.com/>`__\n",
|
|
30
30
|
"state": "ready",
|
|
31
|
-
"source-date-epoch":
|
|
31
|
+
"source-date-epoch": 1718604145,
|
|
32
32
|
"versions": [
|
|
33
|
+
"6.6.0",
|
|
33
34
|
"6.5.0",
|
|
34
35
|
"6.4.0",
|
|
35
36
|
"6.3.0",
|
|
@@ -74,6 +75,10 @@ def get_provider_info():
|
|
|
74
75
|
"requests>=2.27.0,<3",
|
|
75
76
|
"databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
|
|
76
77
|
"aiohttp>=3.9.2, <4",
|
|
78
|
+
"mergedeep>=1.3.4",
|
|
79
|
+
'pandas>=2.1.2,<2.2;python_version>="3.9"',
|
|
80
|
+
'pandas>=1.5.3,<2.2;python_version<"3.9"',
|
|
81
|
+
"pyarrow>=14.0.1",
|
|
77
82
|
],
|
|
78
83
|
"additional-extras": [
|
|
79
84
|
{
|
|
@@ -92,6 +97,7 @@ def get_provider_info():
|
|
|
92
97
|
"/docs/apache-airflow-providers-databricks/operators/notebook.rst",
|
|
93
98
|
"/docs/apache-airflow-providers-databricks/operators/submit_run.rst",
|
|
94
99
|
"/docs/apache-airflow-providers-databricks/operators/run_now.rst",
|
|
100
|
+
"/docs/apache-airflow-providers-databricks/operators/task.rst",
|
|
95
101
|
],
|
|
96
102
|
"logo": "/integration-logos/databricks/Databricks.png",
|
|
97
103
|
"tags": ["service"],
|
|
@@ -117,6 +123,13 @@ def get_provider_info():
|
|
|
117
123
|
"logo": "/integration-logos/databricks/Databricks.png",
|
|
118
124
|
"tags": ["service"],
|
|
119
125
|
},
|
|
126
|
+
{
|
|
127
|
+
"integration-name": "Databricks Workflow",
|
|
128
|
+
"external-doc-url": "https://docs.databricks.com/en/workflows/index.html",
|
|
129
|
+
"how-to-guide": ["/docs/apache-airflow-providers-databricks/operators/workflow.rst"],
|
|
130
|
+
"logo": "/integration-logos/databricks/Databricks.png",
|
|
131
|
+
"tags": ["service"],
|
|
132
|
+
},
|
|
120
133
|
],
|
|
121
134
|
"operators": [
|
|
122
135
|
{
|
|
@@ -131,6 +144,10 @@ def get_provider_info():
|
|
|
131
144
|
"integration-name": "Databricks Repos",
|
|
132
145
|
"python-modules": ["airflow.providers.databricks.operators.databricks_repos"],
|
|
133
146
|
},
|
|
147
|
+
{
|
|
148
|
+
"integration-name": "Databricks Workflow",
|
|
149
|
+
"python-modules": ["airflow.providers.databricks.operators.databricks_workflow"],
|
|
150
|
+
},
|
|
134
151
|
],
|
|
135
152
|
"hooks": [
|
|
136
153
|
{
|
|
@@ -29,6 +29,7 @@ or the ``api/2.1/jobs/runs/submit``
|
|
|
29
29
|
from __future__ import annotations
|
|
30
30
|
|
|
31
31
|
import json
|
|
32
|
+
from enum import Enum
|
|
32
33
|
from typing import Any
|
|
33
34
|
|
|
34
35
|
from requests import exceptions as requests_exceptions
|
|
@@ -63,6 +64,23 @@ WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
|
|
|
63
64
|
SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
|
|
64
65
|
|
|
65
66
|
|
|
67
|
+
class RunLifeCycleState(Enum):
|
|
68
|
+
"""Enum for the run life cycle state concept of Databricks runs.
|
|
69
|
+
|
|
70
|
+
See more information at: https://docs.databricks.com/api/azure/workspace/jobs/listruns#runs-state-life_cycle_state
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
BLOCKED = "BLOCKED"
|
|
74
|
+
INTERNAL_ERROR = "INTERNAL_ERROR"
|
|
75
|
+
PENDING = "PENDING"
|
|
76
|
+
QUEUED = "QUEUED"
|
|
77
|
+
RUNNING = "RUNNING"
|
|
78
|
+
SKIPPED = "SKIPPED"
|
|
79
|
+
TERMINATED = "TERMINATED"
|
|
80
|
+
TERMINATING = "TERMINATING"
|
|
81
|
+
WAITING_FOR_RETRY = "WAITING_FOR_RETRY"
|
|
82
|
+
|
|
83
|
+
|
|
66
84
|
class RunState:
|
|
67
85
|
"""Utility class for the run state concept of Databricks runs."""
|
|
68
86
|
|
|
@@ -238,6 +256,7 @@ class DatabricksHook(BaseDatabricksHook):
|
|
|
238
256
|
expand_tasks: bool = False,
|
|
239
257
|
job_name: str | None = None,
|
|
240
258
|
page_token: str | None = None,
|
|
259
|
+
include_user_names: bool = False,
|
|
241
260
|
) -> list[dict[str, Any]]:
|
|
242
261
|
"""
|
|
243
262
|
List the jobs in the Databricks Job Service.
|
|
@@ -257,6 +276,7 @@ class DatabricksHook(BaseDatabricksHook):
|
|
|
257
276
|
payload: dict[str, Any] = {
|
|
258
277
|
"limit": limit,
|
|
259
278
|
"expand_tasks": expand_tasks,
|
|
279
|
+
"include_user_names": include_user_names,
|
|
260
280
|
}
|
|
261
281
|
payload["page_token"] = page_token
|
|
262
282
|
if job_name:
|
|
@@ -499,21 +499,21 @@ class BaseDatabricksHook(BaseHook):
|
|
|
499
499
|
)
|
|
500
500
|
return self.databricks_conn.extra_dejson["token"]
|
|
501
501
|
elif not self.databricks_conn.login and self.databricks_conn.password:
|
|
502
|
-
self.log.
|
|
502
|
+
self.log.debug("Using token auth.")
|
|
503
503
|
return self.databricks_conn.password
|
|
504
504
|
elif "azure_tenant_id" in self.databricks_conn.extra_dejson:
|
|
505
505
|
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
|
|
506
506
|
raise AirflowException("Azure SPN credentials aren't provided")
|
|
507
|
-
self.log.
|
|
507
|
+
self.log.debug("Using AAD Token for SPN.")
|
|
508
508
|
return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
|
|
509
509
|
elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
|
|
510
|
-
self.log.
|
|
510
|
+
self.log.debug("Using AAD Token for managed identity.")
|
|
511
511
|
self._check_azure_metadata_service()
|
|
512
512
|
return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
|
|
513
513
|
elif self.databricks_conn.extra_dejson.get("service_principal_oauth", False):
|
|
514
514
|
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
|
|
515
515
|
raise AirflowException("Service Principal credentials aren't provided")
|
|
516
|
-
self.log.
|
|
516
|
+
self.log.debug("Using Service Principal Token.")
|
|
517
517
|
return self._get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host))
|
|
518
518
|
elif raise_error:
|
|
519
519
|
raise AirflowException("Token authentication isn't configured")
|
|
@@ -527,21 +527,21 @@ class BaseDatabricksHook(BaseHook):
|
|
|
527
527
|
)
|
|
528
528
|
return self.databricks_conn.extra_dejson["token"]
|
|
529
529
|
elif not self.databricks_conn.login and self.databricks_conn.password:
|
|
530
|
-
self.log.
|
|
530
|
+
self.log.debug("Using token auth.")
|
|
531
531
|
return self.databricks_conn.password
|
|
532
532
|
elif "azure_tenant_id" in self.databricks_conn.extra_dejson:
|
|
533
533
|
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
|
|
534
534
|
raise AirflowException("Azure SPN credentials aren't provided")
|
|
535
|
-
self.log.
|
|
535
|
+
self.log.debug("Using AAD Token for SPN.")
|
|
536
536
|
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
|
|
537
537
|
elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
|
|
538
|
-
self.log.
|
|
538
|
+
self.log.debug("Using AAD Token for managed identity.")
|
|
539
539
|
await self._a_check_azure_metadata_service()
|
|
540
540
|
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
|
|
541
541
|
elif self.databricks_conn.extra_dejson.get("service_principal_oauth", False):
|
|
542
542
|
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
|
|
543
543
|
raise AirflowException("Service Principal credentials aren't provided")
|
|
544
|
-
self.log.
|
|
544
|
+
self.log.debug("Using Service Principal Token.")
|
|
545
545
|
return await self._a_get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host))
|
|
546
546
|
elif raise_error:
|
|
547
547
|
raise AirflowException("Token authentication isn't configured")
|
|
@@ -20,6 +20,7 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
import time
|
|
23
|
+
from abc import ABC, abstractmethod
|
|
23
24
|
from functools import cached_property
|
|
24
25
|
from logging import Logger
|
|
25
26
|
from typing import TYPE_CHECKING, Any, Sequence
|
|
@@ -29,13 +30,18 @@ from deprecated import deprecated
|
|
|
29
30
|
from airflow.configuration import conf
|
|
30
31
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
31
32
|
from airflow.models import BaseOperator, BaseOperatorLink, XCom
|
|
32
|
-
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState
|
|
33
|
+
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState
|
|
34
|
+
from airflow.providers.databricks.operators.databricks_workflow import (
|
|
35
|
+
DatabricksWorkflowTaskGroup,
|
|
36
|
+
WorkflowRunMetadata,
|
|
37
|
+
)
|
|
33
38
|
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
|
|
34
39
|
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
|
|
35
40
|
|
|
36
41
|
if TYPE_CHECKING:
|
|
37
42
|
from airflow.models.taskinstancekey import TaskInstanceKey
|
|
38
43
|
from airflow.utils.context import Context
|
|
44
|
+
from airflow.utils.task_group import TaskGroup
|
|
39
45
|
|
|
40
46
|
DEFER_METHOD_NAME = "execute_complete"
|
|
41
47
|
XCOM_RUN_ID_KEY = "run_id"
|
|
@@ -894,79 +900,64 @@ class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
|
|
|
894
900
|
super().__init__(deferrable=True, *args, **kwargs)
|
|
895
901
|
|
|
896
902
|
|
|
897
|
-
class
|
|
903
|
+
class DatabricksTaskBaseOperator(BaseOperator, ABC):
|
|
898
904
|
"""
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
The DatabricksNotebookOperator allows users to launch and monitor notebook
|
|
902
|
-
job runs on Databricks as Airflow tasks.
|
|
905
|
+
Base class for operators that are run as Databricks job tasks or tasks within a Databricks workflow.
|
|
903
906
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
:param
|
|
909
|
-
:param
|
|
910
|
-
from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository
|
|
911
|
-
defined in git_source. If the value is empty, the task will use GIT if git_source is defined
|
|
912
|
-
and WORKSPACE otherwise. For more information please visit
|
|
913
|
-
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
|
|
914
|
-
:param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task.
|
|
915
|
-
:param notebook_packages: A list of the Python libraries to be installed on the cluster running the
|
|
916
|
-
notebook.
|
|
917
|
-
:param new_cluster: Specs for a new cluster on which this task will be run.
|
|
907
|
+
:param caller: The name of the caller operator to be used in the logs.
|
|
908
|
+
:param databricks_conn_id: The name of the Airflow connection to use.
|
|
909
|
+
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
|
|
910
|
+
:param databricks_retry_delay: Number of seconds to wait between retries.
|
|
911
|
+
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
|
|
912
|
+
:param deferrable: Whether to run the operator in the deferrable mode.
|
|
918
913
|
:param existing_cluster_id: ID for existing cluster on which to run this task.
|
|
919
914
|
:param job_cluster_key: The key for the job cluster.
|
|
915
|
+
:param new_cluster: Specs for a new cluster on which this task will be run.
|
|
916
|
+
:param notebook_packages: A list of the Python libraries to be installed on the cluster running the
|
|
917
|
+
notebook.
|
|
918
|
+
:param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task.
|
|
920
919
|
:param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
|
|
921
|
-
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
|
|
922
|
-
:param databricks_retry_delay: Number of seconds to wait between retries.
|
|
923
|
-
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
|
|
924
920
|
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
|
|
925
|
-
:param
|
|
926
|
-
|
|
921
|
+
:param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within
|
|
922
|
+
a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow.
|
|
927
923
|
"""
|
|
928
924
|
|
|
929
|
-
template_fields = ("notebook_params",)
|
|
930
|
-
CALLER = "DatabricksNotebookOperator"
|
|
931
|
-
|
|
932
925
|
def __init__(
|
|
933
926
|
self,
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
927
|
+
caller: str = "DatabricksTaskBaseOperator",
|
|
928
|
+
databricks_conn_id: str = "databricks_default",
|
|
929
|
+
databricks_retry_args: dict[Any, Any] | None = None,
|
|
930
|
+
databricks_retry_delay: int = 1,
|
|
931
|
+
databricks_retry_limit: int = 3,
|
|
932
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
|
939
933
|
existing_cluster_id: str = "",
|
|
940
934
|
job_cluster_key: str = "",
|
|
935
|
+
new_cluster: dict[str, Any] | None = None,
|
|
941
936
|
polling_period_seconds: int = 5,
|
|
942
|
-
databricks_retry_limit: int = 3,
|
|
943
|
-
databricks_retry_delay: int = 1,
|
|
944
|
-
databricks_retry_args: dict[Any, Any] | None = None,
|
|
945
937
|
wait_for_termination: bool = True,
|
|
946
|
-
|
|
947
|
-
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
|
938
|
+
workflow_run_metadata: dict[str, Any] | None = None,
|
|
948
939
|
**kwargs: Any,
|
|
949
940
|
):
|
|
950
|
-
self.
|
|
951
|
-
self.
|
|
952
|
-
self.
|
|
953
|
-
self.
|
|
954
|
-
self.
|
|
941
|
+
self.caller = caller
|
|
942
|
+
self.databricks_conn_id = databricks_conn_id
|
|
943
|
+
self.databricks_retry_args = databricks_retry_args
|
|
944
|
+
self.databricks_retry_delay = databricks_retry_delay
|
|
945
|
+
self.databricks_retry_limit = databricks_retry_limit
|
|
946
|
+
self.deferrable = deferrable
|
|
955
947
|
self.existing_cluster_id = existing_cluster_id
|
|
956
948
|
self.job_cluster_key = job_cluster_key
|
|
949
|
+
self.new_cluster = new_cluster or {}
|
|
957
950
|
self.polling_period_seconds = polling_period_seconds
|
|
958
|
-
self.databricks_retry_limit = databricks_retry_limit
|
|
959
|
-
self.databricks_retry_delay = databricks_retry_delay
|
|
960
|
-
self.databricks_retry_args = databricks_retry_args
|
|
961
951
|
self.wait_for_termination = wait_for_termination
|
|
962
|
-
self.
|
|
952
|
+
self.workflow_run_metadata = workflow_run_metadata
|
|
953
|
+
|
|
963
954
|
self.databricks_run_id: int | None = None
|
|
964
|
-
|
|
955
|
+
|
|
965
956
|
super().__init__(**kwargs)
|
|
966
957
|
|
|
967
958
|
@cached_property
|
|
968
959
|
def _hook(self) -> DatabricksHook:
|
|
969
|
-
return self._get_hook(caller=self.
|
|
960
|
+
return self._get_hook(caller=self.caller)
|
|
970
961
|
|
|
971
962
|
def _get_hook(self, caller: str) -> DatabricksHook:
|
|
972
963
|
return DatabricksHook(
|
|
@@ -974,47 +965,37 @@ class DatabricksNotebookOperator(BaseOperator):
|
|
|
974
965
|
retry_limit=self.databricks_retry_limit,
|
|
975
966
|
retry_delay=self.databricks_retry_delay,
|
|
976
967
|
retry_args=self.databricks_retry_args,
|
|
977
|
-
caller=
|
|
968
|
+
caller=caller,
|
|
978
969
|
)
|
|
979
970
|
|
|
980
|
-
def
|
|
971
|
+
def _get_databricks_task_id(self, task_id: str) -> str:
|
|
972
|
+
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
|
|
973
|
+
return f"{self.dag_id}__{task_id.replace('.', '__')}"
|
|
974
|
+
|
|
975
|
+
@property
|
|
976
|
+
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
|
|
981
977
|
"""
|
|
982
|
-
|
|
978
|
+
Traverse up parent TaskGroups until the `is_databricks` flag associated with the root DatabricksWorkflowTaskGroup is found.
|
|
983
979
|
|
|
984
|
-
|
|
985
|
-
execution_timeout is not defined, the task continues to run indefinitely. Therefore,
|
|
986
|
-
to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating
|
|
987
|
-
that the job should run indefinitely. This aligns with the default behavior of Databricks jobs,
|
|
988
|
-
where a timeout seconds value of 0 signifies an indefinite run duration.
|
|
989
|
-
More details can be found in the Databricks documentation:
|
|
990
|
-
See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds
|
|
980
|
+
If found, returns the task group. Otherwise, return None.
|
|
991
981
|
"""
|
|
992
|
-
|
|
993
|
-
return 0
|
|
994
|
-
execution_timeout_seconds = int(self.execution_timeout.total_seconds())
|
|
995
|
-
if execution_timeout_seconds == 0:
|
|
996
|
-
raise ValueError(
|
|
997
|
-
"If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to "
|
|
998
|
-
"`None` if you desire the task to run indefinitely."
|
|
999
|
-
)
|
|
1000
|
-
return execution_timeout_seconds
|
|
982
|
+
parent_tg: TaskGroup | DatabricksWorkflowTaskGroup | None = self.task_group
|
|
1001
983
|
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
"timeout_seconds": self._get_task_timeout_seconds(),
|
|
1006
|
-
"email_notifications": {},
|
|
1007
|
-
"notebook_task": {
|
|
1008
|
-
"notebook_path": self.notebook_path,
|
|
1009
|
-
"source": self.source,
|
|
1010
|
-
"base_parameters": self.notebook_params,
|
|
1011
|
-
},
|
|
1012
|
-
"libraries": self.notebook_packages,
|
|
1013
|
-
}
|
|
984
|
+
while parent_tg:
|
|
985
|
+
if getattr(parent_tg, "is_databricks", False):
|
|
986
|
+
return parent_tg # type: ignore[return-value]
|
|
1014
987
|
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
988
|
+
if getattr(parent_tg, "task_group", None):
|
|
989
|
+
parent_tg = parent_tg.task_group
|
|
990
|
+
else:
|
|
991
|
+
return None
|
|
992
|
+
|
|
993
|
+
return None
|
|
994
|
+
|
|
995
|
+
@abstractmethod
|
|
996
|
+
def _get_task_base_json(self) -> dict[str, Any]:
|
|
997
|
+
"""Get the base json for the task."""
|
|
998
|
+
raise NotImplementedError()
|
|
1018
999
|
|
|
1019
1000
|
def _get_run_json(self) -> dict[str, Any]:
|
|
1020
1001
|
"""Get run json to be used for task submissions."""
|
|
@@ -1032,65 +1013,335 @@ class DatabricksNotebookOperator(BaseOperator):
|
|
|
1032
1013
|
raise ValueError("Must specify either existing_cluster_id or new_cluster.")
|
|
1033
1014
|
return run_json
|
|
1034
1015
|
|
|
1035
|
-
def
|
|
1016
|
+
def _launch_job(self) -> int:
|
|
1017
|
+
"""Launch the job on Databricks."""
|
|
1036
1018
|
run_json = self._get_run_json()
|
|
1037
1019
|
self.databricks_run_id = self._hook.submit_run(run_json)
|
|
1038
1020
|
url = self._hook.get_run_page_url(self.databricks_run_id)
|
|
1039
1021
|
self.log.info("Check the job run in Databricks: %s", url)
|
|
1040
1022
|
return self.databricks_run_id
|
|
1041
1023
|
|
|
1024
|
+
def _handle_terminal_run_state(self, run_state: RunState) -> None:
|
|
1025
|
+
"""Handle the terminal state of the run."""
|
|
1026
|
+
if run_state.life_cycle_state != RunLifeCycleState.TERMINATED.value:
|
|
1027
|
+
raise AirflowException(
|
|
1028
|
+
f"Databricks job failed with state {run_state.life_cycle_state}. Message: {run_state.state_message}"
|
|
1029
|
+
)
|
|
1030
|
+
if not run_state.is_successful:
|
|
1031
|
+
raise AirflowException(
|
|
1032
|
+
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
|
|
1033
|
+
)
|
|
1034
|
+
self.log.info("Task succeeded. Final state %s.", run_state.result_state)
|
|
1035
|
+
|
|
1036
|
+
def _get_current_databricks_task(self) -> dict[str, Any]:
|
|
1037
|
+
"""Retrieve the Databricks task corresponding to the current Airflow task."""
|
|
1038
|
+
if self.databricks_run_id is None:
|
|
1039
|
+
raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
|
|
1040
|
+
return {task["task_key"]: task for task in self._hook.get_run(self.databricks_run_id)["tasks"]}[
|
|
1041
|
+
self._get_databricks_task_id(self.task_id)
|
|
1042
|
+
]
|
|
1043
|
+
|
|
1044
|
+
def _convert_to_databricks_workflow_task(
|
|
1045
|
+
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
|
|
1046
|
+
) -> dict[str, object]:
|
|
1047
|
+
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
|
|
1048
|
+
base_task_json = self._get_task_base_json()
|
|
1049
|
+
result = {
|
|
1050
|
+
"task_key": self._get_databricks_task_id(self.task_id),
|
|
1051
|
+
"depends_on": [
|
|
1052
|
+
{"task_key": self._get_databricks_task_id(task_id)}
|
|
1053
|
+
for task_id in self.upstream_task_ids
|
|
1054
|
+
if task_id in relevant_upstreams
|
|
1055
|
+
],
|
|
1056
|
+
**base_task_json,
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
if self.existing_cluster_id and self.job_cluster_key:
|
|
1060
|
+
raise ValueError(
|
|
1061
|
+
"Both existing_cluster_id and job_cluster_key are set. Only one can be set per task."
|
|
1062
|
+
)
|
|
1063
|
+
if self.existing_cluster_id:
|
|
1064
|
+
result["existing_cluster_id"] = self.existing_cluster_id
|
|
1065
|
+
elif self.job_cluster_key:
|
|
1066
|
+
result["job_cluster_key"] = self.job_cluster_key
|
|
1067
|
+
|
|
1068
|
+
return result
|
|
1069
|
+
|
|
1042
1070
|
def monitor_databricks_job(self) -> None:
|
|
1071
|
+
"""
|
|
1072
|
+
Monitor the Databricks job.
|
|
1073
|
+
|
|
1074
|
+
Wait for the job to terminate. If deferrable, defer the task.
|
|
1075
|
+
"""
|
|
1043
1076
|
if self.databricks_run_id is None:
|
|
1044
1077
|
raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
|
|
1045
|
-
|
|
1078
|
+
current_task_run_id = self._get_current_databricks_task()["run_id"]
|
|
1079
|
+
run = self._hook.get_run(current_task_run_id)
|
|
1080
|
+
run_page_url = run["run_page_url"]
|
|
1081
|
+
self.log.info("Check the task run in Databricks: %s", run_page_url)
|
|
1046
1082
|
run_state = RunState(**run["state"])
|
|
1047
|
-
self.log.info(
|
|
1083
|
+
self.log.info(
|
|
1084
|
+
"Current state of the the databricks task %s is %s",
|
|
1085
|
+
self._get_databricks_task_id(self.task_id),
|
|
1086
|
+
run_state.life_cycle_state,
|
|
1087
|
+
)
|
|
1048
1088
|
if self.deferrable and not run_state.is_terminal:
|
|
1049
1089
|
self.defer(
|
|
1050
1090
|
trigger=DatabricksExecutionTrigger(
|
|
1051
|
-
run_id=
|
|
1091
|
+
run_id=current_task_run_id,
|
|
1052
1092
|
databricks_conn_id=self.databricks_conn_id,
|
|
1053
1093
|
polling_period_seconds=self.polling_period_seconds,
|
|
1054
1094
|
retry_limit=self.databricks_retry_limit,
|
|
1055
1095
|
retry_delay=self.databricks_retry_delay,
|
|
1056
1096
|
retry_args=self.databricks_retry_args,
|
|
1057
|
-
caller=self.
|
|
1097
|
+
caller=self.caller,
|
|
1058
1098
|
),
|
|
1059
1099
|
method_name=DEFER_METHOD_NAME,
|
|
1060
1100
|
)
|
|
1061
1101
|
while not run_state.is_terminal:
|
|
1062
1102
|
time.sleep(self.polling_period_seconds)
|
|
1063
|
-
run = self._hook.get_run(
|
|
1103
|
+
run = self._hook.get_run(current_task_run_id)
|
|
1064
1104
|
run_state = RunState(**run["state"])
|
|
1065
1105
|
self.log.info(
|
|
1066
|
-
"task %s %s",
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
if run_state.life_cycle_state != "TERMINATED":
|
|
1070
|
-
raise AirflowException(
|
|
1071
|
-
f"Databricks job failed with state {run_state.life_cycle_state}. "
|
|
1072
|
-
f"Message: {run_state.state_message}"
|
|
1106
|
+
"Current state of the databricks task %s is %s",
|
|
1107
|
+
self._get_databricks_task_id(self.task_id),
|
|
1108
|
+
run_state.life_cycle_state,
|
|
1073
1109
|
)
|
|
1074
|
-
|
|
1075
|
-
raise AirflowException(
|
|
1076
|
-
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
|
|
1077
|
-
)
|
|
1078
|
-
self.log.info("Task succeeded. Final state %s.", run_state.result_state)
|
|
1110
|
+
self._handle_terminal_run_state(run_state)
|
|
1079
1111
|
|
|
1080
1112
|
def execute(self, context: Context) -> None:
|
|
1081
|
-
|
|
1113
|
+
"""Execute the operator. Launch the job and monitor it if wait_for_termination is set to True."""
|
|
1114
|
+
if self._databricks_workflow_task_group:
|
|
1115
|
+
# If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched.
|
|
1116
|
+
if not self.workflow_run_metadata:
|
|
1117
|
+
launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch"))
|
|
1118
|
+
self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id)
|
|
1119
|
+
workflow_run_metadata = WorkflowRunMetadata( # type: ignore[arg-type]
|
|
1120
|
+
**self.workflow_run_metadata
|
|
1121
|
+
)
|
|
1122
|
+
self.databricks_run_id = workflow_run_metadata.run_id
|
|
1123
|
+
self.databricks_conn_id = workflow_run_metadata.conn_id
|
|
1124
|
+
else:
|
|
1125
|
+
self._launch_job()
|
|
1082
1126
|
if self.wait_for_termination:
|
|
1083
1127
|
self.monitor_databricks_job()
|
|
1084
1128
|
|
|
1085
1129
|
def execute_complete(self, context: dict | None, event: dict) -> None:
|
|
1086
1130
|
run_state = RunState.from_json(event["run_state"])
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1131
|
+
self._handle_terminal_run_state(run_state)
|
|
1132
|
+
|
|
1133
|
+
|
|
1134
|
+
class DatabricksNotebookOperator(DatabricksTaskBaseOperator):
|
|
1135
|
+
"""
|
|
1136
|
+
Runs a notebook on Databricks using an Airflow operator.
|
|
1137
|
+
|
|
1138
|
+
The DatabricksNotebookOperator allows users to launch and monitor notebook job runs on Databricks as
|
|
1139
|
+
Airflow tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job
|
|
1140
|
+
clusters, which allows users to run their tasks on cheaper clusters that can be shared between tasks.
|
|
1141
|
+
|
|
1142
|
+
.. seealso::
|
|
1143
|
+
For more information on how to use this operator, take a look at the guide:
|
|
1144
|
+
:ref:`howto/operator:DatabricksNotebookOperator`
|
|
1145
|
+
|
|
1146
|
+
:param notebook_path: The path to the notebook in Databricks.
|
|
1147
|
+
:param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved
|
|
1148
|
+
from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository
|
|
1149
|
+
defined in git_source. If the value is empty, the task will use GIT if git_source is defined
|
|
1150
|
+
and WORKSPACE otherwise. For more information please visit
|
|
1151
|
+
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
|
|
1152
|
+
:param databricks_conn_id: The name of the Airflow connection to use.
|
|
1153
|
+
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
|
|
1154
|
+
:param databricks_retry_delay: Number of seconds to wait between retries.
|
|
1155
|
+
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
|
|
1156
|
+
:param deferrable: Whether to run the operator in the deferrable mode.
|
|
1157
|
+
:param existing_cluster_id: ID for existing cluster on which to run this task.
|
|
1158
|
+
:param job_cluster_key: The key for the job cluster.
|
|
1159
|
+
:param new_cluster: Specs for a new cluster on which this task will be run.
|
|
1160
|
+
:param notebook_packages: A list of the Python libraries to be installed on the cluster running the
|
|
1161
|
+
notebook.
|
|
1162
|
+
:param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task.
|
|
1163
|
+
:param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
|
|
1164
|
+
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
|
|
1165
|
+
:param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within
|
|
1166
|
+
a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow.
|
|
1167
|
+
"""
|
|
1168
|
+
|
|
1169
|
+
template_fields = (
|
|
1170
|
+
"notebook_params",
|
|
1171
|
+
"workflow_run_metadata",
|
|
1172
|
+
)
|
|
1173
|
+
CALLER = "DatabricksNotebookOperator"
|
|
1174
|
+
|
|
1175
|
+
def __init__(
|
|
1176
|
+
self,
|
|
1177
|
+
notebook_path: str,
|
|
1178
|
+
source: str,
|
|
1179
|
+
databricks_conn_id: str = "databricks_default",
|
|
1180
|
+
databricks_retry_args: dict[Any, Any] | None = None,
|
|
1181
|
+
databricks_retry_delay: int = 1,
|
|
1182
|
+
databricks_retry_limit: int = 3,
|
|
1183
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
|
1184
|
+
existing_cluster_id: str = "",
|
|
1185
|
+
job_cluster_key: str = "",
|
|
1186
|
+
new_cluster: dict[str, Any] | None = None,
|
|
1187
|
+
notebook_packages: list[dict[str, Any]] | None = None,
|
|
1188
|
+
notebook_params: dict | None = None,
|
|
1189
|
+
polling_period_seconds: int = 5,
|
|
1190
|
+
wait_for_termination: bool = True,
|
|
1191
|
+
workflow_run_metadata: dict | None = None,
|
|
1192
|
+
**kwargs: Any,
|
|
1193
|
+
):
|
|
1194
|
+
self.notebook_path = notebook_path
|
|
1195
|
+
self.source = source
|
|
1196
|
+
self.notebook_packages = notebook_packages or []
|
|
1197
|
+
self.notebook_params = notebook_params or {}
|
|
1198
|
+
|
|
1199
|
+
super().__init__(
|
|
1200
|
+
caller=self.CALLER,
|
|
1201
|
+
databricks_conn_id=databricks_conn_id,
|
|
1202
|
+
databricks_retry_args=databricks_retry_args,
|
|
1203
|
+
databricks_retry_delay=databricks_retry_delay,
|
|
1204
|
+
databricks_retry_limit=databricks_retry_limit,
|
|
1205
|
+
deferrable=deferrable,
|
|
1206
|
+
existing_cluster_id=existing_cluster_id,
|
|
1207
|
+
job_cluster_key=job_cluster_key,
|
|
1208
|
+
new_cluster=new_cluster,
|
|
1209
|
+
polling_period_seconds=polling_period_seconds,
|
|
1210
|
+
wait_for_termination=wait_for_termination,
|
|
1211
|
+
workflow_run_metadata=workflow_run_metadata,
|
|
1212
|
+
**kwargs,
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
def _get_task_timeout_seconds(self) -> int:
|
|
1216
|
+
"""
|
|
1217
|
+
Get the timeout seconds value for the Databricks job based on the execution timeout value provided for the Airflow task.
|
|
1218
|
+
|
|
1219
|
+
By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when
|
|
1220
|
+
execution_timeout is not defined, the task continues to run indefinitely. Therefore,
|
|
1221
|
+
to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating
|
|
1222
|
+
that the job should run indefinitely. This aligns with the default behavior of Databricks jobs,
|
|
1223
|
+
where a timeout seconds value of 0 signifies an indefinite run duration.
|
|
1224
|
+
More details can be found in the Databricks documentation:
|
|
1225
|
+
See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds
|
|
1226
|
+
"""
|
|
1227
|
+
if self.execution_timeout is None:
|
|
1228
|
+
return 0
|
|
1229
|
+
execution_timeout_seconds = int(self.execution_timeout.total_seconds())
|
|
1230
|
+
if execution_timeout_seconds == 0:
|
|
1231
|
+
raise ValueError(
|
|
1232
|
+
"If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to "
|
|
1233
|
+
"`None` if you desire the task to run indefinitely."
|
|
1091
1234
|
)
|
|
1092
|
-
|
|
1235
|
+
return execution_timeout_seconds
|
|
1236
|
+
|
|
1237
|
+
def _get_task_base_json(self) -> dict[str, Any]:
|
|
1238
|
+
"""Get task base json to be used for task submissions."""
|
|
1239
|
+
return {
|
|
1240
|
+
"timeout_seconds": self._get_task_timeout_seconds(),
|
|
1241
|
+
"email_notifications": {},
|
|
1242
|
+
"notebook_task": {
|
|
1243
|
+
"notebook_path": self.notebook_path,
|
|
1244
|
+
"source": self.source,
|
|
1245
|
+
"base_parameters": self.notebook_params,
|
|
1246
|
+
},
|
|
1247
|
+
"libraries": self.notebook_packages,
|
|
1248
|
+
}
|
|
1249
|
+
|
|
1250
|
+
def _extend_workflow_notebook_packages(
|
|
1251
|
+
self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup
|
|
1252
|
+
) -> None:
|
|
1253
|
+
"""Extend the task group packages into the notebook's packages, without adding any duplicates."""
|
|
1254
|
+
for task_group_package in databricks_workflow_task_group.notebook_packages:
|
|
1255
|
+
exists = any(
|
|
1256
|
+
task_group_package == existing_package for existing_package in self.notebook_packages
|
|
1257
|
+
)
|
|
1258
|
+
if not exists:
|
|
1259
|
+
self.notebook_packages.append(task_group_package)
|
|
1260
|
+
|
|
1261
|
+
def _convert_to_databricks_workflow_task(
|
|
1262
|
+
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
|
|
1263
|
+
) -> dict[str, object]:
|
|
1264
|
+
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
|
|
1265
|
+
databricks_workflow_task_group = self._databricks_workflow_task_group
|
|
1266
|
+
if not databricks_workflow_task_group:
|
|
1093
1267
|
raise AirflowException(
|
|
1094
|
-
|
|
1268
|
+
"Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup."
|
|
1095
1269
|
)
|
|
1096
|
-
|
|
1270
|
+
|
|
1271
|
+
if hasattr(databricks_workflow_task_group, "notebook_packages"):
|
|
1272
|
+
self._extend_workflow_notebook_packages(databricks_workflow_task_group)
|
|
1273
|
+
|
|
1274
|
+
if hasattr(databricks_workflow_task_group, "notebook_params"):
|
|
1275
|
+
self.notebook_params = {
|
|
1276
|
+
**self.notebook_params,
|
|
1277
|
+
**databricks_workflow_task_group.notebook_params,
|
|
1278
|
+
}
|
|
1279
|
+
|
|
1280
|
+
return super()._convert_to_databricks_workflow_task(relevant_upstreams, context=context)
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
class DatabricksTaskOperator(DatabricksTaskBaseOperator):
|
|
1284
|
+
"""
|
|
1285
|
+
Runs a task on Databricks using an Airflow operator.
|
|
1286
|
+
|
|
1287
|
+
The DatabricksTaskOperator allows users to launch and monitor task job runs on Databricks as Airflow
|
|
1288
|
+
tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, which
|
|
1289
|
+
allows users to run their tasks on cheaper clusters that can be shared between tasks.
|
|
1290
|
+
|
|
1291
|
+
.. seealso::
|
|
1292
|
+
For more information on how to use this operator, take a look at the guide:
|
|
1293
|
+
:ref:`howto/operator:DatabricksTaskOperator`
|
|
1294
|
+
|
|
1295
|
+
:param task_config: The configuration of the task to be run on Databricks.
|
|
1296
|
+
:param databricks_conn_id: The name of the Airflow connection to use.
|
|
1297
|
+
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
|
|
1298
|
+
:param databricks_retry_delay: Number of seconds to wait between retries.
|
|
1299
|
+
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
|
|
1300
|
+
:param deferrable: Whether to run the operator in the deferrable mode.
|
|
1301
|
+
:param existing_cluster_id: ID for existing cluster on which to run this task.
|
|
1302
|
+
:param job_cluster_key: The key for the job cluster.
|
|
1303
|
+
:param new_cluster: Specs for a new cluster on which this task will be run.
|
|
1304
|
+
:param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
|
|
1305
|
+
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
|
|
1306
|
+
"""
|
|
1307
|
+
|
|
1308
|
+
CALLER = "DatabricksTaskOperator"
|
|
1309
|
+
template_fields = ("workflow_run_metadata",)
|
|
1310
|
+
|
|
1311
|
+
def __init__(
|
|
1312
|
+
self,
|
|
1313
|
+
task_config: dict,
|
|
1314
|
+
databricks_conn_id: str = "databricks_default",
|
|
1315
|
+
databricks_retry_args: dict[Any, Any] | None = None,
|
|
1316
|
+
databricks_retry_delay: int = 1,
|
|
1317
|
+
databricks_retry_limit: int = 3,
|
|
1318
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
|
1319
|
+
existing_cluster_id: str = "",
|
|
1320
|
+
job_cluster_key: str = "",
|
|
1321
|
+
new_cluster: dict[str, Any] | None = None,
|
|
1322
|
+
polling_period_seconds: int = 5,
|
|
1323
|
+
wait_for_termination: bool = True,
|
|
1324
|
+
workflow_run_metadata: dict | None = None,
|
|
1325
|
+
**kwargs,
|
|
1326
|
+
):
|
|
1327
|
+
self.task_config = task_config
|
|
1328
|
+
|
|
1329
|
+
super().__init__(
|
|
1330
|
+
caller=self.CALLER,
|
|
1331
|
+
databricks_conn_id=databricks_conn_id,
|
|
1332
|
+
databricks_retry_args=databricks_retry_args,
|
|
1333
|
+
databricks_retry_delay=databricks_retry_delay,
|
|
1334
|
+
databricks_retry_limit=databricks_retry_limit,
|
|
1335
|
+
deferrable=deferrable,
|
|
1336
|
+
existing_cluster_id=existing_cluster_id,
|
|
1337
|
+
job_cluster_key=job_cluster_key,
|
|
1338
|
+
new_cluster=new_cluster,
|
|
1339
|
+
polling_period_seconds=polling_period_seconds,
|
|
1340
|
+
wait_for_termination=wait_for_termination,
|
|
1341
|
+
workflow_run_metadata=workflow_run_metadata,
|
|
1342
|
+
**kwargs,
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
def _get_task_base_json(self) -> dict[str, Any]:
|
|
1346
|
+
"""Get task base json to be used for task submissions."""
|
|
1347
|
+
return self.task_config
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
|
3
|
+
# distributed with this work for additional information
|
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
|
6
|
+
# "License"); you may not use this file except in compliance
|
|
7
|
+
# with the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
|
12
|
+
# software distributed under the License is distributed on an
|
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
14
|
+
# KIND, either express or implied. See the License for the
|
|
15
|
+
# specific language governing permissions and limitations
|
|
16
|
+
# under the License.
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import json
|
|
21
|
+
import time
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from functools import cached_property
|
|
24
|
+
from typing import TYPE_CHECKING, Any
|
|
25
|
+
|
|
26
|
+
from mergedeep import merge
|
|
27
|
+
|
|
28
|
+
from airflow.exceptions import AirflowException
|
|
29
|
+
from airflow.models import BaseOperator
|
|
30
|
+
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState
|
|
31
|
+
from airflow.utils.task_group import TaskGroup
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from types import TracebackType
|
|
35
|
+
|
|
36
|
+
from airflow.models.taskmixin import DAGNode
|
|
37
|
+
from airflow.utils.context import Context
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class WorkflowRunMetadata:
|
|
42
|
+
"""
|
|
43
|
+
Metadata for a Databricks workflow run.
|
|
44
|
+
|
|
45
|
+
:param run_id: The ID of the Databricks workflow run.
|
|
46
|
+
:param job_id: The ID of the Databricks workflow job.
|
|
47
|
+
:param conn_id: The connection ID used to connect to Databricks.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
conn_id: str
|
|
51
|
+
job_id: str
|
|
52
|
+
run_id: int
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _flatten_node(
|
|
56
|
+
node: TaskGroup | BaseOperator | DAGNode, tasks: list[BaseOperator] | None = None
|
|
57
|
+
) -> list[BaseOperator]:
|
|
58
|
+
"""Flatten a node (either a TaskGroup or Operator) to a list of nodes."""
|
|
59
|
+
if tasks is None:
|
|
60
|
+
tasks = []
|
|
61
|
+
if isinstance(node, BaseOperator):
|
|
62
|
+
return [node]
|
|
63
|
+
|
|
64
|
+
if isinstance(node, TaskGroup):
|
|
65
|
+
new_tasks = []
|
|
66
|
+
for _, child in node.children.items():
|
|
67
|
+
new_tasks += _flatten_node(child, tasks)
|
|
68
|
+
|
|
69
|
+
return tasks + new_tasks
|
|
70
|
+
|
|
71
|
+
return tasks
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class _CreateDatabricksWorkflowOperator(BaseOperator):
|
|
75
|
+
"""
|
|
76
|
+
Creates a Databricks workflow from a DatabricksWorkflowTaskGroup specified in a DAG.
|
|
77
|
+
|
|
78
|
+
:param task_id: The task_id of the operator
|
|
79
|
+
:param databricks_conn_id: The connection ID to use when connecting to Databricks.
|
|
80
|
+
:param existing_clusters: A list of existing clusters to use for the workflow.
|
|
81
|
+
:param extra_job_params: A dictionary of extra properties which will override the default Databricks
|
|
82
|
+
Workflow Job definitions.
|
|
83
|
+
:param job_clusters: A list of job clusters to use for the workflow.
|
|
84
|
+
:param max_concurrent_runs: The maximum number of concurrent runs for the workflow.
|
|
85
|
+
:param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters
|
|
86
|
+
will be passed to all notebooks in the workflow.
|
|
87
|
+
:param tasks_to_convert: A list of tasks to convert to a Databricks workflow. This list can also be
|
|
88
|
+
populated after instantiation using the `add_task` method.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
template_fields = ("notebook_params",)
|
|
92
|
+
caller = "_CreateDatabricksWorkflowOperator"
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
task_id: str,
|
|
97
|
+
databricks_conn_id: str,
|
|
98
|
+
existing_clusters: list[str] | None = None,
|
|
99
|
+
extra_job_params: dict[str, Any] | None = None,
|
|
100
|
+
job_clusters: list[dict[str, object]] | None = None,
|
|
101
|
+
max_concurrent_runs: int = 1,
|
|
102
|
+
notebook_params: dict | None = None,
|
|
103
|
+
tasks_to_convert: list[BaseOperator] | None = None,
|
|
104
|
+
**kwargs,
|
|
105
|
+
):
|
|
106
|
+
self.databricks_conn_id = databricks_conn_id
|
|
107
|
+
self.existing_clusters = existing_clusters or []
|
|
108
|
+
self.extra_job_params = extra_job_params or {}
|
|
109
|
+
self.job_clusters = job_clusters or []
|
|
110
|
+
self.max_concurrent_runs = max_concurrent_runs
|
|
111
|
+
self.notebook_params = notebook_params or {}
|
|
112
|
+
self.tasks_to_convert = tasks_to_convert or []
|
|
113
|
+
self.relevant_upstreams = [task_id]
|
|
114
|
+
super().__init__(task_id=task_id, **kwargs)
|
|
115
|
+
|
|
116
|
+
def _get_hook(self, caller: str) -> DatabricksHook:
|
|
117
|
+
return DatabricksHook(
|
|
118
|
+
self.databricks_conn_id,
|
|
119
|
+
caller=caller,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
@cached_property
|
|
123
|
+
def _hook(self) -> DatabricksHook:
|
|
124
|
+
return self._get_hook(caller=self.caller)
|
|
125
|
+
|
|
126
|
+
def add_task(self, task: BaseOperator) -> None:
|
|
127
|
+
"""Add a task to the list of tasks to convert to a Databricks workflow."""
|
|
128
|
+
self.tasks_to_convert.append(task)
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def job_name(self) -> str:
|
|
132
|
+
if not self.task_group:
|
|
133
|
+
raise AirflowException("Task group must be set before accessing job_name")
|
|
134
|
+
return f"{self.dag_id}.{self.task_group.group_id}"
|
|
135
|
+
|
|
136
|
+
def create_workflow_json(self, context: Context | None = None) -> dict[str, object]:
|
|
137
|
+
"""Create a workflow json to be used in the Databricks API."""
|
|
138
|
+
task_json = [
|
|
139
|
+
task._convert_to_databricks_workflow_task( # type: ignore[attr-defined]
|
|
140
|
+
relevant_upstreams=self.relevant_upstreams, context=context
|
|
141
|
+
)
|
|
142
|
+
for task in self.tasks_to_convert
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
default_json = {
|
|
146
|
+
"name": self.job_name,
|
|
147
|
+
"email_notifications": {"no_alert_for_skipped_runs": False},
|
|
148
|
+
"timeout_seconds": 0,
|
|
149
|
+
"tasks": task_json,
|
|
150
|
+
"format": "MULTI_TASK",
|
|
151
|
+
"job_clusters": self.job_clusters,
|
|
152
|
+
"max_concurrent_runs": self.max_concurrent_runs,
|
|
153
|
+
}
|
|
154
|
+
return merge(default_json, self.extra_job_params)
|
|
155
|
+
|
|
156
|
+
def _create_or_reset_job(self, context: Context) -> int:
|
|
157
|
+
job_spec = self.create_workflow_json(context=context)
|
|
158
|
+
existing_jobs = self._hook.list_jobs(job_name=self.job_name)
|
|
159
|
+
job_id = existing_jobs[0]["job_id"] if existing_jobs else None
|
|
160
|
+
if job_id:
|
|
161
|
+
self.log.info(
|
|
162
|
+
"Updating existing Databricks workflow job %s with spec %s",
|
|
163
|
+
self.job_name,
|
|
164
|
+
json.dumps(job_spec, indent=2),
|
|
165
|
+
)
|
|
166
|
+
self._hook.reset_job(job_id, job_spec)
|
|
167
|
+
else:
|
|
168
|
+
self.log.info(
|
|
169
|
+
"Creating new Databricks workflow job %s with spec %s",
|
|
170
|
+
self.job_name,
|
|
171
|
+
json.dumps(job_spec, indent=2),
|
|
172
|
+
)
|
|
173
|
+
job_id = self._hook.create_job(job_spec)
|
|
174
|
+
return job_id
|
|
175
|
+
|
|
176
|
+
def _wait_for_job_to_start(self, run_id: int) -> None:
|
|
177
|
+
run_url = self._hook.get_run_page_url(run_id)
|
|
178
|
+
self.log.info("Check the progress of the Databricks job at %s", run_url)
|
|
179
|
+
life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state
|
|
180
|
+
if life_cycle_state not in (
|
|
181
|
+
RunLifeCycleState.PENDING.value,
|
|
182
|
+
RunLifeCycleState.RUNNING.value,
|
|
183
|
+
RunLifeCycleState.BLOCKED.value,
|
|
184
|
+
):
|
|
185
|
+
raise AirflowException(f"Could not start the workflow job. State: {life_cycle_state}")
|
|
186
|
+
while life_cycle_state in (RunLifeCycleState.PENDING.value, RunLifeCycleState.BLOCKED.value):
|
|
187
|
+
self.log.info("Waiting for the Databricks job to start running")
|
|
188
|
+
time.sleep(5)
|
|
189
|
+
life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state
|
|
190
|
+
self.log.info("Databricks job started. State: %s", life_cycle_state)
|
|
191
|
+
|
|
192
|
+
def execute(self, context: Context) -> Any:
|
|
193
|
+
if not isinstance(self.task_group, DatabricksWorkflowTaskGroup):
|
|
194
|
+
raise AirflowException("Task group must be a DatabricksWorkflowTaskGroup")
|
|
195
|
+
|
|
196
|
+
job_id = self._create_or_reset_job(context)
|
|
197
|
+
|
|
198
|
+
run_id = self._hook.run_now(
|
|
199
|
+
{
|
|
200
|
+
"job_id": job_id,
|
|
201
|
+
"jar_params": self.task_group.jar_params,
|
|
202
|
+
"notebook_params": self.notebook_params,
|
|
203
|
+
"python_params": self.task_group.python_params,
|
|
204
|
+
"spark_submit_params": self.task_group.spark_submit_params,
|
|
205
|
+
}
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
self._wait_for_job_to_start(run_id)
|
|
209
|
+
|
|
210
|
+
return {
|
|
211
|
+
"conn_id": self.databricks_conn_id,
|
|
212
|
+
"job_id": job_id,
|
|
213
|
+
"run_id": run_id,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class DatabricksWorkflowTaskGroup(TaskGroup):
|
|
218
|
+
"""
|
|
219
|
+
A task group that takes a list of tasks and creates a databricks workflow.
|
|
220
|
+
|
|
221
|
+
The DatabricksWorkflowTaskGroup takes a list of tasks and creates a databricks workflow
|
|
222
|
+
based on the metadata produced by those tasks. For a task to be eligible for this
|
|
223
|
+
TaskGroup, it must contain the ``_convert_to_databricks_workflow_task`` method. If any tasks
|
|
224
|
+
do not contain this method then the Taskgroup will raise an error at parse time.
|
|
225
|
+
|
|
226
|
+
.. seealso::
|
|
227
|
+
For more information on how to use this operator, take a look at the guide:
|
|
228
|
+
:ref:`howto/operator:DatabricksWorkflowTaskGroup`
|
|
229
|
+
|
|
230
|
+
:param databricks_conn_id: The name of the databricks connection to use.
|
|
231
|
+
:param existing_clusters: A list of existing clusters to use for this workflow.
|
|
232
|
+
:param extra_job_params: A dictionary containing properties which will override the default
|
|
233
|
+
Databricks Workflow Job definitions.
|
|
234
|
+
:param jar_params: A list of jar parameters to pass to the workflow. These parameters will be passed to all jar
|
|
235
|
+
tasks in the workflow.
|
|
236
|
+
:param job_clusters: A list of job clusters to use for this workflow.
|
|
237
|
+
:param max_concurrent_runs: The maximum number of concurrent runs for this workflow.
|
|
238
|
+
:param notebook_packages: A list of dictionary of Python packages to be installed. Packages defined
|
|
239
|
+
at the workflow task group level are installed for each of the notebook tasks under it. And
|
|
240
|
+
packages defined at the notebook task level are installed specific for the notebook task.
|
|
241
|
+
:param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters
|
|
242
|
+
will be passed to all notebook tasks in the workflow.
|
|
243
|
+
:param python_params: A list of python parameters to pass to the workflow. These parameters will be passed to
|
|
244
|
+
all python tasks in the workflow.
|
|
245
|
+
:param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters
|
|
246
|
+
will be passed to all spark submit tasks.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
is_databricks = True
|
|
250
|
+
|
|
251
|
+
def __init__(
|
|
252
|
+
self,
|
|
253
|
+
databricks_conn_id: str,
|
|
254
|
+
existing_clusters: list[str] | None = None,
|
|
255
|
+
extra_job_params: dict[str, Any] | None = None,
|
|
256
|
+
jar_params: list[str] | None = None,
|
|
257
|
+
job_clusters: list[dict] | None = None,
|
|
258
|
+
max_concurrent_runs: int = 1,
|
|
259
|
+
notebook_packages: list[dict[str, Any]] | None = None,
|
|
260
|
+
notebook_params: dict | None = None,
|
|
261
|
+
python_params: list | None = None,
|
|
262
|
+
spark_submit_params: list | None = None,
|
|
263
|
+
**kwargs,
|
|
264
|
+
):
|
|
265
|
+
self.databricks_conn_id = databricks_conn_id
|
|
266
|
+
self.existing_clusters = existing_clusters or []
|
|
267
|
+
self.extra_job_params = extra_job_params or {}
|
|
268
|
+
self.jar_params = jar_params or []
|
|
269
|
+
self.job_clusters = job_clusters or []
|
|
270
|
+
self.max_concurrent_runs = max_concurrent_runs
|
|
271
|
+
self.notebook_packages = notebook_packages or []
|
|
272
|
+
self.notebook_params = notebook_params or {}
|
|
273
|
+
self.python_params = python_params or []
|
|
274
|
+
self.spark_submit_params = spark_submit_params or []
|
|
275
|
+
super().__init__(**kwargs)
|
|
276
|
+
|
|
277
|
+
def __exit__(
|
|
278
|
+
self, _type: type[BaseException] | None, _value: BaseException | None, _tb: TracebackType | None
|
|
279
|
+
) -> None:
|
|
280
|
+
"""Exit the context manager and add tasks to a single ``_CreateDatabricksWorkflowOperator``."""
|
|
281
|
+
roots = list(self.get_roots())
|
|
282
|
+
tasks = _flatten_node(self)
|
|
283
|
+
|
|
284
|
+
create_databricks_workflow_task = _CreateDatabricksWorkflowOperator(
|
|
285
|
+
dag=self.dag,
|
|
286
|
+
task_group=self,
|
|
287
|
+
task_id="launch",
|
|
288
|
+
databricks_conn_id=self.databricks_conn_id,
|
|
289
|
+
existing_clusters=self.existing_clusters,
|
|
290
|
+
extra_job_params=self.extra_job_params,
|
|
291
|
+
job_clusters=self.job_clusters,
|
|
292
|
+
max_concurrent_runs=self.max_concurrent_runs,
|
|
293
|
+
notebook_params=self.notebook_params,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
for task in tasks:
|
|
297
|
+
if not (
|
|
298
|
+
hasattr(task, "_convert_to_databricks_workflow_task")
|
|
299
|
+
and callable(task._convert_to_databricks_workflow_task)
|
|
300
|
+
):
|
|
301
|
+
raise AirflowException(
|
|
302
|
+
f"Task {task.task_id} does not support conversion to databricks workflow task."
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
task.workflow_run_metadata = create_databricks_workflow_task.output
|
|
306
|
+
create_databricks_workflow_task.relevant_upstreams.append(task.task_id)
|
|
307
|
+
create_databricks_workflow_task.add_task(task)
|
|
308
|
+
|
|
309
|
+
for root_task in roots:
|
|
310
|
+
root_task.set_upstream(create_databricks_workflow_task)
|
|
311
|
+
|
|
312
|
+
super().__exit__(_type, _value, _tb)
|
|
@@ -28,7 +28,7 @@ build-backend = "flit_core.buildapi"
|
|
|
28
28
|
|
|
29
29
|
[project]
|
|
30
30
|
name = "apache-airflow-providers-databricks"
|
|
31
|
-
version = "6.
|
|
31
|
+
version = "6.6.0"
|
|
32
32
|
description = "Provider package apache-airflow-providers-databricks for Apache Airflow"
|
|
33
33
|
readme = "README.rst"
|
|
34
34
|
authors = [
|
|
@@ -60,12 +60,16 @@ dependencies = [
|
|
|
60
60
|
"apache-airflow-providers-common-sql>=1.10.0",
|
|
61
61
|
"apache-airflow>=2.7.0",
|
|
62
62
|
"databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
|
|
63
|
+
"mergedeep>=1.3.4",
|
|
64
|
+
"pandas>=1.5.3,<2.2;python_version<\"3.9\"",
|
|
65
|
+
"pandas>=2.1.2,<2.2;python_version>=\"3.9\"",
|
|
66
|
+
"pyarrow>=14.0.1",
|
|
63
67
|
"requests>=2.27.0,<3",
|
|
64
68
|
]
|
|
65
69
|
|
|
66
70
|
[project.urls]
|
|
67
|
-
"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
68
|
-
"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
71
|
+
"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0"
|
|
72
|
+
"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.6.0/changelog.html"
|
|
69
73
|
"Bug Tracker" = "https://github.com/apache/airflow/issues"
|
|
70
74
|
"Source Code" = "https://github.com/apache/airflow"
|
|
71
75
|
"Slack Chat" = "https://s.apache.org/airflow-slack"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|