acryl-datahub-airflow-plugin 1.3.1.5__py3-none-any.whl → 1.3.1.5rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acryl_datahub_airflow_plugin-1.3.1.5rc1.dist-info/METADATA +91 -0
- acryl_datahub_airflow_plugin-1.3.1.5rc1.dist-info/RECORD +33 -0
- datahub_airflow_plugin/_airflow_shims.py +31 -64
- datahub_airflow_plugin/_config.py +19 -97
- datahub_airflow_plugin/_datahub_ol_adapter.py +2 -14
- datahub_airflow_plugin/_extractors.py +365 -0
- datahub_airflow_plugin/_version.py +1 -1
- datahub_airflow_plugin/client/airflow_generator.py +43 -147
- datahub_airflow_plugin/datahub_listener.py +790 -19
- datahub_airflow_plugin/example_dags/__init__.py +0 -32
- datahub_airflow_plugin/example_dags/graph_usage_sample_dag.py +4 -12
- datahub_airflow_plugin/hooks/datahub.py +2 -11
- datahub_airflow_plugin/operators/datahub.py +3 -20
- acryl_datahub_airflow_plugin-1.3.1.5.dist-info/METADATA +0 -303
- acryl_datahub_airflow_plugin-1.3.1.5.dist-info/RECORD +0 -65
- datahub_airflow_plugin/_airflow_compat.py +0 -32
- datahub_airflow_plugin/_airflow_version_specific.py +0 -184
- datahub_airflow_plugin/_constants.py +0 -16
- datahub_airflow_plugin/airflow2/__init__.py +0 -6
- datahub_airflow_plugin/airflow2/_airflow2_sql_parser_patch.py +0 -402
- datahub_airflow_plugin/airflow2/_airflow_compat.py +0 -95
- datahub_airflow_plugin/airflow2/_extractors.py +0 -477
- datahub_airflow_plugin/airflow2/_legacy_shims.py +0 -20
- datahub_airflow_plugin/airflow2/_openlineage_compat.py +0 -123
- datahub_airflow_plugin/airflow2/_provider_shims.py +0 -29
- datahub_airflow_plugin/airflow2/_shims.py +0 -88
- datahub_airflow_plugin/airflow2/datahub_listener.py +0 -1072
- datahub_airflow_plugin/airflow3/__init__.py +0 -6
- datahub_airflow_plugin/airflow3/_airflow3_sql_parser_patch.py +0 -408
- datahub_airflow_plugin/airflow3/_airflow_compat.py +0 -108
- datahub_airflow_plugin/airflow3/_athena_openlineage_patch.py +0 -153
- datahub_airflow_plugin/airflow3/_bigquery_openlineage_patch.py +0 -273
- datahub_airflow_plugin/airflow3/_shims.py +0 -82
- datahub_airflow_plugin/airflow3/_sqlite_openlineage_patch.py +0 -88
- datahub_airflow_plugin/airflow3/_teradata_openlineage_patch.py +0 -308
- datahub_airflow_plugin/airflow3/datahub_listener.py +0 -1452
- datahub_airflow_plugin/example_dags/airflow2/__init__.py +0 -8
- datahub_airflow_plugin/example_dags/airflow2/generic_recipe_sample_dag.py +0 -54
- datahub_airflow_plugin/example_dags/airflow2/graph_usage_sample_dag.py +0 -43
- datahub_airflow_plugin/example_dags/airflow2/lineage_backend_demo.py +0 -69
- datahub_airflow_plugin/example_dags/airflow2/lineage_backend_taskflow_demo.py +0 -69
- datahub_airflow_plugin/example_dags/airflow2/lineage_emission_dag.py +0 -81
- datahub_airflow_plugin/example_dags/airflow2/mysql_sample_dag.py +0 -68
- datahub_airflow_plugin/example_dags/airflow2/snowflake_sample_dag.py +0 -99
- datahub_airflow_plugin/example_dags/airflow3/__init__.py +0 -8
- datahub_airflow_plugin/example_dags/airflow3/lineage_backend_demo.py +0 -51
- datahub_airflow_plugin/example_dags/airflow3/lineage_backend_taskflow_demo.py +0 -51
- datahub_airflow_plugin/example_dags/airflow3/snowflake_sample_dag.py +0 -89
- {acryl_datahub_airflow_plugin-1.3.1.5.dist-info → acryl_datahub_airflow_plugin-1.3.1.5rc1.dist-info}/WHEEL +0 -0
- {acryl_datahub_airflow_plugin-1.3.1.5.dist-info → acryl_datahub_airflow_plugin-1.3.1.5rc1.dist-info}/entry_points.txt +0 -0
- {acryl_datahub_airflow_plugin-1.3.1.5.dist-info → acryl_datahub_airflow_plugin-1.3.1.5rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import logging
|
|
3
|
+
import unittest.mock
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
|
+
|
|
6
|
+
from airflow.models.operator import Operator
|
|
7
|
+
from openlineage.airflow.extractors import (
|
|
8
|
+
BaseExtractor,
|
|
9
|
+
ExtractorManager as OLExtractorManager,
|
|
10
|
+
TaskMetadata,
|
|
11
|
+
)
|
|
12
|
+
from openlineage.airflow.extractors.snowflake_extractor import SnowflakeExtractor
|
|
13
|
+
from openlineage.airflow.extractors.sql_extractor import SqlExtractor
|
|
14
|
+
from openlineage.airflow.utils import get_operator_class, try_import_from_string
|
|
15
|
+
from openlineage.client.facet import (
|
|
16
|
+
ExtractionError,
|
|
17
|
+
ExtractionErrorRunFacet,
|
|
18
|
+
SqlJobFacet,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
import datahub.emitter.mce_builder as builder
|
|
22
|
+
from datahub.ingestion.source.sql.sqlalchemy_uri_mapper import (
|
|
23
|
+
get_platform_from_sqlalchemy_uri,
|
|
24
|
+
)
|
|
25
|
+
from datahub.sql_parsing.sqlglot_lineage import (
|
|
26
|
+
SqlParsingResult,
|
|
27
|
+
create_lineage_sql_parsed_result,
|
|
28
|
+
)
|
|
29
|
+
from datahub_airflow_plugin._datahub_ol_adapter import OL_SCHEME_TWEAKS
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from airflow.models import DagRun, TaskInstance
|
|
33
|
+
|
|
34
|
+
from datahub.ingestion.graph.client import DataHubGraph
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
_DATAHUB_GRAPH_CONTEXT_KEY = "datahub_graph"
|
|
38
|
+
SQL_PARSING_RESULT_KEY = "datahub_sql"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ExtractorManager(OLExtractorManager):
|
|
42
|
+
# TODO: On Airflow 2.7, the OLExtractorManager is part of the built-in Airflow API.
|
|
43
|
+
# When available, we should use that instead. The same goe for most of the OL
|
|
44
|
+
# extractors.
|
|
45
|
+
|
|
46
|
+
def __init__(self):
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
_sql_operator_overrides = [
|
|
50
|
+
# The OL BigQuery extractor has some complex logic to fetch detect
|
|
51
|
+
# the BigQuery job_id and fetch lineage from there. However, it can't
|
|
52
|
+
# generate CLL, so we disable it and use our own extractor instead.
|
|
53
|
+
"BigQueryOperator",
|
|
54
|
+
"BigQueryExecuteQueryOperator",
|
|
55
|
+
# Athena also does something similar.
|
|
56
|
+
"AWSAthenaOperator",
|
|
57
|
+
# Additional types that OL doesn't support. This is only necessary because
|
|
58
|
+
# on older versions of Airflow, these operators don't inherit from SQLExecuteQueryOperator.
|
|
59
|
+
"SqliteOperator",
|
|
60
|
+
]
|
|
61
|
+
for operator in _sql_operator_overrides:
|
|
62
|
+
self.task_to_extractor.extractors[operator] = GenericSqlExtractor
|
|
63
|
+
|
|
64
|
+
self.task_to_extractor.extractors["AthenaOperator"] = AthenaOperatorExtractor
|
|
65
|
+
|
|
66
|
+
self.task_to_extractor.extractors["BigQueryInsertJobOperator"] = (
|
|
67
|
+
BigQueryInsertJobOperatorExtractor
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self.task_to_extractor.extractors["TeradataOperator"] = (
|
|
71
|
+
TeradataOperatorExtractor
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self._graph: Optional["DataHubGraph"] = None
|
|
75
|
+
|
|
76
|
+
@contextlib.contextmanager
|
|
77
|
+
def _patch_extractors(self):
|
|
78
|
+
with contextlib.ExitStack() as stack:
|
|
79
|
+
# Patch the SqlExtractor.extract() method.
|
|
80
|
+
stack.enter_context(
|
|
81
|
+
unittest.mock.patch.object(
|
|
82
|
+
SqlExtractor,
|
|
83
|
+
"extract",
|
|
84
|
+
_sql_extractor_extract,
|
|
85
|
+
)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Patch the SnowflakeExtractor.default_schema property.
|
|
89
|
+
stack.enter_context(
|
|
90
|
+
unittest.mock.patch.object(
|
|
91
|
+
SnowflakeExtractor,
|
|
92
|
+
"default_schema",
|
|
93
|
+
property(_snowflake_default_schema),
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# TODO: Override the BigQuery extractor to use the DataHub SQL parser.
|
|
98
|
+
# self.extractor_manager.add_extractor()
|
|
99
|
+
|
|
100
|
+
# TODO: Override the Athena extractor to use the DataHub SQL parser.
|
|
101
|
+
|
|
102
|
+
yield
|
|
103
|
+
|
|
104
|
+
def extract_metadata(
|
|
105
|
+
self,
|
|
106
|
+
dagrun: "DagRun",
|
|
107
|
+
task: "Operator",
|
|
108
|
+
complete: bool = False,
|
|
109
|
+
task_instance: Optional["TaskInstance"] = None,
|
|
110
|
+
task_uuid: Optional[str] = None,
|
|
111
|
+
graph: Optional["DataHubGraph"] = None,
|
|
112
|
+
) -> TaskMetadata:
|
|
113
|
+
self._graph = graph
|
|
114
|
+
with self._patch_extractors():
|
|
115
|
+
return super().extract_metadata(
|
|
116
|
+
dagrun, task, complete, task_instance, task_uuid
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def _get_extractor(self, task: "Operator") -> Optional[BaseExtractor]:
|
|
120
|
+
# By adding this, we can use the generic extractor as a fallback for
|
|
121
|
+
# any operator that inherits from SQLExecuteQueryOperator.
|
|
122
|
+
clazz = get_operator_class(task)
|
|
123
|
+
SQLExecuteQueryOperator = try_import_from_string(
|
|
124
|
+
"airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator"
|
|
125
|
+
)
|
|
126
|
+
if SQLExecuteQueryOperator and issubclass(clazz, SQLExecuteQueryOperator):
|
|
127
|
+
self.task_to_extractor.extractors.setdefault(
|
|
128
|
+
clazz.__name__, GenericSqlExtractor
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
extractor = super()._get_extractor(task)
|
|
132
|
+
if extractor:
|
|
133
|
+
extractor.set_context(_DATAHUB_GRAPH_CONTEXT_KEY, self._graph)
|
|
134
|
+
return extractor
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class GenericSqlExtractor(SqlExtractor):
|
|
138
|
+
# Note that the extract() method is patched elsewhere.
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def default_schema(self):
|
|
142
|
+
return super().default_schema
|
|
143
|
+
|
|
144
|
+
def _get_scheme(self) -> Optional[str]:
|
|
145
|
+
# Best effort conversion to DataHub platform names.
|
|
146
|
+
|
|
147
|
+
with contextlib.suppress(Exception):
|
|
148
|
+
if self.hook:
|
|
149
|
+
if hasattr(self.hook, "get_uri"):
|
|
150
|
+
uri = self.hook.get_uri()
|
|
151
|
+
return get_platform_from_sqlalchemy_uri(uri)
|
|
152
|
+
|
|
153
|
+
return self.conn.conn_type or super().dialect
|
|
154
|
+
|
|
155
|
+
def _get_database(self) -> Optional[str]:
|
|
156
|
+
if self.conn:
|
|
157
|
+
# For BigQuery, the "database" is the project name.
|
|
158
|
+
if hasattr(self.conn, "project_id"):
|
|
159
|
+
return self.conn.project_id
|
|
160
|
+
|
|
161
|
+
return self.conn.schema
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _sql_extractor_extract(self: "SqlExtractor") -> TaskMetadata:
|
|
166
|
+
# Why not override the OL sql_parse method directly, instead of overriding
|
|
167
|
+
# extract()? A few reasons:
|
|
168
|
+
#
|
|
169
|
+
# 1. We would want to pass the default_db and graph instance into our sql parser
|
|
170
|
+
# method. The OL code doesn't pass the default_db (despite having it available),
|
|
171
|
+
# and it's not clear how to get the graph instance into that method.
|
|
172
|
+
# 2. OL has some janky logic to fetch table schemas as part of the sql extractor.
|
|
173
|
+
# We don't want that behavior and this lets us disable it.
|
|
174
|
+
# 3. Our SqlParsingResult already has DataHub urns, whereas using SqlMeta would
|
|
175
|
+
# require us to convert those urns to OL uris, just for them to get converted
|
|
176
|
+
# back to urns later on in our processing.
|
|
177
|
+
|
|
178
|
+
task_name = f"{self.operator.dag_id}.{self.operator.task_id}"
|
|
179
|
+
sql = self.operator.sql
|
|
180
|
+
|
|
181
|
+
default_database = getattr(self.operator, "database", None)
|
|
182
|
+
if not default_database:
|
|
183
|
+
default_database = self.database
|
|
184
|
+
default_schema = self.default_schema
|
|
185
|
+
|
|
186
|
+
# TODO: Add better handling for sql being a list of statements.
|
|
187
|
+
if isinstance(sql, list):
|
|
188
|
+
logger.info(f"Got list of SQL statements for {task_name}. Using first one.")
|
|
189
|
+
sql = sql[0]
|
|
190
|
+
|
|
191
|
+
# Run the SQL parser.
|
|
192
|
+
scheme = self.scheme
|
|
193
|
+
platform = OL_SCHEME_TWEAKS.get(scheme, scheme)
|
|
194
|
+
|
|
195
|
+
return _parse_sql_into_task_metadata(
|
|
196
|
+
self,
|
|
197
|
+
sql,
|
|
198
|
+
platform=platform,
|
|
199
|
+
default_database=default_database,
|
|
200
|
+
default_schema=default_schema,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _parse_sql_into_task_metadata(
|
|
205
|
+
self: "BaseExtractor",
|
|
206
|
+
sql: str,
|
|
207
|
+
platform: str,
|
|
208
|
+
default_database: Optional[str],
|
|
209
|
+
default_schema: Optional[str],
|
|
210
|
+
) -> TaskMetadata:
|
|
211
|
+
task_name = f"{self.operator.dag_id}.{self.operator.task_id}"
|
|
212
|
+
|
|
213
|
+
run_facets = {}
|
|
214
|
+
job_facets = {"sql": SqlJobFacet(query=SqlExtractor._normalize_sql(sql))}
|
|
215
|
+
|
|
216
|
+
# Prepare to run the SQL parser.
|
|
217
|
+
graph = self.context.get(_DATAHUB_GRAPH_CONTEXT_KEY, None)
|
|
218
|
+
|
|
219
|
+
self.log.debug(
|
|
220
|
+
"Running the SQL parser %s (platform=%s, default db=%s, schema=%s): %s",
|
|
221
|
+
"with graph client" if graph else "in offline mode",
|
|
222
|
+
platform,
|
|
223
|
+
default_database,
|
|
224
|
+
default_schema,
|
|
225
|
+
sql,
|
|
226
|
+
)
|
|
227
|
+
sql_parsing_result: SqlParsingResult = create_lineage_sql_parsed_result(
|
|
228
|
+
query=sql,
|
|
229
|
+
graph=graph,
|
|
230
|
+
platform=platform,
|
|
231
|
+
platform_instance=None,
|
|
232
|
+
env=builder.DEFAULT_ENV,
|
|
233
|
+
default_db=default_database,
|
|
234
|
+
default_schema=default_schema,
|
|
235
|
+
)
|
|
236
|
+
self.log.debug(f"Got sql lineage {sql_parsing_result}")
|
|
237
|
+
|
|
238
|
+
if sql_parsing_result.debug_info.error:
|
|
239
|
+
error = sql_parsing_result.debug_info.error
|
|
240
|
+
run_facets["extractionError"] = ExtractionErrorRunFacet(
|
|
241
|
+
totalTasks=1,
|
|
242
|
+
failedTasks=1,
|
|
243
|
+
errors=[
|
|
244
|
+
ExtractionError(
|
|
245
|
+
errorMessage=str(error),
|
|
246
|
+
stackTrace=None,
|
|
247
|
+
task="datahub_sql_parser",
|
|
248
|
+
taskNumber=None,
|
|
249
|
+
)
|
|
250
|
+
],
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Save sql_parsing_result to the facets dict. It is removed from the
|
|
254
|
+
# facet dict in the extractor's processing logic.
|
|
255
|
+
run_facets[SQL_PARSING_RESULT_KEY] = sql_parsing_result # type: ignore
|
|
256
|
+
|
|
257
|
+
return TaskMetadata(
|
|
258
|
+
name=task_name,
|
|
259
|
+
inputs=[],
|
|
260
|
+
outputs=[],
|
|
261
|
+
run_facets=run_facets,
|
|
262
|
+
job_facets=job_facets,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class BigQueryInsertJobOperatorExtractor(BaseExtractor):
|
|
267
|
+
def extract(self) -> Optional[TaskMetadata]:
|
|
268
|
+
from airflow.providers.google.cloud.operators.bigquery import (
|
|
269
|
+
BigQueryInsertJobOperator, # type: ignore
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
operator: "BigQueryInsertJobOperator" = self.operator
|
|
273
|
+
sql = operator.configuration.get("query", {}).get("query")
|
|
274
|
+
if not sql:
|
|
275
|
+
self.log.warning("No query found in BigQueryInsertJobOperator")
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
destination_table = operator.configuration.get("query", {}).get(
|
|
279
|
+
"destinationTable"
|
|
280
|
+
)
|
|
281
|
+
destination_table_urn = None
|
|
282
|
+
if destination_table:
|
|
283
|
+
project_id = destination_table.get("projectId")
|
|
284
|
+
dataset_id = destination_table.get("datasetId")
|
|
285
|
+
table_id = destination_table.get("tableId")
|
|
286
|
+
|
|
287
|
+
if project_id and dataset_id and table_id:
|
|
288
|
+
destination_table_urn = builder.make_dataset_urn(
|
|
289
|
+
platform="bigquery",
|
|
290
|
+
name=f"{project_id}.{dataset_id}.{table_id}",
|
|
291
|
+
env=builder.DEFAULT_ENV,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
task_metadata = _parse_sql_into_task_metadata(
|
|
295
|
+
self,
|
|
296
|
+
sql,
|
|
297
|
+
platform="bigquery",
|
|
298
|
+
default_database=operator.project_id,
|
|
299
|
+
default_schema=None,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if destination_table_urn and task_metadata:
|
|
303
|
+
sql_parsing_result = task_metadata.run_facets.get(SQL_PARSING_RESULT_KEY)
|
|
304
|
+
if sql_parsing_result and isinstance(sql_parsing_result, SqlParsingResult):
|
|
305
|
+
sql_parsing_result.out_tables.append(destination_table_urn)
|
|
306
|
+
|
|
307
|
+
return task_metadata
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class AthenaOperatorExtractor(BaseExtractor):
|
|
311
|
+
def extract(self) -> Optional[TaskMetadata]:
|
|
312
|
+
from airflow.providers.amazon.aws.operators.athena import (
|
|
313
|
+
AthenaOperator, # type: ignore
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
operator: "AthenaOperator" = self.operator
|
|
317
|
+
sql = operator.query
|
|
318
|
+
if not sql:
|
|
319
|
+
self.log.warning("No query found in AthenaOperator")
|
|
320
|
+
return None
|
|
321
|
+
|
|
322
|
+
return _parse_sql_into_task_metadata(
|
|
323
|
+
self,
|
|
324
|
+
sql,
|
|
325
|
+
platform="athena",
|
|
326
|
+
default_database=None,
|
|
327
|
+
default_schema=self.operator.database,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _snowflake_default_schema(self: "SnowflakeExtractor") -> Optional[str]:
|
|
332
|
+
if hasattr(self.operator, "schema") and self.operator.schema is not None:
|
|
333
|
+
return self.operator.schema
|
|
334
|
+
return (
|
|
335
|
+
self.conn.extra_dejson.get("extra__snowflake__schema", "")
|
|
336
|
+
or self.conn.extra_dejson.get("schema", "")
|
|
337
|
+
or self.conn.schema
|
|
338
|
+
)
|
|
339
|
+
# TODO: Should we try a fallback of:
|
|
340
|
+
# execute_query_on_hook(self.hook, "SELECT current_schema();")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class TeradataOperatorExtractor(BaseExtractor):
|
|
344
|
+
"""Extractor for Teradata SQL operations.
|
|
345
|
+
|
|
346
|
+
Extracts lineage from TeradataOperator tasks by parsing the SQL queries
|
|
347
|
+
and understanding Teradata's two-tier database.table naming convention.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
def extract(self) -> Optional[TaskMetadata]:
|
|
351
|
+
from airflow.providers.teradata.operators.teradata import TeradataOperator
|
|
352
|
+
|
|
353
|
+
operator: "TeradataOperator" = self.operator
|
|
354
|
+
sql = operator.sql
|
|
355
|
+
if not sql:
|
|
356
|
+
self.log.warning("No query found in TeradataOperator")
|
|
357
|
+
return None
|
|
358
|
+
|
|
359
|
+
return _parse_sql_into_task_metadata(
|
|
360
|
+
self,
|
|
361
|
+
sql,
|
|
362
|
+
platform="teradata",
|
|
363
|
+
default_database=None,
|
|
364
|
+
default_schema=None,
|
|
365
|
+
)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
|
-
import
|
|
2
|
-
from
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, cast
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, cast
|
|
4
3
|
|
|
5
4
|
from airflow.configuration import conf
|
|
6
5
|
|
|
@@ -13,60 +12,12 @@ from datahub.emitter.generic_emitter import Emitter
|
|
|
13
12
|
from datahub.metadata.schema_classes import DataProcessTypeClass
|
|
14
13
|
from datahub.utilities.urns.data_flow_urn import DataFlowUrn
|
|
15
14
|
from datahub.utilities.urns.data_job_urn import DataJobUrn
|
|
16
|
-
from datahub_airflow_plugin._airflow_version_specific import (
|
|
17
|
-
get_task_instance_attributes,
|
|
18
|
-
)
|
|
19
15
|
from datahub_airflow_plugin._config import DatahubLineageConfig, DatajobUrl
|
|
20
16
|
|
|
21
17
|
if TYPE_CHECKING:
|
|
22
18
|
from airflow import DAG
|
|
23
19
|
from airflow.models import DagRun, TaskInstance
|
|
24
|
-
|
|
25
|
-
from datahub_airflow_plugin._airflow_shims import Operator
|
|
26
|
-
|
|
27
|
-
try:
|
|
28
|
-
from airflow.serialization.serialized_objects import (
|
|
29
|
-
SerializedBaseOperator,
|
|
30
|
-
SerializedDAG,
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
DagType = Union[DAG, SerializedDAG]
|
|
34
|
-
OperatorType = Union[Operator, SerializedBaseOperator]
|
|
35
|
-
except ImportError:
|
|
36
|
-
DagType = DAG # type: ignore[misc]
|
|
37
|
-
OperatorType = Operator # type: ignore[misc]
|
|
38
|
-
|
|
39
|
-
# Add type ignore for ti.task which can be MappedOperator from different modules
|
|
40
|
-
# airflow.models.mappedoperator.MappedOperator (2.x) vs airflow.sdk.definitions.mappedoperator.MappedOperator (3.x)
|
|
41
|
-
TaskType = Union[OperatorType, Any] # type: ignore[misc]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def _get_base_url() -> str:
|
|
45
|
-
"""
|
|
46
|
-
Get the Airflow base URL for constructing web UI links.
|
|
47
|
-
|
|
48
|
-
Tries multiple configuration sources for backward compatibility:
|
|
49
|
-
1. webserver.base_url (Airflow 2.x and 3.x with computed default)
|
|
50
|
-
2. api.base_url (Airflow 3.x alternative configuration)
|
|
51
|
-
3. Fallback to http://localhost:8080 (safe default)
|
|
52
|
-
|
|
53
|
-
Returns:
|
|
54
|
-
str: The base URL for the Airflow web UI
|
|
55
|
-
"""
|
|
56
|
-
# Try webserver.base_url first (works in both Airflow 2.x and 3.x)
|
|
57
|
-
# In Airflow 3.x, this is computed from web_server_host + web_server_port
|
|
58
|
-
base_url = conf.get("webserver", "base_url", fallback=None)
|
|
59
|
-
if base_url:
|
|
60
|
-
return base_url
|
|
61
|
-
|
|
62
|
-
# Fallback to api.base_url for environments that use it
|
|
63
|
-
# Some Airflow 3.x deployments may set this explicitly
|
|
64
|
-
api_base_url = conf.get("api", "base_url", fallback=None)
|
|
65
|
-
if api_base_url:
|
|
66
|
-
return api_base_url
|
|
67
|
-
|
|
68
|
-
# Final fallback to localhost (safe default for development/testing)
|
|
69
|
-
return "http://localhost:8080"
|
|
20
|
+
from airflow.models.operator import Operator
|
|
70
21
|
|
|
71
22
|
|
|
72
23
|
def _task_downstream_task_ids(operator: "Operator") -> Set[str]:
|
|
@@ -78,8 +29,8 @@ def _task_downstream_task_ids(operator: "Operator") -> Set[str]:
|
|
|
78
29
|
class AirflowGenerator:
|
|
79
30
|
@staticmethod
|
|
80
31
|
def _get_dependencies(
|
|
81
|
-
task: "
|
|
82
|
-
dag: "
|
|
32
|
+
task: "Operator",
|
|
33
|
+
dag: "DAG",
|
|
83
34
|
flow_urn: DataFlowUrn,
|
|
84
35
|
config: Optional[DatahubLineageConfig] = None,
|
|
85
36
|
) -> List[DataJobUrn]:
|
|
@@ -116,18 +67,14 @@ class AirflowGenerator:
|
|
|
116
67
|
|
|
117
68
|
# subdags are always named with 'parent.child' style or Airflow won't run them
|
|
118
69
|
# add connection from subdag trigger(s) if subdag task has no upstreams
|
|
119
|
-
# Note: is_subdag was removed in Airflow 3.x (subdags deprecated in Airflow 2.0)
|
|
120
|
-
parent_dag = getattr(dag, "parent_dag", None)
|
|
121
70
|
if (
|
|
122
|
-
|
|
123
|
-
and parent_dag is not None
|
|
71
|
+
dag.is_subdag
|
|
72
|
+
and dag.parent_dag is not None
|
|
124
73
|
and len(task.upstream_task_ids) == 0
|
|
125
74
|
):
|
|
126
75
|
# filter through the parent dag's tasks and find the subdag trigger(s)
|
|
127
76
|
subdags = [
|
|
128
|
-
x
|
|
129
|
-
for x in parent_dag.task_dict.values()
|
|
130
|
-
if x.subdag is not None # type: ignore[union-attr]
|
|
77
|
+
x for x in dag.parent_dag.task_dict.values() if x.subdag is not None
|
|
131
78
|
]
|
|
132
79
|
matched_subdags = [
|
|
133
80
|
x for x in subdags if x.subdag and x.subdag.dag_id == dag.dag_id
|
|
@@ -137,14 +84,14 @@ class AirflowGenerator:
|
|
|
137
84
|
subdag_task_id = matched_subdags[0].task_id
|
|
138
85
|
|
|
139
86
|
# iterate through the parent dag's tasks and find the ones that trigger the subdag
|
|
140
|
-
for upstream_task_id in parent_dag.task_dict:
|
|
141
|
-
upstream_task = parent_dag.task_dict[upstream_task_id]
|
|
87
|
+
for upstream_task_id in dag.parent_dag.task_dict:
|
|
88
|
+
upstream_task = dag.parent_dag.task_dict[upstream_task_id]
|
|
142
89
|
upstream_task_urn = DataJobUrn.create_from_ids(
|
|
143
90
|
data_flow_urn=str(flow_urn), job_id=upstream_task_id
|
|
144
91
|
)
|
|
145
92
|
|
|
146
93
|
# if the task triggers the subdag, link it to this node in the subdag
|
|
147
|
-
if subdag_task_id in sorted(_task_downstream_task_ids(upstream_task)):
|
|
94
|
+
if subdag_task_id in sorted(_task_downstream_task_ids(upstream_task)):
|
|
148
95
|
upstream_subdag_triggers.append(upstream_task_urn)
|
|
149
96
|
|
|
150
97
|
# If the operator is an ExternalTaskSensor then we set the remote task as upstream.
|
|
@@ -153,16 +100,14 @@ class AirflowGenerator:
|
|
|
153
100
|
external_task_upstreams = []
|
|
154
101
|
if isinstance(task, ExternalTaskSensor):
|
|
155
102
|
task = cast(ExternalTaskSensor, task)
|
|
156
|
-
|
|
157
|
-
external_dag_id = getattr(task, "external_dag_id", None)
|
|
158
|
-
if external_task_id is not None and external_dag_id is not None:
|
|
103
|
+
if hasattr(task, "external_task_id") and task.external_task_id is not None:
|
|
159
104
|
external_task_upstreams = [
|
|
160
105
|
DataJobUrn.create_from_ids(
|
|
161
|
-
job_id=external_task_id,
|
|
106
|
+
job_id=task.external_task_id,
|
|
162
107
|
data_flow_urn=str(
|
|
163
108
|
DataFlowUrn.create_from_ids(
|
|
164
109
|
orchestrator=flow_urn.orchestrator,
|
|
165
|
-
flow_id=external_dag_id,
|
|
110
|
+
flow_id=task.external_dag_id,
|
|
166
111
|
env=flow_urn.cluster,
|
|
167
112
|
platform_instance=config.platform_instance
|
|
168
113
|
if config
|
|
@@ -185,13 +130,13 @@ class AirflowGenerator:
|
|
|
185
130
|
return upstream_tasks
|
|
186
131
|
|
|
187
132
|
@staticmethod
|
|
188
|
-
def _extract_owners(dag: "
|
|
133
|
+
def _extract_owners(dag: "DAG") -> List[str]:
|
|
189
134
|
return [owner.strip() for owner in dag.owner.split(",")]
|
|
190
135
|
|
|
191
136
|
@staticmethod
|
|
192
137
|
def generate_dataflow(
|
|
193
138
|
config: DatahubLineageConfig,
|
|
194
|
-
dag: "
|
|
139
|
+
dag: "DAG",
|
|
195
140
|
) -> DataFlow:
|
|
196
141
|
"""
|
|
197
142
|
Generates a Dataflow object from an Airflow DAG
|
|
@@ -228,34 +173,12 @@ class AirflowGenerator:
|
|
|
228
173
|
"timezone",
|
|
229
174
|
]
|
|
230
175
|
|
|
231
|
-
def _serialize_dag_property(value: Any) -> str:
|
|
232
|
-
"""Serialize DAG property values to string format (JSON-compatible when possible)."""
|
|
233
|
-
if value is None:
|
|
234
|
-
return ""
|
|
235
|
-
elif isinstance(value, bool):
|
|
236
|
-
return "true" if value else "false"
|
|
237
|
-
elif isinstance(value, datetime):
|
|
238
|
-
return value.isoformat()
|
|
239
|
-
elif isinstance(value, (set, frozenset)):
|
|
240
|
-
# Convert set to JSON array string
|
|
241
|
-
return json.dumps(sorted(list(value)))
|
|
242
|
-
elif isinstance(value, tzinfo):
|
|
243
|
-
return str(value.tzname(None))
|
|
244
|
-
elif isinstance(value, (int, float)):
|
|
245
|
-
return str(value)
|
|
246
|
-
elif isinstance(value, str):
|
|
247
|
-
return value
|
|
248
|
-
else:
|
|
249
|
-
# For other types, convert to string but avoid repr() format
|
|
250
|
-
return str(value)
|
|
251
|
-
|
|
252
176
|
for key in allowed_flow_keys:
|
|
253
177
|
if hasattr(dag, key):
|
|
254
|
-
|
|
255
|
-
flow_property_bag[key] = _serialize_dag_property(value)
|
|
178
|
+
flow_property_bag[key] = repr(getattr(dag, key))
|
|
256
179
|
|
|
257
180
|
data_flow.properties = flow_property_bag
|
|
258
|
-
base_url =
|
|
181
|
+
base_url = conf.get("webserver", "base_url")
|
|
259
182
|
data_flow.url = f"{base_url}/tree?dag_id={dag.dag_id}"
|
|
260
183
|
|
|
261
184
|
if config.capture_ownership_info and dag.owner:
|
|
@@ -271,8 +194,8 @@ class AirflowGenerator:
|
|
|
271
194
|
return data_flow
|
|
272
195
|
|
|
273
196
|
@staticmethod
|
|
274
|
-
def _get_description(task: "
|
|
275
|
-
from
|
|
197
|
+
def _get_description(task: "Operator") -> Optional[str]:
|
|
198
|
+
from airflow.models.baseoperator import BaseOperator
|
|
276
199
|
|
|
277
200
|
if not isinstance(task, BaseOperator):
|
|
278
201
|
# TODO: Get docs for mapped operators.
|
|
@@ -293,8 +216,8 @@ class AirflowGenerator:
|
|
|
293
216
|
@staticmethod
|
|
294
217
|
def generate_datajob(
|
|
295
218
|
cluster: str,
|
|
296
|
-
task: "
|
|
297
|
-
dag: "
|
|
219
|
+
task: "Operator",
|
|
220
|
+
dag: "DAG",
|
|
298
221
|
set_dependencies: bool = True,
|
|
299
222
|
capture_owner: bool = True,
|
|
300
223
|
capture_tags: bool = True,
|
|
@@ -366,15 +289,11 @@ class AirflowGenerator:
|
|
|
366
289
|
break
|
|
367
290
|
|
|
368
291
|
datajob.properties = job_property_bag
|
|
369
|
-
base_url =
|
|
292
|
+
base_url = conf.get("webserver", "base_url")
|
|
370
293
|
|
|
371
294
|
if config and config.datajob_url_link == DatajobUrl.GRID:
|
|
372
295
|
datajob.url = f"{base_url}/dags/{dag.dag_id}/grid?task_id={task.task_id}"
|
|
373
|
-
elif config and config.datajob_url_link == DatajobUrl.TASKS:
|
|
374
|
-
# Airflow 3.x task URL format
|
|
375
|
-
datajob.url = f"{base_url}/dags/{dag.dag_id}/tasks/{task.task_id}"
|
|
376
296
|
else:
|
|
377
|
-
# Airflow 2.x taskinstance list URL format
|
|
378
297
|
datajob.url = f"{base_url}/taskinstance/list/?flt1_dag_id_equals={dag.dag_id}&_flt_3_task_id={task.task_id}"
|
|
379
298
|
|
|
380
299
|
if capture_owner and dag.owner:
|
|
@@ -528,12 +447,8 @@ class AirflowGenerator:
|
|
|
528
447
|
) -> DataProcessInstance:
|
|
529
448
|
if datajob is None:
|
|
530
449
|
assert ti.task is not None
|
|
531
|
-
# ti.task can be MappedOperator from different modules (airflow.models vs airflow.sdk.definitions)
|
|
532
450
|
datajob = AirflowGenerator.generate_datajob(
|
|
533
|
-
config.cluster,
|
|
534
|
-
ti.task, # type: ignore[arg-type]
|
|
535
|
-
dag,
|
|
536
|
-
config=config,
|
|
451
|
+
config.cluster, ti.task, dag, config=config
|
|
537
452
|
)
|
|
538
453
|
|
|
539
454
|
assert dag_run.run_id
|
|
@@ -543,23 +458,26 @@ class AirflowGenerator:
|
|
|
543
458
|
clone_inlets=True,
|
|
544
459
|
clone_outlets=True,
|
|
545
460
|
)
|
|
546
|
-
|
|
547
|
-
job_property_bag =
|
|
548
|
-
|
|
549
|
-
|
|
461
|
+
job_property_bag: Dict[str, str] = {}
|
|
462
|
+
job_property_bag["run_id"] = str(dag_run.run_id)
|
|
463
|
+
job_property_bag["duration"] = str(ti.duration)
|
|
464
|
+
job_property_bag["start_date"] = str(ti.start_date)
|
|
465
|
+
job_property_bag["end_date"] = str(ti.end_date)
|
|
466
|
+
job_property_bag["execution_date"] = str(ti.execution_date)
|
|
467
|
+
job_property_bag["try_number"] = str(ti.try_number - 1)
|
|
468
|
+
job_property_bag["max_tries"] = str(ti.max_tries)
|
|
469
|
+
# Not compatible with Airflow 1
|
|
470
|
+
if hasattr(ti, "external_executor_id"):
|
|
471
|
+
job_property_bag["external_executor_id"] = str(ti.external_executor_id)
|
|
472
|
+
job_property_bag["state"] = str(ti.state)
|
|
473
|
+
job_property_bag["operator"] = str(ti.operator)
|
|
474
|
+
job_property_bag["priority_weight"] = str(ti.priority_weight)
|
|
475
|
+
job_property_bag["log_url"] = ti.log_url
|
|
550
476
|
job_property_bag["orchestrator"] = "airflow"
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
if "task_id" not in job_property_bag:
|
|
554
|
-
job_property_bag["task_id"] = str(ti.task_id)
|
|
555
|
-
if "run_id" not in job_property_bag:
|
|
556
|
-
job_property_bag["run_id"] = str(dag_run.run_id)
|
|
557
|
-
|
|
477
|
+
job_property_bag["dag_id"] = str(dag.dag_id)
|
|
478
|
+
job_property_bag["task_id"] = str(ti.task_id)
|
|
558
479
|
dpi.properties.update(job_property_bag)
|
|
559
|
-
|
|
560
|
-
# Set URL if log_url is available
|
|
561
|
-
if "log_url" in job_property_bag:
|
|
562
|
-
dpi.url = job_property_bag["log_url"]
|
|
480
|
+
dpi.url = ti.log_url
|
|
563
481
|
|
|
564
482
|
# This property only exists in Airflow2
|
|
565
483
|
if hasattr(ti, "dag_run") and hasattr(ti.dag_run, "run_type"):
|
|
@@ -620,12 +538,8 @@ class AirflowGenerator:
|
|
|
620
538
|
"""
|
|
621
539
|
if datajob is None:
|
|
622
540
|
assert ti.task is not None
|
|
623
|
-
# ti.task can be MappedOperator from different modules (airflow.models vs airflow.sdk.definitions)
|
|
624
541
|
datajob = AirflowGenerator.generate_datajob(
|
|
625
|
-
cluster,
|
|
626
|
-
ti.task, # type: ignore[arg-type]
|
|
627
|
-
dag,
|
|
628
|
-
config=config,
|
|
542
|
+
cluster, ti.task, dag, config=config
|
|
629
543
|
)
|
|
630
544
|
|
|
631
545
|
if end_timestamp_millis is None:
|
|
@@ -652,24 +566,6 @@ class AirflowGenerator:
|
|
|
652
566
|
clone_inlets=True,
|
|
653
567
|
clone_outlets=True,
|
|
654
568
|
)
|
|
655
|
-
|
|
656
|
-
job_property_bag = get_task_instance_attributes(ti)
|
|
657
|
-
|
|
658
|
-
# Add orchestrator and DAG/task IDs
|
|
659
|
-
job_property_bag["orchestrator"] = "airflow"
|
|
660
|
-
if "dag_id" not in job_property_bag:
|
|
661
|
-
job_property_bag["dag_id"] = str(dag.dag_id)
|
|
662
|
-
if "task_id" not in job_property_bag:
|
|
663
|
-
job_property_bag["task_id"] = str(ti.task_id)
|
|
664
|
-
if "run_id" not in job_property_bag:
|
|
665
|
-
job_property_bag["run_id"] = str(dag_run.run_id)
|
|
666
|
-
|
|
667
|
-
dpi.properties.update(job_property_bag)
|
|
668
|
-
|
|
669
|
-
# Set URL if log_url is available
|
|
670
|
-
if "log_url" in job_property_bag:
|
|
671
|
-
dpi.url = job_property_bag["log_url"]
|
|
672
|
-
|
|
673
569
|
dpi.emit_process_end(
|
|
674
570
|
emitter=emitter,
|
|
675
571
|
end_timestamp_millis=end_timestamp_millis,
|