atlan-application-sdk 0.1.1rc40__py3-none-any.whl → 0.1.1rc42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- application_sdk/activities/common/utils.py +78 -4
- application_sdk/activities/metadata_extraction/sql.py +400 -27
- application_sdk/application/__init__.py +2 -0
- application_sdk/application/metadata_extraction/sql.py +3 -0
- application_sdk/clients/models.py +42 -0
- application_sdk/clients/sql.py +17 -13
- application_sdk/common/aws_utils.py +259 -11
- application_sdk/common/utils.py +145 -9
- application_sdk/handlers/__init__.py +8 -1
- application_sdk/handlers/sql.py +63 -22
- application_sdk/inputs/__init__.py +98 -2
- application_sdk/inputs/json.py +59 -87
- application_sdk/inputs/parquet.py +173 -94
- application_sdk/observability/decorators/observability_decorator.py +36 -22
- application_sdk/server/fastapi/__init__.py +59 -3
- application_sdk/server/fastapi/models.py +27 -0
- application_sdk/test_utils/hypothesis/strategies/inputs/json_input.py +10 -5
- application_sdk/test_utils/hypothesis/strategies/inputs/parquet_input.py +9 -4
- application_sdk/version.py +1 -1
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/METADATA +1 -1
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/RECORD +24 -23
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/WHEEL +0 -0
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/licenses/LICENSE +0 -0
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pydantic models for database client configurations.
|
|
3
|
+
This module provides Pydantic models for database connection configurations,
|
|
4
|
+
ensuring type safety and validation for database client settings.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DatabaseConfig(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Pydantic model for database connection configuration.
|
|
15
|
+
This model defines the structure for database connection configurations,
|
|
16
|
+
including connection templates, required parameters, defaults, and additional
|
|
17
|
+
connection parameters.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
template: str = Field(
|
|
21
|
+
...,
|
|
22
|
+
description="SQLAlchemy connection string template with placeholders for connection parameters",
|
|
23
|
+
)
|
|
24
|
+
required: List[str] = Field(
|
|
25
|
+
default=[],
|
|
26
|
+
description="List of required connection parameters that must be provided",
|
|
27
|
+
)
|
|
28
|
+
defaults: Optional[Dict[str, Any]] = Field(
|
|
29
|
+
default=None,
|
|
30
|
+
description="Default connection parameters to be added to the connection string",
|
|
31
|
+
)
|
|
32
|
+
parameters: Optional[List[str]] = Field(
|
|
33
|
+
default=None,
|
|
34
|
+
description="List of additional connection parameter names that can be dynamically added from credentials",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
class Config:
|
|
38
|
+
"""Pydantic configuration for the DatabaseConfig model."""
|
|
39
|
+
|
|
40
|
+
extra = "forbid" # Prevent additional fields
|
|
41
|
+
validate_assignment = True # Validate on assignment
|
|
42
|
+
use_enum_values = True # Use enum values instead of enum objects
|
application_sdk/clients/sql.py
CHANGED
|
@@ -7,13 +7,14 @@ database operations, supporting batch processing and server-side cursors.
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
-
from typing import Any, Dict, List
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
11
|
from urllib.parse import quote_plus
|
|
12
12
|
|
|
13
13
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
|
14
14
|
from temporalio import activity
|
|
15
15
|
|
|
16
16
|
from application_sdk.clients import ClientInterface
|
|
17
|
+
from application_sdk.clients.models import DatabaseConfig
|
|
17
18
|
from application_sdk.common.aws_utils import (
|
|
18
19
|
generate_aws_rds_token_with_iam_role,
|
|
19
20
|
generate_aws_rds_token_with_iam_user,
|
|
@@ -48,7 +49,7 @@ class BaseSQLClient(ClientInterface):
|
|
|
48
49
|
credentials: Dict[str, Any] = {}
|
|
49
50
|
resolved_credentials: Dict[str, Any] = {}
|
|
50
51
|
use_server_side_cursor: bool = USE_SERVER_SIDE_CURSOR
|
|
51
|
-
DB_CONFIG:
|
|
52
|
+
DB_CONFIG: Optional[DatabaseConfig] = None
|
|
52
53
|
|
|
53
54
|
def __init__(
|
|
54
55
|
self,
|
|
@@ -262,7 +263,9 @@ class BaseSQLClient(ClientInterface):
|
|
|
262
263
|
Returns:
|
|
263
264
|
str: The updated URL with the dialect.
|
|
264
265
|
"""
|
|
265
|
-
|
|
266
|
+
if not self.DB_CONFIG:
|
|
267
|
+
raise ValueError("DB_CONFIG is not configured for this SQL client.")
|
|
268
|
+
installed_dialect = self.DB_CONFIG.template.split("://")[0]
|
|
266
269
|
url_dialect = sqlalchemy_url.split("://")[0]
|
|
267
270
|
if installed_dialect != url_dialect:
|
|
268
271
|
sqlalchemy_url = sqlalchemy_url.replace(url_dialect, installed_dialect)
|
|
@@ -281,6 +284,9 @@ class BaseSQLClient(ClientInterface):
|
|
|
281
284
|
Raises:
|
|
282
285
|
ValueError: If required connection parameters are missing.
|
|
283
286
|
"""
|
|
287
|
+
if not self.DB_CONFIG:
|
|
288
|
+
raise ValueError("DB_CONFIG is not configured for this SQL client.")
|
|
289
|
+
|
|
284
290
|
extra = parse_credentials_extra(self.credentials)
|
|
285
291
|
|
|
286
292
|
# TODO: Uncomment this when the native deployment is ready
|
|
@@ -293,7 +299,7 @@ class BaseSQLClient(ClientInterface):
|
|
|
293
299
|
|
|
294
300
|
# Prepare parameters
|
|
295
301
|
param_values = {}
|
|
296
|
-
for param in self.DB_CONFIG
|
|
302
|
+
for param in self.DB_CONFIG.required:
|
|
297
303
|
if param == "password":
|
|
298
304
|
param_values[param] = auth_token
|
|
299
305
|
else:
|
|
@@ -303,21 +309,19 @@ class BaseSQLClient(ClientInterface):
|
|
|
303
309
|
param_values[param] = value
|
|
304
310
|
|
|
305
311
|
# Fill in base template
|
|
306
|
-
conn_str = self.DB_CONFIG
|
|
312
|
+
conn_str = self.DB_CONFIG.template.format(**param_values)
|
|
307
313
|
|
|
308
314
|
# Append defaults if not already in the template
|
|
309
|
-
if self.DB_CONFIG.
|
|
310
|
-
conn_str = self.add_connection_params(conn_str, self.DB_CONFIG
|
|
315
|
+
if self.DB_CONFIG.defaults:
|
|
316
|
+
conn_str = self.add_connection_params(conn_str, self.DB_CONFIG.defaults)
|
|
311
317
|
|
|
312
|
-
if self.DB_CONFIG.
|
|
313
|
-
parameter_keys = self.DB_CONFIG
|
|
314
|
-
|
|
318
|
+
if self.DB_CONFIG.parameters:
|
|
319
|
+
parameter_keys = self.DB_CONFIG.parameters
|
|
320
|
+
parameter_values = {
|
|
315
321
|
key: self.credentials.get(key) or extra.get(key)
|
|
316
322
|
for key in parameter_keys
|
|
317
323
|
}
|
|
318
|
-
conn_str = self.add_connection_params(
|
|
319
|
-
conn_str, self.DB_CONFIG["parameters"]
|
|
320
|
-
)
|
|
324
|
+
conn_str = self.add_connection_params(conn_str, parameter_values)
|
|
321
325
|
|
|
322
326
|
return conn_str
|
|
323
327
|
|
|
@@ -1,4 +1,13 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import boto3
|
|
5
|
+
from sqlalchemy.engine.url import URL
|
|
6
|
+
|
|
1
7
|
from application_sdk.constants import AWS_SESSION_NAME
|
|
8
|
+
from application_sdk.observability.logger_adaptor import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
2
11
|
|
|
3
12
|
|
|
4
13
|
def get_region_name_from_hostname(hostname: str) -> str:
|
|
@@ -12,11 +21,14 @@ def get_region_name_from_hostname(hostname: str) -> str:
|
|
|
12
21
|
Returns:
|
|
13
22
|
str: AWS region name
|
|
14
23
|
"""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
24
|
+
match = re.search(r"\.([a-z]{2}-[a-z]+-\d)\.", hostname)
|
|
25
|
+
if match:
|
|
26
|
+
return match.group(1)
|
|
27
|
+
# Some services may use - instead of . (rare)
|
|
28
|
+
match = re.search(r"-([a-z]{2}-[a-z]+-\d)\.", hostname)
|
|
29
|
+
if match:
|
|
30
|
+
return match.group(1)
|
|
31
|
+
raise ValueError("Could not find valid AWS region from hostname")
|
|
20
32
|
|
|
21
33
|
|
|
22
34
|
def generate_aws_rds_token_with_iam_role(
|
|
@@ -55,12 +67,10 @@ def generate_aws_rds_token_with_iam_role(
|
|
|
55
67
|
)
|
|
56
68
|
|
|
57
69
|
credentials = assumed_role["Credentials"]
|
|
58
|
-
aws_client =
|
|
59
|
-
"rds",
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
aws_session_token=credentials["SessionToken"],
|
|
63
|
-
region_name=region or get_region_name_from_hostname(host),
|
|
70
|
+
aws_client = create_aws_client(
|
|
71
|
+
service="rds",
|
|
72
|
+
region=region or get_region_name_from_hostname(host),
|
|
73
|
+
temp_credentials=credentials,
|
|
64
74
|
)
|
|
65
75
|
token: str = aws_client.generate_db_auth_token(
|
|
66
76
|
DBHostname=host, Port=port, DBUsername=user
|
|
@@ -107,3 +117,241 @@ def generate_aws_rds_token_with_iam_user(
|
|
|
107
117
|
return token
|
|
108
118
|
except Exception as e:
|
|
109
119
|
raise Exception(f"Failed to get user credentials: {str(e)}")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_cluster_identifier(aws_client) -> Optional[str]:
|
|
123
|
+
"""
|
|
124
|
+
Retrieve the cluster identifier from AWS Redshift clusters.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
aws_client: Boto3 Redshift client instance
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
str: The cluster identifier
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
RuntimeError: If no clusters are found
|
|
134
|
+
"""
|
|
135
|
+
clusters = aws_client.describe_clusters()
|
|
136
|
+
|
|
137
|
+
for cluster in clusters["Clusters"]:
|
|
138
|
+
cluster_identifier = cluster.get("ClusterIdentifier")
|
|
139
|
+
if cluster_identifier:
|
|
140
|
+
# Optionally, you can add logic to filter clusters if needed
|
|
141
|
+
# we are reading first clusters ID if not provided
|
|
142
|
+
return cluster_identifier # Just return the string
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def create_aws_session(credentials: Dict[str, Any]) -> boto3.Session:
|
|
147
|
+
"""
|
|
148
|
+
Create a boto3 session with AWS credentials.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
credentials: Dictionary containing AWS credentials
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
boto3.Session: Configured boto3 session
|
|
155
|
+
"""
|
|
156
|
+
aws_access_key_id = credentials.get("aws_access_key_id") or credentials.get(
|
|
157
|
+
"username"
|
|
158
|
+
)
|
|
159
|
+
aws_secret_access_key = credentials.get("aws_secret_access_key") or credentials.get(
|
|
160
|
+
"password"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return boto3.Session(
|
|
164
|
+
aws_access_key_id=aws_access_key_id,
|
|
165
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_cluster_credentials(
|
|
170
|
+
aws_client, credentials: Dict[str, Any], extra: Dict[str, Any]
|
|
171
|
+
) -> Dict[str, str]:
|
|
172
|
+
"""
|
|
173
|
+
Retrieve cluster credentials using IAM authentication.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
aws_client: Boto3 Redshift client instance
|
|
177
|
+
credentials: Dictionary containing connection credentials
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dict[str, str]: Dictionary containing DbUser and DbPassword
|
|
181
|
+
"""
|
|
182
|
+
database = extra["database"]
|
|
183
|
+
cluster_identifier = credentials.get("cluster_id") or get_cluster_identifier(
|
|
184
|
+
aws_client
|
|
185
|
+
)
|
|
186
|
+
return aws_client.get_cluster_credentials_with_iam(
|
|
187
|
+
DbName=database,
|
|
188
|
+
ClusterIdentifier=cluster_identifier,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def create_aws_client(
|
|
193
|
+
service: str,
|
|
194
|
+
region: str,
|
|
195
|
+
session: Optional[boto3.Session] = None,
|
|
196
|
+
temp_credentials: Optional[Dict[str, str]] = None,
|
|
197
|
+
use_default_credentials: bool = False,
|
|
198
|
+
) -> Any:
|
|
199
|
+
"""
|
|
200
|
+
Create an AWS client with flexible credential options.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
service: AWS service name (e.g., 'redshift', 'redshift-serverless', 'sts', 'rds')
|
|
204
|
+
region: AWS region name
|
|
205
|
+
session: Optional boto3 session instance. If provided, uses session credentials
|
|
206
|
+
temp_credentials: Optional dictionary containing temporary credentials from assume_role.
|
|
207
|
+
Must contain 'AccessKeyId', 'SecretAccessKey', and 'SessionToken'
|
|
208
|
+
use_default_credentials: If True, uses default AWS credentials (environment, IAM role, etc.)
|
|
209
|
+
This is the fallback if no other credentials are provided
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
AWS client instance
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ValueError: If invalid credential combination is provided
|
|
216
|
+
Exception: If client creation fails
|
|
217
|
+
|
|
218
|
+
Examples:
|
|
219
|
+
Using temporary credentials::
|
|
220
|
+
|
|
221
|
+
client = create_aws_client(
|
|
222
|
+
service="redshift",
|
|
223
|
+
region="us-east-1",
|
|
224
|
+
temp_credentials={
|
|
225
|
+
"AccessKeyId": "AKIA...",
|
|
226
|
+
"SecretAccessKey": "...",
|
|
227
|
+
"SessionToken": "..."
|
|
228
|
+
}
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
Using a session::
|
|
232
|
+
|
|
233
|
+
session = boto3.Session(profile_name="my-profile")
|
|
234
|
+
client = create_aws_client(
|
|
235
|
+
service="rds",
|
|
236
|
+
region="us-west-2",
|
|
237
|
+
session=session
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
Using default credentials::
|
|
241
|
+
|
|
242
|
+
client = create_aws_client(
|
|
243
|
+
service="sts",
|
|
244
|
+
region="us-east-1",
|
|
245
|
+
use_default_credentials=True
|
|
246
|
+
)
|
|
247
|
+
"""
|
|
248
|
+
# Validate credential options
|
|
249
|
+
credential_sources = sum(
|
|
250
|
+
[session is not None, temp_credentials is not None, use_default_credentials]
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if credential_sources == 0:
|
|
254
|
+
raise ValueError("At least one credential source must be provided")
|
|
255
|
+
if credential_sources > 1:
|
|
256
|
+
raise ValueError("Only one credential source should be provided at a time")
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
# Priority 1: Use provided session
|
|
260
|
+
if session is not None:
|
|
261
|
+
logger.debug(
|
|
262
|
+
f"Creating {service} client using provided session in region {region}"
|
|
263
|
+
)
|
|
264
|
+
return session.client(service, region_name=region) # type: ignore
|
|
265
|
+
|
|
266
|
+
# Priority 2: Use temporary credentials
|
|
267
|
+
if temp_credentials is not None:
|
|
268
|
+
logger.debug(
|
|
269
|
+
f"Creating {service} client using temporary credentials in region {region}"
|
|
270
|
+
)
|
|
271
|
+
return boto3.client( # type: ignore
|
|
272
|
+
service,
|
|
273
|
+
aws_access_key_id=temp_credentials["AccessKeyId"],
|
|
274
|
+
aws_secret_access_key=temp_credentials["SecretAccessKey"],
|
|
275
|
+
aws_session_token=temp_credentials["SessionToken"],
|
|
276
|
+
region_name=region,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Priority 3: Use default credentials
|
|
280
|
+
if use_default_credentials:
|
|
281
|
+
logger.debug(
|
|
282
|
+
f"Creating {service} client using default credentials in region {region}"
|
|
283
|
+
)
|
|
284
|
+
return boto3.client(service, region_name=region) # type: ignore
|
|
285
|
+
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.error(f"Failed to create {service} client in region {region}: {e}")
|
|
288
|
+
raise Exception(f"Failed to create {service} client: {str(e)}")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def create_engine_url(
|
|
292
|
+
drivername: str,
|
|
293
|
+
credentials: Dict[str, Any],
|
|
294
|
+
cluster_credentials: Dict[str, str],
|
|
295
|
+
extra: Dict[str, Any],
|
|
296
|
+
) -> URL:
|
|
297
|
+
"""
|
|
298
|
+
Create SQLAlchemy engine URL for Redshift connection.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
credentials: Dictionary containing connection credentials
|
|
302
|
+
cluster_credentials: Dictionary containing DbUser and DbPassword
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
URL: SQLAlchemy engine URL
|
|
306
|
+
"""
|
|
307
|
+
host = credentials["host"]
|
|
308
|
+
port = credentials.get("port")
|
|
309
|
+
database = extra["database"]
|
|
310
|
+
|
|
311
|
+
return URL.create(
|
|
312
|
+
drivername=drivername,
|
|
313
|
+
username=cluster_credentials["DbUser"],
|
|
314
|
+
password=cluster_credentials["DbPassword"],
|
|
315
|
+
host=host,
|
|
316
|
+
port=port,
|
|
317
|
+
database=database,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def get_all_aws_regions() -> list[str]:
|
|
322
|
+
"""
|
|
323
|
+
Get all available AWS regions dynamically using EC2 describe_regions API.
|
|
324
|
+
Returns:
|
|
325
|
+
list[str]: List of all AWS region names
|
|
326
|
+
Raises:
|
|
327
|
+
Exception: If unable to retrieve regions from AWS
|
|
328
|
+
"""
|
|
329
|
+
try:
|
|
330
|
+
# Use us-east-1 as the default region for the EC2 client since it's always available
|
|
331
|
+
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
|
332
|
+
response = ec2_client.describe_regions()
|
|
333
|
+
regions = [region["RegionName"] for region in response["Regions"]]
|
|
334
|
+
return sorted(regions) # Sort for consistent ordering
|
|
335
|
+
except Exception as e:
|
|
336
|
+
# Fallback to a comprehensive hardcoded list if API call fails
|
|
337
|
+
logger.warning(
|
|
338
|
+
f"Failed to retrieve AWS regions dynamically: {e}. Using fallback list."
|
|
339
|
+
)
|
|
340
|
+
return [
|
|
341
|
+
"ap-northeast-1",
|
|
342
|
+
"ap-south-1",
|
|
343
|
+
"ap-southeast-1",
|
|
344
|
+
"ap-southeast-2",
|
|
345
|
+
"aws-global",
|
|
346
|
+
"ca-central-1",
|
|
347
|
+
"eu-central-1",
|
|
348
|
+
"eu-north-1",
|
|
349
|
+
"eu-west-1",
|
|
350
|
+
"eu-west-2",
|
|
351
|
+
"eu-west-3",
|
|
352
|
+
"sa-east-1",
|
|
353
|
+
"us-east-1",
|
|
354
|
+
"us-east-2",
|
|
355
|
+
"us-west-1",
|
|
356
|
+
"us-west-2",
|
|
357
|
+
]
|
application_sdk/common/utils.py
CHANGED
|
@@ -17,8 +17,12 @@ from typing import (
|
|
|
17
17
|
Union,
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
+
from application_sdk.activities.common.utils import get_object_store_prefix
|
|
20
21
|
from application_sdk.common.error_codes import CommonError
|
|
22
|
+
from application_sdk.constants import TEMPORARY_PATH
|
|
23
|
+
from application_sdk.inputs.sql_query import SQLQueryInput
|
|
21
24
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
25
|
+
from application_sdk.services.objectstore import ObjectStore
|
|
22
26
|
|
|
23
27
|
logger = get_logger(__name__)
|
|
24
28
|
|
|
@@ -106,10 +110,42 @@ def extract_database_names_from_regex_common(
|
|
|
106
110
|
return empty_default
|
|
107
111
|
|
|
108
112
|
|
|
113
|
+
def transform_posix_regex(regex_pattern: str) -> str:
|
|
114
|
+
r"""
|
|
115
|
+
Transform regex pattern for POSIX compatibility.
|
|
116
|
+
|
|
117
|
+
Rules:
|
|
118
|
+
1. Add ^ before each database name before \.
|
|
119
|
+
2. Add an additional . between \. and * if * follows \.
|
|
120
|
+
|
|
121
|
+
Example: 'dev\.public$|dev\.atlan_test_schema$|wide_world_importers\.*'
|
|
122
|
+
Becomes: '^dev\.public$|^dev\.atlan_test_schema$|^wide_world_importers\..*'
|
|
123
|
+
"""
|
|
124
|
+
if not regex_pattern:
|
|
125
|
+
return regex_pattern
|
|
126
|
+
|
|
127
|
+
# Split by | to handle each pattern separately
|
|
128
|
+
patterns = regex_pattern.split("|")
|
|
129
|
+
transformed_patterns = []
|
|
130
|
+
|
|
131
|
+
for pattern in patterns:
|
|
132
|
+
# Add ^ at the beginning if it's not already there
|
|
133
|
+
if not pattern.startswith("^"):
|
|
134
|
+
pattern = "^" + pattern
|
|
135
|
+
|
|
136
|
+
# Add additional . between \. and * if * follows \.
|
|
137
|
+
pattern = re.sub(r"\\\.\*", r"\..*", pattern)
|
|
138
|
+
|
|
139
|
+
transformed_patterns.append(pattern)
|
|
140
|
+
|
|
141
|
+
return "|".join(transformed_patterns)
|
|
142
|
+
|
|
143
|
+
|
|
109
144
|
def prepare_query(
|
|
110
145
|
query: Optional[str],
|
|
111
146
|
workflow_args: Dict[str, Any],
|
|
112
147
|
temp_table_regex_sql: Optional[str] = "",
|
|
148
|
+
use_posix_regex: Optional[bool] = False,
|
|
113
149
|
) -> Optional[str]:
|
|
114
150
|
"""
|
|
115
151
|
Prepares a SQL query by applying include and exclude filters, and optional
|
|
@@ -158,6 +194,14 @@ def prepare_query(
|
|
|
158
194
|
include_filter, exclude_filter
|
|
159
195
|
)
|
|
160
196
|
|
|
197
|
+
if use_posix_regex:
|
|
198
|
+
normalized_include_regex_posix = transform_posix_regex(
|
|
199
|
+
normalized_include_regex
|
|
200
|
+
)
|
|
201
|
+
normalized_exclude_regex_posix = transform_posix_regex(
|
|
202
|
+
normalized_exclude_regex
|
|
203
|
+
)
|
|
204
|
+
|
|
161
205
|
# Extract database names from the normalized regex patterns
|
|
162
206
|
include_databases = extract_database_names_from_regex_common(
|
|
163
207
|
normalized_regex=normalized_include_regex,
|
|
@@ -176,15 +220,26 @@ def prepare_query(
|
|
|
176
220
|
)
|
|
177
221
|
exclude_views = workflow_args.get("metadata", {}).get("exclude_views", False)
|
|
178
222
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
223
|
+
if use_posix_regex:
|
|
224
|
+
return query.format(
|
|
225
|
+
include_databases=include_databases,
|
|
226
|
+
exclude_databases=exclude_databases,
|
|
227
|
+
normalized_include_regex=normalized_include_regex_posix,
|
|
228
|
+
normalized_exclude_regex=normalized_exclude_regex_posix,
|
|
229
|
+
temp_table_regex_sql=temp_table_regex_sql,
|
|
230
|
+
exclude_empty_tables=exclude_empty_tables,
|
|
231
|
+
exclude_views=exclude_views,
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
return query.format(
|
|
235
|
+
include_databases=include_databases,
|
|
236
|
+
exclude_databases=exclude_databases,
|
|
237
|
+
normalized_include_regex=normalized_include_regex,
|
|
238
|
+
normalized_exclude_regex=normalized_exclude_regex,
|
|
239
|
+
temp_table_regex_sql=temp_table_regex_sql,
|
|
240
|
+
exclude_empty_tables=exclude_empty_tables,
|
|
241
|
+
exclude_views=exclude_views,
|
|
242
|
+
)
|
|
188
243
|
except CommonError as e:
|
|
189
244
|
# Extract the original error message from the CommonError
|
|
190
245
|
error_message = str(e).split(": ", 1)[-1] if ": " in str(e) else str(e)
|
|
@@ -195,6 +250,47 @@ def prepare_query(
|
|
|
195
250
|
return None
|
|
196
251
|
|
|
197
252
|
|
|
253
|
+
async def get_database_names(
|
|
254
|
+
sql_client, workflow_args, fetch_database_sql
|
|
255
|
+
) -> Optional[List[str]]:
|
|
256
|
+
"""
|
|
257
|
+
Get the database names from the workflow args if include-filter is present
|
|
258
|
+
Args:
|
|
259
|
+
workflow_args: The workflow args
|
|
260
|
+
Returns:
|
|
261
|
+
List[str]: The database names
|
|
262
|
+
"""
|
|
263
|
+
database_names = parse_filter_input(
|
|
264
|
+
workflow_args.get("metadata", {}).get("include-filter", {})
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
database_names = [
|
|
268
|
+
re.sub(r"^[^\w]+|[^\w]+$", "", database_name)
|
|
269
|
+
for database_name in database_names
|
|
270
|
+
]
|
|
271
|
+
if not database_names:
|
|
272
|
+
# if database_names are not provided in the include-filter, we'll run the query to get all the database names
|
|
273
|
+
# because by default for an empty include-filter, we fetch details corresponding to all the databases.
|
|
274
|
+
temp_table_regex_sql = workflow_args.get("metadata", {}).get(
|
|
275
|
+
"temp-table-regex", ""
|
|
276
|
+
)
|
|
277
|
+
prepared_query = prepare_query(
|
|
278
|
+
query=fetch_database_sql,
|
|
279
|
+
workflow_args=workflow_args,
|
|
280
|
+
temp_table_regex_sql=temp_table_regex_sql,
|
|
281
|
+
use_posix_regex=True,
|
|
282
|
+
)
|
|
283
|
+
# We'll run the query to get all the database names
|
|
284
|
+
database_sql_input = SQLQueryInput(
|
|
285
|
+
engine=sql_client.engine,
|
|
286
|
+
query=prepared_query, # type: ignore
|
|
287
|
+
chunk_size=None,
|
|
288
|
+
)
|
|
289
|
+
database_dataframe = await database_sql_input.get_dataframe()
|
|
290
|
+
database_names = list(database_dataframe["database_name"])
|
|
291
|
+
return database_names
|
|
292
|
+
|
|
293
|
+
|
|
198
294
|
def parse_filter_input(
|
|
199
295
|
filter_input: Union[str, Dict[str, Any], None],
|
|
200
296
|
) -> Dict[str, Any]:
|
|
@@ -416,6 +512,46 @@ def parse_credentials_extra(credentials: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
416
512
|
return extra # We know it's a Dict[str, Any] due to the Union type and str check
|
|
417
513
|
|
|
418
514
|
|
|
515
|
+
def has_custom_control_config(workflow_args: Dict[str, Any]) -> bool:
|
|
516
|
+
"""
|
|
517
|
+
Check if custom control configuration is present in workflow arguments.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
workflow_args: The workflow arguments
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
bool: True if custom control configuration is present, False otherwise
|
|
524
|
+
"""
|
|
525
|
+
return (
|
|
526
|
+
workflow_args.get("control-config-strategy") == "custom"
|
|
527
|
+
and workflow_args.get("control-config") is not None
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
async def get_file_names(output_path: str, typename: str) -> List[str]:
|
|
532
|
+
"""
|
|
533
|
+
Get file names for a specific asset type from the transformed directory.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
output_path (str): The base output path
|
|
537
|
+
typename (str): The asset type (e.g., 'table', 'schema', 'column')
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
List[str]: List of relative file paths for the asset type
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
source = get_object_store_prefix(os.path.join(output_path, typename))
|
|
544
|
+
await ObjectStore.download_prefix(source, TEMPORARY_PATH)
|
|
545
|
+
|
|
546
|
+
file_pattern = os.path.join(output_path, typename, "*.json")
|
|
547
|
+
file_names = glob.glob(file_pattern)
|
|
548
|
+
file_name_list = [
|
|
549
|
+
"/".join(file_name.rsplit("/", 2)[-2:]) for file_name in file_names
|
|
550
|
+
]
|
|
551
|
+
|
|
552
|
+
return file_name_list
|
|
553
|
+
|
|
554
|
+
|
|
419
555
|
def run_sync(func):
|
|
420
556
|
"""Run a function in a thread pool executor.
|
|
421
557
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any, Dict
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class HandlerInterface(ABC):
|
|
@@ -37,3 +37,10 @@ class HandlerInterface(ABC):
|
|
|
37
37
|
To be implemented by the subclass
|
|
38
38
|
"""
|
|
39
39
|
raise NotImplementedError("fetch_metadata method not implemented")
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
async def get_configmap(config_map_id: str) -> Dict[str, Any]:
|
|
43
|
+
"""
|
|
44
|
+
Static method to get the configmap
|
|
45
|
+
"""
|
|
46
|
+
return {}
|