acryl-datahub 1.1.0.5rc6__py3-none-any.whl → 1.1.0.5rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

Files changed (38) hide show
  1. {acryl_datahub-1.1.0.5rc6.dist-info → acryl_datahub-1.1.0.5rc7.dist-info}/METADATA +2547 -2547
  2. {acryl_datahub-1.1.0.5rc6.dist-info → acryl_datahub-1.1.0.5rc7.dist-info}/RECORD +38 -35
  3. datahub/_version.py +1 -1
  4. datahub/ingestion/api/report.py +183 -35
  5. datahub/ingestion/autogenerated/capability_summary.json +3366 -0
  6. datahub/ingestion/autogenerated/lineage.json +401 -0
  7. datahub/ingestion/autogenerated/lineage_helper.py +30 -128
  8. datahub/ingestion/run/pipeline.py +4 -1
  9. datahub/ingestion/source/bigquery_v2/bigquery.py +23 -22
  10. datahub/ingestion/source/cassandra/cassandra_profiling.py +6 -5
  11. datahub/ingestion/source/common/subtypes.py +1 -1
  12. datahub/ingestion/source/data_lake_common/object_store.py +40 -0
  13. datahub/ingestion/source/dremio/dremio_source.py +6 -3
  14. datahub/ingestion/source/gcs/gcs_source.py +4 -1
  15. datahub/ingestion/source/ge_data_profiler.py +28 -20
  16. datahub/ingestion/source/kafka_connect/source_connectors.py +59 -4
  17. datahub/ingestion/source/mock_data/datahub_mock_data.py +45 -0
  18. datahub/ingestion/source/redshift/usage.py +4 -3
  19. datahub/ingestion/source/s3/source.py +19 -3
  20. datahub/ingestion/source/snowflake/snowflake_queries.py +47 -3
  21. datahub/ingestion/source/snowflake/snowflake_usage_v2.py +8 -2
  22. datahub/ingestion/source/snowflake/stored_proc_lineage.py +143 -0
  23. datahub/ingestion/source/unity/proxy.py +4 -3
  24. datahub/ingestion/source/unity/source.py +10 -8
  25. datahub/integrations/assertion/snowflake/compiler.py +4 -3
  26. datahub/metadata/_internal_schema_classes.py +85 -4
  27. datahub/metadata/com/linkedin/pegasus2avro/settings/global/__init__.py +2 -0
  28. datahub/metadata/schema.avsc +54 -1
  29. datahub/metadata/schemas/CorpUserSettings.avsc +17 -1
  30. datahub/metadata/schemas/GlobalSettingsInfo.avsc +37 -0
  31. datahub/sdk/lineage_client.py +2 -0
  32. datahub/sql_parsing/sql_parsing_aggregator.py +3 -3
  33. datahub/sql_parsing/sqlglot_lineage.py +2 -0
  34. datahub/utilities/sqlalchemy_query_combiner.py +5 -2
  35. {acryl_datahub-1.1.0.5rc6.dist-info → acryl_datahub-1.1.0.5rc7.dist-info}/WHEEL +0 -0
  36. {acryl_datahub-1.1.0.5rc6.dist-info → acryl_datahub-1.1.0.5rc7.dist-info}/entry_points.txt +0 -0
  37. {acryl_datahub-1.1.0.5rc6.dist-info → acryl_datahub-1.1.0.5rc7.dist-info}/licenses/LICENSE +0 -0
  38. {acryl_datahub-1.1.0.5rc6.dist-info → acryl_datahub-1.1.0.5rc7.dist-info}/top_level.txt +0 -0
@@ -519,6 +519,13 @@ class ObjectStoreSourceAdapter:
519
519
  "get_external_url",
520
520
  lambda table_data: self.get_gcs_external_url(table_data),
521
521
  )
522
+ # Fix URI mismatch issue in pattern matching
523
+ self.register_customization(
524
+ "_normalize_uri_for_pattern_matching",
525
+ self._normalize_gcs_uri_for_pattern_matching,
526
+ )
527
+ # Fix URI handling in schema extraction - override strip_s3_prefix for GCS
528
+ self.register_customization("strip_s3_prefix", self._strip_gcs_prefix)
522
529
  elif platform == "s3":
523
530
  self.register_customization("is_s3_platform", lambda: True)
524
531
  self.register_customization("create_s3_path", self.create_s3_path)
@@ -612,6 +619,39 @@ class ObjectStoreSourceAdapter:
612
619
  return self.get_abs_external_url(table_data)
613
620
  return None
614
621
 
622
+ def _normalize_gcs_uri_for_pattern_matching(self, uri: str) -> str:
623
+ """
624
+ Normalize GCS URI for pattern matching.
625
+
626
+ This method converts gs:// URIs to s3:// URIs for pattern matching purposes,
627
+ fixing the URI mismatch issue in GCS ingestion.
628
+
629
+ Args:
630
+ uri: The URI to normalize
631
+
632
+ Returns:
633
+ The normalized URI for pattern matching
634
+ """
635
+ if uri.startswith("gs://"):
636
+ return uri.replace("gs://", "s3://", 1)
637
+ return uri
638
+
639
+ def _strip_gcs_prefix(self, uri: str) -> str:
640
+ """
641
+ Strip GCS prefix from URI.
642
+
643
+ This method removes the gs:// prefix from GCS URIs for path processing.
644
+
645
+ Args:
646
+ uri: The URI to strip the prefix from
647
+
648
+ Returns:
649
+ The URI without the gs:// prefix
650
+ """
651
+ if uri.startswith("gs://"):
652
+ return uri[5:] # Remove "gs://" prefix
653
+ return uri
654
+
615
655
 
616
656
  # Factory function to create an adapter for a specific platform
617
657
  def create_object_store_adapter(
@@ -261,9 +261,12 @@ class DremioSource(StatefulIngestionSourceBase):
261
261
 
262
262
  # Profiling
263
263
  if self.config.is_profiling_enabled():
264
- with self.report.new_stage(PROFILING), ThreadPoolExecutor(
265
- max_workers=self.config.profiling.max_workers
266
- ) as executor:
264
+ with (
265
+ self.report.new_stage(PROFILING),
266
+ ThreadPoolExecutor(
267
+ max_workers=self.config.profiling.max_workers
268
+ ) as executor,
269
+ ):
267
270
  future_to_dataset = {
268
271
  executor.submit(self.generate_profiles, dataset): dataset
269
272
  for dataset in datasets
@@ -112,6 +112,7 @@ class GCSSource(StatefulIngestionSourceBase):
112
112
  env=self.config.env,
113
113
  max_rows=self.config.max_rows,
114
114
  number_of_files_to_sample=self.config.number_of_files_to_sample,
115
+ platform=PLATFORM_GCS, # Ensure GCS platform is used for correct container subtypes
115
116
  )
116
117
  return s3_config
117
118
 
@@ -138,7 +139,9 @@ class GCSSource(StatefulIngestionSourceBase):
138
139
 
139
140
  def create_equivalent_s3_source(self, ctx: PipelineContext) -> S3Source:
140
141
  config = self.create_equivalent_s3_config()
141
- s3_source = S3Source(config, PipelineContext(ctx.run_id))
142
+ # Create a new context for S3 source without graph to avoid duplicate checkpointer registration
143
+ s3_ctx = PipelineContext(run_id=ctx.run_id, pipeline_name=ctx.pipeline_name)
144
+ s3_source = S3Source(config, s3_ctx)
142
145
  return self.s3_source_overrides(s3_source)
143
146
 
144
147
  def s3_source_overrides(self, source: S3Source) -> S3Source:
@@ -1213,26 +1213,34 @@ class DatahubGEProfiler:
1213
1213
  f"Will profile {len(requests)} table(s) with {max_workers} worker(s) - this may take a while"
1214
1214
  )
1215
1215
 
1216
- with PerfTimer() as timer, unittest.mock.patch(
1217
- "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_unique_count",
1218
- get_column_unique_count_dh_patch,
1219
- ), unittest.mock.patch(
1220
- "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_bigquery",
1221
- _get_column_quantiles_bigquery_patch,
1222
- ), unittest.mock.patch(
1223
- "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_awsathena",
1224
- _get_column_quantiles_awsathena_patch,
1225
- ), unittest.mock.patch(
1226
- "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_median",
1227
- _get_column_median_patch,
1228
- ), concurrent.futures.ThreadPoolExecutor(
1229
- max_workers=max_workers
1230
- ) as async_executor, SQLAlchemyQueryCombiner(
1231
- enabled=self.config.query_combiner_enabled,
1232
- catch_exceptions=self.config.catch_exceptions,
1233
- is_single_row_query_method=_is_single_row_query_method,
1234
- serial_execution_fallback_enabled=True,
1235
- ).activate() as query_combiner:
1216
+ with (
1217
+ PerfTimer() as timer,
1218
+ unittest.mock.patch(
1219
+ "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_unique_count",
1220
+ get_column_unique_count_dh_patch,
1221
+ ),
1222
+ unittest.mock.patch(
1223
+ "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_bigquery",
1224
+ _get_column_quantiles_bigquery_patch,
1225
+ ),
1226
+ unittest.mock.patch(
1227
+ "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_awsathena",
1228
+ _get_column_quantiles_awsathena_patch,
1229
+ ),
1230
+ unittest.mock.patch(
1231
+ "great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_median",
1232
+ _get_column_median_patch,
1233
+ ),
1234
+ concurrent.futures.ThreadPoolExecutor(
1235
+ max_workers=max_workers
1236
+ ) as async_executor,
1237
+ SQLAlchemyQueryCombiner(
1238
+ enabled=self.config.query_combiner_enabled,
1239
+ catch_exceptions=self.config.catch_exceptions,
1240
+ is_single_row_query_method=_is_single_row_query_method,
1241
+ serial_execution_fallback_enabled=True,
1242
+ ).activate() as query_combiner,
1243
+ ):
1236
1244
  # Submit the profiling requests to the thread pool executor.
1237
1245
  async_profiles = collections.deque(
1238
1246
  async_executor.submit(
@@ -20,6 +20,8 @@ from datahub.ingestion.source.sql.sqlalchemy_uri_mapper import (
20
20
  get_platform_from_sqlalchemy_uri,
21
21
  )
22
22
 
23
+ logger = logging.getLogger(__name__)
24
+
23
25
 
24
26
  @dataclass
25
27
  class ConfluentJDBCSourceConnector(BaseConnector):
@@ -392,7 +394,7 @@ class MongoSourceConnector(BaseConnector):
392
394
  db_connection_url=connector_manifest.config.get("connection.uri"),
393
395
  source_platform="mongodb",
394
396
  database_name=connector_manifest.config.get("database"),
395
- topic_prefix=connector_manifest.config.get("topic_prefix"),
397
+ topic_prefix=connector_manifest.config.get("topic.prefix"),
396
398
  transforms=(
397
399
  connector_manifest.config["transforms"].split(",")
398
400
  if "transforms" in connector_manifest.config
@@ -406,7 +408,11 @@ class MongoSourceConnector(BaseConnector):
406
408
  lineages: List[KafkaConnectLineage] = list()
407
409
  parser = self.get_parser(self.connector_manifest)
408
410
  source_platform = parser.source_platform
409
- topic_naming_pattern = r"mongodb\.(\w+)\.(\w+)"
411
+ topic_prefix = parser.topic_prefix or ""
412
+
413
+ # Escape topic_prefix to handle cases where it contains dots
414
+ # Some users configure topic.prefix like "my.mongodb" which breaks the regex
415
+ topic_naming_pattern = rf"{re.escape(topic_prefix)}\.(\w+)\.(\w+)"
410
416
 
411
417
  if not self.connector_manifest.topic_names:
412
418
  return lineages
@@ -429,6 +435,26 @@ class MongoSourceConnector(BaseConnector):
429
435
 
430
436
  @dataclass
431
437
  class DebeziumSourceConnector(BaseConnector):
438
+ # Debezium topic naming patterns by connector type
439
+ # - MySQL: {topic.prefix}.{database}.{table}
440
+ # - PostgreSQL: {topic.prefix}.{schema}.{table}
441
+ # - SQL Server: {topic.prefix}.{database}.{schema}.{table}
442
+ # - Oracle: {topic.prefix}.{schema}.{table}
443
+ # - DB2: {topic.prefix}.{schema}.{table}
444
+ # - MongoDB: {topic.prefix}.{database}.{collection}
445
+ # - Vitess: {topic.prefix}.{keyspace}.{table}
446
+
447
+ # Note SQL Server allows for "database.names" (multiple databases) config,
448
+ # and so database is in the topic naming pattern.
449
+ # However, others have "database.dbname" which is a single database name. For these connectors,
450
+ # additional databases would require a different connector instance
451
+
452
+ # Connectors with 2-level container in pattern (database + schema)
453
+ # Others have either database XOR schema, but not both
454
+ DEBEZIUM_CONNECTORS_WITH_2_LEVEL_CONTAINER_IN_PATTERN = {
455
+ "io.debezium.connector.sqlserver.SqlServerConnector",
456
+ }
457
+
432
458
  @dataclass
433
459
  class DebeziumParser:
434
460
  source_platform: str
@@ -514,16 +540,45 @@ class DebeziumSourceConnector(BaseConnector):
514
540
  source_platform = parser.source_platform
515
541
  server_name = parser.server_name
516
542
  database_name = parser.database_name
517
- topic_naming_pattern = rf"({server_name})\.(\w+\.\w+)"
543
+ # Escape server_name to handle cases where topic.prefix contains dots
544
+ # Some users configure topic.prefix like "my.server" which breaks the regex
545
+ server_name = server_name or ""
546
+ # Regex pattern (\w+\.\w+(?:\.\w+)?) supports BOTH 2-part and 3-part table names
547
+ topic_naming_pattern = rf"({re.escape(server_name)})\.(\w+\.\w+(?:\.\w+)?)"
518
548
 
519
549
  if not self.connector_manifest.topic_names:
520
550
  return lineages
521
551
 
552
+ # Handle connectors with 2-level container (database + schema) in topic pattern
553
+ connector_class = self.connector_manifest.config.get(CONNECTOR_CLASS, "")
554
+ maybe_duplicated_database_name = (
555
+ connector_class
556
+ in self.DEBEZIUM_CONNECTORS_WITH_2_LEVEL_CONTAINER_IN_PATTERN
557
+ )
558
+
522
559
  for topic in self.connector_manifest.topic_names:
523
560
  found = re.search(re.compile(topic_naming_pattern), topic)
561
+ logger.debug(
562
+ f"Processing topic: '{topic}' with regex pattern '{topic_naming_pattern}', found: {found}"
563
+ )
524
564
 
525
565
  if found:
526
- table_name = get_dataset_name(database_name, found.group(2))
566
+ # Extract the table part after server_name
567
+ table_part = found.group(2)
568
+
569
+ if (
570
+ maybe_duplicated_database_name
571
+ and database_name
572
+ and table_part.startswith(f"{database_name}.")
573
+ ):
574
+ table_part = table_part[len(database_name) + 1 :]
575
+
576
+ logger.debug(
577
+ f"Extracted table part: '{table_part}' from topic '{topic}'"
578
+ )
579
+ # Apply database name to create final dataset name
580
+ table_name = get_dataset_name(database_name, table_part)
581
+ logger.debug(f"Final table name: '{table_name}'")
527
582
 
528
583
  lineage = KafkaConnectLineage(
529
584
  source_dataset=table_name,
@@ -21,9 +21,13 @@ from datahub.ingestion.source.mock_data.datahub_mock_data_report import (
21
21
  )
22
22
  from datahub.ingestion.source.mock_data.table_naming_helper import TableNamingHelper
23
23
  from datahub.metadata.schema_classes import (
24
+ CalendarIntervalClass,
24
25
  DatasetLineageTypeClass,
26
+ DatasetProfileClass,
27
+ DatasetUsageStatisticsClass,
25
28
  StatusClass,
26
29
  SubTypesClass,
30
+ TimeWindowSizeClass,
27
31
  UpstreamClass,
28
32
  UpstreamLineageClass,
29
33
  )
@@ -278,6 +282,10 @@ class DataHubMockDataSource(Source):
278
282
 
279
283
  yield self._get_subtypes_aspect(table_name, i, j)
280
284
 
285
+ yield self._get_profile_aspect(table_name)
286
+
287
+ yield self._get_usage_aspect(table_name)
288
+
281
289
  yield from self._generate_lineage_for_table(
282
290
  table_name=table_name,
283
291
  table_level=i,
@@ -381,5 +389,42 @@ class DataHubMockDataSource(Source):
381
389
  )
382
390
  return mcp.as_workunit()
383
391
 
392
+ def _get_profile_aspect(self, table: str) -> MetadataWorkUnit:
393
+ urn = make_dataset_urn(
394
+ platform="fake",
395
+ name=table,
396
+ )
397
+ mcp = MetadataChangeProposalWrapper(
398
+ entityUrn=urn,
399
+ entityType="dataset",
400
+ aspect=DatasetProfileClass(
401
+ timestampMillis=0,
402
+ rowCount=100,
403
+ columnCount=10,
404
+ sizeInBytes=1000,
405
+ ),
406
+ )
407
+ return mcp.as_workunit()
408
+
409
+ def _get_usage_aspect(self, table: str) -> MetadataWorkUnit:
410
+ urn = make_dataset_urn(
411
+ platform="fake",
412
+ name=table,
413
+ )
414
+ mcp = MetadataChangeProposalWrapper(
415
+ entityUrn=urn,
416
+ entityType="dataset",
417
+ aspect=DatasetUsageStatisticsClass(
418
+ timestampMillis=0,
419
+ eventGranularity=TimeWindowSizeClass(unit=CalendarIntervalClass.DAY),
420
+ uniqueUserCount=0,
421
+ totalSqlQueries=0,
422
+ topSqlQueries=[],
423
+ userCounts=[],
424
+ fieldCounts=[],
425
+ ),
426
+ )
427
+ return mcp.as_workunit()
428
+
384
429
  def get_report(self) -> SourceReport:
385
430
  return self.report
@@ -182,9 +182,10 @@ class RedshiftUsageExtractor:
182
182
  self.report.num_operational_stats_filtered = 0
183
183
 
184
184
  if self.config.include_operational_stats:
185
- with self.report.new_stage(
186
- USAGE_EXTRACTION_OPERATIONAL_STATS
187
- ), PerfTimer() as timer:
185
+ with (
186
+ self.report.new_stage(USAGE_EXTRACTION_OPERATIONAL_STATS),
187
+ PerfTimer() as timer,
188
+ ):
188
189
  # Generate operation aspect workunits
189
190
  yield from self._gen_operation_aspect_workunits(
190
191
  self.connection, all_tables
@@ -682,7 +682,7 @@ class S3Source(StatefulIngestionSourceBase):
682
682
 
683
683
  logger.info(f"Extracting table schema from file: {table_data.full_path}")
684
684
  browse_path: str = (
685
- strip_s3_prefix(table_data.table_path)
685
+ self.strip_s3_prefix(table_data.table_path)
686
686
  if self.is_s3_platform()
687
687
  else table_data.table_path.strip("/")
688
688
  )
@@ -949,7 +949,10 @@ class S3Source(StatefulIngestionSourceBase):
949
949
  """
950
950
 
951
951
  def _is_allowed_path(path_spec_: PathSpec, s3_uri: str) -> bool:
952
- allowed = path_spec_.allowed(s3_uri)
952
+ # Normalize URI for pattern matching
953
+ normalized_uri = self._normalize_uri_for_pattern_matching(s3_uri)
954
+
955
+ allowed = path_spec_.allowed(normalized_uri)
953
956
  if not allowed:
954
957
  logger.debug(f"File {s3_uri} not allowed and skipping")
955
958
  self.report.report_file_dropped(s3_uri)
@@ -1394,8 +1397,13 @@ class S3Source(StatefulIngestionSourceBase):
1394
1397
  )
1395
1398
  table_dict: Dict[str, TableData] = {}
1396
1399
  for browse_path in file_browser:
1400
+ # Normalize URI for pattern matching
1401
+ normalized_file_path = self._normalize_uri_for_pattern_matching(
1402
+ browse_path.file
1403
+ )
1404
+
1397
1405
  if not path_spec.allowed(
1398
- browse_path.file,
1406
+ normalized_file_path,
1399
1407
  ignore_ext=self.is_s3_platform()
1400
1408
  and self.source_config.use_s3_content_type,
1401
1409
  ):
@@ -1471,5 +1479,13 @@ class S3Source(StatefulIngestionSourceBase):
1471
1479
  def is_s3_platform(self):
1472
1480
  return self.source_config.platform == "s3"
1473
1481
 
1482
+ def strip_s3_prefix(self, s3_uri: str) -> str:
1483
+ """Strip S3 prefix from URI. Can be overridden by adapters for other platforms."""
1484
+ return strip_s3_prefix(s3_uri)
1485
+
1486
+ def _normalize_uri_for_pattern_matching(self, uri: str) -> str:
1487
+ """Normalize URI for pattern matching. Can be overridden by adapters for other platforms."""
1488
+ return uri
1489
+
1474
1490
  def get_report(self):
1475
1491
  return self.report
@@ -44,6 +44,11 @@ from datahub.ingestion.source.snowflake.snowflake_utils import (
44
44
  SnowflakeIdentifierBuilder,
45
45
  SnowflakeStructuredReportMixin,
46
46
  )
47
+ from datahub.ingestion.source.snowflake.stored_proc_lineage import (
48
+ StoredProcCall,
49
+ StoredProcLineageReport,
50
+ StoredProcLineageTracker,
51
+ )
47
52
  from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
48
53
  from datahub.metadata.urns import CorpUserUrn
49
54
  from datahub.sql_parsing.schema_resolver import SchemaResolver
@@ -130,6 +135,7 @@ class SnowflakeQueriesExtractorReport(Report):
130
135
  aggregator_generate_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
131
136
 
132
137
  sql_aggregator: Optional[SqlAggregatorReport] = None
138
+ stored_proc_lineage: Optional[StoredProcLineageReport] = None
133
139
 
134
140
  num_ddl_queries_dropped: int = 0
135
141
  num_stream_queries_observed: int = 0
@@ -261,6 +267,7 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
261
267
  TableRename,
262
268
  TableSwap,
263
269
  ObservedQuery,
270
+ StoredProcCall,
264
271
  ]
265
272
  ] = self._exit_stack.enter_context(FileBackedList(shared_connection))
266
273
 
@@ -277,12 +284,34 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
277
284
  for entry in self.fetch_query_log(users):
278
285
  queries.append(entry)
279
286
 
287
+ stored_proc_tracker: StoredProcLineageTracker = self._exit_stack.enter_context(
288
+ StoredProcLineageTracker(
289
+ platform=self.identifiers.platform,
290
+ shared_connection=shared_connection,
291
+ )
292
+ )
293
+ self.report.stored_proc_lineage = stored_proc_tracker.report
294
+
280
295
  with self.report.audit_log_load_timer:
281
296
  for i, query in enumerate(queries):
282
297
  if i % 1000 == 0:
283
298
  logger.info(f"Added {i} query log entries to SQL aggregator")
284
299
 
285
- self.aggregator.add(query)
300
+ if isinstance(query, StoredProcCall):
301
+ stored_proc_tracker.add_stored_proc_call(query)
302
+ continue
303
+
304
+ if not (
305
+ isinstance(query, PreparsedQuery)
306
+ and stored_proc_tracker.add_related_query(query)
307
+ ):
308
+ # Only add to aggregator if it's not part of a stored procedure.
309
+ self.aggregator.add(query)
310
+
311
+ # Generate and add stored procedure lineage entries.
312
+ for lineage_entry in stored_proc_tracker.build_merged_lineage_entries():
313
+ # TODO: Make this the lowest priority lineage - so that it doesn't override other lineage entries.
314
+ self.aggregator.add(lineage_entry)
286
315
 
287
316
  with self.report.aggregator_generate_timer:
288
317
  yield from auto_workunit(self.aggregator.gen_metadata())
@@ -342,7 +371,9 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
342
371
 
343
372
  def fetch_query_log(
344
373
  self, users: UsersMapping
345
- ) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap, ObservedQuery]]:
374
+ ) -> Iterable[
375
+ Union[PreparsedQuery, TableRename, TableSwap, ObservedQuery, StoredProcCall]
376
+ ]:
346
377
  query_log_query = _build_enriched_query_log_query(
347
378
  start_time=self.config.window.start_time,
348
379
  end_time=self.config.window.end_time,
@@ -382,7 +413,9 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
382
413
 
383
414
  def _parse_audit_log_row(
384
415
  self, row: Dict[str, Any], users: UsersMapping
385
- ) -> Optional[Union[TableRename, TableSwap, PreparsedQuery, ObservedQuery]]:
416
+ ) -> Optional[
417
+ Union[TableRename, TableSwap, PreparsedQuery, ObservedQuery, StoredProcCall]
418
+ ]:
386
419
  json_fields = {
387
420
  "DIRECT_OBJECTS_ACCESSED",
388
421
  "OBJECTS_MODIFIED",
@@ -482,6 +515,17 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
482
515
  extra_info=extra_info,
483
516
  )
484
517
 
518
+ if snowflake_query_type == "CALL" and res["root_query_id"] is None:
519
+ return StoredProcCall(
520
+ # This is the top-level query ID that other entries will reference.
521
+ snowflake_root_query_id=res["query_id"],
522
+ query_text=query_text,
523
+ timestamp=timestamp,
524
+ user=user,
525
+ default_db=res["default_db"],
526
+ default_schema=res["default_schema"],
527
+ )
528
+
485
529
  upstreams = []
486
530
  column_usage = {}
487
531
 
@@ -231,7 +231,10 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
231
231
 
232
232
  with self.report.usage_aggregation.result_fetch_timer as fetch_timer:
233
233
  for row in results:
234
- with fetch_timer.pause(), self.report.usage_aggregation.result_skip_timer as skip_timer:
234
+ with (
235
+ fetch_timer.pause(),
236
+ self.report.usage_aggregation.result_skip_timer as skip_timer,
237
+ ):
235
238
  if results.rownumber is not None and results.rownumber % 1000 == 0:
236
239
  logger.debug(f"Processing usage row number {results.rownumber}")
237
240
  logger.debug(self.report.usage_aggregation.as_string())
@@ -255,7 +258,10 @@ class SnowflakeUsageExtractor(SnowflakeCommonMixin, Closeable):
255
258
  f"Skipping usage for {row['OBJECT_DOMAIN']} {dataset_identifier}, as table is not accessible."
256
259
  )
257
260
  continue
258
- with skip_timer.pause(), self.report.usage_aggregation.result_map_timer as map_timer:
261
+ with (
262
+ skip_timer.pause(),
263
+ self.report.usage_aggregation.result_map_timer as map_timer,
264
+ ):
259
265
  wu = self.build_usage_statistics_for_dataset(
260
266
  dataset_identifier, row
261
267
  )
@@ -0,0 +1,143 @@
1
+ import dataclasses
2
+ from dataclasses import dataclass
3
+ from datetime import datetime
4
+ from typing import Any, Iterable, List, Optional
5
+
6
+ from datahub.ingestion.api.closeable import Closeable
7
+ from datahub.metadata.urns import CorpUserUrn
8
+ from datahub.sql_parsing.sql_parsing_aggregator import (
9
+ PreparsedQuery,
10
+ UrnStr,
11
+ )
12
+ from datahub.sql_parsing.sqlglot_utils import get_query_fingerprint
13
+ from datahub.utilities.file_backed_collections import FileBackedDict
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class StoredProcCall:
18
+ snowflake_root_query_id: str
19
+
20
+ # Query text will typically be something like:
21
+ # "CALL SALES_FORECASTING.CUSTOMER_ANALYSIS_PROC();"
22
+ query_text: str
23
+
24
+ timestamp: datetime
25
+ user: CorpUserUrn
26
+ default_db: str
27
+ default_schema: str
28
+
29
+
30
+ @dataclass
31
+ class StoredProcExecutionLineage:
32
+ call: StoredProcCall
33
+
34
+ inputs: List[UrnStr]
35
+ outputs: List[UrnStr]
36
+
37
+
38
+ @dataclass
39
+ class StoredProcLineageReport:
40
+ num_stored_proc_calls: int = 0
41
+ num_related_queries: int = 0
42
+ num_related_queries_without_proc_call: int = 0
43
+
44
+ # Incremented at generation/build time.
45
+ num_stored_proc_lineage_entries: int = 0
46
+ num_stored_proc_calls_with_no_inputs: int = 0
47
+ num_stored_proc_calls_with_no_outputs: int = 0
48
+
49
+
50
+ class StoredProcLineageTracker(Closeable):
51
+ """
52
+ Tracks table-level lineage for Snowflake stored procedures.
53
+
54
+ Stored procedures in Snowflake trigger multiple SQL queries during execution.
55
+ Snowflake assigns each stored procedure call a unique query_id and uses this as the
56
+ root_query_id for all subsequent queries executed within that procedure. This allows
57
+ us to trace which queries belong to a specific stored procedure execution and build
58
+ table-level lineage by aggregating inputs/outputs from all related queries.
59
+ """
60
+
61
+ def __init__(self, platform: str, shared_connection: Optional[Any] = None):
62
+ self.platform = platform
63
+ self.report = StoredProcLineageReport()
64
+
65
+ # { root_query_id -> StoredProcExecutionLineage }
66
+ self._stored_proc_execution_lineage: FileBackedDict[
67
+ StoredProcExecutionLineage
68
+ ] = FileBackedDict(shared_connection)
69
+
70
+ def add_stored_proc_call(self, call: StoredProcCall) -> None:
71
+ """Add a stored procedure call to track."""
72
+ self._stored_proc_execution_lineage[call.snowflake_root_query_id] = (
73
+ StoredProcExecutionLineage(
74
+ call=call,
75
+ # Will be populated by subsequent queries.
76
+ inputs=[],
77
+ outputs=[],
78
+ )
79
+ )
80
+ self.report.num_stored_proc_calls += 1
81
+
82
+ def add_related_query(self, query: PreparsedQuery) -> bool:
83
+ """Add a query that might be related to a stored procedure execution.
84
+
85
+ Returns True if the query was added to a stored procedure execution, False otherwise.
86
+ """
87
+ snowflake_root_query_id = (query.extra_info or {}).get(
88
+ "snowflake_root_query_id"
89
+ )
90
+
91
+ if snowflake_root_query_id:
92
+ if snowflake_root_query_id not in self._stored_proc_execution_lineage:
93
+ self.report.num_related_queries_without_proc_call += 1
94
+ return False
95
+
96
+ stored_proc_execution = self._stored_proc_execution_lineage.for_mutation(
97
+ snowflake_root_query_id
98
+ )
99
+ stored_proc_execution.inputs.extend(query.upstreams)
100
+ if query.downstream is not None:
101
+ stored_proc_execution.outputs.append(query.downstream)
102
+ self.report.num_related_queries += 1
103
+ return True
104
+
105
+ return False
106
+
107
+ def build_merged_lineage_entries(self) -> Iterable[PreparsedQuery]:
108
+ # For stored procedures, we can only get table-level lineage from the audit log.
109
+ # We represent these as PreparsedQuery objects for now. Eventually we'll want to
110
+ # create dataJobInputOutput lineage instead.
111
+
112
+ for stored_proc_execution in self._stored_proc_execution_lineage.values():
113
+ if not stored_proc_execution.inputs:
114
+ self.report.num_stored_proc_calls_with_no_inputs += 1
115
+ continue
116
+
117
+ if not stored_proc_execution.outputs:
118
+ self.report.num_stored_proc_calls_with_no_outputs += 1
119
+ # Still continue to generate lineage for cases where we have inputs but no outputs
120
+
121
+ for downstream in stored_proc_execution.outputs:
122
+ stored_proc_query_id = get_query_fingerprint(
123
+ stored_proc_execution.call.query_text,
124
+ self.platform,
125
+ fast=True,
126
+ secondary_id=downstream,
127
+ )
128
+
129
+ lineage_entry = PreparsedQuery(
130
+ query_id=stored_proc_query_id,
131
+ query_text=stored_proc_execution.call.query_text,
132
+ upstreams=stored_proc_execution.inputs,
133
+ downstream=downstream,
134
+ query_count=0,
135
+ user=stored_proc_execution.call.user,
136
+ timestamp=stored_proc_execution.call.timestamp,
137
+ )
138
+
139
+ self.report.num_stored_proc_lineage_entries += 1
140
+ yield lineage_entry
141
+
142
+ def close(self) -> None:
143
+ self._stored_proc_execution_lineage.close()
@@ -507,9 +507,10 @@ class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
507
507
  def _execute_sql_query(self, query: str) -> List[List[str]]:
508
508
  """Execute SQL query using databricks-sql connector for better performance"""
509
509
  try:
510
- with connect(
511
- **self._sql_connection_params
512
- ) as connection, connection.cursor() as cursor:
510
+ with (
511
+ connect(**self._sql_connection_params) as connection,
512
+ connection.cursor() as cursor,
513
+ ):
513
514
  cursor.execute(query)
514
515
  return cursor.fetchall()
515
516