acryl-datahub 1.3.0.1rc4__py3-none-any.whl → 1.3.0.1rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

Files changed (31) hide show
  1. {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/METADATA +2637 -2633
  2. {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/RECORD +31 -28
  3. datahub/_version.py +1 -1
  4. datahub/ingestion/source/aws/aws_common.py +161 -0
  5. datahub/ingestion/source/bigquery_v2/bigquery.py +17 -1
  6. datahub/ingestion/source/bigquery_v2/bigquery_config.py +16 -0
  7. datahub/ingestion/source/bigquery_v2/bigquery_schema_gen.py +5 -3
  8. datahub/ingestion/source/bigquery_v2/queries_extractor.py +41 -4
  9. datahub/ingestion/source/redshift/redshift_schema.py +17 -12
  10. datahub/ingestion/source/redshift/usage.py +2 -2
  11. datahub/ingestion/source/snowflake/snowflake_config.py +16 -0
  12. datahub/ingestion/source/snowflake/snowflake_queries.py +46 -6
  13. datahub/ingestion/source/snowflake/snowflake_v2.py +14 -1
  14. datahub/ingestion/source/sql/mysql.py +101 -4
  15. datahub/ingestion/source/sql/postgres.py +81 -4
  16. datahub/ingestion/source/sql/sqlalchemy_uri.py +39 -7
  17. datahub/ingestion/source/state/redundant_run_skip_handler.py +21 -0
  18. datahub/ingestion/source/state/stateful_ingestion_base.py +30 -2
  19. datahub/metadata/_internal_schema_classes.py +772 -546
  20. datahub/metadata/_urns/urn_defs.py +1751 -1695
  21. datahub/metadata/com/linkedin/pegasus2avro/file/__init__.py +19 -0
  22. datahub/metadata/com/linkedin/pegasus2avro/metadata/key/__init__.py +2 -0
  23. datahub/metadata/schema.avsc +18450 -18242
  24. datahub/metadata/schemas/DataHubFileInfo.avsc +228 -0
  25. datahub/metadata/schemas/DataHubFileKey.avsc +21 -0
  26. datahub/metadata/schemas/DataHubPageModuleProperties.avsc +3 -1
  27. datahub/sql_parsing/sql_parsing_aggregator.py +18 -4
  28. {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/WHEEL +0 -0
  29. {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/entry_points.txt +0 -0
  30. {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/licenses/LICENSE +0 -0
  31. {acryl_datahub-1.3.0.1rc4.dist-info → acryl_datahub-1.3.0.1rc6.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from datahub.ingestion.source.redshift.query import (
25
25
  RedshiftServerlessQuery,
26
26
  )
27
27
  from datahub.ingestion.source.redshift.redshift_schema import (
28
+ RedshiftDataDictionary,
28
29
  RedshiftTable,
29
30
  RedshiftView,
30
31
  )
@@ -263,8 +264,7 @@ class RedshiftUsageExtractor:
263
264
  connection: redshift_connector.Connection,
264
265
  all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]],
265
266
  ) -> Iterable[RedshiftAccessEvent]:
266
- cursor = connection.cursor()
267
- cursor.execute(query)
267
+ cursor = RedshiftDataDictionary.get_query_result(conn=connection, query=query)
268
268
  results = cursor.fetchmany()
269
269
  field_names = [i[0] for i in cursor.description]
270
270
  while results:
@@ -31,6 +31,7 @@ from datahub.ingestion.source.sql.sql_config import SQLCommonConfig, SQLFilterCo
31
31
  from datahub.ingestion.source.state.stateful_ingestion_base import (
32
32
  StatefulLineageConfigMixin,
33
33
  StatefulProfilingConfigMixin,
34
+ StatefulTimeWindowConfigMixin,
34
35
  StatefulUsageConfigMixin,
35
36
  )
36
37
  from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
@@ -199,6 +200,7 @@ class SnowflakeV2Config(
199
200
  SnowflakeUsageConfig,
200
201
  StatefulLineageConfigMixin,
201
202
  StatefulUsageConfigMixin,
203
+ StatefulTimeWindowConfigMixin,
202
204
  StatefulProfilingConfigMixin,
203
205
  ClassificationSourceConfigMixin,
204
206
  IncrementalPropertiesConfigMixin,
@@ -477,6 +479,20 @@ class SnowflakeV2Config(
477
479
 
478
480
  return shares
479
481
 
482
+ @root_validator(pre=False, skip_on_failure=True)
483
+ def validate_queries_v2_stateful_ingestion(cls, values: Dict) -> Dict:
484
+ if values.get("use_queries_v2"):
485
+ if values.get("enable_stateful_lineage_ingestion") or values.get(
486
+ "enable_stateful_usage_ingestion"
487
+ ):
488
+ logger.warning(
489
+ "enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion are deprecated "
490
+ "when using use_queries_v2=True. These configs only work with the legacy (non-queries v2) extraction path. "
491
+ "For queries v2, use enable_stateful_time_window instead to enable stateful ingestion "
492
+ "for the unified time window extraction (lineage + usage + operations + queries)."
493
+ )
494
+ return values
495
+
480
496
  def outbounds(self) -> Dict[str, Set[DatabaseId]]:
481
497
  """
482
498
  Returns mapping of
@@ -17,6 +17,7 @@ from datahub.configuration.common import AllowDenyPattern, ConfigModel, HiddenFr
17
17
  from datahub.configuration.time_window_config import (
18
18
  BaseTimeWindowConfig,
19
19
  BucketDuration,
20
+ get_time_bucket,
20
21
  )
21
22
  from datahub.ingestion.api.closeable import Closeable
22
23
  from datahub.ingestion.api.common import PipelineContext
@@ -50,6 +51,9 @@ from datahub.ingestion.source.snowflake.stored_proc_lineage import (
50
51
  StoredProcLineageReport,
51
52
  StoredProcLineageTracker,
52
53
  )
54
+ from datahub.ingestion.source.state.redundant_run_skip_handler import (
55
+ RedundantQueriesRunSkipHandler,
56
+ )
53
57
  from datahub.ingestion.source.usage.usage_common import BaseUsageConfig
54
58
  from datahub.metadata.urns import CorpUserUrn
55
59
  from datahub.sql_parsing.schema_resolver import SchemaResolver
@@ -180,6 +184,7 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
180
184
  structured_report: SourceReport,
181
185
  filters: SnowflakeFilter,
182
186
  identifiers: SnowflakeIdentifierBuilder,
187
+ redundant_run_skip_handler: Optional[RedundantQueriesRunSkipHandler] = None,
183
188
  graph: Optional[DataHubGraph] = None,
184
189
  schema_resolver: Optional[SchemaResolver] = None,
185
190
  discovered_tables: Optional[List[str]] = None,
@@ -191,9 +196,13 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
191
196
  self.filters = filters
192
197
  self.identifiers = identifiers
193
198
  self.discovered_tables = set(discovered_tables) if discovered_tables else None
199
+ self.redundant_run_skip_handler = redundant_run_skip_handler
194
200
 
195
201
  self._structured_report = structured_report
196
202
 
203
+ # Adjust time window based on stateful ingestion state
204
+ self.start_time, self.end_time = self._get_time_window()
205
+
197
206
  # The exit stack helps ensure that we close all the resources we open.
198
207
  self._exit_stack = contextlib.ExitStack()
199
208
 
@@ -211,8 +220,8 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
211
220
  generate_query_usage_statistics=self.config.include_query_usage_statistics,
212
221
  usage_config=BaseUsageConfig(
213
222
  bucket_duration=self.config.window.bucket_duration,
214
- start_time=self.config.window.start_time,
215
- end_time=self.config.window.end_time,
223
+ start_time=self.start_time,
224
+ end_time=self.end_time,
216
225
  user_email_pattern=self.config.user_email_pattern,
217
226
  # TODO make the rest of the fields configurable
218
227
  ),
@@ -228,6 +237,34 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
228
237
  def structured_reporter(self) -> SourceReport:
229
238
  return self._structured_report
230
239
 
240
+ def _get_time_window(self) -> tuple[datetime, datetime]:
241
+ if self.redundant_run_skip_handler:
242
+ start_time, end_time = (
243
+ self.redundant_run_skip_handler.suggest_run_time_window(
244
+ self.config.window.start_time,
245
+ self.config.window.end_time,
246
+ )
247
+ )
248
+ else:
249
+ start_time = self.config.window.start_time
250
+ end_time = self.config.window.end_time
251
+
252
+ # Usage statistics are aggregated per bucket (typically per day).
253
+ # To ensure accurate aggregated metrics, we need to align the start_time
254
+ # to the beginning of a bucket so that we include complete bucket periods.
255
+ if self.config.include_usage_statistics:
256
+ start_time = get_time_bucket(start_time, self.config.window.bucket_duration)
257
+
258
+ return start_time, end_time
259
+
260
+ def _update_state(self) -> None:
261
+ if self.redundant_run_skip_handler:
262
+ self.redundant_run_skip_handler.update_state(
263
+ self.config.window.start_time,
264
+ self.config.window.end_time,
265
+ self.config.window.bucket_duration,
266
+ )
267
+
231
268
  @functools.cached_property
232
269
  def local_temp_path(self) -> pathlib.Path:
233
270
  if self.config.local_temp_path:
@@ -355,6 +392,9 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
355
392
  with self.report.aggregator_generate_timer:
356
393
  yield from auto_workunit(self.aggregator.gen_metadata())
357
394
 
395
+ # Update the stateful ingestion state after successful extraction
396
+ self._update_state()
397
+
358
398
  def fetch_users(self) -> UsersMapping:
359
399
  users: UsersMapping = dict()
360
400
  with self.structured_reporter.report_exc("Error fetching users from Snowflake"):
@@ -378,8 +418,8 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
378
418
  # Derived from _populate_external_lineage_from_copy_history.
379
419
 
380
420
  query: str = SnowflakeQuery.copy_lineage_history(
381
- start_time_millis=int(self.config.window.start_time.timestamp() * 1000),
382
- end_time_millis=int(self.config.window.end_time.timestamp() * 1000),
421
+ start_time_millis=int(self.start_time.timestamp() * 1000),
422
+ end_time_millis=int(self.end_time.timestamp() * 1000),
383
423
  downstreams_deny_pattern=self.config.temporary_tables_pattern,
384
424
  )
385
425
 
@@ -414,8 +454,8 @@ class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
414
454
  Union[PreparsedQuery, TableRename, TableSwap, ObservedQuery, StoredProcCall]
415
455
  ]:
416
456
  query_log_query = QueryLogQueryBuilder(
417
- start_time=self.config.window.start_time,
418
- end_time=self.config.window.end_time,
457
+ start_time=self.start_time,
458
+ end_time=self.end_time,
419
459
  bucket_duration=self.config.window.bucket_duration,
420
460
  deny_usernames=self.config.pushdown_deny_usernames,
421
461
  allow_usernames=self.config.pushdown_allow_usernames,
@@ -73,6 +73,7 @@ from datahub.ingestion.source.snowflake.snowflake_utils import (
73
73
  from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler
74
74
  from datahub.ingestion.source.state.redundant_run_skip_handler import (
75
75
  RedundantLineageRunSkipHandler,
76
+ RedundantQueriesRunSkipHandler,
76
77
  RedundantUsageRunSkipHandler,
77
78
  )
78
79
  from datahub.ingestion.source.state.stale_entity_removal_handler import (
@@ -207,7 +208,7 @@ class SnowflakeV2Source(
207
208
  )
208
209
  self.report.sql_aggregator = self.aggregator.report
209
210
 
210
- if self.config.include_table_lineage:
211
+ if self.config.include_table_lineage and not self.config.use_queries_v2:
211
212
  redundant_lineage_run_skip_handler: Optional[
212
213
  RedundantLineageRunSkipHandler
213
214
  ] = None
@@ -589,6 +590,17 @@ class SnowflakeV2Source(
589
590
  with self.report.new_stage(f"*: {QUERIES_EXTRACTION}"):
590
591
  schema_resolver = self.aggregator._schema_resolver
591
592
 
593
+ redundant_queries_run_skip_handler: Optional[
594
+ RedundantQueriesRunSkipHandler
595
+ ] = None
596
+ if self.config.enable_stateful_time_window:
597
+ redundant_queries_run_skip_handler = RedundantQueriesRunSkipHandler(
598
+ source=self,
599
+ config=self.config,
600
+ pipeline_name=self.ctx.pipeline_name,
601
+ run_id=self.ctx.run_id,
602
+ )
603
+
592
604
  queries_extractor = SnowflakeQueriesExtractor(
593
605
  connection=self.connection,
594
606
  # TODO: this should be its own section in main recipe
@@ -614,6 +626,7 @@ class SnowflakeV2Source(
614
626
  structured_report=self.report,
615
627
  filters=self.filters,
616
628
  identifiers=self.identifiers,
629
+ redundant_run_skip_handler=redundant_queries_run_skip_handler,
617
630
  schema_resolver=schema_resolver,
618
631
  discovered_tables=self.discovered_datasets,
619
632
  graph=self.ctx.graph,
@@ -1,14 +1,17 @@
1
1
  # This import verifies that the dependencies are available.
2
-
3
- from typing import List
2
+ import logging
3
+ from typing import TYPE_CHECKING, Any, List, Optional
4
4
 
5
5
  import pymysql # noqa: F401
6
6
  from pydantic.fields import Field
7
- from sqlalchemy import util
7
+ from sqlalchemy import create_engine, event, inspect, util
8
8
  from sqlalchemy.dialects.mysql import BIT, base
9
9
  from sqlalchemy.dialects.mysql.enumerated import SET
10
10
  from sqlalchemy.engine.reflection import Inspector
11
11
 
12
+ if TYPE_CHECKING:
13
+ from sqlalchemy.engine import Engine
14
+
12
15
  from datahub.configuration.common import AllowDenyPattern, HiddenFromDocs
13
16
  from datahub.ingestion.api.decorators import (
14
17
  SourceCapability,
@@ -18,11 +21,16 @@ from datahub.ingestion.api.decorators import (
18
21
  platform_name,
19
22
  support_status,
20
23
  )
24
+ from datahub.ingestion.source.aws.aws_common import (
25
+ AwsConnectionConfig,
26
+ RDSIAMTokenManager,
27
+ )
21
28
  from datahub.ingestion.source.sql.sql_common import (
22
29
  make_sqlalchemy_type,
23
30
  register_custom_type,
24
31
  )
25
32
  from datahub.ingestion.source.sql.sql_config import SQLAlchemyConnectionConfig
33
+ from datahub.ingestion.source.sql.sqlalchemy_uri import parse_host_port
26
34
  from datahub.ingestion.source.sql.stored_procedures.base import (
27
35
  BaseProcedure,
28
36
  )
@@ -31,6 +39,9 @@ from datahub.ingestion.source.sql.two_tier_sql_source import (
31
39
  TwoTierSQLAlchemySource,
32
40
  )
33
41
  from datahub.metadata.schema_classes import BytesTypeClass
42
+ from datahub.utilities.str_enum import StrEnum
43
+
44
+ logger = logging.getLogger(__name__)
34
45
 
35
46
  SET.__repr__ = util.generic_repr # type:ignore
36
47
 
@@ -54,11 +65,33 @@ base.ischema_names["polygon"] = POLYGON
54
65
  base.ischema_names["decimal128"] = DECIMAL128
55
66
 
56
67
 
68
+ class MySQLAuthMode(StrEnum):
69
+ """Authentication mode for MySQL connection."""
70
+
71
+ PASSWORD = "PASSWORD"
72
+ AWS_IAM = "AWS_IAM"
73
+
74
+
57
75
  class MySQLConnectionConfig(SQLAlchemyConnectionConfig):
58
76
  # defaults
59
77
  host_port: str = Field(default="localhost:3306", description="MySQL host URL.")
60
78
  scheme: HiddenFromDocs[str] = "mysql+pymysql"
61
79
 
80
+ # Authentication configuration
81
+ auth_mode: MySQLAuthMode = Field(
82
+ default=MySQLAuthMode.PASSWORD,
83
+ description="Authentication mode to use for the MySQL connection. "
84
+ "Options are 'PASSWORD' (default) for standard username/password authentication, "
85
+ "or 'AWS_IAM' for AWS RDS IAM authentication.",
86
+ )
87
+ aws_config: AwsConnectionConfig = Field(
88
+ default_factory=AwsConnectionConfig,
89
+ description="AWS configuration for RDS IAM authentication (only used when auth_mode is AWS_IAM). "
90
+ "Provides full control over AWS credentials, region, profiles, role assumption, retry logic, and proxy settings. "
91
+ "If not explicitly configured, boto3 will automatically use the default credential chain and region from "
92
+ "environment variables (AWS_DEFAULT_REGION, AWS_REGION), AWS config files (~/.aws/config), or IAM role metadata.",
93
+ )
94
+
62
95
 
63
96
  class MySQLConfig(MySQLConnectionConfig, TwoTierSQLAlchemyConfig):
64
97
  def get_identifier(self, *, schema: str, table: str) -> str:
@@ -91,9 +124,27 @@ class MySQLSource(TwoTierSQLAlchemySource):
91
124
  Table, row, and column statistics via optional SQL profiling
92
125
  """
93
126
 
94
- def __init__(self, config, ctx):
127
+ config: MySQLConfig
128
+
129
+ def __init__(self, config: MySQLConfig, ctx: Any):
95
130
  super().__init__(config, ctx, self.get_platform())
96
131
 
132
+ self._rds_iam_token_manager: Optional[RDSIAMTokenManager] = None
133
+ if config.auth_mode == MySQLAuthMode.AWS_IAM:
134
+ hostname, port = parse_host_port(config.host_port, default_port=3306)
135
+ if port is None:
136
+ raise ValueError("Port must be specified for RDS IAM authentication")
137
+
138
+ if not config.username:
139
+ raise ValueError("username is required for RDS IAM authentication")
140
+
141
+ self._rds_iam_token_manager = RDSIAMTokenManager(
142
+ endpoint=hostname,
143
+ username=config.username,
144
+ port=port,
145
+ aws_config=config.aws_config,
146
+ )
147
+
97
148
  def get_platform(self):
98
149
  return "mysql"
99
150
 
@@ -102,6 +153,52 @@ class MySQLSource(TwoTierSQLAlchemySource):
102
153
  config = MySQLConfig.parse_obj(config_dict)
103
154
  return cls(config, ctx)
104
155
 
156
+ def _setup_rds_iam_event_listener(
157
+ self, engine: "Engine", database_name: Optional[str] = None
158
+ ) -> None:
159
+ """Setup SQLAlchemy event listener to inject RDS IAM tokens."""
160
+ if not (
161
+ self.config.auth_mode == MySQLAuthMode.AWS_IAM
162
+ and self._rds_iam_token_manager
163
+ ):
164
+ return
165
+
166
+ def do_connect_listener(_dialect, _conn_rec, _cargs, cparams):
167
+ if not self._rds_iam_token_manager:
168
+ raise RuntimeError("RDS IAM Token Manager is not initialized")
169
+ cparams["password"] = self._rds_iam_token_manager.get_token()
170
+ # PyMySQL requires SSL to be enabled for RDS IAM authentication.
171
+ # Preserve any existing SSL configuration, otherwise enable with default settings.
172
+ # The {"ssl": True} dict is a workaround to make PyMySQL recognize that SSL
173
+ # should be enabled, since the library requires a truthy value in the ssl parameter.
174
+ # See https://pymysql.readthedocs.io/en/latest/modules/connections.html#pymysql.connections.Connection
175
+ cparams["ssl"] = cparams.get("ssl") or {"ssl": True}
176
+
177
+ event.listen(engine, "do_connect", do_connect_listener) # type: ignore[misc]
178
+
179
+ def get_inspectors(self):
180
+ url = self.config.get_sql_alchemy_url()
181
+ logger.debug(f"sql_alchemy_url={url}")
182
+
183
+ engine = create_engine(url, **self.config.options)
184
+ self._setup_rds_iam_event_listener(engine)
185
+
186
+ with engine.connect() as conn:
187
+ inspector = inspect(conn)
188
+ if self.config.database and self.config.database != "":
189
+ databases = [self.config.database]
190
+ else:
191
+ databases = inspector.get_schema_names()
192
+ for db in databases:
193
+ if self.config.database_pattern.allowed(db):
194
+ url = self.config.get_sql_alchemy_url(current_db=db)
195
+ db_engine = create_engine(url, **self.config.options)
196
+ self._setup_rds_iam_event_listener(db_engine, database_name=db)
197
+
198
+ with db_engine.connect() as conn:
199
+ inspector = inspect(conn)
200
+ yield inspector
201
+
105
202
  def add_profile_metadata(self, inspector: Inspector) -> None:
106
203
  if not self.config.is_profiling_enabled():
107
204
  return
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from collections import defaultdict
3
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
3
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
4
4
 
5
5
  # This import verifies that the dependencies are available.
6
6
  import psycopg2 # noqa: F401
@@ -14,9 +14,12 @@ import sqlalchemy.dialects.postgresql as custom_types
14
14
  from geoalchemy2 import Geometry # noqa: F401
15
15
  from pydantic import BaseModel
16
16
  from pydantic.fields import Field
17
- from sqlalchemy import create_engine, inspect
17
+ from sqlalchemy import create_engine, event, inspect
18
18
  from sqlalchemy.engine.reflection import Inspector
19
19
 
20
+ if TYPE_CHECKING:
21
+ from sqlalchemy.engine import Engine
22
+
20
23
  from datahub.configuration.common import AllowDenyPattern
21
24
  from datahub.emitter import mce_builder
22
25
  from datahub.emitter.mcp_builder import mcps_from_mce
@@ -30,12 +33,17 @@ from datahub.ingestion.api.decorators import (
30
33
  support_status,
31
34
  )
32
35
  from datahub.ingestion.api.workunit import MetadataWorkUnit
36
+ from datahub.ingestion.source.aws.aws_common import (
37
+ AwsConnectionConfig,
38
+ RDSIAMTokenManager,
39
+ )
33
40
  from datahub.ingestion.source.sql.sql_common import (
34
41
  SQLAlchemySource,
35
42
  SqlWorkUnit,
36
43
  register_custom_type,
37
44
  )
38
45
  from datahub.ingestion.source.sql.sql_config import BasicSQLAlchemyConfig
46
+ from datahub.ingestion.source.sql.sqlalchemy_uri import parse_host_port
39
47
  from datahub.ingestion.source.sql.stored_procedures.base import (
40
48
  BaseProcedure,
41
49
  )
@@ -44,6 +52,7 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
44
52
  BytesTypeClass,
45
53
  MapTypeClass,
46
54
  )
55
+ from datahub.utilities.str_enum import StrEnum
47
56
 
48
57
  logger: logging.Logger = logging.getLogger(__name__)
49
58
 
@@ -100,12 +109,34 @@ class ViewLineageEntry(BaseModel):
100
109
  dependent_schema: str
101
110
 
102
111
 
112
+ class PostgresAuthMode(StrEnum):
113
+ """Authentication mode for PostgreSQL connection."""
114
+
115
+ PASSWORD = "PASSWORD"
116
+ AWS_IAM = "AWS_IAM"
117
+
118
+
103
119
  class BasePostgresConfig(BasicSQLAlchemyConfig):
104
120
  scheme: str = Field(default="postgresql+psycopg2", description="database scheme")
105
121
  schema_pattern: AllowDenyPattern = Field(
106
122
  default=AllowDenyPattern(deny=["information_schema"])
107
123
  )
108
124
 
125
+ # Authentication configuration
126
+ auth_mode: PostgresAuthMode = Field(
127
+ default=PostgresAuthMode.PASSWORD,
128
+ description="Authentication mode to use for the PostgreSQL connection. "
129
+ "Options are 'PASSWORD' (default) for standard username/password authentication, "
130
+ "or 'AWS_IAM' for AWS RDS IAM authentication.",
131
+ )
132
+ aws_config: AwsConnectionConfig = Field(
133
+ default_factory=AwsConnectionConfig,
134
+ description="AWS configuration for RDS IAM authentication (only used when auth_mode is AWS_IAM). "
135
+ "Provides full control over AWS credentials, region, profiles, role assumption, retry logic, and proxy settings. "
136
+ "If not explicitly configured, boto3 will automatically use the default credential chain and region from "
137
+ "environment variables (AWS_DEFAULT_REGION, AWS_REGION), AWS config files (~/.aws/config), or IAM role metadata.",
138
+ )
139
+
109
140
 
110
141
  class PostgresConfig(BasePostgresConfig):
111
142
  database_pattern: AllowDenyPattern = Field(
@@ -160,6 +191,22 @@ class PostgresSource(SQLAlchemySource):
160
191
  def __init__(self, config: PostgresConfig, ctx: PipelineContext):
161
192
  super().__init__(config, ctx, self.get_platform())
162
193
 
194
+ self._rds_iam_token_manager: Optional[RDSIAMTokenManager] = None
195
+ if config.auth_mode == PostgresAuthMode.AWS_IAM:
196
+ hostname, port = parse_host_port(config.host_port, default_port=5432)
197
+ if port is None:
198
+ raise ValueError("Port must be specified for RDS IAM authentication")
199
+
200
+ if not config.username:
201
+ raise ValueError("username is required for RDS IAM authentication")
202
+
203
+ self._rds_iam_token_manager = RDSIAMTokenManager(
204
+ endpoint=hostname,
205
+ username=config.username,
206
+ port=port,
207
+ aws_config=config.aws_config,
208
+ )
209
+
163
210
  def get_platform(self):
164
211
  return "postgres"
165
212
 
@@ -168,13 +215,36 @@ class PostgresSource(SQLAlchemySource):
168
215
  config = PostgresConfig.parse_obj(config_dict)
169
216
  return cls(config, ctx)
170
217
 
218
+ def _setup_rds_iam_event_listener(
219
+ self, engine: "Engine", database_name: Optional[str] = None
220
+ ) -> None:
221
+ """Setup SQLAlchemy event listener to inject RDS IAM tokens."""
222
+ if not (
223
+ self.config.auth_mode == PostgresAuthMode.AWS_IAM
224
+ and self._rds_iam_token_manager
225
+ ):
226
+ return
227
+
228
+ def do_connect_listener(_dialect, _conn_rec, _cargs, cparams):
229
+ if not self._rds_iam_token_manager:
230
+ raise RuntimeError("RDS IAM Token Manager is not initialized")
231
+ cparams["password"] = self._rds_iam_token_manager.get_token()
232
+ if cparams.get("sslmode") not in ("require", "verify-ca", "verify-full"):
233
+ cparams["sslmode"] = "require"
234
+
235
+ event.listen(engine, "do_connect", do_connect_listener) # type: ignore[misc]
236
+
171
237
  def get_inspectors(self) -> Iterable[Inspector]:
172
238
  # Note: get_sql_alchemy_url will choose `sqlalchemy_uri` over the passed in database
173
239
  url = self.config.get_sql_alchemy_url(
174
240
  database=self.config.database or self.config.initial_database
175
241
  )
242
+
176
243
  logger.debug(f"sql_alchemy_url={url}")
244
+
177
245
  engine = create_engine(url, **self.config.options)
246
+ self._setup_rds_iam_event_listener(engine)
247
+
178
248
  with engine.connect() as conn:
179
249
  if self.config.database or self.config.sqlalchemy_uri:
180
250
  inspector = inspect(conn)
@@ -182,14 +252,21 @@ class PostgresSource(SQLAlchemySource):
182
252
  else:
183
253
  # pg_database catalog - https://www.postgresql.org/docs/current/catalog-pg-database.html
184
254
  # exclude template databases - https://www.postgresql.org/docs/current/manage-ag-templatedbs.html
255
+ # exclude rdsadmin - AWS RDS administrative database
185
256
  databases = conn.execute(
186
- "SELECT datname from pg_database where datname not in ('template0', 'template1')"
257
+ "SELECT datname from pg_database where datname not in ('template0', 'template1', 'rdsadmin')"
187
258
  )
188
259
  for db in databases:
189
260
  if not self.config.database_pattern.allowed(db["datname"]):
190
261
  continue
262
+
191
263
  url = self.config.get_sql_alchemy_url(database=db["datname"])
192
- with create_engine(url, **self.config.options).connect() as conn:
264
+ db_engine = create_engine(url, **self.config.options)
265
+ self._setup_rds_iam_event_listener(
266
+ db_engine, database_name=db["datname"]
267
+ )
268
+
269
+ with db_engine.connect() as conn:
193
270
  inspector = inspect(conn)
194
271
  yield inspector
195
272
 
@@ -1,8 +1,45 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Dict, Optional, Tuple
2
2
 
3
3
  from sqlalchemy.engine import URL
4
4
 
5
5
 
6
+ def parse_host_port(
7
+ host_port: str, default_port: Optional[int] = None
8
+ ) -> Tuple[str, Optional[int]]:
9
+ """
10
+ Parse a host:port string into separate host and port components.
11
+
12
+ Args:
13
+ host_port: String in format "host:port" or just "host"
14
+ default_port: Optional default port to use if not specified in host_port
15
+
16
+ Returns:
17
+ Tuple of (hostname, port) where port may be None if not specified
18
+
19
+ Examples:
20
+ >>> parse_host_port("localhost:3306")
21
+ ('localhost', 3306)
22
+ >>> parse_host_port("localhost")
23
+ ('localhost', None)
24
+ >>> parse_host_port("localhost", 5432)
25
+ ('localhost', 5432)
26
+ >>> parse_host_port("db.example.com:invalid", 3306)
27
+ ('db.example.com', 3306)
28
+ """
29
+ try:
30
+ host, port_str = host_port.rsplit(":", 1)
31
+ port: Optional[int]
32
+ try:
33
+ port = int(port_str)
34
+ except ValueError:
35
+ # Port is not a valid integer
36
+ port = default_port
37
+ return host, port
38
+ except ValueError:
39
+ # No colon found, entire string is the hostname
40
+ return host_port, default_port
41
+
42
+
6
43
  def make_sqlalchemy_uri(
7
44
  scheme: str,
8
45
  username: Optional[str],
@@ -14,12 +51,7 @@ def make_sqlalchemy_uri(
14
51
  host: Optional[str] = None
15
52
  port: Optional[int] = None
16
53
  if at:
17
- try:
18
- host, port_str = at.rsplit(":", 1)
19
- port = int(port_str)
20
- except ValueError:
21
- host = at
22
- port = None
54
+ host, port = parse_host_port(at)
23
55
  if uri_opts:
24
56
  uri_opts = {k: v for k, v in uri_opts.items() if v is not None}
25
57
 
@@ -244,3 +244,24 @@ class RedundantUsageRunSkipHandler(RedundantRunSkipHandler):
244
244
  cur_state.begin_timestamp_millis = datetime_to_ts_millis(start_time)
245
245
  cur_state.end_timestamp_millis = datetime_to_ts_millis(end_time)
246
246
  cur_state.bucket_duration = bucket_duration
247
+
248
+
249
+ class RedundantQueriesRunSkipHandler(RedundantRunSkipHandler):
250
+ """
251
+ Handler for stateful ingestion of queries v2 extraction.
252
+ Manages the time window for audit log extraction that combines
253
+ lineage, usage, operations, and queries.
254
+ """
255
+
256
+ def get_job_name_suffix(self):
257
+ return "_audit_window"
258
+
259
+ def update_state(
260
+ self, start_time: datetime, end_time: datetime, bucket_duration: BucketDuration
261
+ ) -> None:
262
+ cur_checkpoint = self.get_current_checkpoint()
263
+ if cur_checkpoint:
264
+ cur_state = cast(BaseTimeWindowCheckpointState, cur_checkpoint.state)
265
+ cur_state.begin_timestamp_millis = datetime_to_ts_millis(start_time)
266
+ cur_state.end_timestamp_millis = datetime_to_ts_millis(end_time)
267
+ cur_state.bucket_duration = bucket_duration
@@ -101,7 +101,9 @@ class StatefulLineageConfigMixin(ConfigModel):
101
101
  default=True,
102
102
  description="Enable stateful lineage ingestion."
103
103
  " This will store lineage window timestamps after successful lineage ingestion. "
104
- "and will not run lineage ingestion for same timestamps in subsequent run. ",
104
+ "and will not run lineage ingestion for same timestamps in subsequent run. "
105
+ "NOTE: This only works with use_queries_v2=False (legacy extraction path). "
106
+ "For queries v2, use enable_stateful_time_window instead.",
105
107
  )
106
108
 
107
109
  _store_last_lineage_extraction_timestamp = pydantic_renamed_field(
@@ -150,7 +152,9 @@ class StatefulUsageConfigMixin(BaseTimeWindowConfig):
150
152
  default=True,
151
153
  description="Enable stateful lineage ingestion."
152
154
  " This will store usage window timestamps after successful usage ingestion. "
153
- "and will not run usage ingestion for same timestamps in subsequent run. ",
155
+ "and will not run usage ingestion for same timestamps in subsequent run. "
156
+ "NOTE: This only works with use_queries_v2=False (legacy extraction path). "
157
+ "For queries v2, use enable_stateful_time_window instead.",
154
158
  )
155
159
 
156
160
  _store_last_usage_extraction_timestamp = pydantic_renamed_field(
@@ -169,6 +173,30 @@ class StatefulUsageConfigMixin(BaseTimeWindowConfig):
169
173
  return values
170
174
 
171
175
 
176
+ class StatefulTimeWindowConfigMixin(BaseTimeWindowConfig):
177
+ enable_stateful_time_window: bool = Field(
178
+ default=False,
179
+ description="Enable stateful time window tracking."
180
+ " This will store the time window after successful extraction "
181
+ "and adjust the time window in subsequent runs to avoid reprocessing. "
182
+ "NOTE: This is ONLY applicable when using queries v2 (use_queries_v2=True). "
183
+ "This replaces enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion "
184
+ "for the queries v2 extraction path, since queries v2 extracts lineage, usage, operations, "
185
+ "and queries together from a single audit log and uses a unified time window.",
186
+ )
187
+
188
+ @root_validator(skip_on_failure=True)
189
+ def time_window_stateful_option_validator(cls, values: Dict) -> Dict:
190
+ sti = values.get("stateful_ingestion")
191
+ if not sti or not sti.enabled:
192
+ if values.get("enable_stateful_time_window"):
193
+ logger.warning(
194
+ "Stateful ingestion is disabled, disabling enable_stateful_time_window config option as well"
195
+ )
196
+ values["enable_stateful_time_window"] = False
197
+ return values
198
+
199
+
172
200
  @dataclass
173
201
  class StatefulIngestionReport(SourceReport):
174
202
  pass