atlan-application-sdk 0.1.1rc40__py3-none-any.whl → 0.1.1rc41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- application_sdk/activities/metadata_extraction/sql.py +400 -25
- application_sdk/application/__init__.py +2 -0
- application_sdk/application/metadata_extraction/sql.py +3 -0
- application_sdk/clients/models.py +42 -0
- application_sdk/clients/sql.py +17 -13
- application_sdk/common/aws_utils.py +259 -11
- application_sdk/common/utils.py +145 -9
- application_sdk/handlers/__init__.py +8 -1
- application_sdk/handlers/sql.py +63 -22
- application_sdk/observability/decorators/observability_decorator.py +36 -22
- application_sdk/server/fastapi/__init__.py +59 -3
- application_sdk/server/fastapi/models.py +27 -0
- application_sdk/version.py +1 -1
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/METADATA +1 -1
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/RECORD +18 -17
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/WHEEL +0 -0
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/licenses/LICENSE +0 -0
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,4 +1,13 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
import boto3
|
|
5
|
+
from sqlalchemy.engine.url import URL
|
|
6
|
+
|
|
1
7
|
from application_sdk.constants import AWS_SESSION_NAME
|
|
8
|
+
from application_sdk.observability.logger_adaptor import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
2
11
|
|
|
3
12
|
|
|
4
13
|
def get_region_name_from_hostname(hostname: str) -> str:
|
|
@@ -12,11 +21,14 @@ def get_region_name_from_hostname(hostname: str) -> str:
|
|
|
12
21
|
Returns:
|
|
13
22
|
str: AWS region name
|
|
14
23
|
"""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
24
|
+
match = re.search(r"\.([a-z]{2}-[a-z]+-\d)\.", hostname)
|
|
25
|
+
if match:
|
|
26
|
+
return match.group(1)
|
|
27
|
+
# Some services may use - instead of . (rare)
|
|
28
|
+
match = re.search(r"-([a-z]{2}-[a-z]+-\d)\.", hostname)
|
|
29
|
+
if match:
|
|
30
|
+
return match.group(1)
|
|
31
|
+
raise ValueError("Could not find valid AWS region from hostname")
|
|
20
32
|
|
|
21
33
|
|
|
22
34
|
def generate_aws_rds_token_with_iam_role(
|
|
@@ -55,12 +67,10 @@ def generate_aws_rds_token_with_iam_role(
|
|
|
55
67
|
)
|
|
56
68
|
|
|
57
69
|
credentials = assumed_role["Credentials"]
|
|
58
|
-
aws_client =
|
|
59
|
-
"rds",
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
aws_session_token=credentials["SessionToken"],
|
|
63
|
-
region_name=region or get_region_name_from_hostname(host),
|
|
70
|
+
aws_client = create_aws_client(
|
|
71
|
+
service="rds",
|
|
72
|
+
region=region or get_region_name_from_hostname(host),
|
|
73
|
+
temp_credentials=credentials,
|
|
64
74
|
)
|
|
65
75
|
token: str = aws_client.generate_db_auth_token(
|
|
66
76
|
DBHostname=host, Port=port, DBUsername=user
|
|
@@ -107,3 +117,241 @@ def generate_aws_rds_token_with_iam_user(
|
|
|
107
117
|
return token
|
|
108
118
|
except Exception as e:
|
|
109
119
|
raise Exception(f"Failed to get user credentials: {str(e)}")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_cluster_identifier(aws_client) -> Optional[str]:
|
|
123
|
+
"""
|
|
124
|
+
Retrieve the cluster identifier from AWS Redshift clusters.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
aws_client: Boto3 Redshift client instance
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
str: The cluster identifier
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
RuntimeError: If no clusters are found
|
|
134
|
+
"""
|
|
135
|
+
clusters = aws_client.describe_clusters()
|
|
136
|
+
|
|
137
|
+
for cluster in clusters["Clusters"]:
|
|
138
|
+
cluster_identifier = cluster.get("ClusterIdentifier")
|
|
139
|
+
if cluster_identifier:
|
|
140
|
+
# Optionally, you can add logic to filter clusters if needed
|
|
141
|
+
# we are reading first clusters ID if not provided
|
|
142
|
+
return cluster_identifier # Just return the string
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def create_aws_session(credentials: Dict[str, Any]) -> boto3.Session:
|
|
147
|
+
"""
|
|
148
|
+
Create a boto3 session with AWS credentials.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
credentials: Dictionary containing AWS credentials
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
boto3.Session: Configured boto3 session
|
|
155
|
+
"""
|
|
156
|
+
aws_access_key_id = credentials.get("aws_access_key_id") or credentials.get(
|
|
157
|
+
"username"
|
|
158
|
+
)
|
|
159
|
+
aws_secret_access_key = credentials.get("aws_secret_access_key") or credentials.get(
|
|
160
|
+
"password"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return boto3.Session(
|
|
164
|
+
aws_access_key_id=aws_access_key_id,
|
|
165
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_cluster_credentials(
|
|
170
|
+
aws_client, credentials: Dict[str, Any], extra: Dict[str, Any]
|
|
171
|
+
) -> Dict[str, str]:
|
|
172
|
+
"""
|
|
173
|
+
Retrieve cluster credentials using IAM authentication.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
aws_client: Boto3 Redshift client instance
|
|
177
|
+
credentials: Dictionary containing connection credentials
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dict[str, str]: Dictionary containing DbUser and DbPassword
|
|
181
|
+
"""
|
|
182
|
+
database = extra["database"]
|
|
183
|
+
cluster_identifier = credentials.get("cluster_id") or get_cluster_identifier(
|
|
184
|
+
aws_client
|
|
185
|
+
)
|
|
186
|
+
return aws_client.get_cluster_credentials_with_iam(
|
|
187
|
+
DbName=database,
|
|
188
|
+
ClusterIdentifier=cluster_identifier,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def create_aws_client(
|
|
193
|
+
service: str,
|
|
194
|
+
region: str,
|
|
195
|
+
session: Optional[boto3.Session] = None,
|
|
196
|
+
temp_credentials: Optional[Dict[str, str]] = None,
|
|
197
|
+
use_default_credentials: bool = False,
|
|
198
|
+
) -> Any:
|
|
199
|
+
"""
|
|
200
|
+
Create an AWS client with flexible credential options.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
service: AWS service name (e.g., 'redshift', 'redshift-serverless', 'sts', 'rds')
|
|
204
|
+
region: AWS region name
|
|
205
|
+
session: Optional boto3 session instance. If provided, uses session credentials
|
|
206
|
+
temp_credentials: Optional dictionary containing temporary credentials from assume_role.
|
|
207
|
+
Must contain 'AccessKeyId', 'SecretAccessKey', and 'SessionToken'
|
|
208
|
+
use_default_credentials: If True, uses default AWS credentials (environment, IAM role, etc.)
|
|
209
|
+
This is the fallback if no other credentials are provided
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
AWS client instance
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ValueError: If invalid credential combination is provided
|
|
216
|
+
Exception: If client creation fails
|
|
217
|
+
|
|
218
|
+
Examples:
|
|
219
|
+
Using temporary credentials::
|
|
220
|
+
|
|
221
|
+
client = create_aws_client(
|
|
222
|
+
service="redshift",
|
|
223
|
+
region="us-east-1",
|
|
224
|
+
temp_credentials={
|
|
225
|
+
"AccessKeyId": "AKIA...",
|
|
226
|
+
"SecretAccessKey": "...",
|
|
227
|
+
"SessionToken": "..."
|
|
228
|
+
}
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
Using a session::
|
|
232
|
+
|
|
233
|
+
session = boto3.Session(profile_name="my-profile")
|
|
234
|
+
client = create_aws_client(
|
|
235
|
+
service="rds",
|
|
236
|
+
region="us-west-2",
|
|
237
|
+
session=session
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
Using default credentials::
|
|
241
|
+
|
|
242
|
+
client = create_aws_client(
|
|
243
|
+
service="sts",
|
|
244
|
+
region="us-east-1",
|
|
245
|
+
use_default_credentials=True
|
|
246
|
+
)
|
|
247
|
+
"""
|
|
248
|
+
# Validate credential options
|
|
249
|
+
credential_sources = sum(
|
|
250
|
+
[session is not None, temp_credentials is not None, use_default_credentials]
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if credential_sources == 0:
|
|
254
|
+
raise ValueError("At least one credential source must be provided")
|
|
255
|
+
if credential_sources > 1:
|
|
256
|
+
raise ValueError("Only one credential source should be provided at a time")
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
# Priority 1: Use provided session
|
|
260
|
+
if session is not None:
|
|
261
|
+
logger.debug(
|
|
262
|
+
f"Creating {service} client using provided session in region {region}"
|
|
263
|
+
)
|
|
264
|
+
return session.client(service, region_name=region) # type: ignore
|
|
265
|
+
|
|
266
|
+
# Priority 2: Use temporary credentials
|
|
267
|
+
if temp_credentials is not None:
|
|
268
|
+
logger.debug(
|
|
269
|
+
f"Creating {service} client using temporary credentials in region {region}"
|
|
270
|
+
)
|
|
271
|
+
return boto3.client( # type: ignore
|
|
272
|
+
service,
|
|
273
|
+
aws_access_key_id=temp_credentials["AccessKeyId"],
|
|
274
|
+
aws_secret_access_key=temp_credentials["SecretAccessKey"],
|
|
275
|
+
aws_session_token=temp_credentials["SessionToken"],
|
|
276
|
+
region_name=region,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Priority 3: Use default credentials
|
|
280
|
+
if use_default_credentials:
|
|
281
|
+
logger.debug(
|
|
282
|
+
f"Creating {service} client using default credentials in region {region}"
|
|
283
|
+
)
|
|
284
|
+
return boto3.client(service, region_name=region) # type: ignore
|
|
285
|
+
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.error(f"Failed to create {service} client in region {region}: {e}")
|
|
288
|
+
raise Exception(f"Failed to create {service} client: {str(e)}")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def create_engine_url(
|
|
292
|
+
drivername: str,
|
|
293
|
+
credentials: Dict[str, Any],
|
|
294
|
+
cluster_credentials: Dict[str, str],
|
|
295
|
+
extra: Dict[str, Any],
|
|
296
|
+
) -> URL:
|
|
297
|
+
"""
|
|
298
|
+
Create SQLAlchemy engine URL for Redshift connection.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
credentials: Dictionary containing connection credentials
|
|
302
|
+
cluster_credentials: Dictionary containing DbUser and DbPassword
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
URL: SQLAlchemy engine URL
|
|
306
|
+
"""
|
|
307
|
+
host = credentials["host"]
|
|
308
|
+
port = credentials.get("port")
|
|
309
|
+
database = extra["database"]
|
|
310
|
+
|
|
311
|
+
return URL.create(
|
|
312
|
+
drivername=drivername,
|
|
313
|
+
username=cluster_credentials["DbUser"],
|
|
314
|
+
password=cluster_credentials["DbPassword"],
|
|
315
|
+
host=host,
|
|
316
|
+
port=port,
|
|
317
|
+
database=database,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def get_all_aws_regions() -> list[str]:
|
|
322
|
+
"""
|
|
323
|
+
Get all available AWS regions dynamically using EC2 describe_regions API.
|
|
324
|
+
Returns:
|
|
325
|
+
list[str]: List of all AWS region names
|
|
326
|
+
Raises:
|
|
327
|
+
Exception: If unable to retrieve regions from AWS
|
|
328
|
+
"""
|
|
329
|
+
try:
|
|
330
|
+
# Use us-east-1 as the default region for the EC2 client since it's always available
|
|
331
|
+
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
|
332
|
+
response = ec2_client.describe_regions()
|
|
333
|
+
regions = [region["RegionName"] for region in response["Regions"]]
|
|
334
|
+
return sorted(regions) # Sort for consistent ordering
|
|
335
|
+
except Exception as e:
|
|
336
|
+
# Fallback to a comprehensive hardcoded list if API call fails
|
|
337
|
+
logger.warning(
|
|
338
|
+
f"Failed to retrieve AWS regions dynamically: {e}. Using fallback list."
|
|
339
|
+
)
|
|
340
|
+
return [
|
|
341
|
+
"ap-northeast-1",
|
|
342
|
+
"ap-south-1",
|
|
343
|
+
"ap-southeast-1",
|
|
344
|
+
"ap-southeast-2",
|
|
345
|
+
"aws-global",
|
|
346
|
+
"ca-central-1",
|
|
347
|
+
"eu-central-1",
|
|
348
|
+
"eu-north-1",
|
|
349
|
+
"eu-west-1",
|
|
350
|
+
"eu-west-2",
|
|
351
|
+
"eu-west-3",
|
|
352
|
+
"sa-east-1",
|
|
353
|
+
"us-east-1",
|
|
354
|
+
"us-east-2",
|
|
355
|
+
"us-west-1",
|
|
356
|
+
"us-west-2",
|
|
357
|
+
]
|
application_sdk/common/utils.py
CHANGED
|
@@ -17,8 +17,12 @@ from typing import (
|
|
|
17
17
|
Union,
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
+
from application_sdk.activities.common.utils import get_object_store_prefix
|
|
20
21
|
from application_sdk.common.error_codes import CommonError
|
|
22
|
+
from application_sdk.constants import TEMPORARY_PATH
|
|
23
|
+
from application_sdk.inputs.sql_query import SQLQueryInput
|
|
21
24
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
25
|
+
from application_sdk.services.objectstore import ObjectStore
|
|
22
26
|
|
|
23
27
|
logger = get_logger(__name__)
|
|
24
28
|
|
|
@@ -106,10 +110,42 @@ def extract_database_names_from_regex_common(
|
|
|
106
110
|
return empty_default
|
|
107
111
|
|
|
108
112
|
|
|
113
|
+
def transform_posix_regex(regex_pattern: str) -> str:
|
|
114
|
+
r"""
|
|
115
|
+
Transform regex pattern for POSIX compatibility.
|
|
116
|
+
|
|
117
|
+
Rules:
|
|
118
|
+
1. Add ^ before each database name before \.
|
|
119
|
+
2. Add an additional . between \. and * if * follows \.
|
|
120
|
+
|
|
121
|
+
Example: 'dev\.public$|dev\.atlan_test_schema$|wide_world_importers\.*'
|
|
122
|
+
Becomes: '^dev\.public$|^dev\.atlan_test_schema$|^wide_world_importers\..*'
|
|
123
|
+
"""
|
|
124
|
+
if not regex_pattern:
|
|
125
|
+
return regex_pattern
|
|
126
|
+
|
|
127
|
+
# Split by | to handle each pattern separately
|
|
128
|
+
patterns = regex_pattern.split("|")
|
|
129
|
+
transformed_patterns = []
|
|
130
|
+
|
|
131
|
+
for pattern in patterns:
|
|
132
|
+
# Add ^ at the beginning if it's not already there
|
|
133
|
+
if not pattern.startswith("^"):
|
|
134
|
+
pattern = "^" + pattern
|
|
135
|
+
|
|
136
|
+
# Add additional . between \. and * if * follows \.
|
|
137
|
+
pattern = re.sub(r"\\\.\*", r"\..*", pattern)
|
|
138
|
+
|
|
139
|
+
transformed_patterns.append(pattern)
|
|
140
|
+
|
|
141
|
+
return "|".join(transformed_patterns)
|
|
142
|
+
|
|
143
|
+
|
|
109
144
|
def prepare_query(
|
|
110
145
|
query: Optional[str],
|
|
111
146
|
workflow_args: Dict[str, Any],
|
|
112
147
|
temp_table_regex_sql: Optional[str] = "",
|
|
148
|
+
use_posix_regex: Optional[bool] = False,
|
|
113
149
|
) -> Optional[str]:
|
|
114
150
|
"""
|
|
115
151
|
Prepares a SQL query by applying include and exclude filters, and optional
|
|
@@ -158,6 +194,14 @@ def prepare_query(
|
|
|
158
194
|
include_filter, exclude_filter
|
|
159
195
|
)
|
|
160
196
|
|
|
197
|
+
if use_posix_regex:
|
|
198
|
+
normalized_include_regex_posix = transform_posix_regex(
|
|
199
|
+
normalized_include_regex
|
|
200
|
+
)
|
|
201
|
+
normalized_exclude_regex_posix = transform_posix_regex(
|
|
202
|
+
normalized_exclude_regex
|
|
203
|
+
)
|
|
204
|
+
|
|
161
205
|
# Extract database names from the normalized regex patterns
|
|
162
206
|
include_databases = extract_database_names_from_regex_common(
|
|
163
207
|
normalized_regex=normalized_include_regex,
|
|
@@ -176,15 +220,26 @@ def prepare_query(
|
|
|
176
220
|
)
|
|
177
221
|
exclude_views = workflow_args.get("metadata", {}).get("exclude_views", False)
|
|
178
222
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
223
|
+
if use_posix_regex:
|
|
224
|
+
return query.format(
|
|
225
|
+
include_databases=include_databases,
|
|
226
|
+
exclude_databases=exclude_databases,
|
|
227
|
+
normalized_include_regex=normalized_include_regex_posix,
|
|
228
|
+
normalized_exclude_regex=normalized_exclude_regex_posix,
|
|
229
|
+
temp_table_regex_sql=temp_table_regex_sql,
|
|
230
|
+
exclude_empty_tables=exclude_empty_tables,
|
|
231
|
+
exclude_views=exclude_views,
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
return query.format(
|
|
235
|
+
include_databases=include_databases,
|
|
236
|
+
exclude_databases=exclude_databases,
|
|
237
|
+
normalized_include_regex=normalized_include_regex,
|
|
238
|
+
normalized_exclude_regex=normalized_exclude_regex,
|
|
239
|
+
temp_table_regex_sql=temp_table_regex_sql,
|
|
240
|
+
exclude_empty_tables=exclude_empty_tables,
|
|
241
|
+
exclude_views=exclude_views,
|
|
242
|
+
)
|
|
188
243
|
except CommonError as e:
|
|
189
244
|
# Extract the original error message from the CommonError
|
|
190
245
|
error_message = str(e).split(": ", 1)[-1] if ": " in str(e) else str(e)
|
|
@@ -195,6 +250,47 @@ def prepare_query(
|
|
|
195
250
|
return None
|
|
196
251
|
|
|
197
252
|
|
|
253
|
+
async def get_database_names(
|
|
254
|
+
sql_client, workflow_args, fetch_database_sql
|
|
255
|
+
) -> Optional[List[str]]:
|
|
256
|
+
"""
|
|
257
|
+
Get the database names from the workflow args if include-filter is present
|
|
258
|
+
Args:
|
|
259
|
+
workflow_args: The workflow args
|
|
260
|
+
Returns:
|
|
261
|
+
List[str]: The database names
|
|
262
|
+
"""
|
|
263
|
+
database_names = parse_filter_input(
|
|
264
|
+
workflow_args.get("metadata", {}).get("include-filter", {})
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
database_names = [
|
|
268
|
+
re.sub(r"^[^\w]+|[^\w]+$", "", database_name)
|
|
269
|
+
for database_name in database_names
|
|
270
|
+
]
|
|
271
|
+
if not database_names:
|
|
272
|
+
# if database_names are not provided in the include-filter, we'll run the query to get all the database names
|
|
273
|
+
# because by default for an empty include-filter, we fetch details corresponding to all the databases.
|
|
274
|
+
temp_table_regex_sql = workflow_args.get("metadata", {}).get(
|
|
275
|
+
"temp-table-regex", ""
|
|
276
|
+
)
|
|
277
|
+
prepared_query = prepare_query(
|
|
278
|
+
query=fetch_database_sql,
|
|
279
|
+
workflow_args=workflow_args,
|
|
280
|
+
temp_table_regex_sql=temp_table_regex_sql,
|
|
281
|
+
use_posix_regex=True,
|
|
282
|
+
)
|
|
283
|
+
# We'll run the query to get all the database names
|
|
284
|
+
database_sql_input = SQLQueryInput(
|
|
285
|
+
engine=sql_client.engine,
|
|
286
|
+
query=prepared_query, # type: ignore
|
|
287
|
+
chunk_size=None,
|
|
288
|
+
)
|
|
289
|
+
database_dataframe = await database_sql_input.get_dataframe()
|
|
290
|
+
database_names = list(database_dataframe["database_name"])
|
|
291
|
+
return database_names
|
|
292
|
+
|
|
293
|
+
|
|
198
294
|
def parse_filter_input(
|
|
199
295
|
filter_input: Union[str, Dict[str, Any], None],
|
|
200
296
|
) -> Dict[str, Any]:
|
|
@@ -416,6 +512,46 @@ def parse_credentials_extra(credentials: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
416
512
|
return extra # We know it's a Dict[str, Any] due to the Union type and str check
|
|
417
513
|
|
|
418
514
|
|
|
515
|
+
def has_custom_control_config(workflow_args: Dict[str, Any]) -> bool:
|
|
516
|
+
"""
|
|
517
|
+
Check if custom control configuration is present in workflow arguments.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
workflow_args: The workflow arguments
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
bool: True if custom control configuration is present, False otherwise
|
|
524
|
+
"""
|
|
525
|
+
return (
|
|
526
|
+
workflow_args.get("control-config-strategy") == "custom"
|
|
527
|
+
and workflow_args.get("control-config") is not None
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
async def get_file_names(output_path: str, typename: str) -> List[str]:
|
|
532
|
+
"""
|
|
533
|
+
Get file names for a specific asset type from the transformed directory.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
output_path (str): The base output path
|
|
537
|
+
typename (str): The asset type (e.g., 'table', 'schema', 'column')
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
List[str]: List of relative file paths for the asset type
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
source = get_object_store_prefix(os.path.join(output_path, typename))
|
|
544
|
+
await ObjectStore.download_prefix(source, TEMPORARY_PATH)
|
|
545
|
+
|
|
546
|
+
file_pattern = os.path.join(output_path, typename, "*.json")
|
|
547
|
+
file_names = glob.glob(file_pattern)
|
|
548
|
+
file_name_list = [
|
|
549
|
+
"/".join(file_name.rsplit("/", 2)[-2:]) for file_name in file_names
|
|
550
|
+
]
|
|
551
|
+
|
|
552
|
+
return file_name_list
|
|
553
|
+
|
|
554
|
+
|
|
419
555
|
def run_sync(func):
|
|
420
556
|
"""Run a function in a thread pool executor.
|
|
421
557
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any, Dict
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class HandlerInterface(ABC):
|
|
@@ -37,3 +37,10 @@ class HandlerInterface(ABC):
|
|
|
37
37
|
To be implemented by the subclass
|
|
38
38
|
"""
|
|
39
39
|
raise NotImplementedError("fetch_metadata method not implemented")
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
async def get_configmap(config_map_id: str) -> Dict[str, Any]:
|
|
43
|
+
"""
|
|
44
|
+
Static method to get the configmap
|
|
45
|
+
"""
|
|
46
|
+
return {}
|
application_sdk/handlers/sql.py
CHANGED
|
@@ -56,9 +56,13 @@ class BaseSQLHandler(HandlerInterface):
|
|
|
56
56
|
schema_alias_key: str = SQLConstants.SCHEMA_ALIAS_KEY.value
|
|
57
57
|
database_result_key: str = SQLConstants.DATABASE_RESULT_KEY.value
|
|
58
58
|
schema_result_key: str = SQLConstants.SCHEMA_RESULT_KEY.value
|
|
59
|
+
multidb: bool = False
|
|
59
60
|
|
|
60
|
-
def __init__(
|
|
61
|
+
def __init__(
|
|
62
|
+
self, sql_client: BaseSQLClient | None = None, multidb: Optional[bool] = False
|
|
63
|
+
):
|
|
61
64
|
self.sql_client = sql_client
|
|
65
|
+
self.multidb = multidb
|
|
62
66
|
|
|
63
67
|
async def load(self, credentials: Dict[str, Any]) -> None:
|
|
64
68
|
"""
|
|
@@ -294,35 +298,26 @@ class BaseSQLHandler(HandlerInterface):
|
|
|
294
298
|
return False, f"{db}.{sch} schema"
|
|
295
299
|
return True, ""
|
|
296
300
|
|
|
297
|
-
async def tables_check(
|
|
298
|
-
self,
|
|
299
|
-
payload: Dict[str, Any],
|
|
300
|
-
) -> Dict[str, Any]:
|
|
301
|
+
async def tables_check(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
301
302
|
"""
|
|
302
303
|
Method to check the count of tables
|
|
303
304
|
"""
|
|
304
305
|
logger.info("Starting tables check")
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
engine=self.sql_client.engine, query=query, chunk_size=None
|
|
314
|
-
)
|
|
315
|
-
sql_input = await sql_input.get_dataframe()
|
|
316
|
-
try:
|
|
317
|
-
result = 0
|
|
318
|
-
for row in sql_input.to_dict(orient="records"):
|
|
319
|
-
result += row["count"]
|
|
306
|
+
|
|
307
|
+
def _sum_counts_from_records(records_iter) -> int:
|
|
308
|
+
total = 0
|
|
309
|
+
for row in records_iter:
|
|
310
|
+
total += row["count"]
|
|
311
|
+
return total
|
|
312
|
+
|
|
313
|
+
def _build_success(total: int) -> Dict[str, Any]:
|
|
320
314
|
return {
|
|
321
315
|
"success": True,
|
|
322
|
-
"successMessage": f"Tables check successful. Table count: {
|
|
316
|
+
"successMessage": f"Tables check successful. Table count: {total}",
|
|
323
317
|
"failureMessage": "",
|
|
324
318
|
}
|
|
325
|
-
|
|
319
|
+
|
|
320
|
+
def _build_failure(exc: Exception) -> Dict[str, Any]:
|
|
326
321
|
logger.error("Error during tables check", exc_info=True)
|
|
327
322
|
return {
|
|
328
323
|
"success": False,
|
|
@@ -331,6 +326,52 @@ class BaseSQLHandler(HandlerInterface):
|
|
|
331
326
|
"error": str(exc),
|
|
332
327
|
}
|
|
333
328
|
|
|
329
|
+
if self.multidb:
|
|
330
|
+
try:
|
|
331
|
+
from application_sdk.activities.metadata_extraction.sql import (
|
|
332
|
+
BaseSQLMetadataExtractionActivities,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Use the base query executor in multidb mode to get concatenated df
|
|
336
|
+
activities = BaseSQLMetadataExtractionActivities()
|
|
337
|
+
activities.multidb = True
|
|
338
|
+
concatenated_df = await activities.query_executor(
|
|
339
|
+
sql_engine=self.sql_client.engine if self.sql_client else None,
|
|
340
|
+
sql_query=self.tables_check_sql,
|
|
341
|
+
workflow_args=payload,
|
|
342
|
+
output_suffix="raw/table",
|
|
343
|
+
typename="table",
|
|
344
|
+
write_to_file=False,
|
|
345
|
+
concatenate=True,
|
|
346
|
+
return_dataframe=True,
|
|
347
|
+
sql_client=self.sql_client,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
if concatenated_df is None:
|
|
351
|
+
return _build_success(0)
|
|
352
|
+
|
|
353
|
+
total = int(concatenated_df["count"].sum()) # type: ignore[index]
|
|
354
|
+
return _build_success(total)
|
|
355
|
+
except Exception as exc:
|
|
356
|
+
return _build_failure(exc)
|
|
357
|
+
else:
|
|
358
|
+
query = prepare_query(
|
|
359
|
+
query=self.tables_check_sql,
|
|
360
|
+
workflow_args=payload,
|
|
361
|
+
temp_table_regex_sql=self.extract_temp_table_regex_table_sql,
|
|
362
|
+
)
|
|
363
|
+
if not query:
|
|
364
|
+
raise ValueError("tables_check_sql is not defined")
|
|
365
|
+
sql_input = SQLQueryInput(
|
|
366
|
+
engine=self.sql_client.engine, query=query, chunk_size=None
|
|
367
|
+
)
|
|
368
|
+
sql_input = await sql_input.get_dataframe()
|
|
369
|
+
try:
|
|
370
|
+
total = _sum_counts_from_records(sql_input.to_dict(orient="records"))
|
|
371
|
+
return _build_success(total)
|
|
372
|
+
except Exception as exc:
|
|
373
|
+
return _build_failure(exc)
|
|
374
|
+
|
|
334
375
|
async def check_client_version(self) -> Dict[str, Any]:
|
|
335
376
|
"""
|
|
336
377
|
Check if the client version meets the minimum required version.
|