acryl-datahub 1.3.0.1rc4__py3-none-any.whl → 1.3.0.1rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of acryl-datahub might be problematic. Click here for more details.
- {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/METADATA +2637 -2633
- {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/RECORD +31 -28
- datahub/_version.py +1 -1
- datahub/ingestion/source/aws/aws_common.py +161 -0
- datahub/ingestion/source/bigquery_v2/bigquery.py +17 -1
- datahub/ingestion/source/bigquery_v2/bigquery_config.py +16 -0
- datahub/ingestion/source/bigquery_v2/bigquery_schema_gen.py +5 -3
- datahub/ingestion/source/bigquery_v2/queries_extractor.py +41 -4
- datahub/ingestion/source/redshift/redshift_schema.py +17 -12
- datahub/ingestion/source/redshift/usage.py +2 -2
- datahub/ingestion/source/snowflake/snowflake_config.py +16 -0
- datahub/ingestion/source/snowflake/snowflake_queries.py +46 -6
- datahub/ingestion/source/snowflake/snowflake_v2.py +14 -1
- datahub/ingestion/source/sql/mysql.py +101 -4
- datahub/ingestion/source/sql/postgres.py +81 -4
- datahub/ingestion/source/sql/sqlalchemy_uri.py +39 -7
- datahub/ingestion/source/state/redundant_run_skip_handler.py +21 -0
- datahub/ingestion/source/state/stateful_ingestion_base.py +30 -2
- datahub/metadata/_internal_schema_classes.py +772 -546
- datahub/metadata/_urns/urn_defs.py +1751 -1695
- datahub/metadata/com/linkedin/pegasus2avro/file/__init__.py +19 -0
- datahub/metadata/com/linkedin/pegasus2avro/metadata/key/__init__.py +2 -0
- datahub/metadata/schema.avsc +18450 -18242
- datahub/metadata/schemas/DataHubFileInfo.avsc +228 -0
- datahub/metadata/schemas/DataHubFileKey.avsc +21 -0
- datahub/metadata/schemas/DataHubPageModuleProperties.avsc +3 -1
- datahub/sql_parsing/sql_parsing_aggregator.py +18 -4
- {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/WHEEL +0 -0
- {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/entry_points.txt +0 -0
- {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/licenses/LICENSE +0 -0
- {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,7 @@ from datahub.ingestion.source.redshift.query import (
|
|
|
25
25
|
RedshiftServerlessQuery,
|
|
26
26
|
)
|
|
27
27
|
from datahub.ingestion.source.redshift.redshift_schema import (
|
|
28
|
+
RedshiftDataDictionary,
|
|
28
29
|
RedshiftTable,
|
|
29
30
|
RedshiftView,
|
|
30
31
|
)
|
|
@@ -263,8 +264,7 @@ class RedshiftUsageExtractor:
|
|
|
263
264
|
connection: redshift_connector.Connection,
|
|
264
265
|
all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]],
|
|
265
266
|
) -> Iterable[RedshiftAccessEvent]:
|
|
266
|
-
cursor =
|
|
267
|
-
cursor.execute(query)
|
|
267
|
+
cursor = RedshiftDataDictionary.get_query_result(conn=connection, query=query)
|
|
268
268
|
results = cursor.fetchmany()
|
|
269
269
|
field_names = [i[0] for i in cursor.description]
|
|
270
270
|
while results:
|
|
@@ -31,6 +31,7 @@ from datahub.ingestion.source.sql.sql_config import SQLCommonConfig, SQLFilterCo
|
|
|
31
31
|
from datahub.ingestion.source.state.stateful_ingestion_base import (
|
|
32
32
|
StatefulLineageConfigMixin,
|
|
33
33
|
StatefulProfilingConfigMixin,
|
|
34
|
+
StatefulTimeWindowConfigMixin,
|
|
34
35
|
StatefulUsageConfigMixin,
|
|
35
36
|
)
|
|
36
37
|
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
|
|
@@ -199,6 +200,7 @@ class SnowflakeV2Config(
|
|
|
199
200
|
SnowflakeUsageConfig,
|
|
200
201
|
StatefulLineageConfigMixin,
|
|
201
202
|
StatefulUsageConfigMixin,
|
|
203
|
+
StatefulTimeWindowConfigMixin,
|
|
202
204
|
StatefulProfilingConfigMixin,
|
|
203
205
|
ClassificationSourceConfigMixin,
|
|
204
206
|
IncrementalPropertiesConfigMixin,
|
|
@@ -477,6 +479,20 @@ class SnowflakeV2Config(
|
|
|
477
479
|
|
|
478
480
|
return shares
|
|
479
481
|
|
|
482
|
+
@root_validator(pre=False, skip_on_failure=True)
|
|
483
|
+
def validate_queries_v2_stateful_ingestion(cls, values: Dict) -> Dict:
|
|
484
|
+
if values.get("use_queries_v2"):
|
|
485
|
+
if values.get("enable_stateful_lineage_ingestion") or values.get(
|
|
486
|
+
"enable_stateful_usage_ingestion"
|
|
487
|
+
):
|
|
488
|
+
logger.warning(
|
|
489
|
+
"enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion are deprecated "
|
|
490
|
+
"when using use_queries_v2=True. These configs only work with the legacy (non-queries v2) extraction path. "
|
|
491
|
+
"For queries v2, use enable_stateful_time_window instead to enable stateful ingestion "
|
|
492
|
+
"for the unified time window extraction (lineage + usage + operations + queries)."
|
|
493
|
+
)
|
|
494
|
+
return values
|
|
495
|
+
|
|
480
496
|
def outbounds(self) -> Dict[str, Set[DatabaseId]]:
|
|
481
497
|
"""
|
|
482
498
|
Returns mapping of
|
|
@@ -17,6 +17,7 @@ from datahub.configuration.common import AllowDenyPattern, ConfigModel, HiddenFr
|
|
|
17
17
|
from datahub.configuration.time_window_config import (
|
|
18
18
|
BaseTimeWindowConfig,
|
|
19
19
|
BucketDuration,
|
|
20
|
+
get_time_bucket,
|
|
20
21
|
)
|
|
21
22
|
from datahub.ingestion.api.closeable import Closeable
|
|
22
23
|
from datahub.ingestion.api.common import PipelineContext
|
|
@@ -50,6 +51,9 @@ from datahub.ingestion.source.snowflake.stored_proc_lineage import (
|
|
|
50
51
|
StoredProcLineageReport,
|
|
51
52
|
StoredProcLineageTracker,
|
|
52
53
|
)
|
|
54
|
+
from datahub.ingestion.source.state.redundant_run_skip_handler import (
|
|
55
|
+
RedundantQueriesRunSkipHandler,
|
|
56
|
+
)
|
|
53
57
|
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
|
|
54
58
|
from datahub.metadata.urns import CorpUserUrn
|
|
55
59
|
from datahub.sql_parsing.schema_resolver import SchemaResolver
|
|
@@ -180,6 +184,7 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
|
|
|
180
184
|
structured_report: SourceReport,
|
|
181
185
|
filters: SnowflakeFilter,
|
|
182
186
|
identifiers: SnowflakeIdentifierBuilder,
|
|
187
|
+
redundant_run_skip_handler: Optional[RedundantQueriesRunSkipHandler] = None,
|
|
183
188
|
graph: Optional[DataHubGraph] = None,
|
|
184
189
|
schema_resolver: Optional[SchemaResolver] = None,
|
|
185
190
|
discovered_tables: Optional[List[str]] = None,
|
|
@@ -191,9 +196,13 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
|
|
|
191
196
|
self.filters = filters
|
|
192
197
|
self.identifiers = identifiers
|
|
193
198
|
self.discovered_tables = set(discovered_tables) if discovered_tables else None
|
|
199
|
+
self.redundant_run_skip_handler = redundant_run_skip_handler
|
|
194
200
|
|
|
195
201
|
self._structured_report = structured_report
|
|
196
202
|
|
|
203
|
+
# Adjust time window based on stateful ingestion state
|
|
204
|
+
self.start_time, self.end_time = self._get_time_window()
|
|
205
|
+
|
|
197
206
|
# The exit stack helps ensure that we close all the resources we open.
|
|
198
207
|
self._exit_stack = contextlib.ExitStack()
|
|
199
208
|
|
|
@@ -211,8 +220,8 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
|
|
|
211
220
|
generate_query_usage_statistics=self.config.include_query_usage_statistics,
|
|
212
221
|
usage_config=BaseUsageConfig(
|
|
213
222
|
bucket_duration=self.config.window.bucket_duration,
|
|
214
|
-
start_time=self.
|
|
215
|
-
end_time=self.
|
|
223
|
+
start_time=self.start_time,
|
|
224
|
+
end_time=self.end_time,
|
|
216
225
|
user_email_pattern=self.config.user_email_pattern,
|
|
217
226
|
# TODO make the rest of the fields configurable
|
|
218
227
|
),
|
|
@@ -228,6 +237,34 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
|
|
|
228
237
|
def structured_reporter(self) -> SourceReport:
|
|
229
238
|
return self._structured_report
|
|
230
239
|
|
|
240
|
+
def _get_time_window(self) -> tuple[datetime, datetime]:
|
|
241
|
+
if self.redundant_run_skip_handler:
|
|
242
|
+
start_time, end_time = (
|
|
243
|
+
self.redundant_run_skip_handler.suggest_run_time_window(
|
|
244
|
+
self.config.window.start_time,
|
|
245
|
+
self.config.window.end_time,
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
start_time = self.config.window.start_time
|
|
250
|
+
end_time = self.config.window.end_time
|
|
251
|
+
|
|
252
|
+
# Usage statistics are aggregated per bucket (typically per day).
|
|
253
|
+
# To ensure accurate aggregated metrics, we need to align the start_time
|
|
254
|
+
# to the beginning of a bucket so that we include complete bucket periods.
|
|
255
|
+
if self.config.include_usage_statistics:
|
|
256
|
+
start_time = get_time_bucket(start_time, self.config.window.bucket_duration)
|
|
257
|
+
|
|
258
|
+
return start_time, end_time
|
|
259
|
+
|
|
260
|
+
def _update_state(self) -> None:
|
|
261
|
+
if self.redundant_run_skip_handler:
|
|
262
|
+
self.redundant_run_skip_handler.update_state(
|
|
263
|
+
self.config.window.start_time,
|
|
264
|
+
self.config.window.end_time,
|
|
265
|
+
self.config.window.bucket_duration,
|
|
266
|
+
)
|
|
267
|
+
|
|
231
268
|
@functools.cached_property
|
|
232
269
|
def local_temp_path(self) -> pathlib.Path:
|
|
233
270
|
if self.config.local_temp_path:
|
|
@@ -355,6 +392,9 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
|
|
|
355
392
|
with self.report.aggregator_generate_timer:
|
|
356
393
|
yield from auto_workunit(self.aggregator.gen_metadata())
|
|
357
394
|
|
|
395
|
+
# Update the stateful ingestion state after successful extraction
|
|
396
|
+
self._update_state()
|
|
397
|
+
|
|
358
398
|
def fetch_users(self) -> UsersMapping:
|
|
359
399
|
users: UsersMapping = dict()
|
|
360
400
|
with self.structured_reporter.report_exc("Error fetching users from Snowflake"):
|
|
@@ -378,8 +418,8 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
|
|
|
378
418
|
# Derived from _populate_external_lineage_from_copy_history.
|
|
379
419
|
|
|
380
420
|
query: str = SnowflakeQuery.copy_lineage_history(
|
|
381
|
-
start_time_millis=int(self.
|
|
382
|
-
end_time_millis=int(self.
|
|
421
|
+
start_time_millis=int(self.start_time.timestamp() * 1000),
|
|
422
|
+
end_time_millis=int(self.end_time.timestamp() * 1000),
|
|
383
423
|
downstreams_deny_pattern=self.config.temporary_tables_pattern,
|
|
384
424
|
)
|
|
385
425
|
|
|
@@ -414,8 +454,8 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
|
|
|
414
454
|
Union[PreparsedQuery, TableRename, TableSwap, ObservedQuery, StoredProcCall]
|
|
415
455
|
]:
|
|
416
456
|
query_log_query = QueryLogQueryBuilder(
|
|
417
|
-
start_time=self.
|
|
418
|
-
end_time=self.
|
|
457
|
+
start_time=self.start_time,
|
|
458
|
+
end_time=self.end_time,
|
|
419
459
|
bucket_duration=self.config.window.bucket_duration,
|
|
420
460
|
deny_usernames=self.config.pushdown_deny_usernames,
|
|
421
461
|
allow_usernames=self.config.pushdown_allow_usernames,
|
|
@@ -73,6 +73,7 @@ from datahub.ingestion.source.snowflake.snowflake_utils import (
|
|
|
73
73
|
from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler
|
|
74
74
|
from datahub.ingestion.source.state.redundant_run_skip_handler import (
|
|
75
75
|
RedundantLineageRunSkipHandler,
|
|
76
|
+
RedundantQueriesRunSkipHandler,
|
|
76
77
|
RedundantUsageRunSkipHandler,
|
|
77
78
|
)
|
|
78
79
|
from datahub.ingestion.source.state.stale_entity_removal_handler import (
|
|
@@ -207,7 +208,7 @@ class SnowflakeV2Source(
|
|
|
207
208
|
)
|
|
208
209
|
self.report.sql_aggregator = self.aggregator.report
|
|
209
210
|
|
|
210
|
-
if self.config.include_table_lineage:
|
|
211
|
+
if self.config.include_table_lineage and not self.config.use_queries_v2:
|
|
211
212
|
redundant_lineage_run_skip_handler: Optional[
|
|
212
213
|
RedundantLineageRunSkipHandler
|
|
213
214
|
] = None
|
|
@@ -589,6 +590,17 @@ class SnowflakeV2Source(
|
|
|
589
590
|
with self.report.new_stage(f"*: {QUERIES_EXTRACTION}"):
|
|
590
591
|
schema_resolver = self.aggregator._schema_resolver
|
|
591
592
|
|
|
593
|
+
redundant_queries_run_skip_handler: Optional[
|
|
594
|
+
RedundantQueriesRunSkipHandler
|
|
595
|
+
] = None
|
|
596
|
+
if self.config.enable_stateful_time_window:
|
|
597
|
+
redundant_queries_run_skip_handler = RedundantQueriesRunSkipHandler(
|
|
598
|
+
source=self,
|
|
599
|
+
config=self.config,
|
|
600
|
+
pipeline_name=self.ctx.pipeline_name,
|
|
601
|
+
run_id=self.ctx.run_id,
|
|
602
|
+
)
|
|
603
|
+
|
|
592
604
|
queries_extractor = SnowflakeQueriesExtractor(
|
|
593
605
|
connection=self.connection,
|
|
594
606
|
# TODO: this should be its own section in main recipe
|
|
@@ -614,6 +626,7 @@ class SnowflakeV2Source(
|
|
|
614
626
|
structured_report=self.report,
|
|
615
627
|
filters=self.filters,
|
|
616
628
|
identifiers=self.identifiers,
|
|
629
|
+
redundant_run_skip_handler=redundant_queries_run_skip_handler,
|
|
617
630
|
schema_resolver=schema_resolver,
|
|
618
631
|
discovered_tables=self.discovered_datasets,
|
|
619
632
|
graph=self.ctx.graph,
|
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
# This import verifies that the dependencies are available.
|
|
2
|
-
|
|
3
|
-
from typing import List
|
|
2
|
+
import logging
|
|
3
|
+
from typing import TYPE_CHECKING, Any, List, Optional
|
|
4
4
|
|
|
5
5
|
import pymysql # noqa: F401
|
|
6
6
|
from pydantic.fields import Field
|
|
7
|
-
from sqlalchemy import util
|
|
7
|
+
from sqlalchemy import create_engine, event, inspect, util
|
|
8
8
|
from sqlalchemy.dialects.mysql import BIT, base
|
|
9
9
|
from sqlalchemy.dialects.mysql.enumerated import SET
|
|
10
10
|
from sqlalchemy.engine.reflection import Inspector
|
|
11
11
|
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from sqlalchemy.engine import Engine
|
|
14
|
+
|
|
12
15
|
from datahub.configuration.common import AllowDenyPattern, HiddenFromDocs
|
|
13
16
|
from datahub.ingestion.api.decorators import (
|
|
14
17
|
SourceCapability,
|
|
@@ -18,11 +21,16 @@ from datahub.ingestion.api.decorators import (
|
|
|
18
21
|
platform_name,
|
|
19
22
|
support_status,
|
|
20
23
|
)
|
|
24
|
+
from datahub.ingestion.source.aws.aws_common import (
|
|
25
|
+
AwsConnectionConfig,
|
|
26
|
+
RDSIAMTokenManager,
|
|
27
|
+
)
|
|
21
28
|
from datahub.ingestion.source.sql.sql_common import (
|
|
22
29
|
make_sqlalchemy_type,
|
|
23
30
|
register_custom_type,
|
|
24
31
|
)
|
|
25
32
|
from datahub.ingestion.source.sql.sql_config import SQLAlchemyConnectionConfig
|
|
33
|
+
from datahub.ingestion.source.sql.sqlalchemy_uri import parse_host_port
|
|
26
34
|
from datahub.ingestion.source.sql.stored_procedures.base import (
|
|
27
35
|
BaseProcedure,
|
|
28
36
|
)
|
|
@@ -31,6 +39,9 @@ from datahub.ingestion.source.sql.two_tier_sql_source import (
|
|
|
31
39
|
TwoTierSQLAlchemySource,
|
|
32
40
|
)
|
|
33
41
|
from datahub.metadata.schema_classes import BytesTypeClass
|
|
42
|
+
from datahub.utilities.str_enum import StrEnum
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
34
45
|
|
|
35
46
|
SET.__repr__ = util.generic_repr # type:ignore
|
|
36
47
|
|
|
@@ -54,11 +65,33 @@ base.ischema_names["polygon"] = POLYGON
|
|
|
54
65
|
base.ischema_names["decimal128"] = DECIMAL128
|
|
55
66
|
|
|
56
67
|
|
|
68
|
+
class MySQLAuthMode(StrEnum):
|
|
69
|
+
"""Authentication mode for MySQL connection."""
|
|
70
|
+
|
|
71
|
+
PASSWORD = "PASSWORD"
|
|
72
|
+
AWS_IAM = "AWS_IAM"
|
|
73
|
+
|
|
74
|
+
|
|
57
75
|
class MySQLConnectionConfig(SQLAlchemyConnectionConfig):
|
|
58
76
|
# defaults
|
|
59
77
|
host_port: str = Field(default="localhost:3306", description="MySQL host URL.")
|
|
60
78
|
scheme: HiddenFromDocs[str] = "mysql+pymysql"
|
|
61
79
|
|
|
80
|
+
# Authentication configuration
|
|
81
|
+
auth_mode: MySQLAuthMode = Field(
|
|
82
|
+
default=MySQLAuthMode.PASSWORD,
|
|
83
|
+
description="Authentication mode to use for the MySQL connection. "
|
|
84
|
+
"Options are 'PASSWORD' (default) for standard username/password authentication, "
|
|
85
|
+
"or 'AWS_IAM' for AWS RDS IAM authentication.",
|
|
86
|
+
)
|
|
87
|
+
aws_config: AwsConnectionConfig = Field(
|
|
88
|
+
default_factory=AwsConnectionConfig,
|
|
89
|
+
description="AWS configuration for RDS IAM authentication (only used when auth_mode is AWS_IAM). "
|
|
90
|
+
"Provides full control over AWS credentials, region, profiles, role assumption, retry logic, and proxy settings. "
|
|
91
|
+
"If not explicitly configured, boto3 will automatically use the default credential chain and region from "
|
|
92
|
+
"environment variables (AWS_DEFAULT_REGION, AWS_REGION), AWS config files (~/.aws/config), or IAM role metadata.",
|
|
93
|
+
)
|
|
94
|
+
|
|
62
95
|
|
|
63
96
|
class MySQLConfig(MySQLConnectionConfig, TwoTierSQLAlchemyConfig):
|
|
64
97
|
def get_identifier(self, *, schema: str, table: str) -> str:
|
|
@@ -91,9 +124,27 @@ class MySQLSource(TwoTierSQLAlchemySource):
|
|
|
91
124
|
Table, row, and column statistics via optional SQL profiling
|
|
92
125
|
"""
|
|
93
126
|
|
|
94
|
-
|
|
127
|
+
config: MySQLConfig
|
|
128
|
+
|
|
129
|
+
def __init__(self, config: MySQLConfig, ctx: Any):
|
|
95
130
|
super().__init__(config, ctx, self.get_platform())
|
|
96
131
|
|
|
132
|
+
self._rds_iam_token_manager: Optional[RDSIAMTokenManager] = None
|
|
133
|
+
if config.auth_mode == MySQLAuthMode.AWS_IAM:
|
|
134
|
+
hostname, port = parse_host_port(config.host_port, default_port=3306)
|
|
135
|
+
if port is None:
|
|
136
|
+
raise ValueError("Port must be specified for RDS IAM authentication")
|
|
137
|
+
|
|
138
|
+
if not config.username:
|
|
139
|
+
raise ValueError("username is required for RDS IAM authentication")
|
|
140
|
+
|
|
141
|
+
self._rds_iam_token_manager = RDSIAMTokenManager(
|
|
142
|
+
endpoint=hostname,
|
|
143
|
+
username=config.username,
|
|
144
|
+
port=port,
|
|
145
|
+
aws_config=config.aws_config,
|
|
146
|
+
)
|
|
147
|
+
|
|
97
148
|
def get_platform(self):
|
|
98
149
|
return "mysql"
|
|
99
150
|
|
|
@@ -102,6 +153,52 @@ class MySQLSource(TwoTierSQLAlchemySource):
|
|
|
102
153
|
config = MySQLConfig.parse_obj(config_dict)
|
|
103
154
|
return cls(config, ctx)
|
|
104
155
|
|
|
156
|
+
def _setup_rds_iam_event_listener(
|
|
157
|
+
self, engine: "Engine", database_name: Optional[str] = None
|
|
158
|
+
) -> None:
|
|
159
|
+
"""Setup SQLAlchemy event listener to inject RDS IAM tokens."""
|
|
160
|
+
if not (
|
|
161
|
+
self.config.auth_mode == MySQLAuthMode.AWS_IAM
|
|
162
|
+
and self._rds_iam_token_manager
|
|
163
|
+
):
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
def do_connect_listener(_dialect, _conn_rec, _cargs, cparams):
|
|
167
|
+
if not self._rds_iam_token_manager:
|
|
168
|
+
raise RuntimeError("RDS IAM Token Manager is not initialized")
|
|
169
|
+
cparams["password"] = self._rds_iam_token_manager.get_token()
|
|
170
|
+
# PyMySQL requires SSL to be enabled for RDS IAM authentication.
|
|
171
|
+
# Preserve any existing SSL configuration, otherwise enable with default settings.
|
|
172
|
+
# The {"ssl": True} dict is a workaround to make PyMySQL recognize that SSL
|
|
173
|
+
# should be enabled, since the library requires a truthy value in the ssl parameter.
|
|
174
|
+
# See https://pymysql.readthedocs.io/en/latest/modules/connections.html#pymysql.connections.Connection
|
|
175
|
+
cparams["ssl"] = cparams.get("ssl") or {"ssl": True}
|
|
176
|
+
|
|
177
|
+
event.listen(engine, "do_connect", do_connect_listener) # type: ignore[misc]
|
|
178
|
+
|
|
179
|
+
def get_inspectors(self):
|
|
180
|
+
url = self.config.get_sql_alchemy_url()
|
|
181
|
+
logger.debug(f"sql_alchemy_url={url}")
|
|
182
|
+
|
|
183
|
+
engine = create_engine(url, **self.config.options)
|
|
184
|
+
self._setup_rds_iam_event_listener(engine)
|
|
185
|
+
|
|
186
|
+
with engine.connect() as conn:
|
|
187
|
+
inspector = inspect(conn)
|
|
188
|
+
if self.config.database and self.config.database != "":
|
|
189
|
+
databases = [self.config.database]
|
|
190
|
+
else:
|
|
191
|
+
databases = inspector.get_schema_names()
|
|
192
|
+
for db in databases:
|
|
193
|
+
if self.config.database_pattern.allowed(db):
|
|
194
|
+
url = self.config.get_sql_alchemy_url(current_db=db)
|
|
195
|
+
db_engine = create_engine(url, **self.config.options)
|
|
196
|
+
self._setup_rds_iam_event_listener(db_engine, database_name=db)
|
|
197
|
+
|
|
198
|
+
with db_engine.connect() as conn:
|
|
199
|
+
inspector = inspect(conn)
|
|
200
|
+
yield inspector
|
|
201
|
+
|
|
105
202
|
def add_profile_metadata(self, inspector: Inspector) -> None:
|
|
106
203
|
if not self.config.is_profiling_enabled():
|
|
107
204
|
return
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from collections import defaultdict
|
|
3
|
-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
# This import verifies that the dependencies are available.
|
|
6
6
|
import psycopg2 # noqa: F401
|
|
@@ -14,9 +14,12 @@ import sqlalchemy.dialects.postgresql as custom_types
|
|
|
14
14
|
from geoalchemy2 import Geometry # noqa: F401
|
|
15
15
|
from pydantic import BaseModel
|
|
16
16
|
from pydantic.fields import Field
|
|
17
|
-
from sqlalchemy import create_engine, inspect
|
|
17
|
+
from sqlalchemy import create_engine, event, inspect
|
|
18
18
|
from sqlalchemy.engine.reflection import Inspector
|
|
19
19
|
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from sqlalchemy.engine import Engine
|
|
22
|
+
|
|
20
23
|
from datahub.configuration.common import AllowDenyPattern
|
|
21
24
|
from datahub.emitter import mce_builder
|
|
22
25
|
from datahub.emitter.mcp_builder import mcps_from_mce
|
|
@@ -30,12 +33,17 @@ from datahub.ingestion.api.decorators import (
|
|
|
30
33
|
support_status,
|
|
31
34
|
)
|
|
32
35
|
from datahub.ingestion.api.workunit import MetadataWorkUnit
|
|
36
|
+
from datahub.ingestion.source.aws.aws_common import (
|
|
37
|
+
AwsConnectionConfig,
|
|
38
|
+
RDSIAMTokenManager,
|
|
39
|
+
)
|
|
33
40
|
from datahub.ingestion.source.sql.sql_common import (
|
|
34
41
|
SQLAlchemySource,
|
|
35
42
|
SqlWorkUnit,
|
|
36
43
|
register_custom_type,
|
|
37
44
|
)
|
|
38
45
|
from datahub.ingestion.source.sql.sql_config import BasicSQLAlchemyConfig
|
|
46
|
+
from datahub.ingestion.source.sql.sqlalchemy_uri import parse_host_port
|
|
39
47
|
from datahub.ingestion.source.sql.stored_procedures.base import (
|
|
40
48
|
BaseProcedure,
|
|
41
49
|
)
|
|
@@ -44,6 +52,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
|
|
|
44
52
|
BytesTypeClass,
|
|
45
53
|
MapTypeClass,
|
|
46
54
|
)
|
|
55
|
+
from datahub.utilities.str_enum import StrEnum
|
|
47
56
|
|
|
48
57
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
49
58
|
|
|
@@ -100,12 +109,34 @@ class ViewLineageEntry(BaseModel):
|
|
|
100
109
|
dependent_schema: str
|
|
101
110
|
|
|
102
111
|
|
|
112
|
+
class PostgresAuthMode(StrEnum):
|
|
113
|
+
"""Authentication mode for PostgreSQL connection."""
|
|
114
|
+
|
|
115
|
+
PASSWORD = "PASSWORD"
|
|
116
|
+
AWS_IAM = "AWS_IAM"
|
|
117
|
+
|
|
118
|
+
|
|
103
119
|
class BasePostgresConfig(BasicSQLAlchemyConfig):
|
|
104
120
|
scheme: str = Field(default="postgresql+psycopg2", description="database scheme")
|
|
105
121
|
schema_pattern: AllowDenyPattern = Field(
|
|
106
122
|
default=AllowDenyPattern(deny=["information_schema"])
|
|
107
123
|
)
|
|
108
124
|
|
|
125
|
+
# Authentication configuration
|
|
126
|
+
auth_mode: PostgresAuthMode = Field(
|
|
127
|
+
default=PostgresAuthMode.PASSWORD,
|
|
128
|
+
description="Authentication mode to use for the PostgreSQL connection. "
|
|
129
|
+
"Options are 'PASSWORD' (default) for standard username/password authentication, "
|
|
130
|
+
"or 'AWS_IAM' for AWS RDS IAM authentication.",
|
|
131
|
+
)
|
|
132
|
+
aws_config: AwsConnectionConfig = Field(
|
|
133
|
+
default_factory=AwsConnectionConfig,
|
|
134
|
+
description="AWS configuration for RDS IAM authentication (only used when auth_mode is AWS_IAM). "
|
|
135
|
+
"Provides full control over AWS credentials, region, profiles, role assumption, retry logic, and proxy settings. "
|
|
136
|
+
"If not explicitly configured, boto3 will automatically use the default credential chain and region from "
|
|
137
|
+
"environment variables (AWS_DEFAULT_REGION, AWS_REGION), AWS config files (~/.aws/config), or IAM role metadata.",
|
|
138
|
+
)
|
|
139
|
+
|
|
109
140
|
|
|
110
141
|
class PostgresConfig(BasePostgresConfig):
|
|
111
142
|
database_pattern: AllowDenyPattern = Field(
|
|
@@ -160,6 +191,22 @@ class PostgresSource(SQLAlchemySource):
|
|
|
160
191
|
def __init__(self, config: PostgresConfig, ctx: PipelineContext):
|
|
161
192
|
super().__init__(config, ctx, self.get_platform())
|
|
162
193
|
|
|
194
|
+
self._rds_iam_token_manager: Optional[RDSIAMTokenManager] = None
|
|
195
|
+
if config.auth_mode == PostgresAuthMode.AWS_IAM:
|
|
196
|
+
hostname, port = parse_host_port(config.host_port, default_port=5432)
|
|
197
|
+
if port is None:
|
|
198
|
+
raise ValueError("Port must be specified for RDS IAM authentication")
|
|
199
|
+
|
|
200
|
+
if not config.username:
|
|
201
|
+
raise ValueError("username is required for RDS IAM authentication")
|
|
202
|
+
|
|
203
|
+
self._rds_iam_token_manager = RDSIAMTokenManager(
|
|
204
|
+
endpoint=hostname,
|
|
205
|
+
username=config.username,
|
|
206
|
+
port=port,
|
|
207
|
+
aws_config=config.aws_config,
|
|
208
|
+
)
|
|
209
|
+
|
|
163
210
|
def get_platform(self):
|
|
164
211
|
return "postgres"
|
|
165
212
|
|
|
@@ -168,13 +215,36 @@ class PostgresSource(SQLAlchemySource):
|
|
|
168
215
|
config = PostgresConfig.parse_obj(config_dict)
|
|
169
216
|
return cls(config, ctx)
|
|
170
217
|
|
|
218
|
+
def _setup_rds_iam_event_listener(
|
|
219
|
+
self, engine: "Engine", database_name: Optional[str] = None
|
|
220
|
+
) -> None:
|
|
221
|
+
"""Setup SQLAlchemy event listener to inject RDS IAM tokens."""
|
|
222
|
+
if not (
|
|
223
|
+
self.config.auth_mode == PostgresAuthMode.AWS_IAM
|
|
224
|
+
and self._rds_iam_token_manager
|
|
225
|
+
):
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
def do_connect_listener(_dialect, _conn_rec, _cargs, cparams):
|
|
229
|
+
if not self._rds_iam_token_manager:
|
|
230
|
+
raise RuntimeError("RDS IAM Token Manager is not initialized")
|
|
231
|
+
cparams["password"] = self._rds_iam_token_manager.get_token()
|
|
232
|
+
if cparams.get("sslmode") not in ("require", "verify-ca", "verify-full"):
|
|
233
|
+
cparams["sslmode"] = "require"
|
|
234
|
+
|
|
235
|
+
event.listen(engine, "do_connect", do_connect_listener) # type: ignore[misc]
|
|
236
|
+
|
|
171
237
|
def get_inspectors(self) -> Iterable[Inspector]:
|
|
172
238
|
# Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database
|
|
173
239
|
url = self.config.get_sql_alchemy_url(
|
|
174
240
|
database=self.config.database or self.config.initial_database
|
|
175
241
|
)
|
|
242
|
+
|
|
176
243
|
logger.debug(f"sql_alchemy_url={url}")
|
|
244
|
+
|
|
177
245
|
engine = create_engine(url, **self.config.options)
|
|
246
|
+
self._setup_rds_iam_event_listener(engine)
|
|
247
|
+
|
|
178
248
|
with engine.connect() as conn:
|
|
179
249
|
if self.config.database or self.config.sqlalchemy_uri:
|
|
180
250
|
inspector = inspect(conn)
|
|
@@ -182,14 +252,21 @@ class PostgresSource(SQLAlchemySource):
|
|
|
182
252
|
else:
|
|
183
253
|
# pg_database catalog - https://www.postgresql.org/docs/current/catalog-pg-database.html
|
|
184
254
|
# exclude template databases - https://www.postgresql.org/docs/current/manage-ag-templatedbs.html
|
|
255
|
+
# exclude rdsadmin - AWS RDS administrative database
|
|
185
256
|
databases = conn.execute(
|
|
186
|
-
"SELECT datname from pg_database where datname not in ('template0', 'template1')"
|
|
257
|
+
"SELECT datname from pg_database where datname not in ('template0', 'template1', 'rdsadmin')"
|
|
187
258
|
)
|
|
188
259
|
for db in databases:
|
|
189
260
|
if not self.config.database_pattern.allowed(db["datname"]):
|
|
190
261
|
continue
|
|
262
|
+
|
|
191
263
|
url = self.config.get_sql_alchemy_url(database=db["datname"])
|
|
192
|
-
|
|
264
|
+
db_engine = create_engine(url, **self.config.options)
|
|
265
|
+
self._setup_rds_iam_event_listener(
|
|
266
|
+
db_engine, database_name=db["datname"]
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
with db_engine.connect() as conn:
|
|
193
270
|
inspector = inspect(conn)
|
|
194
271
|
yield inspector
|
|
195
272
|
|
|
@@ -1,8 +1,45 @@
|
|
|
1
|
-
from typing import Any, Dict, Optional
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple
|
|
2
2
|
|
|
3
3
|
from sqlalchemy.engine import URL
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
def parse_host_port(
|
|
7
|
+
host_port: str, default_port: Optional[int] = None
|
|
8
|
+
) -> Tuple[str, Optional[int]]:
|
|
9
|
+
"""
|
|
10
|
+
Parse a host:port string into separate host and port components.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
host_port: String in format "host:port" or just "host"
|
|
14
|
+
default_port: Optional default port to use if not specified in host_port
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Tuple of (hostname, port) where port may be None if not specified
|
|
18
|
+
|
|
19
|
+
Examples:
|
|
20
|
+
>>> parse_host_port("localhost:3306")
|
|
21
|
+
('localhost', 3306)
|
|
22
|
+
>>> parse_host_port("localhost")
|
|
23
|
+
('localhost', None)
|
|
24
|
+
>>> parse_host_port("localhost", 5432)
|
|
25
|
+
('localhost', 5432)
|
|
26
|
+
>>> parse_host_port("db.example.com:invalid", 3306)
|
|
27
|
+
('db.example.com', 3306)
|
|
28
|
+
"""
|
|
29
|
+
try:
|
|
30
|
+
host, port_str = host_port.rsplit(":", 1)
|
|
31
|
+
port: Optional[int]
|
|
32
|
+
try:
|
|
33
|
+
port = int(port_str)
|
|
34
|
+
except ValueError:
|
|
35
|
+
# Port is not a valid integer
|
|
36
|
+
port = default_port
|
|
37
|
+
return host, port
|
|
38
|
+
except ValueError:
|
|
39
|
+
# No colon found, entire string is the hostname
|
|
40
|
+
return host_port, default_port
|
|
41
|
+
|
|
42
|
+
|
|
6
43
|
def make_sqlalchemy_uri(
|
|
7
44
|
scheme: str,
|
|
8
45
|
username: Optional[str],
|
|
@@ -14,12 +51,7 @@ def make_sqlalchemy_uri(
|
|
|
14
51
|
host: Optional[str] = None
|
|
15
52
|
port: Optional[int] = None
|
|
16
53
|
if at:
|
|
17
|
-
|
|
18
|
-
host, port_str = at.rsplit(":", 1)
|
|
19
|
-
port = int(port_str)
|
|
20
|
-
except ValueError:
|
|
21
|
-
host = at
|
|
22
|
-
port = None
|
|
54
|
+
host, port = parse_host_port(at)
|
|
23
55
|
if uri_opts:
|
|
24
56
|
uri_opts = {k: v for k, v in uri_opts.items() if v is not None}
|
|
25
57
|
|
|
@@ -244,3 +244,24 @@ class RedundantUsageRunSkipHandler(RedundantRunSkipHandler):
|
|
|
244
244
|
cur_state.begin_timestamp_millis = datetime_to_ts_millis(start_time)
|
|
245
245
|
cur_state.end_timestamp_millis = datetime_to_ts_millis(end_time)
|
|
246
246
|
cur_state.bucket_duration = bucket_duration
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class RedundantQueriesRunSkipHandler(RedundantRunSkipHandler):
|
|
250
|
+
"""
|
|
251
|
+
Handler for stateful ingestion of queries v2 extraction.
|
|
252
|
+
Manages the time window for audit log extraction that combines
|
|
253
|
+
lineage, usage, operations, and queries.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def get_job_name_suffix(self):
|
|
257
|
+
return "_audit_window"
|
|
258
|
+
|
|
259
|
+
def update_state(
|
|
260
|
+
self, start_time: datetime, end_time: datetime, bucket_duration: BucketDuration
|
|
261
|
+
) -> None:
|
|
262
|
+
cur_checkpoint = self.get_current_checkpoint()
|
|
263
|
+
if cur_checkpoint:
|
|
264
|
+
cur_state = cast(BaseTimeWindowCheckpointState, cur_checkpoint.state)
|
|
265
|
+
cur_state.begin_timestamp_millis = datetime_to_ts_millis(start_time)
|
|
266
|
+
cur_state.end_timestamp_millis = datetime_to_ts_millis(end_time)
|
|
267
|
+
cur_state.bucket_duration = bucket_duration
|
|
@@ -101,7 +101,9 @@ class StatefulLineageConfigMixin(ConfigModel):
|
|
|
101
101
|
default=True,
|
|
102
102
|
description="Enable stateful lineage ingestion."
|
|
103
103
|
" This will store lineage window timestamps after successful lineage ingestion. "
|
|
104
|
-
"and will not run lineage ingestion for same timestamps in subsequent run. "
|
|
104
|
+
"and will not run lineage ingestion for same timestamps in subsequent run. "
|
|
105
|
+
"NOTE: This only works with use_queries_v2=False (legacy extraction path). "
|
|
106
|
+
"For queries v2, use enable_stateful_time_window instead.",
|
|
105
107
|
)
|
|
106
108
|
|
|
107
109
|
_store_last_lineage_extraction_timestamp = pydantic_renamed_field(
|
|
@@ -150,7 +152,9 @@ class StatefulUsageConfigMixin(BaseTimeWindowConfig):
|
|
|
150
152
|
default=True,
|
|
151
153
|
description="Enable stateful lineage ingestion."
|
|
152
154
|
" This will store usage window timestamps after successful usage ingestion. "
|
|
153
|
-
"and will not run usage ingestion for same timestamps in subsequent run. "
|
|
155
|
+
"and will not run usage ingestion for same timestamps in subsequent run. "
|
|
156
|
+
"NOTE: This only works with use_queries_v2=False (legacy extraction path). "
|
|
157
|
+
"For queries v2, use enable_stateful_time_window instead.",
|
|
154
158
|
)
|
|
155
159
|
|
|
156
160
|
_store_last_usage_extraction_timestamp = pydantic_renamed_field(
|
|
@@ -169,6 +173,30 @@ class StatefulUsageConfigMixin(BaseTimeWindowConfig):
|
|
|
169
173
|
return values
|
|
170
174
|
|
|
171
175
|
|
|
176
|
+
class StatefulTimeWindowConfigMixin(BaseTimeWindowConfig):
|
|
177
|
+
enable_stateful_time_window: bool = Field(
|
|
178
|
+
default=False,
|
|
179
|
+
description="Enable stateful time window tracking."
|
|
180
|
+
" This will store the time window after successful extraction "
|
|
181
|
+
"and adjust the time window in subsequent runs to avoid reprocessing. "
|
|
182
|
+
"NOTE: This is ONLY applicable when using queries v2 (use_queries_v2=True). "
|
|
183
|
+
"This replaces enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion "
|
|
184
|
+
"for the queries v2 extraction path, since queries v2 extracts lineage, usage, operations, "
|
|
185
|
+
"and queries together from a single audit log and uses a unified time window.",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
@root_validator(skip_on_failure=True)
|
|
189
|
+
def time_window_stateful_option_validator(cls, values: Dict) -> Dict:
|
|
190
|
+
sti = values.get("stateful_ingestion")
|
|
191
|
+
if not sti or not sti.enabled:
|
|
192
|
+
if values.get("enable_stateful_time_window"):
|
|
193
|
+
logger.warning(
|
|
194
|
+
"Stateful ingestion is disabled, disabling enable_stateful_time_window config option as well"
|
|
195
|
+
)
|
|
196
|
+
values["enable_stateful_time_window"] = False
|
|
197
|
+
return values
|
|
198
|
+
|
|
199
|
+
|
|
172
200
|
@dataclass
|
|
173
201
|
class StatefulIngestionReport(SourceReport):
|
|
174
202
|
pass
|