atlan-application-sdk 0.1.1rc40__py3-none-any.whl → 0.1.1rc42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. application_sdk/activities/common/utils.py +78 -4
  2. application_sdk/activities/metadata_extraction/sql.py +400 -27
  3. application_sdk/application/__init__.py +2 -0
  4. application_sdk/application/metadata_extraction/sql.py +3 -0
  5. application_sdk/clients/models.py +42 -0
  6. application_sdk/clients/sql.py +17 -13
  7. application_sdk/common/aws_utils.py +259 -11
  8. application_sdk/common/utils.py +145 -9
  9. application_sdk/handlers/__init__.py +8 -1
  10. application_sdk/handlers/sql.py +63 -22
  11. application_sdk/inputs/__init__.py +98 -2
  12. application_sdk/inputs/json.py +59 -87
  13. application_sdk/inputs/parquet.py +173 -94
  14. application_sdk/observability/decorators/observability_decorator.py +36 -22
  15. application_sdk/server/fastapi/__init__.py +59 -3
  16. application_sdk/server/fastapi/models.py +27 -0
  17. application_sdk/test_utils/hypothesis/strategies/inputs/json_input.py +10 -5
  18. application_sdk/test_utils/hypothesis/strategies/inputs/parquet_input.py +9 -4
  19. application_sdk/version.py +1 -1
  20. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/METADATA +1 -1
  21. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/RECORD +24 -23
  22. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/WHEEL +0 -0
  23. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/licenses/LICENSE +0 -0
  24. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,42 @@
1
+ """
2
+ Pydantic models for database client configurations.
3
+ This module provides Pydantic models for database connection configurations,
4
+ ensuring type safety and validation for database client settings.
5
+ """
6
+
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class DatabaseConfig(BaseModel):
13
+ """
14
+ Pydantic model for database connection configuration.
15
+ This model defines the structure for database connection configurations,
16
+ including connection templates, required parameters, defaults, and additional
17
+ connection parameters.
18
+ """
19
+
20
+ template: str = Field(
21
+ ...,
22
+ description="SQLAlchemy connection string template with placeholders for connection parameters",
23
+ )
24
+ required: List[str] = Field(
25
+ default=[],
26
+ description="List of required connection parameters that must be provided",
27
+ )
28
+ defaults: Optional[Dict[str, Any]] = Field(
29
+ default=None,
30
+ description="Default connection parameters to be added to the connection string",
31
+ )
32
+ parameters: Optional[List[str]] = Field(
33
+ default=None,
34
+ description="List of additional connection parameter names that can be dynamically added from credentials",
35
+ )
36
+
37
+ class Config:
38
+ """Pydantic configuration for the DatabaseConfig model."""
39
+
40
+ extra = "forbid" # Prevent additional fields
41
+ validate_assignment = True # Validate on assignment
42
+ use_enum_values = True # Use enum values instead of enum objects
@@ -7,13 +7,14 @@ database operations, supporting batch processing and server-side cursors.
7
7
 
8
8
  import asyncio
9
9
  from concurrent.futures import ThreadPoolExecutor
10
- from typing import Any, Dict, List
10
+ from typing import Any, Dict, List, Optional
11
11
  from urllib.parse import quote_plus
12
12
 
13
13
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
14
14
  from temporalio import activity
15
15
 
16
16
  from application_sdk.clients import ClientInterface
17
+ from application_sdk.clients.models import DatabaseConfig
17
18
  from application_sdk.common.aws_utils import (
18
19
  generate_aws_rds_token_with_iam_role,
19
20
  generate_aws_rds_token_with_iam_user,
@@ -48,7 +49,7 @@ class BaseSQLClient(ClientInterface):
48
49
  credentials: Dict[str, Any] = {}
49
50
  resolved_credentials: Dict[str, Any] = {}
50
51
  use_server_side_cursor: bool = USE_SERVER_SIDE_CURSOR
51
- DB_CONFIG: Dict[str, Any] = {}
52
+ DB_CONFIG: Optional[DatabaseConfig] = None
52
53
 
53
54
  def __init__(
54
55
  self,
@@ -262,7 +263,9 @@ class BaseSQLClient(ClientInterface):
262
263
  Returns:
263
264
  str: The updated URL with the dialect.
264
265
  """
265
- installed_dialect = self.DB_CONFIG["template"].split("://")[0]
266
+ if not self.DB_CONFIG:
267
+ raise ValueError("DB_CONFIG is not configured for this SQL client.")
268
+ installed_dialect = self.DB_CONFIG.template.split("://")[0]
266
269
  url_dialect = sqlalchemy_url.split("://")[0]
267
270
  if installed_dialect != url_dialect:
268
271
  sqlalchemy_url = sqlalchemy_url.replace(url_dialect, installed_dialect)
@@ -281,6 +284,9 @@ class BaseSQLClient(ClientInterface):
281
284
  Raises:
282
285
  ValueError: If required connection parameters are missing.
283
286
  """
287
+ if not self.DB_CONFIG:
288
+ raise ValueError("DB_CONFIG is not configured for this SQL client.")
289
+
284
290
  extra = parse_credentials_extra(self.credentials)
285
291
 
286
292
  # TODO: Uncomment this when the native deployment is ready
@@ -293,7 +299,7 @@ class BaseSQLClient(ClientInterface):
293
299
 
294
300
  # Prepare parameters
295
301
  param_values = {}
296
- for param in self.DB_CONFIG["required"]:
302
+ for param in self.DB_CONFIG.required:
297
303
  if param == "password":
298
304
  param_values[param] = auth_token
299
305
  else:
@@ -303,21 +309,19 @@ class BaseSQLClient(ClientInterface):
303
309
  param_values[param] = value
304
310
 
305
311
  # Fill in base template
306
- conn_str = self.DB_CONFIG["template"].format(**param_values)
312
+ conn_str = self.DB_CONFIG.template.format(**param_values)
307
313
 
308
314
  # Append defaults if not already in the template
309
- if self.DB_CONFIG.get("defaults"):
310
- conn_str = self.add_connection_params(conn_str, self.DB_CONFIG["defaults"])
315
+ if self.DB_CONFIG.defaults:
316
+ conn_str = self.add_connection_params(conn_str, self.DB_CONFIG.defaults)
311
317
 
312
- if self.DB_CONFIG.get("parameters"):
313
- parameter_keys = self.DB_CONFIG["parameters"]
314
- self.DB_CONFIG["parameters"] = {
318
+ if self.DB_CONFIG.parameters:
319
+ parameter_keys = self.DB_CONFIG.parameters
320
+ parameter_values = {
315
321
  key: self.credentials.get(key) or extra.get(key)
316
322
  for key in parameter_keys
317
323
  }
318
- conn_str = self.add_connection_params(
319
- conn_str, self.DB_CONFIG["parameters"]
320
- )
324
+ conn_str = self.add_connection_params(conn_str, parameter_values)
321
325
 
322
326
  return conn_str
323
327
 
@@ -1,4 +1,13 @@
1
+ import re
2
+ from typing import Any, Dict, Optional
3
+
4
+ import boto3
5
+ from sqlalchemy.engine.url import URL
6
+
1
7
  from application_sdk.constants import AWS_SESSION_NAME
8
+ from application_sdk.observability.logger_adaptor import get_logger
9
+
10
+ logger = get_logger(__name__)
2
11
 
3
12
 
4
13
  def get_region_name_from_hostname(hostname: str) -> str:
@@ -12,11 +21,14 @@ def get_region_name_from_hostname(hostname: str) -> str:
12
21
  Returns:
13
22
  str: AWS region name
14
23
  """
15
- parts = hostname.split(".")
16
- for part in parts:
17
- if part.startswith(("us-", "eu-", "ap-", "ca-", "me-", "sa-", "af-")):
18
- return part
19
- raise ValueError(f"Could not find valid AWS region in hostname: {hostname}")
24
+ match = re.search(r"\.([a-z]{2}-[a-z]+-\d)\.", hostname)
25
+ if match:
26
+ return match.group(1)
27
+ # Some services may use - instead of . (rare)
28
+ match = re.search(r"-([a-z]{2}-[a-z]+-\d)\.", hostname)
29
+ if match:
30
+ return match.group(1)
31
+ raise ValueError("Could not find valid AWS region from hostname")
20
32
 
21
33
 
22
34
  def generate_aws_rds_token_with_iam_role(
@@ -55,12 +67,10 @@ def generate_aws_rds_token_with_iam_role(
55
67
  )
56
68
 
57
69
  credentials = assumed_role["Credentials"]
58
- aws_client = client(
59
- "rds",
60
- aws_access_key_id=credentials["AccessKeyId"],
61
- aws_secret_access_key=credentials["SecretAccessKey"],
62
- aws_session_token=credentials["SessionToken"],
63
- region_name=region or get_region_name_from_hostname(host),
70
+ aws_client = create_aws_client(
71
+ service="rds",
72
+ region=region or get_region_name_from_hostname(host),
73
+ temp_credentials=credentials,
64
74
  )
65
75
  token: str = aws_client.generate_db_auth_token(
66
76
  DBHostname=host, Port=port, DBUsername=user
@@ -107,3 +117,241 @@ def generate_aws_rds_token_with_iam_user(
107
117
  return token
108
118
  except Exception as e:
109
119
  raise Exception(f"Failed to get user credentials: {str(e)}")
120
+
121
+
122
+ def get_cluster_identifier(aws_client) -> Optional[str]:
123
+ """
124
+ Retrieve the cluster identifier from AWS Redshift clusters.
125
+
126
+ Args:
127
+ aws_client: Boto3 Redshift client instance
128
+
129
+ Returns:
130
+ str: The cluster identifier
131
+
132
+ Raises:
133
+ RuntimeError: If no clusters are found
134
+ """
135
+ clusters = aws_client.describe_clusters()
136
+
137
+ for cluster in clusters["Clusters"]:
138
+ cluster_identifier = cluster.get("ClusterIdentifier")
139
+ if cluster_identifier:
140
+ # Optionally, you can add logic to filter clusters if needed
141
+ # we are reading first clusters ID if not provided
142
+ return cluster_identifier # Just return the string
143
+ return None
144
+
145
+
146
+ def create_aws_session(credentials: Dict[str, Any]) -> boto3.Session:
147
+ """
148
+ Create a boto3 session with AWS credentials.
149
+
150
+ Args:
151
+ credentials: Dictionary containing AWS credentials
152
+
153
+ Returns:
154
+ boto3.Session: Configured boto3 session
155
+ """
156
+ aws_access_key_id = credentials.get("aws_access_key_id") or credentials.get(
157
+ "username"
158
+ )
159
+ aws_secret_access_key = credentials.get("aws_secret_access_key") or credentials.get(
160
+ "password"
161
+ )
162
+
163
+ return boto3.Session(
164
+ aws_access_key_id=aws_access_key_id,
165
+ aws_secret_access_key=aws_secret_access_key,
166
+ )
167
+
168
+
169
+ def get_cluster_credentials(
170
+ aws_client, credentials: Dict[str, Any], extra: Dict[str, Any]
171
+ ) -> Dict[str, str]:
172
+ """
173
+ Retrieve cluster credentials using IAM authentication.
174
+
175
+ Args:
176
+ aws_client: Boto3 Redshift client instance
177
+ credentials: Dictionary containing connection credentials
178
+
179
+ Returns:
180
+ Dict[str, str]: Dictionary containing DbUser and DbPassword
181
+ """
182
+ database = extra["database"]
183
+ cluster_identifier = credentials.get("cluster_id") or get_cluster_identifier(
184
+ aws_client
185
+ )
186
+ return aws_client.get_cluster_credentials_with_iam(
187
+ DbName=database,
188
+ ClusterIdentifier=cluster_identifier,
189
+ )
190
+
191
+
192
+ def create_aws_client(
193
+ service: str,
194
+ region: str,
195
+ session: Optional[boto3.Session] = None,
196
+ temp_credentials: Optional[Dict[str, str]] = None,
197
+ use_default_credentials: bool = False,
198
+ ) -> Any:
199
+ """
200
+ Create an AWS client with flexible credential options.
201
+
202
+ Args:
203
+ service: AWS service name (e.g., 'redshift', 'redshift-serverless', 'sts', 'rds')
204
+ region: AWS region name
205
+ session: Optional boto3 session instance. If provided, uses session credentials
206
+ temp_credentials: Optional dictionary containing temporary credentials from assume_role.
207
+ Must contain 'AccessKeyId', 'SecretAccessKey', and 'SessionToken'
208
+ use_default_credentials: If True, uses default AWS credentials (environment, IAM role, etc.)
209
+ This is the fallback if no other credentials are provided
210
+
211
+ Returns:
212
+ AWS client instance
213
+
214
+ Raises:
215
+ ValueError: If invalid credential combination is provided
216
+ Exception: If client creation fails
217
+
218
+ Examples:
219
+ Using temporary credentials::
220
+
221
+ client = create_aws_client(
222
+ service="redshift",
223
+ region="us-east-1",
224
+ temp_credentials={
225
+ "AccessKeyId": "AKIA...",
226
+ "SecretAccessKey": "...",
227
+ "SessionToken": "..."
228
+ }
229
+ )
230
+
231
+ Using a session::
232
+
233
+ session = boto3.Session(profile_name="my-profile")
234
+ client = create_aws_client(
235
+ service="rds",
236
+ region="us-west-2",
237
+ session=session
238
+ )
239
+
240
+ Using default credentials::
241
+
242
+ client = create_aws_client(
243
+ service="sts",
244
+ region="us-east-1",
245
+ use_default_credentials=True
246
+ )
247
+ """
248
+ # Validate credential options
249
+ credential_sources = sum(
250
+ [session is not None, temp_credentials is not None, use_default_credentials]
251
+ )
252
+
253
+ if credential_sources == 0:
254
+ raise ValueError("At least one credential source must be provided")
255
+ if credential_sources > 1:
256
+ raise ValueError("Only one credential source should be provided at a time")
257
+
258
+ try:
259
+ # Priority 1: Use provided session
260
+ if session is not None:
261
+ logger.debug(
262
+ f"Creating {service} client using provided session in region {region}"
263
+ )
264
+ return session.client(service, region_name=region) # type: ignore
265
+
266
+ # Priority 2: Use temporary credentials
267
+ if temp_credentials is not None:
268
+ logger.debug(
269
+ f"Creating {service} client using temporary credentials in region {region}"
270
+ )
271
+ return boto3.client( # type: ignore
272
+ service,
273
+ aws_access_key_id=temp_credentials["AccessKeyId"],
274
+ aws_secret_access_key=temp_credentials["SecretAccessKey"],
275
+ aws_session_token=temp_credentials["SessionToken"],
276
+ region_name=region,
277
+ )
278
+
279
+ # Priority 3: Use default credentials
280
+ if use_default_credentials:
281
+ logger.debug(
282
+ f"Creating {service} client using default credentials in region {region}"
283
+ )
284
+ return boto3.client(service, region_name=region) # type: ignore
285
+
286
+ except Exception as e:
287
+ logger.error(f"Failed to create {service} client in region {region}: {e}")
288
+ raise Exception(f"Failed to create {service} client: {str(e)}")
289
+
290
+
291
+ def create_engine_url(
292
+ drivername: str,
293
+ credentials: Dict[str, Any],
294
+ cluster_credentials: Dict[str, str],
295
+ extra: Dict[str, Any],
296
+ ) -> URL:
297
+ """
298
+ Create SQLAlchemy engine URL for Redshift connection.
299
+
300
+ Args:
301
+ credentials: Dictionary containing connection credentials
302
+ cluster_credentials: Dictionary containing DbUser and DbPassword
303
+
304
+ Returns:
305
+ URL: SQLAlchemy engine URL
306
+ """
307
+ host = credentials["host"]
308
+ port = credentials.get("port")
309
+ database = extra["database"]
310
+
311
+ return URL.create(
312
+ drivername=drivername,
313
+ username=cluster_credentials["DbUser"],
314
+ password=cluster_credentials["DbPassword"],
315
+ host=host,
316
+ port=port,
317
+ database=database,
318
+ )
319
+
320
+
321
+ def get_all_aws_regions() -> list[str]:
322
+ """
323
+ Get all available AWS regions dynamically using EC2 describe_regions API.
324
+ Returns:
325
+ list[str]: List of all AWS region names
326
+ Raises:
327
+ Exception: If unable to retrieve regions from AWS
328
+ """
329
+ try:
330
+ # Use us-east-1 as the default region for the EC2 client since it's always available
331
+ ec2_client = boto3.client("ec2", region_name="us-east-1")
332
+ response = ec2_client.describe_regions()
333
+ regions = [region["RegionName"] for region in response["Regions"]]
334
+ return sorted(regions) # Sort for consistent ordering
335
+ except Exception as e:
336
+ # Fallback to a comprehensive hardcoded list if API call fails
337
+ logger.warning(
338
+ f"Failed to retrieve AWS regions dynamically: {e}. Using fallback list."
339
+ )
340
+ return [
341
+ "ap-northeast-1",
342
+ "ap-south-1",
343
+ "ap-southeast-1",
344
+ "ap-southeast-2",
345
+ "aws-global",
346
+ "ca-central-1",
347
+ "eu-central-1",
348
+ "eu-north-1",
349
+ "eu-west-1",
350
+ "eu-west-2",
351
+ "eu-west-3",
352
+ "sa-east-1",
353
+ "us-east-1",
354
+ "us-east-2",
355
+ "us-west-1",
356
+ "us-west-2",
357
+ ]
@@ -17,8 +17,12 @@ from typing import (
17
17
  Union,
18
18
  )
19
19
 
20
+ from application_sdk.activities.common.utils import get_object_store_prefix
20
21
  from application_sdk.common.error_codes import CommonError
22
+ from application_sdk.constants import TEMPORARY_PATH
23
+ from application_sdk.inputs.sql_query import SQLQueryInput
21
24
  from application_sdk.observability.logger_adaptor import get_logger
25
+ from application_sdk.services.objectstore import ObjectStore
22
26
 
23
27
  logger = get_logger(__name__)
24
28
 
@@ -106,10 +110,42 @@ def extract_database_names_from_regex_common(
106
110
  return empty_default
107
111
 
108
112
 
113
+ def transform_posix_regex(regex_pattern: str) -> str:
114
+ r"""
115
+ Transform regex pattern for POSIX compatibility.
116
+
117
+ Rules:
118
+ 1. Add ^ before each database name before \.
119
+ 2. Add an additional . between \. and * if * follows \.
120
+
121
+ Example: 'dev\.public$|dev\.atlan_test_schema$|wide_world_importers\.*'
122
+ Becomes: '^dev\.public$|^dev\.atlan_test_schema$|^wide_world_importers\..*'
123
+ """
124
+ if not regex_pattern:
125
+ return regex_pattern
126
+
127
+ # Split by | to handle each pattern separately
128
+ patterns = regex_pattern.split("|")
129
+ transformed_patterns = []
130
+
131
+ for pattern in patterns:
132
+ # Add ^ at the beginning if it's not already there
133
+ if not pattern.startswith("^"):
134
+ pattern = "^" + pattern
135
+
136
+ # Add additional . between \. and * if * follows \.
137
+ pattern = re.sub(r"\\\.\*", r"\..*", pattern)
138
+
139
+ transformed_patterns.append(pattern)
140
+
141
+ return "|".join(transformed_patterns)
142
+
143
+
109
144
  def prepare_query(
110
145
  query: Optional[str],
111
146
  workflow_args: Dict[str, Any],
112
147
  temp_table_regex_sql: Optional[str] = "",
148
+ use_posix_regex: Optional[bool] = False,
113
149
  ) -> Optional[str]:
114
150
  """
115
151
  Prepares a SQL query by applying include and exclude filters, and optional
@@ -158,6 +194,14 @@ def prepare_query(
158
194
  include_filter, exclude_filter
159
195
  )
160
196
 
197
+ if use_posix_regex:
198
+ normalized_include_regex_posix = transform_posix_regex(
199
+ normalized_include_regex
200
+ )
201
+ normalized_exclude_regex_posix = transform_posix_regex(
202
+ normalized_exclude_regex
203
+ )
204
+
161
205
  # Extract database names from the normalized regex patterns
162
206
  include_databases = extract_database_names_from_regex_common(
163
207
  normalized_regex=normalized_include_regex,
@@ -176,15 +220,26 @@ def prepare_query(
176
220
  )
177
221
  exclude_views = workflow_args.get("metadata", {}).get("exclude_views", False)
178
222
 
179
- return query.format(
180
- include_databases=include_databases,
181
- exclude_databases=exclude_databases,
182
- normalized_include_regex=normalized_include_regex,
183
- normalized_exclude_regex=normalized_exclude_regex,
184
- temp_table_regex_sql=temp_table_regex_sql,
185
- exclude_empty_tables=exclude_empty_tables,
186
- exclude_views=exclude_views,
187
- )
223
+ if use_posix_regex:
224
+ return query.format(
225
+ include_databases=include_databases,
226
+ exclude_databases=exclude_databases,
227
+ normalized_include_regex=normalized_include_regex_posix,
228
+ normalized_exclude_regex=normalized_exclude_regex_posix,
229
+ temp_table_regex_sql=temp_table_regex_sql,
230
+ exclude_empty_tables=exclude_empty_tables,
231
+ exclude_views=exclude_views,
232
+ )
233
+ else:
234
+ return query.format(
235
+ include_databases=include_databases,
236
+ exclude_databases=exclude_databases,
237
+ normalized_include_regex=normalized_include_regex,
238
+ normalized_exclude_regex=normalized_exclude_regex,
239
+ temp_table_regex_sql=temp_table_regex_sql,
240
+ exclude_empty_tables=exclude_empty_tables,
241
+ exclude_views=exclude_views,
242
+ )
188
243
  except CommonError as e:
189
244
  # Extract the original error message from the CommonError
190
245
  error_message = str(e).split(": ", 1)[-1] if ": " in str(e) else str(e)
@@ -195,6 +250,47 @@ def prepare_query(
195
250
  return None
196
251
 
197
252
 
253
+ async def get_database_names(
254
+ sql_client, workflow_args, fetch_database_sql
255
+ ) -> Optional[List[str]]:
256
+ """
257
+ Get the database names from the workflow args if include-filter is present
258
+ Args:
259
+ workflow_args: The workflow args
260
+ Returns:
261
+ List[str]: The database names
262
+ """
263
+ database_names = parse_filter_input(
264
+ workflow_args.get("metadata", {}).get("include-filter", {})
265
+ )
266
+
267
+ database_names = [
268
+ re.sub(r"^[^\w]+|[^\w]+$", "", database_name)
269
+ for database_name in database_names
270
+ ]
271
+ if not database_names:
272
+ # if database_names are not provided in the include-filter, we'll run the query to get all the database names
273
+ # because by default for an empty include-filter, we fetch details corresponding to all the databases.
274
+ temp_table_regex_sql = workflow_args.get("metadata", {}).get(
275
+ "temp-table-regex", ""
276
+ )
277
+ prepared_query = prepare_query(
278
+ query=fetch_database_sql,
279
+ workflow_args=workflow_args,
280
+ temp_table_regex_sql=temp_table_regex_sql,
281
+ use_posix_regex=True,
282
+ )
283
+ # We'll run the query to get all the database names
284
+ database_sql_input = SQLQueryInput(
285
+ engine=sql_client.engine,
286
+ query=prepared_query, # type: ignore
287
+ chunk_size=None,
288
+ )
289
+ database_dataframe = await database_sql_input.get_dataframe()
290
+ database_names = list(database_dataframe["database_name"])
291
+ return database_names
292
+
293
+
198
294
  def parse_filter_input(
199
295
  filter_input: Union[str, Dict[str, Any], None],
200
296
  ) -> Dict[str, Any]:
@@ -416,6 +512,46 @@ def parse_credentials_extra(credentials: Dict[str, Any]) -> Dict[str, Any]:
416
512
  return extra # We know it's a Dict[str, Any] due to the Union type and str check
417
513
 
418
514
 
515
+ def has_custom_control_config(workflow_args: Dict[str, Any]) -> bool:
516
+ """
517
+ Check if custom control configuration is present in workflow arguments.
518
+
519
+ Args:
520
+ workflow_args: The workflow arguments
521
+
522
+ Returns:
523
+ bool: True if custom control configuration is present, False otherwise
524
+ """
525
+ return (
526
+ workflow_args.get("control-config-strategy") == "custom"
527
+ and workflow_args.get("control-config") is not None
528
+ )
529
+
530
+
531
+ async def get_file_names(output_path: str, typename: str) -> List[str]:
532
+ """
533
+ Get file names for a specific asset type from the transformed directory.
534
+
535
+ Args:
536
+ output_path (str): The base output path
537
+ typename (str): The asset type (e.g., 'table', 'schema', 'column')
538
+
539
+ Returns:
540
+ List[str]: List of relative file paths for the asset type
541
+ """
542
+
543
+ source = get_object_store_prefix(os.path.join(output_path, typename))
544
+ await ObjectStore.download_prefix(source, TEMPORARY_PATH)
545
+
546
+ file_pattern = os.path.join(output_path, typename, "*.json")
547
+ file_names = glob.glob(file_pattern)
548
+ file_name_list = [
549
+ "/".join(file_name.rsplit("/", 2)[-2:]) for file_name in file_names
550
+ ]
551
+
552
+ return file_name_list
553
+
554
+
419
555
  def run_sync(func):
420
556
  """Run a function in a thread pool executor.
421
557
 
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any
2
+ from typing import Any, Dict
3
3
 
4
4
 
5
5
  class HandlerInterface(ABC):
@@ -37,3 +37,10 @@ class HandlerInterface(ABC):
37
37
  To be implemented by the subclass
38
38
  """
39
39
  raise NotImplementedError("fetch_metadata method not implemented")
40
+
41
+ @staticmethod
42
+ async def get_configmap(config_map_id: str) -> Dict[str, Any]:
43
+ """
44
+ Static method to get the configmap
45
+ """
46
+ return {}