apache-airflow-providers-databricks 7.4.0__py3-none-any.whl → 7.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-databricks might be problematic. Click here for more details.

@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "7.4.0"
32
+ __version__ = "7.5.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.10.0"
@@ -107,6 +107,7 @@ def get_provider_info():
107
107
  {
108
108
  "integration-name": "Databricks",
109
109
  "python-modules": [
110
+ "airflow.providers.databricks.sensors.databricks",
110
111
  "airflow.providers.databricks.sensors.databricks_sql",
111
112
  "airflow.providers.databricks.sensors.databricks_partition",
112
113
  ],
@@ -34,7 +34,6 @@ from airflow.providers.databricks.hooks.databricks import (
34
34
  DatabricksHook,
35
35
  RunLifeCycleState,
36
36
  RunState,
37
- SQLStatementState,
38
37
  )
39
38
  from airflow.providers.databricks.operators.databricks_workflow import (
40
39
  DatabricksWorkflowTaskGroup,
@@ -46,13 +45,14 @@ from airflow.providers.databricks.plugins.databricks_workflow import (
46
45
  )
47
46
  from airflow.providers.databricks.triggers.databricks import (
48
47
  DatabricksExecutionTrigger,
49
- DatabricksSQLStatementExecutionTrigger,
50
48
  )
51
49
  from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
50
+ from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin
52
51
  from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
53
52
 
54
53
  if TYPE_CHECKING:
55
54
  from airflow.models.taskinstancekey import TaskInstanceKey
55
+ from airflow.providers.openlineage.extractors import OperatorLineage
56
56
  from airflow.utils.context import Context
57
57
  from airflow.utils.task_group import TaskGroup
58
58
 
@@ -978,7 +978,7 @@ class DatabricksRunNowOperator(BaseOperator):
978
978
  self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id)
979
979
 
980
980
 
981
- class DatabricksSQLStatementsOperator(BaseOperator):
981
+ class DatabricksSQLStatementsOperator(DatabricksSQLStatementsMixin, BaseOperator):
982
982
  """
983
983
  Submits a Databricks SQL Statement to Databricks using the api/2.0/sql/statements/ API endpoint.
984
984
 
@@ -1073,59 +1073,6 @@ class DatabricksSQLStatementsOperator(BaseOperator):
1073
1073
  caller=caller,
1074
1074
  )
1075
1075
 
1076
- def _handle_operator_execution(self) -> None:
1077
- end_time = time.time() + self.timeout
1078
- while end_time > time.time():
1079
- statement_state = self._hook.get_sql_statement_state(self.statement_id)
1080
- if statement_state.is_terminal:
1081
- if statement_state.is_successful:
1082
- self.log.info("%s completed successfully.", self.task_id)
1083
- return
1084
- error_message = (
1085
- f"{self.task_id} failed with terminal state: {statement_state.state} "
1086
- f"and with the error code {statement_state.error_code} "
1087
- f"and error message {statement_state.error_message}"
1088
- )
1089
- raise AirflowException(error_message)
1090
-
1091
- self.log.info("%s in run state: %s", self.task_id, statement_state.state)
1092
- self.log.info("Sleeping for %s seconds.", self.polling_period_seconds)
1093
- time.sleep(self.polling_period_seconds)
1094
-
1095
- self._hook.cancel_sql_statement(self.statement_id)
1096
- raise AirflowException(
1097
- f"{self.task_id} timed out after {self.timeout} seconds with state: {statement_state.state}",
1098
- )
1099
-
1100
- def _handle_deferrable_operator_execution(self) -> None:
1101
- statement_state = self._hook.get_sql_statement_state(self.statement_id)
1102
- end_time = time.time() + self.timeout
1103
- if not statement_state.is_terminal:
1104
- if not self.statement_id:
1105
- raise AirflowException("Failed to retrieve statement_id after submitting SQL statement.")
1106
- self.defer(
1107
- trigger=DatabricksSQLStatementExecutionTrigger(
1108
- statement_id=self.statement_id,
1109
- databricks_conn_id=self.databricks_conn_id,
1110
- end_time=end_time,
1111
- polling_period_seconds=self.polling_period_seconds,
1112
- retry_limit=self.databricks_retry_limit,
1113
- retry_delay=self.databricks_retry_delay,
1114
- retry_args=self.databricks_retry_args,
1115
- ),
1116
- method_name=DEFER_METHOD_NAME,
1117
- )
1118
- else:
1119
- if statement_state.is_successful:
1120
- self.log.info("%s completed successfully.", self.task_id)
1121
- else:
1122
- error_message = (
1123
- f"{self.task_id} failed with terminal state: {statement_state.state} "
1124
- f"and with the error code {statement_state.error_code} "
1125
- f"and error message {statement_state.error_message}"
1126
- )
1127
- raise AirflowException(error_message)
1128
-
1129
1076
  def execute(self, context: Context):
1130
1077
  json = {
1131
1078
  "statement": self.statement,
@@ -1146,34 +1093,65 @@ class DatabricksSQLStatementsOperator(BaseOperator):
1146
1093
  if not self.wait_for_termination:
1147
1094
  return
1148
1095
  if self.deferrable:
1149
- self._handle_deferrable_operator_execution()
1096
+ self._handle_deferrable_execution(defer_method_name=DEFER_METHOD_NAME) # type: ignore[misc]
1150
1097
  else:
1151
- self._handle_operator_execution()
1098
+ self._handle_execution() # type: ignore[misc]
1152
1099
 
1153
- def on_kill(self):
1100
+ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
1101
+ """Implement _on_complete because we use statement_id."""
1102
+ from airflow.providers.common.compat.openlineage.facet import (
1103
+ ExternalQueryRunFacet,
1104
+ SQLJobFacet,
1105
+ )
1106
+ from airflow.providers.openlineage.extractors import OperatorLineage
1107
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo, SQLParser
1108
+
1109
+ db_info = DatabaseInfo(
1110
+ scheme="databricks",
1111
+ authority=self._hook.host,
1112
+ database=self.catalog,
1113
+ is_uppercase_names=False,
1114
+ # Other args will not be used as we'll not query DB for details, we only do sql parsing.
1115
+ )
1116
+
1117
+ sql_parser = SQLParser(
1118
+ dialect="databricks",
1119
+ default_schema=self.schema or "default",
1120
+ )
1121
+
1122
+ run_facets = {}
1154
1123
  if self.statement_id:
1155
- self._hook.cancel_sql_statement(self.statement_id)
1156
- self.log.info(
1157
- "Task: %s with statement ID: %s was requested to be cancelled.",
1158
- self.task_id,
1159
- self.statement_id,
1124
+ run_facets["externalQuery"] = ExternalQueryRunFacet(
1125
+ externalQueryId=self.statement_id, source=sql_parser.create_namespace(db_info)
1160
1126
  )
1161
- else:
1162
- self.log.error(
1163
- "Error: Task: %s with invalid statement_id was requested to be cancelled.", self.task_id
1127
+ job_facets = {"sql": SQLJobFacet(query=SQLParser.normalize_sql(self.statement))}
1128
+
1129
+ query = f"{self.statement}"
1130
+ if self.parameters:
1131
+ # Catalog, schema or table can be parameterized, so it's crucial to fill them before parsing
1132
+ for param in self.parameters:
1133
+ query = query.replace(f":{param['name']}", param.get("value") or "null")
1134
+
1135
+ parser_result = None
1136
+ try:
1137
+ # Try performing offline sql parsing, without db access,
1138
+ parser_result = sql_parser.generate_openlineage_metadata_from_sql(
1139
+ sql=query,
1140
+ database_info=db_info,
1141
+ database=None, # Provided in db_info
1142
+ use_connection=False, # Prevents DB call for table details, that will fail with API
1143
+ sqlalchemy_engine=None, # Not needed when use_connection is False
1144
+ hook=None, # type: ignore[arg-type] # Not needed when use_connection is False
1164
1145
  )
1165
-
1166
- def execute_complete(self, context: dict | None, event: dict):
1167
- statement_state = SQLStatementState.from_json(event["state"])
1168
- error = event["error"]
1169
- statement_id = event["statement_id"]
1170
-
1171
- if statement_state.is_successful:
1172
- self.log.info("SQL Statement with ID %s completed successfully.", statement_id)
1173
- return
1174
-
1175
- error_message = f"SQL Statement execution failed with terminal state: {statement_state} and with the error {error}"
1176
- raise AirflowException(error_message)
1146
+ except Exception as e:
1147
+ self.log.debug("OpenLineage failed to parse query `%s` with error %s", query, e)
1148
+
1149
+ return OperatorLineage(
1150
+ inputs=parser_result.inputs if parser_result else [],
1151
+ outputs=parser_result.outputs if parser_result else [],
1152
+ job_facets=parser_result.job_facets if parser_result else job_facets,
1153
+ run_facets={**parser_result.run_facets, **run_facets} if parser_result else run_facets,
1154
+ )
1177
1155
 
1178
1156
 
1179
1157
  class DatabricksTaskBaseOperator(BaseOperator, ABC):
@@ -277,8 +277,13 @@ class DatabricksCopyIntoOperator(BaseOperator):
277
277
  self._client_parameters = client_parameters or {}
278
278
  if force_copy is not None:
279
279
  self._copy_options["force"] = "true" if force_copy else "false"
280
+ self._sql: str | None = None
280
281
 
281
282
  def _get_hook(self) -> DatabricksSqlHook:
283
+ return self._hook
284
+
285
+ @cached_property
286
+ def _hook(self) -> DatabricksSqlHook:
282
287
  return DatabricksSqlHook(
283
288
  self.databricks_conn_id,
284
289
  http_path=self._http_path,
@@ -354,12 +359,116 @@ FILEFORMAT = {self._file_format}
354
359
  return sql.strip()
355
360
 
356
361
  def execute(self, context: Context) -> Any:
357
- sql = self._create_sql_query()
358
- self.log.info("Executing: %s", sql)
362
+ self._sql = self._create_sql_query()
363
+ self.log.info("Executing: %s", self._sql)
359
364
  hook = self._get_hook()
360
- hook.run(sql)
365
+ hook.run(self._sql)
361
366
 
362
367
  def on_kill(self) -> None:
363
368
  # NB: on_kill isn't required for this operator since query cancelling gets
364
369
  # handled in `DatabricksSqlHook.run()` method which is called in `execute()`
365
370
  ...
371
+
372
+ def _build_input_openlineage_dataset(self) -> tuple[Any, list[Any]]:
373
+ """Parse file_location to build the OpenLineage input dataset."""
374
+ from urllib.parse import urlparse
375
+
376
+ from airflow.providers.common.compat.openlineage.facet import Dataset, Error
377
+
378
+ try:
379
+ uri = urlparse(self.file_location)
380
+
381
+ # Only process schemes we know produce valid OL datasets with current implementation
382
+ if uri.scheme not in ("s3", "s3a", "s3n", "gs", "abfss", "wasbs"):
383
+ raise ValueError(f"Unsupported scheme: `{uri.scheme}` in `{self.file_location}`")
384
+
385
+ namespace = f"{uri.scheme}://{uri.netloc}"
386
+ name = uri.path.strip("/")
387
+ if name in ("", "."):
388
+ name = "/"
389
+ return Dataset(namespace=namespace, name=name), []
390
+ except Exception as e:
391
+ self.log.debug("Failed to parse file_location: `%s`, error: %s", self.file_location, str(e))
392
+ extraction_errors = [
393
+ Error(errorMessage=str(e), stackTrace=None, task=self.file_location, taskNumber=None)
394
+ ]
395
+ return None, extraction_errors
396
+
397
+ def _build_output_openlineage_dataset(self, namespace: str) -> tuple[Any, list[Any]]:
398
+ """Build output OpenLineage dataset from table information."""
399
+ from airflow.providers.common.compat.openlineage.facet import Dataset, Error
400
+
401
+ try:
402
+ table_parts = self.table_name.split(".")
403
+ if len(table_parts) == 3: # catalog.schema.table
404
+ catalog, schema, table = table_parts
405
+ elif len(table_parts) == 2: # schema.table
406
+ catalog = None
407
+ schema, table = table_parts
408
+ else:
409
+ catalog = None
410
+ schema = None
411
+ table = self.table_name
412
+
413
+ hook = self._get_hook()
414
+ schema = schema or hook.get_openlineage_default_schema() # Fallback to default schema
415
+ catalog = catalog or hook.catalog # Fallback to default catalog, if provided
416
+
417
+ # Combine schema/table with optional catalog for final dataset name
418
+ fq_name = table
419
+ if schema:
420
+ fq_name = f"{schema}.{fq_name}"
421
+ if catalog:
422
+ fq_name = f"{catalog}.{fq_name}"
423
+
424
+ return Dataset(namespace=namespace, name=fq_name), []
425
+ except Exception as e:
426
+ self.log.debug("Failed to construct output dataset: `%s`, error: %s", self.table_name, str(e))
427
+ extraction_errors = [
428
+ Error(errorMessage=str(e), stackTrace=None, task=self.table_name, taskNumber=None)
429
+ ]
430
+ return None, extraction_errors
431
+
432
+ def get_openlineage_facets_on_complete(self, _):
433
+ """Implement _on_complete as we are attaching query id."""
434
+ from airflow.providers.common.compat.openlineage.facet import (
435
+ ExternalQueryRunFacet,
436
+ ExtractionErrorRunFacet,
437
+ SQLJobFacet,
438
+ )
439
+ from airflow.providers.openlineage.extractors import OperatorLineage
440
+ from airflow.providers.openlineage.sqlparser import SQLParser
441
+
442
+ if not self._sql:
443
+ self.log.warning("No SQL query found, returning empty OperatorLineage.")
444
+ return OperatorLineage()
445
+
446
+ hook = self._get_hook()
447
+ run_facets = {}
448
+
449
+ connection = hook.get_connection(self.databricks_conn_id)
450
+ database_info = hook.get_openlineage_database_info(connection)
451
+ dbx_namespace = SQLParser.create_namespace(database_info)
452
+
453
+ if hook.query_ids:
454
+ run_facets["externalQuery"] = ExternalQueryRunFacet(
455
+ externalQueryId=hook.query_ids[0], source=dbx_namespace
456
+ )
457
+
458
+ input_dataset, extraction_errors = self._build_input_openlineage_dataset()
459
+ output_dataset, output_errors = self._build_output_openlineage_dataset(dbx_namespace)
460
+ extraction_errors.extend(output_errors)
461
+
462
+ if extraction_errors:
463
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
464
+ totalTasks=1,
465
+ failedTasks=len(extraction_errors),
466
+ errors=extraction_errors,
467
+ )
468
+
469
+ return OperatorLineage(
470
+ inputs=[input_dataset] if input_dataset else [],
471
+ outputs=[output_dataset] if output_dataset else [],
472
+ job_facets={"sql": SQLJobFacet(query=SQLParser.normalize_sql(self._sql))},
473
+ run_facets=run_facets,
474
+ )
@@ -0,0 +1,162 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ #
19
+ from __future__ import annotations
20
+
21
+ from collections.abc import Sequence
22
+ from functools import cached_property
23
+ from typing import TYPE_CHECKING, Any
24
+
25
+ from airflow.configuration import conf
26
+ from airflow.exceptions import AirflowException
27
+ from airflow.providers.databricks.hooks.databricks import DatabricksHook, SQLStatementState
28
+ from airflow.providers.databricks.operators.databricks import DEFER_METHOD_NAME
29
+ from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin
30
+ from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
31
+
32
+ if AIRFLOW_V_3_0_PLUS:
33
+ from airflow.sdk import BaseSensorOperator
34
+ else:
35
+ from airflow.sensors.base import BaseSensorOperator
36
+
37
+ if TYPE_CHECKING:
38
+ from airflow.utils.context import Context
39
+
40
+ XCOM_STATEMENT_ID_KEY = "statement_id"
41
+
42
+
43
+ class DatabricksSQLStatementsSensor(DatabricksSQLStatementsMixin, BaseSensorOperator):
44
+ """DatabricksSQLStatementsSensor."""
45
+
46
+ template_fields: Sequence[str] = (
47
+ "databricks_conn_id",
48
+ "statement",
49
+ "statement_id",
50
+ )
51
+ template_ext: Sequence[str] = (".json-tpl",)
52
+ ui_color = "#1CB1C2"
53
+ ui_fgcolor = "#fff"
54
+
55
+ def __init__(
56
+ self,
57
+ warehouse_id: str,
58
+ *,
59
+ statement: str | None = None,
60
+ statement_id: str | None = None,
61
+ catalog: str | None = None,
62
+ schema: str | None = None,
63
+ parameters: list[dict[str, Any]] | None = None,
64
+ databricks_conn_id: str = "databricks_default",
65
+ polling_period_seconds: int = 30,
66
+ databricks_retry_limit: int = 3,
67
+ databricks_retry_delay: int = 1,
68
+ databricks_retry_args: dict[Any, Any] | None = None,
69
+ do_xcom_push: bool = True,
70
+ wait_for_termination: bool = True,
71
+ timeout: float = 3600,
72
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
73
+ **kwargs,
74
+ ):
75
+ # Handle the scenario where either both statement and statement_id are set/not set
76
+ if statement and statement_id:
77
+ raise AirflowException("Cannot provide both statement and statement_id.")
78
+ if not statement and not statement_id:
79
+ raise AirflowException("One of either statement or statement_id must be provided.")
80
+
81
+ if not warehouse_id:
82
+ raise AirflowException("warehouse_id must be provided.")
83
+
84
+ super().__init__(**kwargs)
85
+
86
+ self.statement = statement
87
+ self.statement_id = statement_id
88
+ self.warehouse_id = warehouse_id
89
+ self.catalog = catalog
90
+ self.schema = schema
91
+ self.parameters = parameters
92
+ self.databricks_conn_id = databricks_conn_id
93
+ self.polling_period_seconds = polling_period_seconds
94
+ self.databricks_retry_limit = databricks_retry_limit
95
+ self.databricks_retry_delay = databricks_retry_delay
96
+ self.databricks_retry_args = databricks_retry_args
97
+ self.wait_for_termination = wait_for_termination
98
+ self.deferrable = deferrable
99
+ self.timeout = timeout
100
+ self.do_xcom_push = do_xcom_push
101
+
102
+ @cached_property
103
+ def _hook(self):
104
+ return self._get_hook(caller="DatabricksSQLStatementsSensor")
105
+
106
+ def _get_hook(self, caller: str) -> DatabricksHook:
107
+ return DatabricksHook(
108
+ self.databricks_conn_id,
109
+ retry_limit=self.databricks_retry_limit,
110
+ retry_delay=self.databricks_retry_delay,
111
+ retry_args=self.databricks_retry_args,
112
+ caller=caller,
113
+ )
114
+
115
+ def execute(self, context: Context):
116
+ if not self.statement_id:
117
+ # Otherwise, we'll go ahead and "submit" the statement
118
+ json = {
119
+ "statement": self.statement,
120
+ "warehouse_id": self.warehouse_id,
121
+ "catalog": self.catalog,
122
+ "schema": self.schema,
123
+ "parameters": self.parameters,
124
+ "wait_timeout": "0s",
125
+ }
126
+
127
+ self.statement_id = self._hook.post_sql_statement(json)
128
+ self.log.info("SQL Statement submitted with statement_id: %s", self.statement_id)
129
+
130
+ if self.do_xcom_push and context is not None:
131
+ context["ti"].xcom_push(key=XCOM_STATEMENT_ID_KEY, value=self.statement_id)
132
+
133
+ # If we're not waiting for the query to complete execution, then we'll go ahead and return. However, a
134
+ # recommendation to use the DatabricksSQLStatementOperator is made in this case
135
+ if not self.wait_for_termination:
136
+ self.log.info(
137
+ "If setting wait_for_termination = False, consider using the DatabricksSQLStatementsOperator instead."
138
+ )
139
+ return
140
+
141
+ if self.deferrable:
142
+ self._handle_deferrable_execution(defer_method_name=DEFER_METHOD_NAME) # type: ignore[misc]
143
+
144
+ def poke(self, context: Context):
145
+ """
146
+ Handle non-deferrable Sensor execution.
147
+
148
+ :param context: (Context)
149
+ :return: (bool)
150
+ """
151
+ # This is going to very closely mirror the execute_complete
152
+ statement_state: SQLStatementState = self._hook.get_sql_statement_state(self.statement_id)
153
+
154
+ if statement_state.is_running:
155
+ self.log.info("SQL Statement with ID %s is running", self.statement_id)
156
+ return False
157
+ if statement_state.is_successful:
158
+ self.log.info("SQL Statement with ID %s completed successfully.", self.statement_id)
159
+ return True
160
+ raise AirflowException(
161
+ f"SQL Statement with ID {statement_state} failed with error: {statement_state.error_message}"
162
+ )
@@ -0,0 +1,194 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ #
19
+ from __future__ import annotations
20
+
21
+ import time
22
+ from logging import Logger
23
+ from typing import (
24
+ TYPE_CHECKING,
25
+ Any,
26
+ Protocol,
27
+ )
28
+
29
+ from airflow.exceptions import AirflowException
30
+ from airflow.providers.databricks.hooks.databricks import DatabricksHook, SQLStatementState
31
+ from airflow.providers.databricks.triggers.databricks import DatabricksSQLStatementExecutionTrigger
32
+
33
+ if TYPE_CHECKING:
34
+ from airflow.utils.context import Context
35
+
36
+
37
+ class GetHookHasFields(Protocol):
38
+ """Protocol for get_hook method."""
39
+
40
+ databricks_conn_id: str
41
+ databricks_retry_args: dict | None
42
+ databricks_retry_delay: int
43
+ databricks_retry_limit: int
44
+
45
+
46
+ class HandleExecutionHasFields(Protocol):
47
+ """Protocol for _handle_execution method."""
48
+
49
+ _hook: DatabricksHook
50
+ log: Logger
51
+ polling_period_seconds: int
52
+ task_id: str
53
+ timeout: int
54
+ statement_id: str
55
+
56
+
57
+ class HandleDeferrableExecutionHasFields(Protocol):
58
+ """Protocol for _handle_deferrable_execution method."""
59
+
60
+ _hook: DatabricksHook
61
+ databricks_conn_id: str
62
+ databricks_retry_args: dict[Any, Any] | None
63
+ databricks_retry_delay: int
64
+ databricks_retry_limit: int
65
+ defer: Any
66
+ log: Logger
67
+ polling_period_seconds: int
68
+ statement_id: str
69
+ task_id: str
70
+ timeout: int
71
+
72
+
73
+ class ExecuteCompleteHasFields(Protocol):
74
+ """Protocol for execute_complete method."""
75
+
76
+ statement_id: str
77
+ _hook: DatabricksHook
78
+ log: Logger
79
+
80
+
81
+ class OnKillHasFields(Protocol):
82
+ """Protocol for on_kill method."""
83
+
84
+ _hook: DatabricksHook
85
+ log: Logger
86
+ statement_id: str
87
+ task_id: str
88
+
89
+
90
+ class DatabricksSQLStatementsMixin:
91
+ """
92
+ Mixin class to be used by both the DatabricksSqlStatementsOperator, and the DatabricksSqlStatementSensor.
93
+
94
+ - _handle_operator_execution (renamed to _handle_execution)
95
+ - _handle_deferrable_operator_execution (renamed to _handle_deferrable_execution)
96
+ - execute_complete
97
+ - on_kill
98
+ """
99
+
100
+ def _handle_execution(self: HandleExecutionHasFields) -> None:
101
+ """Execute a SQL statement in non-deferrable mode."""
102
+ # Determine the time at which the Task will timeout. The statement_state is defined here in the event
103
+ # the while-loop is never entered
104
+ end_time = time.time() + self.timeout
105
+
106
+ while end_time > time.time():
107
+ statement_state: SQLStatementState = self._hook.get_sql_statement_state(self.statement_id)
108
+
109
+ if statement_state.is_terminal:
110
+ if statement_state.is_successful:
111
+ self.log.info("%s completed successfully.", self.task_id)
112
+ return
113
+
114
+ error_message = (
115
+ f"{self.task_id} failed with terminal state: {statement_state.state} "
116
+ f"and with the error code {statement_state.error_code} "
117
+ f"and error message {statement_state.error_message}"
118
+ )
119
+ raise AirflowException(error_message)
120
+
121
+ self.log.info("%s in run state: %s", self.task_id, statement_state.state)
122
+ self.log.info("Sleeping for %s seconds.", self.polling_period_seconds)
123
+ time.sleep(self.polling_period_seconds)
124
+
125
+ # Once the timeout is exceeded, the query is cancelled. This is an important steps; if a query takes
126
+ # to log, it needs to be killed. Otherwise, it may be the case that there are "zombie" queries running
127
+ # that are no longer being orchestrated
128
+ self._hook.cancel_sql_statement(self.statement_id)
129
+ raise AirflowException(
130
+ f"{self.task_id} timed out after {self.timeout} seconds with state: {statement_state}",
131
+ )
132
+
133
+ def _handle_deferrable_execution(
134
+ self: HandleDeferrableExecutionHasFields, defer_method_name: str = "execute_complete"
135
+ ) -> None:
136
+ """Execute a SQL statement in deferrable mode."""
137
+ statement_state: SQLStatementState = self._hook.get_sql_statement_state(self.statement_id)
138
+ end_time: float = time.time() + self.timeout
139
+
140
+ if not statement_state.is_terminal:
141
+ # If the query is still running and there is no statement_id, this is somewhat of a "zombie"
142
+ # query, and should throw an exception
143
+ if not self.statement_id:
144
+ raise AirflowException("Failed to retrieve statement_id after submitting SQL statement.")
145
+
146
+ self.defer(
147
+ trigger=DatabricksSQLStatementExecutionTrigger(
148
+ statement_id=self.statement_id,
149
+ databricks_conn_id=self.databricks_conn_id,
150
+ end_time=end_time,
151
+ polling_period_seconds=self.polling_period_seconds,
152
+ retry_limit=self.databricks_retry_limit,
153
+ retry_delay=self.databricks_retry_delay,
154
+ retry_args=self.databricks_retry_args,
155
+ ),
156
+ method_name=defer_method_name,
157
+ )
158
+
159
+ else:
160
+ if statement_state.is_successful:
161
+ self.log.info("%s completed successfully.", self.task_id)
162
+ else:
163
+ error_message = (
164
+ f"{self.task_id} failed with terminal state: {statement_state.state} "
165
+ f"and with the error code {statement_state.error_code} "
166
+ f"and error message {statement_state.error_message}"
167
+ )
168
+ raise AirflowException(error_message)
169
+
170
+ def execute_complete(self: ExecuteCompleteHasFields, context: Context, event: dict):
171
+ statement_state = SQLStatementState.from_json(event["state"])
172
+ error = event["error"]
173
+ # Save as instance attribute again after coming back from defer (e.g., for later use in listeners)
174
+ self.statement_id = event["statement_id"]
175
+
176
+ if statement_state.is_successful:
177
+ self.log.info("SQL Statement with ID %s completed successfully.", self.statement_id)
178
+ return
179
+
180
+ error_message = f"SQL Statement execution failed with terminal state: {statement_state} and with the error {error}"
181
+ raise AirflowException(error_message)
182
+
183
+ def on_kill(self: OnKillHasFields) -> None:
184
+ if self.statement_id:
185
+ self._hook.cancel_sql_statement(self.statement_id)
186
+ self.log.info(
187
+ "Task: %s with statement ID: %s was requested to be cancelled.",
188
+ self.task_id,
189
+ self.statement_id,
190
+ )
191
+ else:
192
+ self.log.error(
193
+ "Error: Task: %s with invalid statement_id was requested to be cancelled.", self.task_id
194
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: apache-airflow-providers-databricks
3
- Version: 7.4.0
3
+ Version: 7.5.0
4
4
  Summary: Provider package apache-airflow-providers-databricks for Apache Airflow
5
5
  Keywords: airflow-provider,databricks,airflow,integration
6
6
  Author-email: Apache Software Foundation <dev@airflow.apache.org>
@@ -23,7 +23,7 @@ Classifier: Topic :: System :: Monitoring
23
23
  Requires-Dist: apache-airflow>=2.10.0
24
24
  Requires-Dist: apache-airflow-providers-common-compat>=1.6.0
25
25
  Requires-Dist: apache-airflow-providers-common-sql>=1.27.0
26
- Requires-Dist: requests>=2.31.0,<3
26
+ Requires-Dist: requests>=2.32.0,<3
27
27
  Requires-Dist: databricks-sql-connector>=3.0.0
28
28
  Requires-Dist: databricks-sqlalchemy>=1.0.2
29
29
  Requires-Dist: aiohttp>=3.9.2, <4
@@ -36,8 +36,8 @@ Requires-Dist: apache-airflow-providers-openlineage>=2.3.0 ; extra == "openlinea
36
36
  Requires-Dist: databricks-sdk==0.10.0 ; extra == "sdk"
37
37
  Requires-Dist: apache-airflow-providers-standard ; extra == "standard"
38
38
  Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
39
- Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.4.0/changelog.html
40
- Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.4.0
39
+ Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.5.0/changelog.html
40
+ Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.5.0
41
41
  Project-URL: Mastodon, https://fosstodon.org/@airflow
42
42
  Project-URL: Slack Chat, https://s.apache.org/airflow-slack
43
43
  Project-URL: Source Code, https://github.com/apache/airflow
@@ -73,7 +73,7 @@ Provides-Extra: standard
73
73
 
74
74
  Package ``apache-airflow-providers-databricks``
75
75
 
76
- Release: ``7.4.0``
76
+ Release: ``7.5.0``
77
77
 
78
78
 
79
79
  `Databricks <https://databricks.com/>`__
@@ -86,7 +86,7 @@ This is a provider package for ``databricks`` provider. All classes for this pro
86
86
  are in ``airflow.providers.databricks`` python package.
87
87
 
88
88
  You can find package information and changelog for the provider
89
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.4.0/>`_.
89
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.5.0/>`_.
90
90
 
91
91
  Installation
92
92
  ------------
@@ -106,7 +106,7 @@ PIP package Version required
106
106
  ``apache-airflow`` ``>=2.10.0``
107
107
  ``apache-airflow-providers-common-compat`` ``>=1.6.0``
108
108
  ``apache-airflow-providers-common-sql`` ``>=1.27.0``
109
- ``requests`` ``>=2.31.0,<3``
109
+ ``requests`` ``>=2.32.0,<3``
110
110
  ``databricks-sql-connector`` ``>=3.0.0``
111
111
  ``databricks-sqlalchemy`` ``>=1.0.2``
112
112
  ``aiohttp`` ``>=3.9.2,<4``
@@ -125,16 +125,18 @@ You can install such cross-provider dependencies when installing from PyPI. For
125
125
 
126
126
  .. code-block:: bash
127
127
 
128
- pip install apache-airflow-providers-databricks[common.sql]
128
+ pip install apache-airflow-providers-databricks[common.compat]
129
129
 
130
130
 
131
- ============================================================================================================ ==============
132
- Dependent package Extra
133
- ============================================================================================================ ==============
134
- `apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
135
- `apache-airflow-providers-fab <https://airflow.apache.org/docs/apache-airflow-providers-fab>`_ ``fab``
136
- ============================================================================================================ ==============
131
+ ================================================================================================================== =================
132
+ Dependent package Extra
133
+ ================================================================================================================== =================
134
+ `apache-airflow-providers-common-compat <https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_ ``common.compat``
135
+ `apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
136
+ `apache-airflow-providers-fab <https://airflow.apache.org/docs/apache-airflow-providers-fab>`_ ``fab``
137
+ `apache-airflow-providers-openlineage <https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_ ``openlineage``
138
+ ================================================================================================================== =================
137
139
 
138
140
  The changelog for the provider package can be found in the
139
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.4.0/changelog.html>`_.
141
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-databricks/7.5.0/changelog.html>`_.
140
142
 
@@ -1,28 +1,30 @@
1
1
  airflow/providers/databricks/LICENSE,sha256=gXPVwptPlW1TJ4HSuG5OMPg-a3h43OGMkZRR1rpwfJA,10850
2
- airflow/providers/databricks/__init__.py,sha256=D4S1V7H6S0R_iJ75oqoynUbmOYkMET00uTSUyPzvkzM,1499
2
+ airflow/providers/databricks/__init__.py,sha256=d1KmgGbpEA3KHmX7l-luO2NBwkPvv7tSFPf-KMbN2LE,1499
3
3
  airflow/providers/databricks/exceptions.py,sha256=85RklmLOI_PnTzfXNIUd5fAu2aMMUhelwumQAX0wANE,1261
4
- airflow/providers/databricks/get_provider_info.py,sha256=qNMX4Lft-NItPhFewFBSCi8n0_ISid_MQeETKQ67vdo,5573
4
+ airflow/providers/databricks/get_provider_info.py,sha256=NZ-rY6k6ctDZN7rDngN7mAzq7RMhLag5NwfnuBNcKuw,5644
5
5
  airflow/providers/databricks/version_compat.py,sha256=j5PCtXvZ71aBjixu-EFTNtVDPsngzzs7os0ZQDgFVDk,1536
6
6
  airflow/providers/databricks/hooks/__init__.py,sha256=mlJxuZLkd5x-iq2SBwD3mvRQpt3YR7wjz_nceyF1IaI,787
7
7
  airflow/providers/databricks/hooks/databricks.py,sha256=FIoiKWIc9AP3s8Av3Av9yleTg1kI0norwW5CAc6jTQc,28867
8
8
  airflow/providers/databricks/hooks/databricks_base.py,sha256=D7-_74QgQaZm1NfHKl_UOXbVAXRo2xjnOx_r1MI-rWI,34871
9
9
  airflow/providers/databricks/hooks/databricks_sql.py,sha256=xTdi0JN-ZdsGe2XnCa8yBi-AINZUlyIVlP-5nb2d2T0,16964
10
10
  airflow/providers/databricks/operators/__init__.py,sha256=mlJxuZLkd5x-iq2SBwD3mvRQpt3YR7wjz_nceyF1IaI,787
11
- airflow/providers/databricks/operators/databricks.py,sha256=E8fgk3Z67uOTSvWvbF23Miv6EruSGOTdFvHn7pGVWp0,80138
11
+ airflow/providers/databricks/operators/databricks.py,sha256=yDy_pBaAi_muP3NstpXOqBNxSP9WL0_X3fX2OmR1f3c,79235
12
12
  airflow/providers/databricks/operators/databricks_repos.py,sha256=m_72OnnU9df7UB-8SK2Tp5VjfNyjYeAnil3dCKs9SbA,13282
13
- airflow/providers/databricks/operators/databricks_sql.py,sha256=Ycp5mcb3uScQrognB2k8IeSR9oBx-Vnv6NEYGYuE800,17159
13
+ airflow/providers/databricks/operators/databricks_sql.py,sha256=yrYZa9Hq8JDc-8F5DGfW2mkcaNwu4o09JZj_SQQnsrE,21807
14
14
  airflow/providers/databricks/operators/databricks_workflow.py,sha256=9WNQR9COa90fbqb9qSzut34K9Z1S_ZdpNHAfIcuH454,14227
15
15
  airflow/providers/databricks/plugins/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
16
16
  airflow/providers/databricks/plugins/databricks_workflow.py,sha256=1UpsodBLRrTah9zBGBzfM7n1pdkzTo7yilt6QxASspQ,17460
17
17
  airflow/providers/databricks/sensors/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
18
+ airflow/providers/databricks/sensors/databricks.py,sha256=zG2rS7xemVUk-ztDSj0t90Ws47kqRgPm3NsBMQQR8bA,6389
18
19
  airflow/providers/databricks/sensors/databricks_partition.py,sha256=2zWdnqVaSSd7PFTZadfvtbsR7zOI4GwfZFOuEnXRLSM,10023
19
20
  airflow/providers/databricks/sensors/databricks_sql.py,sha256=jIA9oGBUCAlXzyrqigxlg7JQDsBFuNIF8ZUEJM8gPxg,5766
20
21
  airflow/providers/databricks/triggers/__init__.py,sha256=mlJxuZLkd5x-iq2SBwD3mvRQpt3YR7wjz_nceyF1IaI,787
21
22
  airflow/providers/databricks/triggers/databricks.py,sha256=dSogx6GlcJfZ4CFhtlMeWs9sYFEYthP82S_U8-tM2Tk,9240
22
23
  airflow/providers/databricks/utils/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
23
24
  airflow/providers/databricks/utils/databricks.py,sha256=s0qEr_DsFhKW4uUiq2VQbtqcj52isYIplPZsUcxGPrI,2862
25
+ airflow/providers/databricks/utils/mixins.py,sha256=WUmkt3AmXalmV6zOUIJZWbTldxYunAZOstddDhKCC94,7407
24
26
  airflow/providers/databricks/utils/openlineage.py,sha256=7fR3CPcOruHapsz1DOZ38QN3ZcAGDADNHPY28CzYCbg,13194
25
- apache_airflow_providers_databricks-7.4.0.dist-info/entry_points.txt,sha256=hjmZm3ab2cteTR4t9eE28oKixHwNIKtLCThd6sx3XRQ,227
26
- apache_airflow_providers_databricks-7.4.0.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
27
- apache_airflow_providers_databricks-7.4.0.dist-info/METADATA,sha256=6IAI-03VfbWg0tayX29gnzSeq5Em510la7J98MSQPd8,6446
28
- apache_airflow_providers_databricks-7.4.0.dist-info/RECORD,,
27
+ apache_airflow_providers_databricks-7.5.0.dist-info/entry_points.txt,sha256=hjmZm3ab2cteTR4t9eE28oKixHwNIKtLCThd6sx3XRQ,227
28
+ apache_airflow_providers_databricks-7.5.0.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
29
+ apache_airflow_providers_databricks-7.5.0.dist-info/METADATA,sha256=3Of4M9LHN0OsBur6wL1Nr3olZEGS_zAMIvrYOW6Rjaw,6760
30
+ apache_airflow_providers_databricks-7.5.0.dist-info/RECORD,,