apache-airflow-providers-databricks 7.2.2rc1__py3-none-any.whl → 7.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of apache-airflow-providers-databricks might be problematic. Click here for more details.
- airflow/providers/databricks/__init__.py +1 -1
- airflow/providers/databricks/get_provider_info.py +3 -2
- airflow/providers/databricks/hooks/databricks.py +110 -0
- airflow/providers/databricks/hooks/databricks_sql.py +5 -4
- airflow/providers/databricks/operators/databricks.py +209 -2
- airflow/providers/databricks/triggers/databricks.py +100 -0
- airflow/providers/databricks/utils/databricks.py +1 -1
- {apache_airflow_providers_databricks-7.2.2rc1.dist-info → apache_airflow_providers_databricks-7.3.0rc1.dist-info}/METADATA +8 -8
- {apache_airflow_providers_databricks-7.2.2rc1.dist-info → apache_airflow_providers_databricks-7.3.0rc1.dist-info}/RECORD +11 -11
- {apache_airflow_providers_databricks-7.2.2rc1.dist-info → apache_airflow_providers_databricks-7.3.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_databricks-7.2.2rc1.dist-info → apache_airflow_providers_databricks-7.3.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
|
29
29
|
|
|
30
30
|
__all__ = ["__version__"]
|
|
31
31
|
|
|
32
|
-
__version__ = "7.
|
|
32
|
+
__version__ = "7.3.0"
|
|
33
33
|
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
|
35
35
|
"2.9.0"
|
|
@@ -27,8 +27,9 @@ def get_provider_info():
|
|
|
27
27
|
"name": "Databricks",
|
|
28
28
|
"description": "`Databricks <https://databricks.com/>`__\n",
|
|
29
29
|
"state": "ready",
|
|
30
|
-
"source-date-epoch":
|
|
30
|
+
"source-date-epoch": 1743835987,
|
|
31
31
|
"versions": [
|
|
32
|
+
"7.3.0",
|
|
32
33
|
"7.2.2",
|
|
33
34
|
"7.2.1",
|
|
34
35
|
"7.2.0",
|
|
@@ -181,7 +182,7 @@ def get_provider_info():
|
|
|
181
182
|
"dependencies": [
|
|
182
183
|
"apache-airflow>=2.9.0",
|
|
183
184
|
"apache-airflow-providers-common-sql>=1.20.0",
|
|
184
|
-
"requests>=2.
|
|
185
|
+
"requests>=2.31.0,<3",
|
|
185
186
|
"databricks-sql-connector>=3.0.0",
|
|
186
187
|
"aiohttp>=3.9.2, <4",
|
|
187
188
|
"mergedeep>=1.3.4",
|
|
@@ -63,6 +63,7 @@ LIST_PIPELINES_ENDPOINT = ("GET", "api/2.0/pipelines")
|
|
|
63
63
|
WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
|
|
64
64
|
|
|
65
65
|
SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
|
|
66
|
+
SQL_STATEMENTS_ENDPOINT = "api/2.0/sql/statements"
|
|
66
67
|
|
|
67
68
|
|
|
68
69
|
class RunLifeCycleState(Enum):
|
|
@@ -189,6 +190,67 @@ class ClusterState:
|
|
|
189
190
|
return ClusterState(**json.loads(data))
|
|
190
191
|
|
|
191
192
|
|
|
193
|
+
class SQLStatementState:
|
|
194
|
+
"""Utility class for the SQL statement state concept of Databricks statements."""
|
|
195
|
+
|
|
196
|
+
SQL_STATEMENT_LIFE_CYCLE_STATES = [
|
|
197
|
+
"PENDING",
|
|
198
|
+
"RUNNING",
|
|
199
|
+
"SUCCEEDED",
|
|
200
|
+
"FAILED",
|
|
201
|
+
"CANCELED",
|
|
202
|
+
"CLOSED",
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
def __init__(
|
|
206
|
+
self, state: str = "", error_code: str = "", error_message: str = "", *args, **kwargs
|
|
207
|
+
) -> None:
|
|
208
|
+
if state not in self.SQL_STATEMENT_LIFE_CYCLE_STATES:
|
|
209
|
+
raise AirflowException(
|
|
210
|
+
f"Unexpected SQL statement life cycle state: {state}: If the state has "
|
|
211
|
+
"been introduced recently, please check the Databricks user "
|
|
212
|
+
"guide for troubleshooting information"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
self.state = state
|
|
216
|
+
self.error_code = error_code
|
|
217
|
+
self.error_message = error_message
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def is_terminal(self) -> bool:
|
|
221
|
+
"""True if the current state is a terminal state."""
|
|
222
|
+
return self.state in ("SUCCEEDED", "FAILED", "CANCELED", "CLOSED")
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def is_running(self) -> bool:
|
|
226
|
+
"""True if the current state is running."""
|
|
227
|
+
return self.state in ("PENDING", "RUNNING")
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def is_successful(self) -> bool:
|
|
231
|
+
"""True if the state is SUCCEEDED."""
|
|
232
|
+
return self.state == "SUCCEEDED"
|
|
233
|
+
|
|
234
|
+
def __eq__(self, other: object) -> bool:
|
|
235
|
+
if not isinstance(other, SQLStatementState):
|
|
236
|
+
return NotImplemented
|
|
237
|
+
return (
|
|
238
|
+
self.state == other.state
|
|
239
|
+
and self.error_code == other.error_code
|
|
240
|
+
and self.error_message == other.error_message
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def __repr__(self) -> str:
|
|
244
|
+
return str(self.__dict__)
|
|
245
|
+
|
|
246
|
+
def to_json(self) -> str:
|
|
247
|
+
return json.dumps(self.__dict__)
|
|
248
|
+
|
|
249
|
+
@classmethod
|
|
250
|
+
def from_json(cls, data: str) -> SQLStatementState:
|
|
251
|
+
return SQLStatementState(**json.loads(data))
|
|
252
|
+
|
|
253
|
+
|
|
192
254
|
class DatabricksHook(BaseDatabricksHook):
|
|
193
255
|
"""
|
|
194
256
|
Interact with Databricks.
|
|
@@ -709,6 +771,54 @@ class DatabricksHook(BaseDatabricksHook):
|
|
|
709
771
|
"""
|
|
710
772
|
return self._do_api_call(("PATCH", f"api/2.0/permissions/jobs/{job_id}"), json)
|
|
711
773
|
|
|
774
|
+
def post_sql_statement(self, json: dict[str, Any]) -> str:
|
|
775
|
+
"""
|
|
776
|
+
Submit a SQL statement to the Databricks SQL Statements endpoint.
|
|
777
|
+
|
|
778
|
+
:param json: The data used in the body of the request to the SQL Statements endpoint.
|
|
779
|
+
:return: The statement_id as a string.
|
|
780
|
+
"""
|
|
781
|
+
response = self._do_api_call(("POST", f"{SQL_STATEMENTS_ENDPOINT}"), json)
|
|
782
|
+
return response["statement_id"]
|
|
783
|
+
|
|
784
|
+
def get_sql_statement_state(self, statement_id: str) -> SQLStatementState:
|
|
785
|
+
"""
|
|
786
|
+
Retrieve run state of the SQL statement.
|
|
787
|
+
|
|
788
|
+
:param statement_id: ID of the SQL statement.
|
|
789
|
+
:return: state of the SQL statement.
|
|
790
|
+
"""
|
|
791
|
+
get_statement_endpoint = ("GET", f"{SQL_STATEMENTS_ENDPOINT}/{statement_id}")
|
|
792
|
+
response = self._do_api_call(get_statement_endpoint)
|
|
793
|
+
state = response["status"]["state"]
|
|
794
|
+
error_code = response["status"].get("error", {}).get("error_code", "")
|
|
795
|
+
error_message = response["status"].get("error", {}).get("message", "")
|
|
796
|
+
return SQLStatementState(state, error_code, error_message)
|
|
797
|
+
|
|
798
|
+
async def a_get_sql_statement_state(self, statement_id: str) -> SQLStatementState:
|
|
799
|
+
"""
|
|
800
|
+
Async version of `get_sql_statement_state`.
|
|
801
|
+
|
|
802
|
+
:param statement_id: ID of the SQL statement
|
|
803
|
+
:return: state of the SQL statement
|
|
804
|
+
"""
|
|
805
|
+
get_sql_statement_endpoint = ("GET", f"{SQL_STATEMENTS_ENDPOINT}/{statement_id}")
|
|
806
|
+
response = await self._a_do_api_call(get_sql_statement_endpoint)
|
|
807
|
+
state = response["status"]["state"]
|
|
808
|
+
error_code = response["status"].get("error", {}).get("error_code", "")
|
|
809
|
+
error_message = response["status"].get("error", {}).get("message", "")
|
|
810
|
+
return SQLStatementState(state, error_code, error_message)
|
|
811
|
+
|
|
812
|
+
def cancel_sql_statement(self, statement_id: str) -> None:
|
|
813
|
+
"""
|
|
814
|
+
Cancel the SQL statement.
|
|
815
|
+
|
|
816
|
+
:param statement_id: ID of the SQL statement
|
|
817
|
+
"""
|
|
818
|
+
self.log.info("Canceling SQL statement with ID: %s", statement_id)
|
|
819
|
+
cancel_sql_statement_endpoint = ("POST", f"{SQL_STATEMENTS_ENDPOINT}/{statement_id}/cancel")
|
|
820
|
+
self._do_api_call(cancel_sql_statement_endpoint)
|
|
821
|
+
|
|
712
822
|
def test_connection(self) -> tuple[bool, str]:
|
|
713
823
|
"""Test the Databricks connectivity from UI."""
|
|
714
824
|
hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id)
|
|
@@ -35,7 +35,6 @@ from databricks import sql # type: ignore[attr-defined]
|
|
|
35
35
|
from databricks.sql.types import Row
|
|
36
36
|
|
|
37
37
|
from airflow.exceptions import AirflowException
|
|
38
|
-
from airflow.models.connection import Connection as AirflowConnection
|
|
39
38
|
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
|
|
40
39
|
from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout
|
|
41
40
|
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook
|
|
@@ -43,6 +42,8 @@ from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHoo
|
|
|
43
42
|
if TYPE_CHECKING:
|
|
44
43
|
from databricks.sql.client import Connection
|
|
45
44
|
|
|
45
|
+
from airflow.models.connection import Connection as AirflowConnection
|
|
46
|
+
|
|
46
47
|
|
|
47
48
|
LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")
|
|
48
49
|
|
|
@@ -167,7 +168,7 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
|
|
|
167
168
|
|
|
168
169
|
if self._sql_conn is None:
|
|
169
170
|
raise AirflowException("SQL connection is not initialized")
|
|
170
|
-
return cast(AirflowConnection, self._sql_conn)
|
|
171
|
+
return cast("AirflowConnection", self._sql_conn)
|
|
171
172
|
|
|
172
173
|
@overload # type: ignore[override]
|
|
173
174
|
def run(
|
|
@@ -295,11 +296,11 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
|
|
|
295
296
|
return []
|
|
296
297
|
rows_fields = tuple(rows[0].__fields__)
|
|
297
298
|
rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore
|
|
298
|
-
return cast(list[tuple[Any, ...]], [rows_object(*row) for row in rows])
|
|
299
|
+
return cast("list[tuple[Any, ...]]", [rows_object(*row) for row in rows])
|
|
299
300
|
elif isinstance(result, Row):
|
|
300
301
|
row_fields = tuple(result.__fields__)
|
|
301
302
|
row_object = namedtuple("Row", row_fields, rename=True) # type: ignore
|
|
302
|
-
return cast(tuple[Any, ...], row_object(*result))
|
|
303
|
+
return cast("tuple[Any, ...]", row_object(*result))
|
|
303
304
|
else:
|
|
304
305
|
raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}")
|
|
305
306
|
|
|
@@ -30,7 +30,12 @@ from typing import TYPE_CHECKING, Any
|
|
|
30
30
|
from airflow.configuration import conf
|
|
31
31
|
from airflow.exceptions import AirflowException
|
|
32
32
|
from airflow.models import BaseOperator
|
|
33
|
-
from airflow.providers.databricks.hooks.databricks import
|
|
33
|
+
from airflow.providers.databricks.hooks.databricks import (
|
|
34
|
+
DatabricksHook,
|
|
35
|
+
RunLifeCycleState,
|
|
36
|
+
RunState,
|
|
37
|
+
SQLStatementState,
|
|
38
|
+
)
|
|
34
39
|
from airflow.providers.databricks.operators.databricks_workflow import (
|
|
35
40
|
DatabricksWorkflowTaskGroup,
|
|
36
41
|
WorkflowRunMetadata,
|
|
@@ -39,7 +44,10 @@ from airflow.providers.databricks.plugins.databricks_workflow import (
|
|
|
39
44
|
WorkflowJobRepairSingleTaskLink,
|
|
40
45
|
WorkflowJobRunLink,
|
|
41
46
|
)
|
|
42
|
-
from airflow.providers.databricks.triggers.databricks import
|
|
47
|
+
from airflow.providers.databricks.triggers.databricks import (
|
|
48
|
+
DatabricksExecutionTrigger,
|
|
49
|
+
DatabricksSQLStatementExecutionTrigger,
|
|
50
|
+
)
|
|
43
51
|
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
|
|
44
52
|
from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
|
|
45
53
|
|
|
@@ -59,6 +67,7 @@ DEFER_METHOD_NAME = "execute_complete"
|
|
|
59
67
|
XCOM_RUN_ID_KEY = "run_id"
|
|
60
68
|
XCOM_JOB_ID_KEY = "job_id"
|
|
61
69
|
XCOM_RUN_PAGE_URL_KEY = "run_page_url"
|
|
70
|
+
XCOM_STATEMENT_ID_KEY = "statement_id"
|
|
62
71
|
|
|
63
72
|
|
|
64
73
|
def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
|
|
@@ -969,6 +978,204 @@ class DatabricksRunNowOperator(BaseOperator):
|
|
|
969
978
|
self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id)
|
|
970
979
|
|
|
971
980
|
|
|
981
|
+
class DatabricksSQLStatementsOperator(BaseOperator):
|
|
982
|
+
"""
|
|
983
|
+
Submits a Databricks SQL Statement to Databricks using the api/2.0/sql/statements/ API endpoint.
|
|
984
|
+
|
|
985
|
+
See: https://docs.databricks.com/api/workspace/statementexecution
|
|
986
|
+
|
|
987
|
+
.. seealso::
|
|
988
|
+
For more information on how to use this operator, take a look at the guide:
|
|
989
|
+
:ref:`howto/operator:DatabricksSQLStatementsOperator`
|
|
990
|
+
|
|
991
|
+
:param statement: The SQL statement to execute. The statement can optionally be parameterized, see parameters.
|
|
992
|
+
:param warehouse_id: Warehouse upon which to execute a statement.
|
|
993
|
+
:param catalog: Sets default catalog for statement execution, similar to USE CATALOG in SQL.
|
|
994
|
+
:param schema: Sets default schema for statement execution, similar to USE SCHEMA in SQL.
|
|
995
|
+
:param parameters: A list of parameters to pass into a SQL statement containing parameter markers.
|
|
996
|
+
|
|
997
|
+
.. seealso::
|
|
998
|
+
https://docs.databricks.com/api/workspace/statementexecution/executestatement#parameters
|
|
999
|
+
:param wait_for_termination: if we should wait for termination of the statement execution. ``True`` by default.
|
|
1000
|
+
:param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
|
|
1001
|
+
By default and in the common case this will be ``databricks_default``. To use
|
|
1002
|
+
token based authentication, provide the key ``token`` in the extra field for the
|
|
1003
|
+
connection and create the key ``host`` and leave the ``host`` field empty. (templated)
|
|
1004
|
+
:param polling_period_seconds: Controls the rate which we poll for the result of
|
|
1005
|
+
this statement. By default the operator will poll every 30 seconds.
|
|
1006
|
+
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
|
|
1007
|
+
unreachable. Its value must be greater than or equal to 1.
|
|
1008
|
+
:param databricks_retry_delay: Number of seconds to wait between retries (it
|
|
1009
|
+
might be a floating point number).
|
|
1010
|
+
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
|
|
1011
|
+
:param do_xcom_push: Whether we should push statement_id to xcom.:
|
|
1012
|
+
:param timeout: The timeout for the Airflow task executing the SQL statement. By default a value of 3600 seconds is used.
|
|
1013
|
+
:param deferrable: Run operator in the deferrable mode.
|
|
1014
|
+
"""
|
|
1015
|
+
|
|
1016
|
+
# Used in airflow.models.BaseOperator
|
|
1017
|
+
template_fields: Sequence[str] = ("databricks_conn_id",)
|
|
1018
|
+
template_ext: Sequence[str] = (".json-tpl",)
|
|
1019
|
+
# Databricks brand color (blue) under white text
|
|
1020
|
+
ui_color = "#1CB1C2"
|
|
1021
|
+
ui_fgcolor = "#fff"
|
|
1022
|
+
|
|
1023
|
+
def __init__(
|
|
1024
|
+
self,
|
|
1025
|
+
statement: str,
|
|
1026
|
+
warehouse_id: str,
|
|
1027
|
+
*,
|
|
1028
|
+
catalog: str | None = None,
|
|
1029
|
+
schema: str | None = None,
|
|
1030
|
+
parameters: list[dict[str, Any]] | None = None,
|
|
1031
|
+
databricks_conn_id: str = "databricks_default",
|
|
1032
|
+
polling_period_seconds: int = 30,
|
|
1033
|
+
databricks_retry_limit: int = 3,
|
|
1034
|
+
databricks_retry_delay: int = 1,
|
|
1035
|
+
databricks_retry_args: dict[Any, Any] | None = None,
|
|
1036
|
+
do_xcom_push: bool = True,
|
|
1037
|
+
wait_for_termination: bool = True,
|
|
1038
|
+
timeout: float = 3600,
|
|
1039
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
|
1040
|
+
**kwargs,
|
|
1041
|
+
) -> None:
|
|
1042
|
+
"""Create a new ``DatabricksSubmitRunOperator``."""
|
|
1043
|
+
super().__init__(**kwargs)
|
|
1044
|
+
self.statement = statement
|
|
1045
|
+
self.warehouse_id = warehouse_id
|
|
1046
|
+
self.catalog = catalog
|
|
1047
|
+
self.schema = schema
|
|
1048
|
+
self.parameters = parameters
|
|
1049
|
+
self.databricks_conn_id = databricks_conn_id
|
|
1050
|
+
self.polling_period_seconds = polling_period_seconds
|
|
1051
|
+
self.databricks_retry_limit = databricks_retry_limit
|
|
1052
|
+
self.databricks_retry_delay = databricks_retry_delay
|
|
1053
|
+
self.databricks_retry_args = databricks_retry_args
|
|
1054
|
+
self.wait_for_termination = wait_for_termination
|
|
1055
|
+
self.deferrable = deferrable
|
|
1056
|
+
|
|
1057
|
+
# This variable will be used in case our task gets killed.
|
|
1058
|
+
self.statement_id: str | None = None
|
|
1059
|
+
|
|
1060
|
+
self.timeout = timeout
|
|
1061
|
+
self.do_xcom_push = do_xcom_push
|
|
1062
|
+
|
|
1063
|
+
@cached_property
|
|
1064
|
+
def _hook(self):
|
|
1065
|
+
return self._get_hook(caller="DatabricksSQLStatementsOperator")
|
|
1066
|
+
|
|
1067
|
+
def _get_hook(self, caller: str) -> DatabricksHook:
|
|
1068
|
+
return DatabricksHook(
|
|
1069
|
+
self.databricks_conn_id,
|
|
1070
|
+
retry_limit=self.databricks_retry_limit,
|
|
1071
|
+
retry_delay=self.databricks_retry_delay,
|
|
1072
|
+
retry_args=self.databricks_retry_args,
|
|
1073
|
+
caller=caller,
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
def _handle_operator_execution(self) -> None:
|
|
1077
|
+
end_time = time.time() + self.timeout
|
|
1078
|
+
while end_time > time.time():
|
|
1079
|
+
statement_state = self._hook.get_sql_statement_state(self.statement_id)
|
|
1080
|
+
if statement_state.is_terminal:
|
|
1081
|
+
if statement_state.is_successful:
|
|
1082
|
+
self.log.info("%s completed successfully.", self.task_id)
|
|
1083
|
+
return
|
|
1084
|
+
error_message = (
|
|
1085
|
+
f"{self.task_id} failed with terminal state: {statement_state.state} "
|
|
1086
|
+
f"and with the error code {statement_state.error_code} "
|
|
1087
|
+
f"and error message {statement_state.error_message}"
|
|
1088
|
+
)
|
|
1089
|
+
raise AirflowException(error_message)
|
|
1090
|
+
|
|
1091
|
+
self.log.info("%s in run state: %s", self.task_id, statement_state.state)
|
|
1092
|
+
self.log.info("Sleeping for %s seconds.", self.polling_period_seconds)
|
|
1093
|
+
time.sleep(self.polling_period_seconds)
|
|
1094
|
+
|
|
1095
|
+
self._hook.cancel_sql_statement(self.statement_id)
|
|
1096
|
+
raise AirflowException(
|
|
1097
|
+
f"{self.task_id} timed out after {self.timeout} seconds with state: {statement_state.state}",
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
def _handle_deferrable_operator_execution(self) -> None:
|
|
1101
|
+
statement_state = self._hook.get_sql_statement_state(self.statement_id)
|
|
1102
|
+
end_time = time.time() + self.timeout
|
|
1103
|
+
if not statement_state.is_terminal:
|
|
1104
|
+
if not self.statement_id:
|
|
1105
|
+
raise AirflowException("Failed to retrieve statement_id after submitting SQL statement.")
|
|
1106
|
+
self.defer(
|
|
1107
|
+
trigger=DatabricksSQLStatementExecutionTrigger(
|
|
1108
|
+
statement_id=self.statement_id,
|
|
1109
|
+
databricks_conn_id=self.databricks_conn_id,
|
|
1110
|
+
end_time=end_time,
|
|
1111
|
+
polling_period_seconds=self.polling_period_seconds,
|
|
1112
|
+
retry_limit=self.databricks_retry_limit,
|
|
1113
|
+
retry_delay=self.databricks_retry_delay,
|
|
1114
|
+
retry_args=self.databricks_retry_args,
|
|
1115
|
+
),
|
|
1116
|
+
method_name=DEFER_METHOD_NAME,
|
|
1117
|
+
)
|
|
1118
|
+
else:
|
|
1119
|
+
if statement_state.is_successful:
|
|
1120
|
+
self.log.info("%s completed successfully.", self.task_id)
|
|
1121
|
+
else:
|
|
1122
|
+
error_message = (
|
|
1123
|
+
f"{self.task_id} failed with terminal state: {statement_state.state} "
|
|
1124
|
+
f"and with the error code {statement_state.error_code} "
|
|
1125
|
+
f"and error message {statement_state.error_message}"
|
|
1126
|
+
)
|
|
1127
|
+
raise AirflowException(error_message)
|
|
1128
|
+
|
|
1129
|
+
def execute(self, context: Context):
|
|
1130
|
+
json = {
|
|
1131
|
+
"statement": self.statement,
|
|
1132
|
+
"warehouse_id": self.warehouse_id,
|
|
1133
|
+
"catalog": self.catalog,
|
|
1134
|
+
"schema": self.schema,
|
|
1135
|
+
"parameters": self.parameters,
|
|
1136
|
+
# We set the wait timeout to 0s as that seems the appropriate way for our deferrable version
|
|
1137
|
+
# support of the operator. For synchronous version, we still poll on the statement
|
|
1138
|
+
# execution state.
|
|
1139
|
+
"wait_timeout": "0s",
|
|
1140
|
+
}
|
|
1141
|
+
self.statement_id = self._hook.post_sql_statement(json)
|
|
1142
|
+
if self.do_xcom_push and context is not None:
|
|
1143
|
+
context["ti"].xcom_push(key=XCOM_STATEMENT_ID_KEY, value=self.statement_id)
|
|
1144
|
+
|
|
1145
|
+
self.log.info("SQL Statement submitted with statement_id: %s", self.statement_id)
|
|
1146
|
+
if not self.wait_for_termination:
|
|
1147
|
+
return
|
|
1148
|
+
if self.deferrable:
|
|
1149
|
+
self._handle_deferrable_operator_execution()
|
|
1150
|
+
else:
|
|
1151
|
+
self._handle_operator_execution()
|
|
1152
|
+
|
|
1153
|
+
def on_kill(self):
|
|
1154
|
+
if self.statement_id:
|
|
1155
|
+
self._hook.cancel_sql_statement(self.statement_id)
|
|
1156
|
+
self.log.info(
|
|
1157
|
+
"Task: %s with statement ID: %s was requested to be cancelled.",
|
|
1158
|
+
self.task_id,
|
|
1159
|
+
self.statement_id,
|
|
1160
|
+
)
|
|
1161
|
+
else:
|
|
1162
|
+
self.log.error(
|
|
1163
|
+
"Error: Task: %s with invalid statement_id was requested to be cancelled.", self.task_id
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
def execute_complete(self, context: dict | None, event: dict):
|
|
1167
|
+
statement_state = SQLStatementState.from_json(event["state"])
|
|
1168
|
+
error = event["error"]
|
|
1169
|
+
statement_id = event["statement_id"]
|
|
1170
|
+
|
|
1171
|
+
if statement_state.is_successful:
|
|
1172
|
+
self.log.info("SQL Statement with ID %s completed successfully.", statement_id)
|
|
1173
|
+
return
|
|
1174
|
+
|
|
1175
|
+
error_message = f"SQL Statement execution failed with terminal state: {statement_state} and with the error {error}"
|
|
1176
|
+
raise AirflowException(error_message)
|
|
1177
|
+
|
|
1178
|
+
|
|
972
1179
|
class DatabricksTaskBaseOperator(BaseOperator, ABC):
|
|
973
1180
|
"""
|
|
974
1181
|
Base class for operators that are run as Databricks job tasks or tasks within a Databricks workflow.
|
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
20
|
import asyncio
|
|
21
|
+
import time
|
|
21
22
|
from typing import Any
|
|
22
23
|
|
|
23
24
|
from airflow.providers.databricks.hooks.databricks import DatabricksHook
|
|
@@ -119,3 +120,102 @@ class DatabricksExecutionTrigger(BaseTrigger):
|
|
|
119
120
|
}
|
|
120
121
|
)
|
|
121
122
|
return
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class DatabricksSQLStatementExecutionTrigger(BaseTrigger):
|
|
126
|
+
"""
|
|
127
|
+
The trigger handles the logic of async communication with DataBricks SQL Statements API.
|
|
128
|
+
|
|
129
|
+
:param statement_id: ID of the SQL statement.
|
|
130
|
+
:param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
|
|
131
|
+
:param end_time: The end time (set based on timeout supplied for the operator) for the SQL statement execution.
|
|
132
|
+
:param polling_period_seconds: Controls the rate of the poll for the result of this run.
|
|
133
|
+
By default, the trigger will poll every 30 seconds.
|
|
134
|
+
:param retry_limit: The number of times to retry the connection in case of service outages.
|
|
135
|
+
:param retry_delay: The number of seconds to wait between retries.
|
|
136
|
+
:param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
statement_id: str,
|
|
142
|
+
databricks_conn_id: str,
|
|
143
|
+
end_time: float,
|
|
144
|
+
polling_period_seconds: int = 30,
|
|
145
|
+
retry_limit: int = 3,
|
|
146
|
+
retry_delay: int = 10,
|
|
147
|
+
retry_args: dict[Any, Any] | None = None,
|
|
148
|
+
caller: str = "DatabricksSQLStatementExecutionTrigger",
|
|
149
|
+
) -> None:
|
|
150
|
+
super().__init__()
|
|
151
|
+
self.statement_id = statement_id
|
|
152
|
+
self.databricks_conn_id = databricks_conn_id
|
|
153
|
+
self.end_time = end_time
|
|
154
|
+
self.polling_period_seconds = polling_period_seconds
|
|
155
|
+
self.retry_limit = retry_limit
|
|
156
|
+
self.retry_delay = retry_delay
|
|
157
|
+
self.retry_args = retry_args
|
|
158
|
+
self.hook = DatabricksHook(
|
|
159
|
+
databricks_conn_id,
|
|
160
|
+
retry_limit=self.retry_limit,
|
|
161
|
+
retry_delay=self.retry_delay,
|
|
162
|
+
retry_args=retry_args,
|
|
163
|
+
caller=caller,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def serialize(self) -> tuple[str, dict[str, Any]]:
|
|
167
|
+
return (
|
|
168
|
+
"airflow.providers.databricks.triggers.databricks.DatabricksSQLStatementExecutionTrigger",
|
|
169
|
+
{
|
|
170
|
+
"statement_id": self.statement_id,
|
|
171
|
+
"databricks_conn_id": self.databricks_conn_id,
|
|
172
|
+
"polling_period_seconds": self.polling_period_seconds,
|
|
173
|
+
"end_time": self.end_time,
|
|
174
|
+
"retry_limit": self.retry_limit,
|
|
175
|
+
"retry_delay": self.retry_delay,
|
|
176
|
+
"retry_args": self.retry_args,
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
async def run(self):
|
|
181
|
+
async with self.hook:
|
|
182
|
+
while self.end_time > time.time():
|
|
183
|
+
statement_state = await self.hook.a_get_sql_statement_state(self.statement_id)
|
|
184
|
+
if not statement_state.is_terminal:
|
|
185
|
+
self.log.info(
|
|
186
|
+
"Statement ID %s is in state %s. sleeping for %s seconds",
|
|
187
|
+
self.statement_id,
|
|
188
|
+
statement_state,
|
|
189
|
+
self.polling_period_seconds,
|
|
190
|
+
)
|
|
191
|
+
await asyncio.sleep(self.polling_period_seconds)
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
error = {}
|
|
195
|
+
if statement_state.error_code:
|
|
196
|
+
error = {
|
|
197
|
+
"error_code": statement_state.error_code,
|
|
198
|
+
"error_message": statement_state.error_message,
|
|
199
|
+
}
|
|
200
|
+
yield TriggerEvent(
|
|
201
|
+
{
|
|
202
|
+
"statement_id": self.statement_id,
|
|
203
|
+
"state": statement_state.to_json(),
|
|
204
|
+
"error": error,
|
|
205
|
+
}
|
|
206
|
+
)
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
# If we reach here, it means the statement should be timed out as per the end_time.
|
|
210
|
+
self.hook.cancel_sql_statement(self.statement_id)
|
|
211
|
+
yield TriggerEvent(
|
|
212
|
+
{
|
|
213
|
+
"statement_id": self.statement_id,
|
|
214
|
+
"state": statement_state.to_json(),
|
|
215
|
+
"error": {
|
|
216
|
+
"error_code": "TIMEOUT",
|
|
217
|
+
"error_message": f"Statement ID {self.statement_id} timed out after set end time {self.end_time}",
|
|
218
|
+
},
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
return
|
|
@@ -63,4 +63,4 @@ def validate_trigger_event(event: dict):
|
|
|
63
63
|
try:
|
|
64
64
|
RunState.from_json(event["run_state"])
|
|
65
65
|
except Exception:
|
|
66
|
-
raise AirflowException(f
|
|
66
|
+
raise AirflowException(f"Run state returned by the Trigger is incorrect: {event['run_state']}")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: apache-airflow-providers-databricks
|
|
3
|
-
Version: 7.
|
|
3
|
+
Version: 7.3.0rc1
|
|
4
4
|
Summary: Provider package apache-airflow-providers-databricks for Apache Airflow
|
|
5
5
|
Keywords: airflow-provider,databricks,airflow,integration
|
|
6
6
|
Author-email: Apache Software Foundation <dev@airflow.apache.org>
|
|
@@ -22,7 +22,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
22
22
|
Classifier: Topic :: System :: Monitoring
|
|
23
23
|
Requires-Dist: apache-airflow>=2.9.0rc0
|
|
24
24
|
Requires-Dist: apache-airflow-providers-common-sql>=1.20.0rc0
|
|
25
|
-
Requires-Dist: requests>=2.
|
|
25
|
+
Requires-Dist: requests>=2.31.0,<3
|
|
26
26
|
Requires-Dist: databricks-sql-connector>=3.0.0
|
|
27
27
|
Requires-Dist: aiohttp>=3.9.2, <4
|
|
28
28
|
Requires-Dist: mergedeep>=1.3.4
|
|
@@ -33,8 +33,8 @@ Requires-Dist: apache-airflow-providers-fab ; extra == "fab"
|
|
|
33
33
|
Requires-Dist: databricks-sdk==0.10.0 ; extra == "sdk"
|
|
34
34
|
Requires-Dist: apache-airflow-providers-standard ; extra == "standard"
|
|
35
35
|
Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
|
|
36
|
-
Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.
|
|
37
|
-
Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.
|
|
36
|
+
Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.3.0/changelog.html
|
|
37
|
+
Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.3.0
|
|
38
38
|
Project-URL: Mastodon, https://fosstodon.org/@airflow
|
|
39
39
|
Project-URL: Slack Chat, https://s.apache.org/airflow-slack
|
|
40
40
|
Project-URL: Source Code, https://github.com/apache/airflow
|
|
@@ -69,7 +69,7 @@ Provides-Extra: standard
|
|
|
69
69
|
|
|
70
70
|
Package ``apache-airflow-providers-databricks``
|
|
71
71
|
|
|
72
|
-
Release: ``7.
|
|
72
|
+
Release: ``7.3.0``
|
|
73
73
|
|
|
74
74
|
|
|
75
75
|
`Databricks <https://databricks.com/>`__
|
|
@@ -82,7 +82,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
|
|
|
82
82
|
are in ``airflow.providers.databricks`` python package.
|
|
83
83
|
|
|
84
84
|
You can find package information and changelog for the provider
|
|
85
|
-
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.
|
|
85
|
+
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.3.0/>`_.
|
|
86
86
|
|
|
87
87
|
Installation
|
|
88
88
|
------------
|
|
@@ -101,7 +101,7 @@ PIP package Version required
|
|
|
101
101
|
======================================= ==================
|
|
102
102
|
``apache-airflow`` ``>=2.9.0``
|
|
103
103
|
``apache-airflow-providers-common-sql`` ``>=1.20.0``
|
|
104
|
-
``requests`` ``>=2.
|
|
104
|
+
``requests`` ``>=2.31.0,<3``
|
|
105
105
|
``databricks-sql-connector`` ``>=3.0.0``
|
|
106
106
|
``aiohttp`` ``>=3.9.2,<4``
|
|
107
107
|
``mergedeep`` ``>=1.3.4``
|
|
@@ -130,5 +130,5 @@ Dependent package
|
|
|
130
130
|
============================================================================================================ ==============
|
|
131
131
|
|
|
132
132
|
The changelog for the provider package can be found in the
|
|
133
|
-
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.
|
|
133
|
+
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.3.0/changelog.html>`_.
|
|
134
134
|
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
airflow/providers/databricks/LICENSE,sha256=gXPVwptPlW1TJ4HSuG5OMPg-a3h43OGMkZRR1rpwfJA,10850
|
|
2
|
-
airflow/providers/databricks/__init__.py,sha256
|
|
2
|
+
airflow/providers/databricks/__init__.py,sha256=WqZExeHAUK-Gw7a9kjzNnUEcsc7-ku5xymJ7QUIZe1g,1497
|
|
3
3
|
airflow/providers/databricks/exceptions.py,sha256=85RklmLOI_PnTzfXNIUd5fAu2aMMUhelwumQAX0wANE,1261
|
|
4
|
-
airflow/providers/databricks/get_provider_info.py,sha256=
|
|
4
|
+
airflow/providers/databricks/get_provider_info.py,sha256=GVM7LE1BQ4VW1mRxsXMzlqb_qi7ydOmC4yVtGviCnvI,7392
|
|
5
5
|
airflow/providers/databricks/version_compat.py,sha256=aHg90_DtgoSnQvILFICexMyNlHlALBdaeWqkX3dFDug,1605
|
|
6
6
|
airflow/providers/databricks/hooks/__init__.py,sha256=mlJxuZLkd5x-iq2SBwD3mvRQpt3YR7wjz_nceyF1IaI,787
|
|
7
|
-
airflow/providers/databricks/hooks/databricks.py,sha256
|
|
7
|
+
airflow/providers/databricks/hooks/databricks.py,sha256=MSVcURTF7g_B_a3Oi9katAim5ILqlm7MHANGvfRC5yY,28921
|
|
8
8
|
airflow/providers/databricks/hooks/databricks_base.py,sha256=8KVRF-ty20UQpJP3kgE6RDLAYqXk7ZjI07ZpwFIcGB8,34917
|
|
9
|
-
airflow/providers/databricks/hooks/databricks_sql.py,sha256=
|
|
9
|
+
airflow/providers/databricks/hooks/databricks_sql.py,sha256=x-8Ua6LdKXjP_ZJDvA4dr6Lda1Iv0abgVwkvMrtGuzw,13130
|
|
10
10
|
airflow/providers/databricks/operators/__init__.py,sha256=mlJxuZLkd5x-iq2SBwD3mvRQpt3YR7wjz_nceyF1IaI,787
|
|
11
|
-
airflow/providers/databricks/operators/databricks.py,sha256=
|
|
11
|
+
airflow/providers/databricks/operators/databricks.py,sha256=bBVHXmS5SIYYQN_KI3hLBt6GnY7O_bdDAOjYysmFwTk,79403
|
|
12
12
|
airflow/providers/databricks/operators/databricks_repos.py,sha256=m_72OnnU9df7UB-8SK2Tp5VjfNyjYeAnil3dCKs9SbA,13282
|
|
13
13
|
airflow/providers/databricks/operators/databricks_sql.py,sha256=thBHpt9_LMLJZ0PN-eLCI3AaT8IFq3NAHLDWDFP-Jiw,17031
|
|
14
14
|
airflow/providers/databricks/operators/databricks_workflow.py,sha256=0vFu4w6O4tlStZ_Jhk1iswKFcTk-g_dthGFeDpXGZlw,14146
|
|
@@ -18,10 +18,10 @@ airflow/providers/databricks/sensors/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOF
|
|
|
18
18
|
airflow/providers/databricks/sensors/databricks_partition.py,sha256=hS6Q2fb84_vASb7Ai50-pmjVukX6G8xIwdaZVIE17oM,10045
|
|
19
19
|
airflow/providers/databricks/sensors/databricks_sql.py,sha256=jIA9oGBUCAlXzyrqigxlg7JQDsBFuNIF8ZUEJM8gPxg,5766
|
|
20
20
|
airflow/providers/databricks/triggers/__init__.py,sha256=mlJxuZLkd5x-iq2SBwD3mvRQpt3YR7wjz_nceyF1IaI,787
|
|
21
|
-
airflow/providers/databricks/triggers/databricks.py,sha256=
|
|
21
|
+
airflow/providers/databricks/triggers/databricks.py,sha256=dSogx6GlcJfZ4CFhtlMeWs9sYFEYthP82S_U8-tM2Tk,9240
|
|
22
22
|
airflow/providers/databricks/utils/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
|
|
23
|
-
airflow/providers/databricks/utils/databricks.py,sha256=
|
|
24
|
-
apache_airflow_providers_databricks-7.
|
|
25
|
-
apache_airflow_providers_databricks-7.
|
|
26
|
-
apache_airflow_providers_databricks-7.
|
|
27
|
-
apache_airflow_providers_databricks-7.
|
|
23
|
+
airflow/providers/databricks/utils/databricks.py,sha256=9LLgqYAS68s_PTnIez1HfN8xCKPK9D_Dt5SDF4wlbzQ,2890
|
|
24
|
+
apache_airflow_providers_databricks-7.3.0rc1.dist-info/entry_points.txt,sha256=hjmZm3ab2cteTR4t9eE28oKixHwNIKtLCThd6sx3XRQ,227
|
|
25
|
+
apache_airflow_providers_databricks-7.3.0rc1.dist-info/WHEEL,sha256=_2ozNFCLWc93bK4WKHCO-eDUENDlo-dgc9cU3qokYO4,82
|
|
26
|
+
apache_airflow_providers_databricks-7.3.0rc1.dist-info/METADATA,sha256=38GRGYkOqVPE-eRVvDXhKEHJRLAEKwh0n_c_doJ_oAM,6088
|
|
27
|
+
apache_airflow_providers_databricks-7.3.0rc1.dist-info/RECORD,,
|
|
File without changes
|