apache-airflow-providers-databricks 6.7.0__tar.gz → 6.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/PKG-INFO +6 -6
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/README.rst +3 -3
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/__init__.py +1 -1
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/get_provider_info.py +8 -1
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/operators/databricks.py +120 -96
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/operators/databricks_workflow.py +5 -0
- apache_airflow_providers_databricks-6.8.0/airflow/providers/databricks/plugins/databricks_workflow.py +477 -0
- apache_airflow_providers_databricks-6.8.0/airflow/providers/databricks/utils/__init__.py +16 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/utils/databricks.py +2 -2
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/pyproject.toml +5 -3
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/LICENSE +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/hooks/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/hooks/databricks.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/hooks/databricks_base.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/hooks/databricks_sql.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/operators/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/operators/databricks_repos.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/operators/databricks_sql.py +0 -0
- {apache_airflow_providers_databricks-6.7.0/airflow/providers/databricks/sensors → apache_airflow_providers_databricks-6.8.0/airflow/providers/databricks/plugins}/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.7.0/airflow/providers/databricks/utils → apache_airflow_providers_databricks-6.8.0/airflow/providers/databricks/sensors}/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/sensors/databricks_partition.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/sensors/databricks_sql.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/triggers/__init__.py +0 -0
- {apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/airflow/providers/databricks/triggers/databricks.py +0 -0
{apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/PKG-INFO
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: apache-airflow-providers-databricks
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.8.0
|
|
4
4
|
Summary: Provider package apache-airflow-providers-databricks for Apache Airflow
|
|
5
5
|
Keywords: airflow-provider,databricks,airflow,integration
|
|
6
6
|
Author-email: Apache Software Foundation <dev@airflow.apache.org>
|
|
@@ -34,8 +34,8 @@ Requires-Dist: azure-identity>=1.3.1 ; extra == "azure-identity"
|
|
|
34
34
|
Requires-Dist: apache-airflow-providers-common-sql ; extra == "common.sql"
|
|
35
35
|
Requires-Dist: databricks-sdk==0.10.0 ; extra == "sdk"
|
|
36
36
|
Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
|
|
37
|
-
Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
38
|
-
Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
37
|
+
Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.8.0/changelog.html
|
|
38
|
+
Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.8.0
|
|
39
39
|
Project-URL: Slack Chat, https://s.apache.org/airflow-slack
|
|
40
40
|
Project-URL: Source Code, https://github.com/apache/airflow
|
|
41
41
|
Project-URL: Twitter, https://twitter.com/ApacheAirflow
|
|
@@ -88,7 +88,7 @@ Provides-Extra: sdk
|
|
|
88
88
|
|
|
89
89
|
Package ``apache-airflow-providers-databricks``
|
|
90
90
|
|
|
91
|
-
Release: ``6.
|
|
91
|
+
Release: ``6.8.0``
|
|
92
92
|
|
|
93
93
|
|
|
94
94
|
`Databricks <https://databricks.com/>`__
|
|
@@ -101,7 +101,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
|
|
|
101
101
|
are in ``airflow.providers.databricks`` python package.
|
|
102
102
|
|
|
103
103
|
You can find package information and changelog for the provider
|
|
104
|
-
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
104
|
+
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.8.0/>`_.
|
|
105
105
|
|
|
106
106
|
Installation
|
|
107
107
|
------------
|
|
@@ -149,4 +149,4 @@ Dependent package
|
|
|
149
149
|
============================================================================================================ ==============
|
|
150
150
|
|
|
151
151
|
The changelog for the provider package can be found in the
|
|
152
|
-
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
152
|
+
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.8.0/changelog.html>`_.
|
{apache_airflow_providers_databricks-6.7.0 → apache_airflow_providers_databricks-6.8.0}/README.rst
RENAMED
|
@@ -42,7 +42,7 @@
|
|
|
42
42
|
|
|
43
43
|
Package ``apache-airflow-providers-databricks``
|
|
44
44
|
|
|
45
|
-
Release: ``6.
|
|
45
|
+
Release: ``6.8.0``
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
`Databricks <https://databricks.com/>`__
|
|
@@ -55,7 +55,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
|
|
|
55
55
|
are in ``airflow.providers.databricks`` python package.
|
|
56
56
|
|
|
57
57
|
You can find package information and changelog for the provider
|
|
58
|
-
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
58
|
+
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.8.0/>`_.
|
|
59
59
|
|
|
60
60
|
Installation
|
|
61
61
|
------------
|
|
@@ -103,4 +103,4 @@ Dependent package
|
|
|
103
103
|
============================================================================================================ ==============
|
|
104
104
|
|
|
105
105
|
The changelog for the provider package can be found in the
|
|
106
|
-
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
106
|
+
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.8.0/changelog.html>`_.
|
|
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
|
29
29
|
|
|
30
30
|
__all__ = ["__version__"]
|
|
31
31
|
|
|
32
|
-
__version__ = "6.
|
|
32
|
+
__version__ = "6.8.0"
|
|
33
33
|
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
|
35
35
|
"2.7.0"
|
|
@@ -28,8 +28,9 @@ def get_provider_info():
|
|
|
28
28
|
"name": "Databricks",
|
|
29
29
|
"description": "`Databricks <https://databricks.com/>`__\n",
|
|
30
30
|
"state": "ready",
|
|
31
|
-
"source-date-epoch":
|
|
31
|
+
"source-date-epoch": 1722663644,
|
|
32
32
|
"versions": [
|
|
33
|
+
"6.8.0",
|
|
33
34
|
"6.7.0",
|
|
34
35
|
"6.6.0",
|
|
35
36
|
"6.5.0",
|
|
@@ -189,5 +190,11 @@ def get_provider_info():
|
|
|
189
190
|
"connection-type": "databricks",
|
|
190
191
|
}
|
|
191
192
|
],
|
|
193
|
+
"plugins": [
|
|
194
|
+
{
|
|
195
|
+
"name": "databricks_workflow",
|
|
196
|
+
"plugin-class": "airflow.providers.databricks.plugins.databricks_workflow.DatabricksWorkflowPlugin",
|
|
197
|
+
}
|
|
198
|
+
],
|
|
192
199
|
"extra-links": ["airflow.providers.databricks.operators.databricks.DatabricksJobRunLink"],
|
|
193
200
|
}
|
|
@@ -35,8 +35,12 @@ from airflow.providers.databricks.operators.databricks_workflow import (
|
|
|
35
35
|
DatabricksWorkflowTaskGroup,
|
|
36
36
|
WorkflowRunMetadata,
|
|
37
37
|
)
|
|
38
|
+
from airflow.providers.databricks.plugins.databricks_workflow import (
|
|
39
|
+
WorkflowJobRepairSingleTaskLink,
|
|
40
|
+
WorkflowJobRunLink,
|
|
41
|
+
)
|
|
38
42
|
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
|
|
39
|
-
from airflow.providers.databricks.utils.databricks import
|
|
43
|
+
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
|
|
40
44
|
|
|
41
45
|
if TYPE_CHECKING:
|
|
42
46
|
from airflow.models.taskinstancekey import TaskInstanceKey
|
|
@@ -182,17 +186,6 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
|
|
|
182
186
|
raise AirflowException(error_message)
|
|
183
187
|
|
|
184
188
|
|
|
185
|
-
def _handle_overridden_json_params(operator):
|
|
186
|
-
for key, value in operator.overridden_json_params.items():
|
|
187
|
-
if value is not None:
|
|
188
|
-
operator.json[key] = value
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
def normalise_json_content(operator):
|
|
192
|
-
if operator.json:
|
|
193
|
-
operator.json = _normalise_json_content(operator.json)
|
|
194
|
-
|
|
195
|
-
|
|
196
189
|
class DatabricksJobRunLink(BaseOperatorLink):
|
|
197
190
|
"""Constructs a link to monitor a Databricks Job Run."""
|
|
198
191
|
|
|
@@ -296,21 +289,34 @@ class DatabricksCreateJobsOperator(BaseOperator):
|
|
|
296
289
|
self.databricks_retry_limit = databricks_retry_limit
|
|
297
290
|
self.databricks_retry_delay = databricks_retry_delay
|
|
298
291
|
self.databricks_retry_args = databricks_retry_args
|
|
299
|
-
|
|
300
|
-
"name"
|
|
301
|
-
|
|
302
|
-
"
|
|
303
|
-
|
|
304
|
-
"
|
|
305
|
-
|
|
306
|
-
"
|
|
307
|
-
|
|
308
|
-
"
|
|
309
|
-
|
|
310
|
-
"
|
|
311
|
-
|
|
312
|
-
"
|
|
313
|
-
|
|
292
|
+
if name is not None:
|
|
293
|
+
self.json["name"] = name
|
|
294
|
+
if description is not None:
|
|
295
|
+
self.json["description"] = description
|
|
296
|
+
if tags is not None:
|
|
297
|
+
self.json["tags"] = tags
|
|
298
|
+
if tasks is not None:
|
|
299
|
+
self.json["tasks"] = tasks
|
|
300
|
+
if job_clusters is not None:
|
|
301
|
+
self.json["job_clusters"] = job_clusters
|
|
302
|
+
if email_notifications is not None:
|
|
303
|
+
self.json["email_notifications"] = email_notifications
|
|
304
|
+
if webhook_notifications is not None:
|
|
305
|
+
self.json["webhook_notifications"] = webhook_notifications
|
|
306
|
+
if notification_settings is not None:
|
|
307
|
+
self.json["notification_settings"] = notification_settings
|
|
308
|
+
if timeout_seconds is not None:
|
|
309
|
+
self.json["timeout_seconds"] = timeout_seconds
|
|
310
|
+
if schedule is not None:
|
|
311
|
+
self.json["schedule"] = schedule
|
|
312
|
+
if max_concurrent_runs is not None:
|
|
313
|
+
self.json["max_concurrent_runs"] = max_concurrent_runs
|
|
314
|
+
if git_source is not None:
|
|
315
|
+
self.json["git_source"] = git_source
|
|
316
|
+
if access_control_list is not None:
|
|
317
|
+
self.json["access_control_list"] = access_control_list
|
|
318
|
+
if self.json:
|
|
319
|
+
self.json = normalise_json_content(self.json)
|
|
314
320
|
|
|
315
321
|
@cached_property
|
|
316
322
|
def _hook(self):
|
|
@@ -322,24 +328,16 @@ class DatabricksCreateJobsOperator(BaseOperator):
|
|
|
322
328
|
caller="DatabricksCreateJobsOperator",
|
|
323
329
|
)
|
|
324
330
|
|
|
325
|
-
def
|
|
326
|
-
_handle_overridden_json_params(self)
|
|
327
|
-
|
|
331
|
+
def execute(self, context: Context) -> int:
|
|
328
332
|
if "name" not in self.json:
|
|
329
333
|
raise AirflowException("Missing required parameter: name")
|
|
330
|
-
|
|
331
|
-
normalise_json_content(self)
|
|
332
|
-
|
|
333
|
-
def execute(self, context: Context) -> int:
|
|
334
|
-
self._setup_and_validate_json()
|
|
335
|
-
|
|
336
334
|
job_id = self._hook.find_job_id_by_name(self.json["name"])
|
|
337
335
|
if job_id is None:
|
|
338
336
|
return self._hook.create_job(self.json)
|
|
339
337
|
self._hook.reset_job(str(job_id), self.json)
|
|
340
338
|
if (access_control_list := self.json.get("access_control_list")) is not None:
|
|
341
339
|
acl_json = {"access_control_list": access_control_list}
|
|
342
|
-
self._hook.update_job_permission(job_id,
|
|
340
|
+
self._hook.update_job_permission(job_id, normalise_json_content(acl_json))
|
|
343
341
|
|
|
344
342
|
return job_id
|
|
345
343
|
|
|
@@ -512,23 +510,43 @@ class DatabricksSubmitRunOperator(BaseOperator):
|
|
|
512
510
|
self.databricks_retry_args = databricks_retry_args
|
|
513
511
|
self.wait_for_termination = wait_for_termination
|
|
514
512
|
self.deferrable = deferrable
|
|
515
|
-
|
|
516
|
-
"tasks"
|
|
517
|
-
|
|
518
|
-
"
|
|
519
|
-
|
|
520
|
-
"
|
|
521
|
-
|
|
522
|
-
"
|
|
523
|
-
|
|
524
|
-
"
|
|
525
|
-
|
|
526
|
-
"
|
|
527
|
-
|
|
528
|
-
"
|
|
529
|
-
|
|
530
|
-
"
|
|
531
|
-
|
|
513
|
+
if tasks is not None:
|
|
514
|
+
self.json["tasks"] = tasks
|
|
515
|
+
if spark_jar_task is not None:
|
|
516
|
+
self.json["spark_jar_task"] = spark_jar_task
|
|
517
|
+
if notebook_task is not None:
|
|
518
|
+
self.json["notebook_task"] = notebook_task
|
|
519
|
+
if spark_python_task is not None:
|
|
520
|
+
self.json["spark_python_task"] = spark_python_task
|
|
521
|
+
if spark_submit_task is not None:
|
|
522
|
+
self.json["spark_submit_task"] = spark_submit_task
|
|
523
|
+
if pipeline_task is not None:
|
|
524
|
+
self.json["pipeline_task"] = pipeline_task
|
|
525
|
+
if dbt_task is not None:
|
|
526
|
+
self.json["dbt_task"] = dbt_task
|
|
527
|
+
if new_cluster is not None:
|
|
528
|
+
self.json["new_cluster"] = new_cluster
|
|
529
|
+
if existing_cluster_id is not None:
|
|
530
|
+
self.json["existing_cluster_id"] = existing_cluster_id
|
|
531
|
+
if libraries is not None:
|
|
532
|
+
self.json["libraries"] = libraries
|
|
533
|
+
if run_name is not None:
|
|
534
|
+
self.json["run_name"] = run_name
|
|
535
|
+
if timeout_seconds is not None:
|
|
536
|
+
self.json["timeout_seconds"] = timeout_seconds
|
|
537
|
+
if "run_name" not in self.json:
|
|
538
|
+
self.json["run_name"] = run_name or kwargs["task_id"]
|
|
539
|
+
if idempotency_token is not None:
|
|
540
|
+
self.json["idempotency_token"] = idempotency_token
|
|
541
|
+
if access_control_list is not None:
|
|
542
|
+
self.json["access_control_list"] = access_control_list
|
|
543
|
+
if git_source is not None:
|
|
544
|
+
self.json["git_source"] = git_source
|
|
545
|
+
|
|
546
|
+
if "dbt_task" in self.json and "git_source" not in self.json:
|
|
547
|
+
raise AirflowException("git_source is required for dbt_task")
|
|
548
|
+
if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task:
|
|
549
|
+
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")
|
|
532
550
|
|
|
533
551
|
# This variable will be used in case our task gets killed.
|
|
534
552
|
self.run_id: int | None = None
|
|
@@ -547,25 +565,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
|
|
|
547
565
|
caller=caller,
|
|
548
566
|
)
|
|
549
567
|
|
|
550
|
-
def _setup_and_validate_json(self):
|
|
551
|
-
_handle_overridden_json_params(self)
|
|
552
|
-
|
|
553
|
-
if "run_name" not in self.json or self.json["run_name"] is None:
|
|
554
|
-
self.json["run_name"] = self.task_id
|
|
555
|
-
|
|
556
|
-
if "dbt_task" in self.json and "git_source" not in self.json:
|
|
557
|
-
raise AirflowException("git_source is required for dbt_task")
|
|
558
|
-
if (
|
|
559
|
-
"pipeline_task" in self.json
|
|
560
|
-
and "pipeline_id" in self.json["pipeline_task"]
|
|
561
|
-
and "pipeline_name" in self.json["pipeline_task"]
|
|
562
|
-
):
|
|
563
|
-
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")
|
|
564
|
-
|
|
565
|
-
normalise_json_content(self)
|
|
566
|
-
|
|
567
568
|
def execute(self, context: Context):
|
|
568
|
-
self._setup_and_validate_json()
|
|
569
569
|
if (
|
|
570
570
|
"pipeline_task" in self.json
|
|
571
571
|
and self.json["pipeline_task"].get("pipeline_id") is None
|
|
@@ -575,7 +575,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
|
|
|
575
575
|
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
|
|
576
576
|
self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name)
|
|
577
577
|
del self.json["pipeline_task"]["pipeline_name"]
|
|
578
|
-
json_normalised =
|
|
578
|
+
json_normalised = normalise_json_content(self.json)
|
|
579
579
|
self.run_id = self._hook.submit_run(json_normalised)
|
|
580
580
|
if self.deferrable:
|
|
581
581
|
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
|
|
@@ -611,7 +611,7 @@ class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
|
|
|
611
611
|
|
|
612
612
|
def execute(self, context):
|
|
613
613
|
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
|
|
614
|
-
json_normalised =
|
|
614
|
+
json_normalised = normalise_json_content(self.json)
|
|
615
615
|
self.run_id = hook.submit_run(json_normalised)
|
|
616
616
|
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)
|
|
617
617
|
|
|
@@ -811,16 +811,27 @@ class DatabricksRunNowOperator(BaseOperator):
|
|
|
811
811
|
self.deferrable = deferrable
|
|
812
812
|
self.repair_run = repair_run
|
|
813
813
|
self.cancel_previous_runs = cancel_previous_runs
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
"
|
|
817
|
-
|
|
818
|
-
"
|
|
819
|
-
|
|
820
|
-
"
|
|
821
|
-
|
|
822
|
-
"
|
|
823
|
-
|
|
814
|
+
|
|
815
|
+
if job_id is not None:
|
|
816
|
+
self.json["job_id"] = job_id
|
|
817
|
+
if job_name is not None:
|
|
818
|
+
self.json["job_name"] = job_name
|
|
819
|
+
if "job_id" in self.json and "job_name" in self.json:
|
|
820
|
+
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
|
|
821
|
+
if notebook_params is not None:
|
|
822
|
+
self.json["notebook_params"] = notebook_params
|
|
823
|
+
if python_params is not None:
|
|
824
|
+
self.json["python_params"] = python_params
|
|
825
|
+
if python_named_params is not None:
|
|
826
|
+
self.json["python_named_params"] = python_named_params
|
|
827
|
+
if jar_params is not None:
|
|
828
|
+
self.json["jar_params"] = jar_params
|
|
829
|
+
if spark_submit_params is not None:
|
|
830
|
+
self.json["spark_submit_params"] = spark_submit_params
|
|
831
|
+
if idempotency_token is not None:
|
|
832
|
+
self.json["idempotency_token"] = idempotency_token
|
|
833
|
+
if self.json:
|
|
834
|
+
self.json = normalise_json_content(self.json)
|
|
824
835
|
# This variable will be used in case our task gets killed.
|
|
825
836
|
self.run_id: int | None = None
|
|
826
837
|
self.do_xcom_push = do_xcom_push
|
|
@@ -838,16 +849,7 @@ class DatabricksRunNowOperator(BaseOperator):
|
|
|
838
849
|
caller=caller,
|
|
839
850
|
)
|
|
840
851
|
|
|
841
|
-
def _setup_and_validate_json(self):
|
|
842
|
-
_handle_overridden_json_params(self)
|
|
843
|
-
|
|
844
|
-
if "job_id" in self.json and "job_name" in self.json:
|
|
845
|
-
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
|
|
846
|
-
|
|
847
|
-
normalise_json_content(self)
|
|
848
|
-
|
|
849
852
|
def execute(self, context: Context):
|
|
850
|
-
self._setup_and_validate_json()
|
|
851
853
|
hook = self._hook
|
|
852
854
|
if "job_name" in self.json:
|
|
853
855
|
job_id = hook.find_job_id_by_name(self.json["job_name"])
|
|
@@ -958,6 +960,15 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
|
|
|
958
960
|
|
|
959
961
|
super().__init__(**kwargs)
|
|
960
962
|
|
|
963
|
+
if self._databricks_workflow_task_group is not None:
|
|
964
|
+
self.operator_extra_links = (
|
|
965
|
+
WorkflowJobRunLink(),
|
|
966
|
+
WorkflowJobRepairSingleTaskLink(),
|
|
967
|
+
)
|
|
968
|
+
else:
|
|
969
|
+
# Databricks does not support repair for non-workflow tasks, hence do not show the repair link.
|
|
970
|
+
self.operator_extra_links = (DatabricksJobRunLink(),)
|
|
971
|
+
|
|
961
972
|
@cached_property
|
|
962
973
|
def _hook(self) -> DatabricksHook:
|
|
963
974
|
return self._get_hook(caller=self.caller)
|
|
@@ -1016,12 +1027,17 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
|
|
|
1016
1027
|
raise ValueError("Must specify either existing_cluster_id or new_cluster.")
|
|
1017
1028
|
return run_json
|
|
1018
1029
|
|
|
1019
|
-
def _launch_job(self) -> int:
|
|
1030
|
+
def _launch_job(self, context: Context | None = None) -> int:
|
|
1020
1031
|
"""Launch the job on Databricks."""
|
|
1021
1032
|
run_json = self._get_run_json()
|
|
1022
1033
|
self.databricks_run_id = self._hook.submit_run(run_json)
|
|
1023
1034
|
url = self._hook.get_run_page_url(self.databricks_run_id)
|
|
1024
1035
|
self.log.info("Check the job run in Databricks: %s", url)
|
|
1036
|
+
|
|
1037
|
+
if self.do_xcom_push and context is not None:
|
|
1038
|
+
context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=self.databricks_run_id)
|
|
1039
|
+
context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=url)
|
|
1040
|
+
|
|
1025
1041
|
return self.databricks_run_id
|
|
1026
1042
|
|
|
1027
1043
|
def _handle_terminal_run_state(self, run_state: RunState) -> None:
|
|
@@ -1040,7 +1056,15 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
|
|
|
1040
1056
|
"""Retrieve the Databricks task corresponding to the current Airflow task."""
|
|
1041
1057
|
if self.databricks_run_id is None:
|
|
1042
1058
|
raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
|
|
1043
|
-
|
|
1059
|
+
tasks = self._hook.get_run(self.databricks_run_id)["tasks"]
|
|
1060
|
+
|
|
1061
|
+
# Because the task_key remains the same across multiple runs, and the Databricks API does not return
|
|
1062
|
+
# tasks sorted by their attempts/start time, we sort the tasks by start time. This ensures that we
|
|
1063
|
+
# map the latest attempt (whose status is to be monitored) of the task run to the task_key while
|
|
1064
|
+
# building the {task_key: task} map below.
|
|
1065
|
+
sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"])
|
|
1066
|
+
|
|
1067
|
+
return {task["task_key"]: task for task in sorted_task_runs}[
|
|
1044
1068
|
self._get_databricks_task_id(self.task_id)
|
|
1045
1069
|
]
|
|
1046
1070
|
|
|
@@ -1125,7 +1149,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
|
|
|
1125
1149
|
self.databricks_run_id = workflow_run_metadata.run_id
|
|
1126
1150
|
self.databricks_conn_id = workflow_run_metadata.conn_id
|
|
1127
1151
|
else:
|
|
1128
|
-
self._launch_job()
|
|
1152
|
+
self._launch_job(context=context)
|
|
1129
1153
|
if self.wait_for_termination:
|
|
1130
1154
|
self.monitor_databricks_job()
|
|
1131
1155
|
|
|
@@ -28,6 +28,10 @@ from mergedeep import merge
|
|
|
28
28
|
from airflow.exceptions import AirflowException
|
|
29
29
|
from airflow.models import BaseOperator
|
|
30
30
|
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState
|
|
31
|
+
from airflow.providers.databricks.plugins.databricks_workflow import (
|
|
32
|
+
WorkflowJobRepairAllFailedLink,
|
|
33
|
+
WorkflowJobRunLink,
|
|
34
|
+
)
|
|
31
35
|
from airflow.utils.task_group import TaskGroup
|
|
32
36
|
|
|
33
37
|
if TYPE_CHECKING:
|
|
@@ -88,6 +92,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
|
|
|
88
92
|
populated after instantiation using the `add_task` method.
|
|
89
93
|
"""
|
|
90
94
|
|
|
95
|
+
operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink())
|
|
91
96
|
template_fields = ("notebook_params",)
|
|
92
97
|
caller = "_CreateDatabricksWorkflowOperator"
|
|
93
98
|
|
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
|
3
|
+
# distributed with this work for additional information
|
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
|
6
|
+
# "License"); you may not use this file except in compliance
|
|
7
|
+
# with the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
|
12
|
+
# software distributed under the License is distributed on an
|
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
14
|
+
# KIND, either express or implied. See the License for the
|
|
15
|
+
# specific language governing permissions and limitations
|
|
16
|
+
# under the License.
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
import os
|
|
22
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
23
|
+
from urllib.parse import unquote
|
|
24
|
+
|
|
25
|
+
from flask import current_app, flash, redirect, request, url_for
|
|
26
|
+
from flask_appbuilder.api import expose
|
|
27
|
+
from packaging.version import Version
|
|
28
|
+
|
|
29
|
+
from airflow.exceptions import AirflowException, TaskInstanceNotFound
|
|
30
|
+
from airflow.models import BaseOperator, BaseOperatorLink
|
|
31
|
+
from airflow.models.dag import DAG, clear_task_instances
|
|
32
|
+
from airflow.models.dagrun import DagRun
|
|
33
|
+
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
|
|
34
|
+
from airflow.models.xcom import XCom
|
|
35
|
+
from airflow.plugins_manager import AirflowPlugin
|
|
36
|
+
from airflow.providers.databricks.hooks.databricks import DatabricksHook
|
|
37
|
+
from airflow.security import permissions
|
|
38
|
+
from airflow.utils.airflow_flask_app import AirflowApp
|
|
39
|
+
from airflow.utils.log.logging_mixin import LoggingMixin
|
|
40
|
+
from airflow.utils.session import NEW_SESSION, provide_session
|
|
41
|
+
from airflow.utils.state import TaskInstanceState
|
|
42
|
+
from airflow.utils.task_group import TaskGroup
|
|
43
|
+
from airflow.version import version
|
|
44
|
+
from airflow.www import auth
|
|
45
|
+
from airflow.www.views import AirflowBaseView
|
|
46
|
+
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
from sqlalchemy.orm.session import Session
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20)
|
|
52
|
+
REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5)
|
|
53
|
+
|
|
54
|
+
airflow_app = cast(AirflowApp, current_app)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_auth_decorator():
|
|
58
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.8.0
|
|
59
|
+
if Version(version) < Version("2.8"):
|
|
60
|
+
return auth.has_access(
|
|
61
|
+
[
|
|
62
|
+
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
|
|
63
|
+
(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
|
|
64
|
+
]
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
from airflow.auth.managers.models.resource_details import DagAccessEntity
|
|
68
|
+
|
|
69
|
+
return auth.has_access_dag("POST", DagAccessEntity.RUN)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_databricks_task_id(task: BaseOperator) -> str:
|
|
73
|
+
"""
|
|
74
|
+
Get the databricks task ID using dag_id and task_id. removes illegal characters.
|
|
75
|
+
|
|
76
|
+
:param task: The task to get the databricks task ID for.
|
|
77
|
+
:return: The databricks task ID.
|
|
78
|
+
"""
|
|
79
|
+
return f"{task.dag_id}__{task.task_id.replace('.', '__')}"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_databricks_task_ids(
|
|
83
|
+
group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger
|
|
84
|
+
) -> list[str]:
|
|
85
|
+
"""
|
|
86
|
+
Return a list of all Databricks task IDs for a dictionary of Airflow tasks.
|
|
87
|
+
|
|
88
|
+
:param group_id: The task group ID.
|
|
89
|
+
:param task_map: A dictionary mapping task IDs to BaseOperator instances.
|
|
90
|
+
:param log: The logger to use for logging.
|
|
91
|
+
:return: A list of Databricks task IDs for the given task group.
|
|
92
|
+
"""
|
|
93
|
+
task_ids = []
|
|
94
|
+
log.debug("Getting databricks task ids for group %s", group_id)
|
|
95
|
+
for task_id, task in task_map.items():
|
|
96
|
+
if task_id == f"{group_id}.launch":
|
|
97
|
+
continue
|
|
98
|
+
databricks_task_id = _get_databricks_task_id(task)
|
|
99
|
+
log.debug("databricks task id for task %s is %s", task_id, databricks_task_id)
|
|
100
|
+
task_ids.append(databricks_task_id)
|
|
101
|
+
return task_ids
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@provide_session
|
|
105
|
+
def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun:
|
|
106
|
+
"""
|
|
107
|
+
Retrieve the DagRun object associated with the specified DAG and run_id.
|
|
108
|
+
|
|
109
|
+
:param dag: The DAG object associated with the DagRun to retrieve.
|
|
110
|
+
:param run_id: The run_id associated with the DagRun to retrieve.
|
|
111
|
+
:param session: The SQLAlchemy session to use for the query. If None, uses the default session.
|
|
112
|
+
:return: The DagRun object associated with the specified DAG and run_id.
|
|
113
|
+
"""
|
|
114
|
+
if not session:
|
|
115
|
+
raise AirflowException("Session not provided.")
|
|
116
|
+
|
|
117
|
+
return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@provide_session
|
|
121
|
+
def _clear_task_instances(
|
|
122
|
+
dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None
|
|
123
|
+
) -> None:
|
|
124
|
+
dag = airflow_app.dag_bag.get_dag(dag_id)
|
|
125
|
+
log.debug("task_ids %s to clear", str(task_ids))
|
|
126
|
+
dr: DagRun = _get_dagrun(dag, run_id, session=session)
|
|
127
|
+
tis_to_clear = [ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids]
|
|
128
|
+
clear_task_instances(tis_to_clear, session)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _repair_task(
|
|
132
|
+
databricks_conn_id: str,
|
|
133
|
+
databricks_run_id: int,
|
|
134
|
+
tasks_to_repair: list[str],
|
|
135
|
+
logger: logging.Logger,
|
|
136
|
+
) -> int:
|
|
137
|
+
"""
|
|
138
|
+
Repair a Databricks task using the Databricks API.
|
|
139
|
+
|
|
140
|
+
This function allows the Airflow retry function to create a repair job for Databricks.
|
|
141
|
+
It uses the Databricks API to get the latest repair ID before sending the repair query.
|
|
142
|
+
|
|
143
|
+
:param databricks_conn_id: The Databricks connection ID.
|
|
144
|
+
:param databricks_run_id: The Databricks run ID.
|
|
145
|
+
:param tasks_to_repair: A list of Databricks task IDs to repair.
|
|
146
|
+
:param logger: The logger to use for logging.
|
|
147
|
+
:return: None
|
|
148
|
+
"""
|
|
149
|
+
hook = DatabricksHook(databricks_conn_id=databricks_conn_id)
|
|
150
|
+
|
|
151
|
+
repair_history_id = hook.get_latest_repair_id(databricks_run_id)
|
|
152
|
+
logger.debug("Latest repair ID is %s", repair_history_id)
|
|
153
|
+
logger.debug(
|
|
154
|
+
"Sending repair query for tasks %s on run %s",
|
|
155
|
+
tasks_to_repair,
|
|
156
|
+
databricks_run_id,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
repair_json = {
|
|
160
|
+
"run_id": databricks_run_id,
|
|
161
|
+
"latest_repair_id": repair_history_id,
|
|
162
|
+
"rerun_tasks": tasks_to_repair,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
return hook.repair_run(repair_json)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def get_launch_task_id(task_group: TaskGroup) -> str:
|
|
169
|
+
"""
|
|
170
|
+
Retrieve the launch task ID from the current task group or a parent task group, recursively.
|
|
171
|
+
|
|
172
|
+
:param task_group: Task Group to be inspected
|
|
173
|
+
:return: launch Task ID
|
|
174
|
+
"""
|
|
175
|
+
try:
|
|
176
|
+
launch_task_id = task_group.get_child_by_label("launch").task_id # type: ignore[attr-defined]
|
|
177
|
+
except KeyError as e:
|
|
178
|
+
if not task_group.parent_group:
|
|
179
|
+
raise AirflowException("No launch task can be found in the task group.") from e
|
|
180
|
+
launch_task_id = get_launch_task_id(task_group.parent_group)
|
|
181
|
+
|
|
182
|
+
return launch_task_id
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _get_launch_task_key(current_task_key: TaskInstanceKey, task_id: str) -> TaskInstanceKey:
|
|
186
|
+
"""
|
|
187
|
+
Return the task key for the launch task.
|
|
188
|
+
|
|
189
|
+
This allows us to gather databricks Metadata even if the current task has failed (since tasks only
|
|
190
|
+
create xcom values if they succeed).
|
|
191
|
+
|
|
192
|
+
:param current_task_key: The task key for the current task.
|
|
193
|
+
:param task_id: The task ID for the current task.
|
|
194
|
+
:return: The task key for the launch task.
|
|
195
|
+
"""
|
|
196
|
+
if task_id:
|
|
197
|
+
return TaskInstanceKey(
|
|
198
|
+
dag_id=current_task_key.dag_id,
|
|
199
|
+
task_id=task_id,
|
|
200
|
+
run_id=current_task_key.run_id,
|
|
201
|
+
try_number=current_task_key.try_number,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
return current_task_key
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@provide_session
|
|
208
|
+
def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance:
|
|
209
|
+
dag_id = operator.dag.dag_id
|
|
210
|
+
dag_run = DagRun.find(dag_id, execution_date=dttm)[0]
|
|
211
|
+
ti = (
|
|
212
|
+
session.query(TaskInstance)
|
|
213
|
+
.filter(
|
|
214
|
+
TaskInstance.dag_id == dag_id,
|
|
215
|
+
TaskInstance.run_id == dag_run.run_id,
|
|
216
|
+
TaskInstance.task_id == operator.task_id,
|
|
217
|
+
)
|
|
218
|
+
.one_or_none()
|
|
219
|
+
)
|
|
220
|
+
if not ti:
|
|
221
|
+
raise TaskInstanceNotFound("Task instance not found")
|
|
222
|
+
return ti
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def get_xcom_result(
|
|
226
|
+
ti_key: TaskInstanceKey,
|
|
227
|
+
key: str,
|
|
228
|
+
) -> Any:
|
|
229
|
+
result = XCom.get_value(
|
|
230
|
+
ti_key=ti_key,
|
|
231
|
+
key=key,
|
|
232
|
+
)
|
|
233
|
+
from airflow.providers.databricks.operators.databricks_workflow import WorkflowRunMetadata
|
|
234
|
+
|
|
235
|
+
return WorkflowRunMetadata(**result)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class WorkflowJobRunLink(BaseOperatorLink, LoggingMixin):
|
|
239
|
+
"""Constructs a link to monitor a Databricks Job Run."""
|
|
240
|
+
|
|
241
|
+
name = "See Databricks Job Run"
|
|
242
|
+
|
|
243
|
+
def get_link(
|
|
244
|
+
self,
|
|
245
|
+
operator: BaseOperator,
|
|
246
|
+
dttm=None,
|
|
247
|
+
*,
|
|
248
|
+
ti_key: TaskInstanceKey | None = None,
|
|
249
|
+
) -> str:
|
|
250
|
+
if not ti_key:
|
|
251
|
+
ti = get_task_instance(operator, dttm)
|
|
252
|
+
ti_key = ti.key
|
|
253
|
+
task_group = operator.task_group
|
|
254
|
+
|
|
255
|
+
if not task_group:
|
|
256
|
+
raise AirflowException("Task group is required for generating Databricks Workflow Job Run Link.")
|
|
257
|
+
|
|
258
|
+
dag = airflow_app.dag_bag.get_dag(ti_key.dag_id)
|
|
259
|
+
dag.get_task(ti_key.task_id)
|
|
260
|
+
self.log.info("Getting link for task %s", ti_key.task_id)
|
|
261
|
+
if ".launch" not in ti_key.task_id:
|
|
262
|
+
self.log.debug("Finding the launch task for job run metadata %s", ti_key.task_id)
|
|
263
|
+
launch_task_id = get_launch_task_id(task_group)
|
|
264
|
+
ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id)
|
|
265
|
+
metadata = get_xcom_result(ti_key, "return_value")
|
|
266
|
+
|
|
267
|
+
hook = DatabricksHook(metadata.conn_id)
|
|
268
|
+
return f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}"
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class WorkflowJobRepairAllFailedLink(BaseOperatorLink, LoggingMixin):
|
|
272
|
+
"""Constructs a link to send a request to repair all failed tasks in the Databricks workflow."""
|
|
273
|
+
|
|
274
|
+
name = "Repair All Failed Tasks"
|
|
275
|
+
|
|
276
|
+
def get_link(
|
|
277
|
+
self,
|
|
278
|
+
operator,
|
|
279
|
+
dttm=None,
|
|
280
|
+
*,
|
|
281
|
+
ti_key: TaskInstanceKey | None = None,
|
|
282
|
+
) -> str:
|
|
283
|
+
if not ti_key:
|
|
284
|
+
ti = get_task_instance(operator, dttm)
|
|
285
|
+
ti_key = ti.key
|
|
286
|
+
task_group = operator.task_group
|
|
287
|
+
self.log.debug(
|
|
288
|
+
"Creating link to repair all tasks for databricks job run %s",
|
|
289
|
+
task_group.group_id,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
metadata = get_xcom_result(ti_key, "return_value")
|
|
293
|
+
|
|
294
|
+
tasks_str = self.get_tasks_to_run(ti_key, operator, self.log)
|
|
295
|
+
self.log.debug("tasks to rerun: %s", tasks_str)
|
|
296
|
+
|
|
297
|
+
query_params = {
|
|
298
|
+
"dag_id": ti_key.dag_id,
|
|
299
|
+
"databricks_conn_id": metadata.conn_id,
|
|
300
|
+
"databricks_run_id": metadata.run_id,
|
|
301
|
+
"run_id": ti_key.run_id,
|
|
302
|
+
"tasks_to_repair": tasks_str,
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
return url_for("RepairDatabricksTasks.repair", **query_params)
|
|
306
|
+
|
|
307
|
+
@classmethod
|
|
308
|
+
def get_task_group_children(cls, task_group: TaskGroup) -> dict[str, BaseOperator]:
|
|
309
|
+
"""
|
|
310
|
+
Given a TaskGroup, return children which are Tasks, inspecting recursively any TaskGroups within.
|
|
311
|
+
|
|
312
|
+
:param task_group: An Airflow TaskGroup
|
|
313
|
+
:return: Dictionary that contains Task IDs as keys and Tasks as values.
|
|
314
|
+
"""
|
|
315
|
+
children: dict[str, Any] = {}
|
|
316
|
+
for child_id, child in task_group.children.items():
|
|
317
|
+
if isinstance(child, TaskGroup):
|
|
318
|
+
child_children = cls.get_task_group_children(child)
|
|
319
|
+
children = {**children, **child_children}
|
|
320
|
+
else:
|
|
321
|
+
children[child_id] = child
|
|
322
|
+
return children
|
|
323
|
+
|
|
324
|
+
def get_tasks_to_run(self, ti_key: TaskInstanceKey, operator: BaseOperator, log: logging.Logger) -> str:
|
|
325
|
+
task_group = operator.task_group
|
|
326
|
+
if not task_group:
|
|
327
|
+
raise AirflowException("Task group is required for generating repair link.")
|
|
328
|
+
if not task_group.group_id:
|
|
329
|
+
raise AirflowException("Task group ID is required for generating repair link.")
|
|
330
|
+
dag = airflow_app.dag_bag.get_dag(ti_key.dag_id)
|
|
331
|
+
dr = _get_dagrun(dag, ti_key.run_id)
|
|
332
|
+
log.debug("Getting failed and skipped tasks for dag run %s", dr.run_id)
|
|
333
|
+
task_group_sub_tasks = self.get_task_group_children(task_group).items()
|
|
334
|
+
failed_and_skipped_tasks = self._get_failed_and_skipped_tasks(dr)
|
|
335
|
+
log.debug("Failed and skipped tasks: %s", failed_and_skipped_tasks)
|
|
336
|
+
|
|
337
|
+
tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in failed_and_skipped_tasks}
|
|
338
|
+
|
|
339
|
+
return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log))
|
|
340
|
+
|
|
341
|
+
@staticmethod
|
|
342
|
+
def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]:
|
|
343
|
+
"""
|
|
344
|
+
Return a list of task IDs for tasks that have failed or have been skipped in the given DagRun.
|
|
345
|
+
|
|
346
|
+
:param dr: The DagRun object for which to retrieve failed and skipped tasks.
|
|
347
|
+
|
|
348
|
+
:return: A list of task IDs for tasks that have failed or have been skipped.
|
|
349
|
+
"""
|
|
350
|
+
return [
|
|
351
|
+
t.task_id
|
|
352
|
+
for t in dr.get_task_instances(
|
|
353
|
+
state=[
|
|
354
|
+
TaskInstanceState.FAILED,
|
|
355
|
+
TaskInstanceState.SKIPPED,
|
|
356
|
+
TaskInstanceState.UP_FOR_RETRY,
|
|
357
|
+
TaskInstanceState.UPSTREAM_FAILED,
|
|
358
|
+
None,
|
|
359
|
+
],
|
|
360
|
+
)
|
|
361
|
+
]
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class WorkflowJobRepairSingleTaskLink(BaseOperatorLink, LoggingMixin):
|
|
365
|
+
"""Construct a link to send a repair request for a single databricks task."""
|
|
366
|
+
|
|
367
|
+
name = "Repair a single task"
|
|
368
|
+
|
|
369
|
+
def get_link(
|
|
370
|
+
self,
|
|
371
|
+
operator,
|
|
372
|
+
dttm=None,
|
|
373
|
+
*,
|
|
374
|
+
ti_key: TaskInstanceKey | None = None,
|
|
375
|
+
) -> str:
|
|
376
|
+
if not ti_key:
|
|
377
|
+
ti = get_task_instance(operator, dttm)
|
|
378
|
+
ti_key = ti.key
|
|
379
|
+
|
|
380
|
+
task_group = operator.task_group
|
|
381
|
+
if not task_group:
|
|
382
|
+
raise AirflowException("Task group is required for generating repair link.")
|
|
383
|
+
|
|
384
|
+
self.log.info(
|
|
385
|
+
"Creating link to repair a single task for databricks job run %s task %s",
|
|
386
|
+
task_group.group_id,
|
|
387
|
+
ti_key.task_id,
|
|
388
|
+
)
|
|
389
|
+
dag = airflow_app.dag_bag.get_dag(ti_key.dag_id)
|
|
390
|
+
task = dag.get_task(ti_key.task_id)
|
|
391
|
+
|
|
392
|
+
if ".launch" not in ti_key.task_id:
|
|
393
|
+
launch_task_id = get_launch_task_id(task_group)
|
|
394
|
+
ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id)
|
|
395
|
+
metadata = get_xcom_result(ti_key, "return_value")
|
|
396
|
+
|
|
397
|
+
query_params = {
|
|
398
|
+
"dag_id": ti_key.dag_id,
|
|
399
|
+
"databricks_conn_id": metadata.conn_id,
|
|
400
|
+
"databricks_run_id": metadata.run_id,
|
|
401
|
+
"run_id": ti_key.run_id,
|
|
402
|
+
"tasks_to_repair": _get_databricks_task_id(task),
|
|
403
|
+
}
|
|
404
|
+
return url_for("RepairDatabricksTasks.repair", **query_params)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class RepairDatabricksTasks(AirflowBaseView, LoggingMixin):
|
|
408
|
+
"""Repair databricks tasks from Airflow."""
|
|
409
|
+
|
|
410
|
+
default_view = "repair"
|
|
411
|
+
|
|
412
|
+
@expose("/repair_databricks_job/<string:dag_id>/<string:run_id>", methods=("GET",))
|
|
413
|
+
@get_auth_decorator()
|
|
414
|
+
def repair(self, dag_id: str, run_id: str):
|
|
415
|
+
return_url = self._get_return_url(dag_id, run_id)
|
|
416
|
+
|
|
417
|
+
tasks_to_repair = request.values.get("tasks_to_repair")
|
|
418
|
+
self.log.info("Tasks to repair: %s", tasks_to_repair)
|
|
419
|
+
if not tasks_to_repair:
|
|
420
|
+
flash("No tasks to repair. Not sending repair request.")
|
|
421
|
+
return redirect(return_url)
|
|
422
|
+
|
|
423
|
+
databricks_conn_id = request.values.get("databricks_conn_id")
|
|
424
|
+
databricks_run_id = request.values.get("databricks_run_id")
|
|
425
|
+
|
|
426
|
+
if not databricks_conn_id:
|
|
427
|
+
flash("No Databricks connection ID provided. Cannot repair tasks.")
|
|
428
|
+
return redirect(return_url)
|
|
429
|
+
|
|
430
|
+
if not databricks_run_id:
|
|
431
|
+
flash("No Databricks run ID provided. Cannot repair tasks.")
|
|
432
|
+
return redirect(return_url)
|
|
433
|
+
|
|
434
|
+
self.log.info("Repairing databricks job %s", databricks_run_id)
|
|
435
|
+
res = _repair_task(
|
|
436
|
+
databricks_conn_id=databricks_conn_id,
|
|
437
|
+
databricks_run_id=int(databricks_run_id),
|
|
438
|
+
tasks_to_repair=tasks_to_repair.split(","),
|
|
439
|
+
logger=self.log,
|
|
440
|
+
)
|
|
441
|
+
self.log.info("Repairing databricks job query for run %s sent", databricks_run_id)
|
|
442
|
+
|
|
443
|
+
self.log.info("Clearing tasks to rerun in airflow")
|
|
444
|
+
|
|
445
|
+
run_id = unquote(run_id)
|
|
446
|
+
_clear_task_instances(dag_id, run_id, tasks_to_repair.split(","), self.log)
|
|
447
|
+
flash(f"Databricks repair job is starting!: {res}")
|
|
448
|
+
return redirect(return_url)
|
|
449
|
+
|
|
450
|
+
@staticmethod
|
|
451
|
+
def _get_return_url(dag_id: str, run_id: str) -> str:
|
|
452
|
+
return url_for("Airflow.grid", dag_id=dag_id, dag_run_id=run_id)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
repair_databricks_view = RepairDatabricksTasks()
|
|
456
|
+
|
|
457
|
+
repair_databricks_package = {
|
|
458
|
+
"view": repair_databricks_view,
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
class DatabricksWorkflowPlugin(AirflowPlugin):
|
|
463
|
+
"""
|
|
464
|
+
Databricks Workflows plugin for Airflow.
|
|
465
|
+
|
|
466
|
+
.. seealso::
|
|
467
|
+
For more information on how to use this plugin, take a look at the guide:
|
|
468
|
+
:ref:`howto/plugin:DatabricksWorkflowPlugin`
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
name = "databricks_workflow"
|
|
472
|
+
operator_extra_links = [
|
|
473
|
+
WorkflowJobRepairAllFailedLink(),
|
|
474
|
+
WorkflowJobRepairSingleTaskLink(),
|
|
475
|
+
WorkflowJobRunLink(),
|
|
476
|
+
]
|
|
477
|
+
appbuilder_views = [repair_databricks_package]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
|
3
|
+
# distributed with this work for additional information
|
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
|
6
|
+
# "License"); you may not use this file except in compliance
|
|
7
|
+
# with the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
|
12
|
+
# software distributed under the License is distributed on an
|
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
14
|
+
# KIND, either express or implied. See the License for the
|
|
15
|
+
# specific language governing permissions and limitations
|
|
16
|
+
# under the License.
|
|
@@ -21,7 +21,7 @@ from airflow.exceptions import AirflowException
|
|
|
21
21
|
from airflow.providers.databricks.hooks.databricks import RunState
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def
|
|
24
|
+
def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
|
|
25
25
|
"""
|
|
26
26
|
Normalize content or all values of content if it is a dict to a string.
|
|
27
27
|
|
|
@@ -33,7 +33,7 @@ def _normalise_json_content(content, json_path: str = "json") -> str | bool | li
|
|
|
33
33
|
The only one exception is when we have boolean values, they can not be converted
|
|
34
34
|
to string type because databricks does not understand 'True' or 'False' values.
|
|
35
35
|
"""
|
|
36
|
-
normalise =
|
|
36
|
+
normalise = normalise_json_content
|
|
37
37
|
if isinstance(content, (str, bool)):
|
|
38
38
|
return content
|
|
39
39
|
elif isinstance(content, (int, float)):
|
|
@@ -28,7 +28,7 @@ build-backend = "flit_core.buildapi"
|
|
|
28
28
|
|
|
29
29
|
[project]
|
|
30
30
|
name = "apache-airflow-providers-databricks"
|
|
31
|
-
version = "6.
|
|
31
|
+
version = "6.8.0"
|
|
32
32
|
description = "Provider package apache-airflow-providers-databricks for Apache Airflow"
|
|
33
33
|
readme = "README.rst"
|
|
34
34
|
authors = [
|
|
@@ -68,8 +68,8 @@ dependencies = [
|
|
|
68
68
|
]
|
|
69
69
|
|
|
70
70
|
[project.urls]
|
|
71
|
-
"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
72
|
-
"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.
|
|
71
|
+
"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.8.0"
|
|
72
|
+
"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-databricks/6.8.0/changelog.html"
|
|
73
73
|
"Bug Tracker" = "https://github.com/apache/airflow/issues"
|
|
74
74
|
"Source Code" = "https://github.com/apache/airflow"
|
|
75
75
|
"Slack Chat" = "https://s.apache.org/airflow-slack"
|
|
@@ -78,6 +78,8 @@ dependencies = [
|
|
|
78
78
|
|
|
79
79
|
[project.entry-points."apache_airflow_provider"]
|
|
80
80
|
provider_info = "airflow.providers.databricks.get_provider_info:get_provider_info"
|
|
81
|
+
[project.entry-points."airflow.plugins"]
|
|
82
|
+
databricks_workflow = "airflow.providers.databricks.plugins.databricks_workflow:DatabricksWorkflowPlugin"
|
|
81
83
|
[project.optional-dependencies]
|
|
82
84
|
"common.sql" = [
|
|
83
85
|
"apache-airflow-providers-common-sql",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|