apache-airflow-providers-openlineage 2.0.0rc2__py3-none-any.whl → 2.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-openlineage might be problematic. Click here for more details.

@@ -19,10 +19,10 @@ from __future__ import annotations
19
19
  import logging
20
20
  import os
21
21
  from concurrent.futures import ProcessPoolExecutor
22
+ from datetime import datetime
22
23
  from typing import TYPE_CHECKING
23
24
 
24
25
  import psutil
25
- from openlineage.client.serde import Serde
26
26
  from setproctitle import getproctitle, setproctitle
27
27
 
28
28
  from airflow import settings
@@ -33,6 +33,7 @@ from airflow.providers.openlineage.extractors import ExtractorManager
33
33
  from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState
34
34
  from airflow.providers.openlineage.utils.utils import (
35
35
  AIRFLOW_V_2_10_PLUS,
36
+ AIRFLOW_V_3_0_PLUS,
36
37
  get_airflow_dag_run_facet,
37
38
  get_airflow_debug_facet,
38
39
  get_airflow_job_facet,
@@ -42,7 +43,6 @@ from airflow.providers.openlineage.utils.utils import (
42
43
  get_user_provided_run_facets,
43
44
  is_operator_disabled,
44
45
  is_selective_lineage_enabled,
45
- is_ti_rescheduled_already,
46
46
  print_warning,
47
47
  )
48
48
  from airflow.settings import configure_orm
@@ -50,11 +50,12 @@ from airflow.stats import Stats
50
50
  from airflow.utils import timezone
51
51
  from airflow.utils.state import TaskInstanceState
52
52
  from airflow.utils.timeout import timeout
53
+ from openlineage.client.serde import Serde
53
54
 
54
55
  if TYPE_CHECKING:
55
- from sqlalchemy.orm import Session
56
-
57
56
  from airflow.models import TaskInstance
57
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
58
+ from airflow.settings import Session
58
59
 
59
60
  _openlineage_listener: OpenLineageListener | None = None
60
61
 
@@ -87,28 +88,58 @@ class OpenLineageListener:
87
88
  self.extractor_manager = ExtractorManager()
88
89
  self.adapter = OpenLineageAdapter()
89
90
 
90
- @hookimpl
91
- def on_task_instance_running(
92
- self,
93
- previous_state: TaskInstanceState,
94
- task_instance: TaskInstance,
95
- session: Session, # This will always be QUEUED
96
- ) -> None:
97
- if not getattr(task_instance, "task", None) is not None:
98
- self.log.warning(
99
- "No task set for TI object task_id: %s - dag_id: %s - run_id %s",
100
- task_instance.task_id,
101
- task_instance.dag_id,
102
- task_instance.run_id,
103
- )
104
- return
91
+ if AIRFLOW_V_3_0_PLUS:
105
92
 
106
- self.log.debug("OpenLineage listener got notification about task instance start")
107
- dagrun = task_instance.dag_run
108
- task = task_instance.task
109
- if TYPE_CHECKING:
110
- assert task
111
- dag = task.dag
93
+ @hookimpl
94
+ def on_task_instance_running(
95
+ self,
96
+ previous_state: TaskInstanceState,
97
+ task_instance: RuntimeTaskInstance,
98
+ ):
99
+ self.log.debug("OpenLineage listener got notification about task instance start")
100
+ context = task_instance.get_template_context()
101
+
102
+ task = context["task"]
103
+ if TYPE_CHECKING:
104
+ assert task
105
+ dagrun = context["dag_run"]
106
+ dag = context["dag"]
107
+ start_date = context["start_date"]
108
+ self._on_task_instance_running(task_instance, dag, dagrun, task, start_date)
109
+ else:
110
+
111
+ @hookimpl
112
+ def on_task_instance_running(
113
+ self,
114
+ previous_state: TaskInstanceState,
115
+ task_instance: TaskInstance,
116
+ session: Session, # type: ignore[valid-type]
117
+ ) -> None:
118
+ from airflow.providers.openlineage.utils.utils import is_ti_rescheduled_already
119
+
120
+ if not getattr(task_instance, "task", None) is not None:
121
+ self.log.warning(
122
+ "No task set for TI object task_id: %s - dag_id: %s - run_id %s",
123
+ task_instance.task_id,
124
+ task_instance.dag_id,
125
+ task_instance.run_id,
126
+ )
127
+ return
128
+
129
+ self.log.debug("OpenLineage listener got notification about task instance start")
130
+ task = task_instance.task
131
+ if TYPE_CHECKING:
132
+ assert task
133
+ start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow()
134
+
135
+ if is_ti_rescheduled_already(task_instance):
136
+ self.log.debug("Skipping this instance of rescheduled task - START event was emitted already")
137
+ return
138
+ self._on_task_instance_running(task_instance, task.dag, task_instance.dag_run, task, start_date)
139
+
140
+ def _on_task_instance_running(
141
+ self, task_instance: RuntimeTaskInstance | TaskInstance, dag, dagrun, task, start_date: datetime
142
+ ):
112
143
  if is_operator_disabled(task):
113
144
  self.log.debug(
114
145
  "Skipping OpenLineage event emission for operator `%s` "
@@ -127,35 +158,41 @@ class OpenLineageListener:
127
158
  return
128
159
 
129
160
  # Needs to be calculated outside of inner method so that it gets cached for usage in fork processes
161
+ data_interval_start = dagrun.data_interval_start
162
+ if isinstance(data_interval_start, datetime):
163
+ data_interval_start = data_interval_start.isoformat()
164
+ data_interval_end = dagrun.data_interval_end
165
+ if isinstance(data_interval_end, datetime):
166
+ data_interval_end = data_interval_end.isoformat()
167
+
168
+ clear_number = 0
169
+ if hasattr(dagrun, "clear_number"):
170
+ clear_number = dagrun.clear_number
171
+
130
172
  debug_facet = get_airflow_debug_facet()
131
173
 
132
174
  @print_warning(self.log)
133
175
  def on_running():
134
- # that's a workaround to detect task running from deferred state
135
- # we return here because Airflow 2.3 needs task from deferred state
136
- if task_instance.next_method is not None:
137
- return
138
-
139
- if is_ti_rescheduled_already(task_instance):
176
+ context = task_instance.get_template_context()
177
+ if hasattr(context, "task_reschedule_count") and context["task_reschedule_count"] > 0:
140
178
  self.log.debug("Skipping this instance of rescheduled task - START event was emitted already")
141
179
  return
142
180
 
181
+ date = dagrun.logical_date
182
+ if AIRFLOW_V_3_0_PLUS and date is None:
183
+ date = dagrun.run_after
184
+
143
185
  parent_run_id = self.adapter.build_dag_run_id(
144
186
  dag_id=dag.dag_id,
145
- logical_date=dagrun.logical_date,
146
- clear_number=dagrun.clear_number,
187
+ logical_date=date,
188
+ clear_number=clear_number,
147
189
  )
148
190
 
149
- if hasattr(task_instance, "logical_date"):
150
- logical_date = task_instance.logical_date
151
- else:
152
- logical_date = task_instance.execution_date
153
-
154
191
  task_uuid = self.adapter.build_task_instance_run_id(
155
192
  dag_id=dag.dag_id,
156
193
  task_id=task.task_id,
157
194
  try_number=task_instance.try_number,
158
- logical_date=logical_date,
195
+ logical_date=date,
159
196
  map_index=task_instance.map_index,
160
197
  )
161
198
  event_type = RunState.RUNNING.value.lower()
@@ -164,11 +201,6 @@ class OpenLineageListener:
164
201
  with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
165
202
  task_metadata = self.extractor_manager.extract_metadata(dagrun, task)
166
203
 
167
- start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow()
168
- data_interval_start = (
169
- dagrun.data_interval_start.isoformat() if dagrun.data_interval_start else None
170
- )
171
- data_interval_end = dagrun.data_interval_end.isoformat() if dagrun.data_interval_end else None
172
204
  redacted_event = self.adapter.start_task(
173
205
  run_id=task_uuid,
174
206
  job_name=get_job_name(task),
@@ -195,17 +227,39 @@ class OpenLineageListener:
195
227
 
196
228
  self._execute(on_running, "on_running", use_fork=True)
197
229
 
198
- @hookimpl
199
- def on_task_instance_success(
200
- self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session
201
- ) -> None:
202
- self.log.debug("OpenLineage listener got notification about task instance success")
230
+ if AIRFLOW_V_3_0_PLUS:
231
+
232
+ @hookimpl
233
+ def on_task_instance_success(
234
+ self, previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance
235
+ ) -> None:
236
+ self.log.debug("OpenLineage listener got notification about task instance success")
203
237
 
204
- dagrun = task_instance.dag_run
205
- task = task_instance.task
206
- if TYPE_CHECKING:
207
- assert task
208
- dag = task.dag
238
+ context = task_instance.get_template_context()
239
+ task = context["task"]
240
+ if TYPE_CHECKING:
241
+ assert task
242
+ dagrun = context["dag_run"]
243
+ dag = context["dag"]
244
+ self._on_task_instance_success(task_instance, dag, dagrun, task)
245
+
246
+ else:
247
+
248
+ @hookimpl
249
+ def on_task_instance_success(
250
+ self,
251
+ previous_state: TaskInstanceState,
252
+ task_instance: TaskInstance,
253
+ session: Session, # type: ignore[valid-type]
254
+ ) -> None:
255
+ self.log.debug("OpenLineage listener got notification about task instance success")
256
+ task = task_instance.task
257
+ if TYPE_CHECKING:
258
+ assert task
259
+ self._on_task_instance_success(task_instance, task.dag, task_instance.dag_run, task)
260
+
261
+ def _on_task_instance_success(self, task_instance: RuntimeTaskInstance, dag, dagrun, task):
262
+ end_date = timezone.utcnow()
209
263
 
210
264
  if is_operator_disabled(task):
211
265
  self.log.debug(
@@ -226,21 +280,21 @@ class OpenLineageListener:
226
280
 
227
281
  @print_warning(self.log)
228
282
  def on_success():
283
+ date = dagrun.logical_date
284
+ if AIRFLOW_V_3_0_PLUS and date is None:
285
+ date = dagrun.run_after
286
+
229
287
  parent_run_id = self.adapter.build_dag_run_id(
230
288
  dag_id=dag.dag_id,
231
- logical_date=dagrun.logical_date,
289
+ logical_date=date,
232
290
  clear_number=dagrun.clear_number,
233
291
  )
234
292
 
235
- if hasattr(task_instance, "logical_date"):
236
- logical_date = task_instance.logical_date
237
- else:
238
- logical_date = task_instance.execution_date
239
293
  task_uuid = self.adapter.build_task_instance_run_id(
240
294
  dag_id=dag.dag_id,
241
295
  task_id=task.task_id,
242
296
  try_number=_get_try_number_success(task_instance),
243
- logical_date=logical_date,
297
+ logical_date=date,
244
298
  map_index=task_instance.map_index,
245
299
  )
246
300
  event_type = RunState.COMPLETE.value.lower()
@@ -251,8 +305,6 @@ class OpenLineageListener:
251
305
  dagrun, task, complete=True, task_instance=task_instance
252
306
  )
253
307
 
254
- end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow()
255
-
256
308
  redacted_event = self.adapter.complete_task(
257
309
  run_id=task_uuid,
258
310
  job_name=get_job_name(task),
@@ -273,7 +325,7 @@ class OpenLineageListener:
273
325
 
274
326
  self._execute(on_success, "on_success", use_fork=True)
275
327
 
276
- if AIRFLOW_V_2_10_PLUS:
328
+ if AIRFLOW_V_3_0_PLUS:
277
329
 
278
330
  @hookimpl
279
331
  def on_task_instance_failed(
@@ -281,36 +333,54 @@ class OpenLineageListener:
281
333
  previous_state: TaskInstanceState,
282
334
  task_instance: TaskInstance,
283
335
  error: None | str | BaseException,
284
- session: Session,
285
336
  ) -> None:
286
- self._on_task_instance_failed(
287
- previous_state=previous_state, task_instance=task_instance, error=error, session=session
288
- )
337
+ self.log.debug("OpenLineage listener got notification about task instance failure")
338
+ context = task_instance.get_template_context()
339
+ task = context["task"]
340
+ if TYPE_CHECKING:
341
+ assert task
342
+ dagrun = context["dag_run"]
343
+ dag = context["dag"]
344
+ self._on_task_instance_failed(task_instance, dag, dagrun, task, error)
289
345
 
346
+ elif AIRFLOW_V_2_10_PLUS:
347
+
348
+ @hookimpl
349
+ def on_task_instance_failed(
350
+ self,
351
+ previous_state: TaskInstanceState,
352
+ task_instance: TaskInstance,
353
+ error: None | str | BaseException,
354
+ session: Session, # type: ignore[valid-type]
355
+ ) -> None:
356
+ self.log.debug("OpenLineage listener got notification about task instance failure")
357
+ task = task_instance.task
358
+ if TYPE_CHECKING:
359
+ assert task
360
+ self._on_task_instance_failed(task_instance, task.dag, task_instance.dag_run, task, error)
290
361
  else:
291
362
 
292
363
  @hookimpl
293
364
  def on_task_instance_failed(
294
- self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session
365
+ self,
366
+ previous_state: TaskInstanceState,
367
+ task_instance: TaskInstance,
368
+ session: Session, # type: ignore[valid-type]
295
369
  ) -> None:
296
- self._on_task_instance_failed(
297
- previous_state=previous_state, task_instance=task_instance, error=None, session=session
298
- )
370
+ task = task_instance.task
371
+ if TYPE_CHECKING:
372
+ assert task
373
+ self._on_task_instance_failed(task_instance, task.dag, task_instance.dag_run, task)
299
374
 
300
375
  def _on_task_instance_failed(
301
376
  self,
302
- previous_state: TaskInstanceState,
303
- task_instance: TaskInstance,
304
- session: Session,
377
+ task_instance: TaskInstance | RuntimeTaskInstance,
378
+ dag,
379
+ dagrun,
380
+ task,
305
381
  error: None | str | BaseException = None,
306
382
  ) -> None:
307
- self.log.debug("OpenLineage listener got notification about task instance failure")
308
-
309
- dagrun = task_instance.dag_run
310
- task = task_instance.task
311
- if TYPE_CHECKING:
312
- assert task
313
- dag = task.dag
383
+ end_date = timezone.utcnow()
314
384
 
315
385
  if is_operator_disabled(task):
316
386
  self.log.debug(
@@ -331,22 +401,21 @@ class OpenLineageListener:
331
401
 
332
402
  @print_warning(self.log)
333
403
  def on_failure():
404
+ date = dagrun.logical_date
405
+ if AIRFLOW_V_3_0_PLUS and date is None:
406
+ date = dagrun.run_after
407
+
334
408
  parent_run_id = self.adapter.build_dag_run_id(
335
409
  dag_id=dag.dag_id,
336
- logical_date=dagrun.logical_date,
410
+ logical_date=date,
337
411
  clear_number=dagrun.clear_number,
338
412
  )
339
413
 
340
- if hasattr(task_instance, "logical_date"):
341
- logical_date = task_instance.logical_date
342
- else:
343
- logical_date = task_instance.execution_date
344
-
345
414
  task_uuid = self.adapter.build_task_instance_run_id(
346
415
  dag_id=dag.dag_id,
347
416
  task_id=task.task_id,
348
417
  try_number=task_instance.try_number,
349
- logical_date=logical_date,
418
+ logical_date=date,
350
419
  map_index=task_instance.map_index,
351
420
  )
352
421
  event_type = RunState.FAIL.value.lower()
@@ -357,8 +426,6 @@ class OpenLineageListener:
357
426
  dagrun, task, complete=True, task_instance=task_instance
358
427
  )
359
428
 
360
- end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow()
361
-
362
429
  redacted_event = self.adapter.fail_task(
363
430
  run_id=task_uuid,
364
431
  job_name=get_job_name(task),
@@ -462,10 +529,14 @@ class OpenLineageListener:
462
529
 
463
530
  run_facets = {**get_airflow_dag_run_facet(dag_run)}
464
531
 
532
+ date = dag_run.logical_date
533
+ if AIRFLOW_V_3_0_PLUS and date is None:
534
+ date = dag_run.run_after
535
+
465
536
  self.submit_callable(
466
537
  self.adapter.dag_started,
467
538
  dag_id=dag_run.dag_id,
468
- logical_date=dag_run.logical_date,
539
+ logical_date=date,
469
540
  start_date=dag_run.start_date,
470
541
  nominal_start_time=data_interval_start,
471
542
  nominal_end_time=data_interval_end,
@@ -500,15 +571,21 @@ class OpenLineageListener:
500
571
  task_ids = DagRun._get_partial_task_ids(dag_run.dag)
501
572
  else:
502
573
  task_ids = dag_run.dag.task_ids if dag_run.dag and dag_run.dag.partial else None
574
+
575
+ date = dag_run.logical_date
576
+ if AIRFLOW_V_3_0_PLUS and date is None:
577
+ date = dag_run.run_after
578
+
503
579
  self.submit_callable(
504
580
  self.adapter.dag_success,
505
581
  dag_id=dag_run.dag_id,
506
582
  run_id=dag_run.run_id,
507
583
  end_date=dag_run.end_date,
508
- logical_date=dag_run.logical_date,
584
+ logical_date=date,
509
585
  clear_number=dag_run.clear_number,
510
586
  task_ids=task_ids,
511
587
  dag_run_state=dag_run.get_state(),
588
+ run_facets={**get_airflow_dag_run_facet(dag_run)},
512
589
  )
513
590
  except BaseException as e:
514
591
  self.log.warning("OpenLineage received exception in method on_dag_run_success", exc_info=e)
@@ -533,16 +610,22 @@ class OpenLineageListener:
533
610
  task_ids = DagRun._get_partial_task_ids(dag_run.dag)
534
611
  else:
535
612
  task_ids = dag_run.dag.task_ids if dag_run.dag and dag_run.dag.partial else None
613
+
614
+ date = dag_run.logical_date
615
+ if AIRFLOW_V_3_0_PLUS and date is None:
616
+ date = dag_run.run_after
617
+
536
618
  self.submit_callable(
537
619
  self.adapter.dag_failed,
538
620
  dag_id=dag_run.dag_id,
539
621
  run_id=dag_run.run_id,
540
622
  end_date=dag_run.end_date,
541
- logical_date=dag_run.logical_date,
623
+ logical_date=date,
542
624
  clear_number=dag_run.clear_number,
543
625
  dag_run_state=dag_run.get_state(),
544
626
  task_ids=task_ids,
545
627
  msg=msg,
628
+ run_facets={**get_airflow_dag_run_facet(dag_run)},
546
629
  )
547
630
  except BaseException as e:
548
631
  self.log.warning("OpenLineage received exception in method on_dag_run_failed", exc_info=e)
@@ -16,13 +16,11 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
- from typing import TYPE_CHECKING, Callable
19
+ import logging
20
+ from typing import TYPE_CHECKING, Callable, TypedDict
20
21
 
21
22
  import sqlparse
22
23
  from attrs import define
23
- from openlineage.client.event_v2 import Dataset
24
- from openlineage.client.facet_v2 import column_lineage_dataset, extraction_error_run, sql_job
25
- from openlineage.common.sql import DbTableMeta, SqlMeta, parse
26
24
 
27
25
  from airflow.providers.openlineage.extractors.base import OperatorLineage
28
26
  from airflow.providers.openlineage.utils.sql import (
@@ -30,14 +28,20 @@ from airflow.providers.openlineage.utils.sql import (
30
28
  create_information_schema_query,
31
29
  get_table_schemas,
32
30
  )
33
- from airflow.typing_compat import TypedDict
31
+ from airflow.providers.openlineage.utils.utils import should_use_external_connection
34
32
  from airflow.utils.log.logging_mixin import LoggingMixin
33
+ from openlineage.client.event_v2 import Dataset
34
+ from openlineage.client.facet_v2 import column_lineage_dataset, extraction_error_run, sql_job
35
+ from openlineage.common.sql import DbTableMeta, SqlMeta, parse
35
36
 
36
37
  if TYPE_CHECKING:
37
- from openlineage.client.facet_v2 import JobFacet, RunFacet
38
38
  from sqlalchemy.engine import Engine
39
39
 
40
40
  from airflow.hooks.base import BaseHook
41
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
42
+ from openlineage.client.facet_v2 import JobFacet, RunFacet
43
+
44
+ log = logging.getLogger(__name__)
41
45
 
42
46
  DEFAULT_NAMESPACE = "default"
43
47
  DEFAULT_INFORMATION_SCHEMA_COLUMNS = [
@@ -78,10 +82,67 @@ class DatabaseInfo:
78
82
  :param database: Takes precedence over parsed database name.
79
83
  :param information_schema_columns: List of columns names from information schema table.
80
84
  :param information_schema_table_name: Information schema table name.
81
- :param use_flat_cross_db_query: Specifies if single information schema table should be used
82
- for cross-database queries (e.g. for Redshift).
83
- :param is_information_schema_cross_db: Specifies if information schema contains
84
- cross-database data.
85
+ :param use_flat_cross_db_query: Specifies whether a single, "global" information schema table should
86
+ be used for cross-database queries (e.g., in Redshift), or if multiple, per-database "local"
87
+ information schema tables should be queried individually.
88
+
89
+ If True, assumes a single, universal information schema table is available
90
+ (for example, in Redshift, the `SVV_REDSHIFT_COLUMNS` view)
91
+ [https://docs.aws.amazon.com/redshift/latest/dg/r_SVV_REDSHIFT_COLUMNS.html].
92
+ In this mode, we query only `information_schema_table_name` directly.
93
+ Depending on the `is_information_schema_cross_db` argument, you can also filter
94
+ by database name in the WHERE clause.
95
+
96
+ If False, treats each database as having its own local information schema table containing
97
+ metadata for that database only. As a result, one query per database may be generated
98
+ and then combined (often via `UNION ALL`).
99
+ This approach is necessary for dialects that do not maintain a single global view of
100
+ all metadata or that require per-database queries.
101
+ Depending on the `is_information_schema_cross_db` argument, queries can
102
+ include or omit database information in both identifiers and filters.
103
+
104
+ See `is_information_schema_cross_db` which also affects how final queries are constructed.
105
+ :param is_information_schema_cross_db: Specifies whether database information should be tracked
106
+ and included in queries that retrieve schema information from the information_schema_table.
107
+ In short, this determines whether queries are capable of spanning multiple databases.
108
+
109
+ If True, database identifiers are included wherever applicable, allowing retrieval of
110
+ metadata from more than one database. For instance, in Snowflake or MS SQL
111
+ (where each database is treated as a top-level namespace), you might have a query like:
112
+
113
+ ```
114
+ SELECT ...
115
+ FROM db1.information_schema.columns WHERE ...
116
+ UNION ALL
117
+ SELECT ...
118
+ FROM db2.information_schema.columns WHERE ...
119
+ ```
120
+
121
+ In Redshift, setting this to True together with `use_flat_cross_db_query=True` allows
122
+ adding database filters to the query, for example:
123
+
124
+ ```
125
+ SELECT ...
126
+ FROM SVV_REDSHIFT_COLUMNS
127
+ WHERE
128
+ SVV_REDSHIFT_COLUMNS.database == db1 # This is skipped when False
129
+ AND SVV_REDSHIFT_COLUMNS.schema == schema1
130
+ AND SVV_REDSHIFT_COLUMNS.table IN (table1, table2)
131
+ OR ...
132
+ ```
133
+
134
+ However, certain databases (e.g., PostgreSQL) do not permit true cross-database queries.
135
+ In such dialects, enabling cross-database support may lead to errors or be unnecessary.
136
+ Always consult your dialect's documentation or test sample queries to confirm if
137
+ cross-database querying is supported.
138
+
139
+ If False, database qualifiers are ignored, effectively restricting queries to a single
140
+ database (or making the database-level qualifier optional). This is typically
141
+ safer for databases that do not support cross-database operations or only provide a
142
+ two-level namespace (schema + table) instead of a three-level one (database + schema + table).
143
+ For example, some MySQL or PostgreSQL contexts might not need or permit cross-database queries at all.
144
+
145
+ See `use_flat_cross_db_query` which also affects how final queries are constructed.
85
146
  :param is_uppercase_names: Specifies if database accepts only uppercase names (e.g. Snowflake).
86
147
  :param normalize_name_method: Method to normalize database, schema and table names.
87
148
  Defaults to `name.lower()`.
@@ -397,3 +458,43 @@ class SQLParser(LoggingMixin):
397
458
  tables = schemas.setdefault(normalize_name(table.schema) if table.schema else None, [])
398
459
  tables.append(table.name)
399
460
  return hierarchy
461
+
462
+
463
+ def get_openlineage_facets_with_sql(
464
+ hook: DbApiHook, sql: str | list[str], conn_id: str, database: str | None
465
+ ) -> OperatorLineage | None:
466
+ connection = hook.get_connection(conn_id)
467
+ try:
468
+ database_info = hook.get_openlineage_database_info(connection)
469
+ except AttributeError:
470
+ database_info = None
471
+
472
+ if database_info is None:
473
+ log.debug("%s has no database info provided", hook)
474
+ return None
475
+
476
+ try:
477
+ sql_parser = SQLParser(
478
+ dialect=hook.get_openlineage_database_dialect(connection),
479
+ default_schema=hook.get_openlineage_default_schema(),
480
+ )
481
+ except AttributeError:
482
+ log.debug("%s failed to get database dialect", hook)
483
+ return None
484
+
485
+ try:
486
+ sqlalchemy_engine = hook.get_sqlalchemy_engine()
487
+ except Exception as e:
488
+ log.debug("Failed to get sql alchemy engine: %s", e)
489
+ sqlalchemy_engine = None
490
+
491
+ operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
492
+ sql=sql,
493
+ hook=hook,
494
+ database_info=database_info,
495
+ database=database,
496
+ sqlalchemy_engine=sqlalchemy_engine,
497
+ use_connection=should_use_external_connection(hook),
498
+ )
499
+
500
+ return operator_lineage
@@ -18,16 +18,29 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import logging
21
- from typing import TypeVar
21
+ from typing import TYPE_CHECKING, TypeVar
22
22
 
23
- from airflow.models import DAG, Operator, Param
23
+ from airflow.models import Param
24
24
  from airflow.models.xcom_arg import XComArg
25
25
 
26
+ if TYPE_CHECKING:
27
+ from airflow.sdk import DAG
28
+ from airflow.sdk.definitions._internal.abstractoperator import Operator
29
+ else:
30
+ try:
31
+ from airflow.sdk import DAG
32
+ except ImportError:
33
+ from airflow.models import DAG
34
+
26
35
  ENABLE_OL_PARAM_NAME = "_selective_enable_ol"
27
36
  ENABLE_OL_PARAM = Param(True, const=True)
28
37
  DISABLE_OL_PARAM = Param(False, const=False)
29
38
  T = TypeVar("T", bound="DAG | Operator")
30
39
 
40
+ if TYPE_CHECKING:
41
+ from airflow.sdk.definitions.baseoperator import BaseOperator as SdkBaseOperator
42
+
43
+
31
44
  log = logging.getLogger(__name__)
32
45
 
33
46
 
@@ -65,7 +78,7 @@ def disable_lineage(obj: T) -> T:
65
78
  return obj
66
79
 
67
80
 
68
- def is_task_lineage_enabled(task: Operator) -> bool:
81
+ def is_task_lineage_enabled(task: Operator | SdkBaseOperator) -> bool:
69
82
  """Check if selective enable OpenLineage parameter is set to True on task level."""
70
83
  if task.params.get(ENABLE_OL_PARAM_NAME) is False:
71
84
  log.debug(