atlan-application-sdk 0.1.1rc40__py3-none-any.whl → 0.1.1rc41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- application_sdk/activities/metadata_extraction/sql.py +400 -25
- application_sdk/application/__init__.py +2 -0
- application_sdk/application/metadata_extraction/sql.py +3 -0
- application_sdk/clients/models.py +42 -0
- application_sdk/clients/sql.py +17 -13
- application_sdk/common/aws_utils.py +259 -11
- application_sdk/common/utils.py +145 -9
- application_sdk/handlers/__init__.py +8 -1
- application_sdk/handlers/sql.py +63 -22
- application_sdk/observability/decorators/observability_decorator.py +36 -22
- application_sdk/server/fastapi/__init__.py +59 -3
- application_sdk/server/fastapi/models.py +27 -0
- application_sdk/version.py +1 -1
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/METADATA +1 -1
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/RECORD +18 -17
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/WHEEL +0 -0
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/licenses/LICENSE +0 -0
- {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc41.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,5 +1,20 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
Any,
|
|
5
|
+
AsyncGenerator,
|
|
6
|
+
AsyncIterator,
|
|
7
|
+
Dict,
|
|
8
|
+
Generator,
|
|
9
|
+
Iterator,
|
|
10
|
+
List,
|
|
11
|
+
Optional,
|
|
12
|
+
Tuple,
|
|
13
|
+
Type,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
overload,
|
|
17
|
+
)
|
|
3
18
|
|
|
4
19
|
from temporalio import activity
|
|
5
20
|
|
|
@@ -13,7 +28,12 @@ from application_sdk.activities.common.utils import (
|
|
|
13
28
|
from application_sdk.clients.sql import BaseSQLClient
|
|
14
29
|
from application_sdk.common.dataframe_utils import is_empty_dataframe
|
|
15
30
|
from application_sdk.common.error_codes import ActivityError
|
|
16
|
-
from application_sdk.common.utils import
|
|
31
|
+
from application_sdk.common.utils import (
|
|
32
|
+
get_database_names,
|
|
33
|
+
parse_credentials_extra,
|
|
34
|
+
prepare_query,
|
|
35
|
+
read_sql_files,
|
|
36
|
+
)
|
|
17
37
|
from application_sdk.constants import APP_TENANT_ID, APPLICATION_NAME, SQL_QUERIES_PATH
|
|
18
38
|
from application_sdk.handlers.sql import BaseSQLHandler
|
|
19
39
|
from application_sdk.inputs.parquet import ParquetInput
|
|
@@ -31,6 +51,9 @@ activity.logger = logger
|
|
|
31
51
|
|
|
32
52
|
queries = read_sql_files(queries_prefix=SQL_QUERIES_PATH)
|
|
33
53
|
|
|
54
|
+
if TYPE_CHECKING:
|
|
55
|
+
import pandas as pd
|
|
56
|
+
|
|
34
57
|
|
|
35
58
|
class BaseSQLMetadataExtractionActivitiesState(ActivitiesState):
|
|
36
59
|
"""State class for SQL metadata extraction activities.
|
|
@@ -90,6 +113,7 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
90
113
|
sql_client_class: Optional[Type[BaseSQLClient]] = None,
|
|
91
114
|
handler_class: Optional[Type[BaseSQLHandler]] = None,
|
|
92
115
|
transformer_class: Optional[Type[TransformerInterface]] = None,
|
|
116
|
+
multidb: bool = False,
|
|
93
117
|
):
|
|
94
118
|
"""Initialize the SQL metadata extraction activities.
|
|
95
119
|
|
|
@@ -100,6 +124,8 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
100
124
|
Defaults to BaseSQLHandler.
|
|
101
125
|
transformer_class (Type[TransformerInterface], optional): Class for metadata transformation.
|
|
102
126
|
Defaults to QueryBasedTransformer.
|
|
127
|
+
multidb (bool): When True, executes queries across multiple databases using
|
|
128
|
+
`multidb_query_executor`. Defaults to False.
|
|
103
129
|
"""
|
|
104
130
|
if sql_client_class:
|
|
105
131
|
self.sql_client_class = sql_client_class
|
|
@@ -108,6 +134,9 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
108
134
|
if transformer_class:
|
|
109
135
|
self.transformer_class = transformer_class
|
|
110
136
|
|
|
137
|
+
# Control whether to execute per-db using multidb executor
|
|
138
|
+
self.multidb = multidb
|
|
139
|
+
|
|
111
140
|
super().__init__()
|
|
112
141
|
|
|
113
142
|
# State methods
|
|
@@ -206,6 +235,7 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
206
235
|
raise ValueError("Missing required workflow arguments")
|
|
207
236
|
return output_prefix, output_path, typename, workflow_id, workflow_run_id
|
|
208
237
|
|
|
238
|
+
@overload
|
|
209
239
|
async def query_executor(
|
|
210
240
|
self,
|
|
211
241
|
sql_engine: Any,
|
|
@@ -213,7 +243,38 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
213
243
|
workflow_args: Dict[str, Any],
|
|
214
244
|
output_suffix: str,
|
|
215
245
|
typename: str,
|
|
216
|
-
|
|
246
|
+
write_to_file: bool = True,
|
|
247
|
+
concatenate: bool = False,
|
|
248
|
+
return_dataframe: bool = False,
|
|
249
|
+
sql_client: Optional[BaseSQLClient] = None,
|
|
250
|
+
) -> Optional[ActivityStatistics]: ...
|
|
251
|
+
|
|
252
|
+
@overload
|
|
253
|
+
async def query_executor(
|
|
254
|
+
self,
|
|
255
|
+
sql_engine: Any,
|
|
256
|
+
sql_query: Optional[str],
|
|
257
|
+
workflow_args: Dict[str, Any],
|
|
258
|
+
output_suffix: str,
|
|
259
|
+
typename: str,
|
|
260
|
+
write_to_file: bool = True,
|
|
261
|
+
concatenate: bool = False,
|
|
262
|
+
return_dataframe: bool = True,
|
|
263
|
+
sql_client: Optional[BaseSQLClient] = None,
|
|
264
|
+
) -> Optional[Union[ActivityStatistics, "pd.DataFrame"]]: ...
|
|
265
|
+
|
|
266
|
+
async def query_executor(
|
|
267
|
+
self,
|
|
268
|
+
sql_engine: Any,
|
|
269
|
+
sql_query: Optional[str],
|
|
270
|
+
workflow_args: Dict[str, Any],
|
|
271
|
+
output_suffix: str,
|
|
272
|
+
typename: str,
|
|
273
|
+
write_to_file: bool = True,
|
|
274
|
+
concatenate: bool = False,
|
|
275
|
+
return_dataframe: bool = False,
|
|
276
|
+
sql_client: Optional[BaseSQLClient] = None,
|
|
277
|
+
) -> Optional[Union[ActivityStatistics, "pd.DataFrame"]]:
|
|
217
278
|
"""
|
|
218
279
|
Executes a SQL query using the provided engine and saves the results to Parquet.
|
|
219
280
|
|
|
@@ -233,44 +294,358 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
233
294
|
typename: Type name used for generating output statistics.
|
|
234
295
|
|
|
235
296
|
Returns:
|
|
236
|
-
Optional[ActivityStatistics]: Statistics about the generated Parquet file,
|
|
237
|
-
or None if the query is empty or execution fails
|
|
297
|
+
Optional[Union[ActivityStatistics, pd.DataFrame]]: Statistics about the generated Parquet file,
|
|
298
|
+
or a DataFrame if return_dataframe=True, or None if the query is empty or execution fails.
|
|
238
299
|
|
|
239
300
|
Raises:
|
|
240
301
|
ValueError: If `sql_engine` is not provided.
|
|
241
302
|
"""
|
|
303
|
+
# Common pre-checks and setup shared by both multidb and single-db paths
|
|
304
|
+
if not sql_query:
|
|
305
|
+
logger.warning("Query is empty, skipping execution.")
|
|
306
|
+
return None
|
|
307
|
+
|
|
242
308
|
if not sql_engine:
|
|
243
309
|
logger.error("SQL engine is not set.")
|
|
244
310
|
raise ValueError("SQL engine must be provided.")
|
|
245
|
-
|
|
246
|
-
|
|
311
|
+
|
|
312
|
+
# Setup parquet output using helper method
|
|
313
|
+
parquet_output = self._setup_parquet_output(
|
|
314
|
+
workflow_args, output_suffix, write_to_file
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# If multidb mode is enabled, run per-database flow
|
|
318
|
+
if getattr(self, "multidb", False):
|
|
319
|
+
return await self._execute_multidb_flow(
|
|
320
|
+
sql_client,
|
|
321
|
+
sql_query,
|
|
322
|
+
workflow_args,
|
|
323
|
+
output_suffix,
|
|
324
|
+
typename,
|
|
325
|
+
write_to_file,
|
|
326
|
+
concatenate,
|
|
327
|
+
return_dataframe,
|
|
328
|
+
parquet_output,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Single-db execution path
|
|
332
|
+
# Prepare query for single-db execution
|
|
333
|
+
prepared_query = self._prepare_database_query(
|
|
334
|
+
sql_query, None, workflow_args, typename
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Execute using helper method
|
|
338
|
+
success, _ = await self._execute_single_db(
|
|
339
|
+
sql_engine, prepared_query, parquet_output, write_to_file
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
if not success:
|
|
343
|
+
logger.error("Failed to execute single-db query")
|
|
247
344
|
return None
|
|
248
345
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
346
|
+
if parquet_output:
|
|
347
|
+
logger.info(
|
|
348
|
+
f"Successfully wrote query results to {parquet_output.get_full_path()}"
|
|
349
|
+
)
|
|
350
|
+
return await parquet_output.get_statistics(typename=typename)
|
|
351
|
+
|
|
352
|
+
logger.warning("No parquet output configured for single-db execution")
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
def _setup_parquet_output(
|
|
356
|
+
self,
|
|
357
|
+
workflow_args: Dict[str, Any],
|
|
358
|
+
output_suffix: str,
|
|
359
|
+
write_to_file: bool,
|
|
360
|
+
) -> Optional[ParquetOutput]:
|
|
361
|
+
if not write_to_file:
|
|
362
|
+
return None
|
|
363
|
+
output_prefix = workflow_args.get("output_prefix")
|
|
364
|
+
output_path = workflow_args.get("output_path")
|
|
365
|
+
if not output_prefix or not output_path:
|
|
366
|
+
logger.error("Output prefix or path not provided in workflow_args.")
|
|
367
|
+
raise ValueError(
|
|
368
|
+
"Output prefix and path must be specified in workflow_args."
|
|
369
|
+
)
|
|
370
|
+
return ParquetOutput(
|
|
371
|
+
output_prefix=output_prefix,
|
|
372
|
+
output_path=output_path,
|
|
373
|
+
output_suffix=output_suffix,
|
|
374
|
+
)
|
|
252
375
|
|
|
253
|
-
|
|
254
|
-
|
|
376
|
+
def _get_temp_table_regex_sql(self, typename: str) -> str:
|
|
377
|
+
"""Get the appropriate temp table regex SQL based on typename."""
|
|
378
|
+
if typename == "column":
|
|
379
|
+
return self.extract_temp_table_regex_column_sql or ""
|
|
380
|
+
elif typename == "table":
|
|
381
|
+
return self.extract_temp_table_regex_table_sql or ""
|
|
382
|
+
else:
|
|
383
|
+
return ""
|
|
255
384
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
385
|
+
def _prepare_database_query(
|
|
386
|
+
self,
|
|
387
|
+
sql_query: str,
|
|
388
|
+
database_name: Optional[str],
|
|
389
|
+
workflow_args: Dict[str, Any],
|
|
390
|
+
typename: str,
|
|
391
|
+
use_posix_regex: bool = False,
|
|
392
|
+
) -> Optional[str]:
|
|
393
|
+
"""Prepare query for database execution with proper substitutions."""
|
|
394
|
+
# Replace database name placeholder if provided
|
|
395
|
+
fetch_sql = sql_query
|
|
396
|
+
if database_name:
|
|
397
|
+
fetch_sql = fetch_sql.replace("{database_name}", database_name)
|
|
398
|
+
|
|
399
|
+
# Get temp table regex SQL
|
|
400
|
+
temp_table_regex_sql = self._get_temp_table_regex_sql(typename)
|
|
401
|
+
|
|
402
|
+
# Prepare the query
|
|
403
|
+
prepared_query = prepare_query(
|
|
404
|
+
query=fetch_sql,
|
|
405
|
+
workflow_args=workflow_args,
|
|
406
|
+
temp_table_regex_sql=temp_table_regex_sql,
|
|
407
|
+
use_posix_regex=use_posix_regex,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
if prepared_query is None:
|
|
411
|
+
db_context = f" for database {database_name}" if database_name else ""
|
|
412
|
+
raise ValueError(f"Failed to prepare query{db_context}")
|
|
413
|
+
|
|
414
|
+
return prepared_query
|
|
415
|
+
|
|
416
|
+
async def _setup_database_connection(
|
|
417
|
+
self,
|
|
418
|
+
sql_client: BaseSQLClient,
|
|
419
|
+
database_name: str,
|
|
420
|
+
) -> None:
|
|
421
|
+
"""Setup connection for a specific database."""
|
|
422
|
+
extra = parse_credentials_extra(sql_client.credentials)
|
|
423
|
+
extra["database"] = database_name
|
|
424
|
+
sql_client.credentials["extra"] = extra
|
|
425
|
+
await sql_client.load(sql_client.credentials)
|
|
426
|
+
|
|
427
|
+
# NOTE: Consolidated: per-database processing is now inlined in the multi-DB loop
|
|
428
|
+
|
|
429
|
+
async def _finalize_multidb_results(
|
|
430
|
+
self,
|
|
431
|
+
write_to_file: bool,
|
|
432
|
+
concatenate: bool,
|
|
433
|
+
return_dataframe: bool,
|
|
434
|
+
parquet_output: Optional[ParquetOutput],
|
|
435
|
+
dataframe_list: List[
|
|
436
|
+
Union[AsyncIterator["pd.DataFrame"], Iterator["pd.DataFrame"]]
|
|
437
|
+
],
|
|
438
|
+
workflow_args: Dict[str, Any],
|
|
439
|
+
output_suffix: str,
|
|
440
|
+
typename: str,
|
|
441
|
+
) -> Optional[Union[ActivityStatistics, "pd.DataFrame"]]:
|
|
442
|
+
"""Finalize results for multi-database execution."""
|
|
443
|
+
if write_to_file and parquet_output:
|
|
444
|
+
return await parquet_output.get_statistics(typename=typename)
|
|
445
|
+
|
|
446
|
+
if not write_to_file and concatenate:
|
|
447
|
+
try:
|
|
448
|
+
import pandas as pd # type: ignore
|
|
449
|
+
|
|
450
|
+
valid_dataframes: List[pd.DataFrame] = []
|
|
451
|
+
for df_generator in dataframe_list:
|
|
452
|
+
if df_generator is None:
|
|
453
|
+
continue
|
|
454
|
+
for dataframe in df_generator: # type: ignore[assignment]
|
|
455
|
+
if dataframe is None:
|
|
456
|
+
continue
|
|
457
|
+
if hasattr(dataframe, "empty") and getattr(dataframe, "empty"):
|
|
458
|
+
continue
|
|
459
|
+
valid_dataframes.append(dataframe)
|
|
460
|
+
|
|
461
|
+
if not valid_dataframes:
|
|
462
|
+
logger.warning(
|
|
463
|
+
"No valid dataframes collected across databases for concatenation"
|
|
464
|
+
)
|
|
465
|
+
return None
|
|
466
|
+
|
|
467
|
+
concatenated = pd.concat(valid_dataframes, ignore_index=True)
|
|
468
|
+
|
|
469
|
+
if return_dataframe:
|
|
470
|
+
return concatenated # type: ignore[return-value]
|
|
471
|
+
|
|
472
|
+
# Create new parquet output for concatenated data
|
|
473
|
+
concatenated_parquet_output = self._setup_parquet_output(
|
|
474
|
+
workflow_args, output_suffix, True
|
|
260
475
|
)
|
|
476
|
+
if concatenated_parquet_output:
|
|
477
|
+
await concatenated_parquet_output.write_dataframe(concatenated) # type: ignore[arg-type]
|
|
478
|
+
return await concatenated_parquet_output.get_statistics(
|
|
479
|
+
typename=typename
|
|
480
|
+
)
|
|
481
|
+
except Exception as e: # noqa: BLE001
|
|
482
|
+
logger.error(
|
|
483
|
+
f"Error concatenating multi-DB dataframes: {str(e)}",
|
|
484
|
+
exc_info=True,
|
|
485
|
+
)
|
|
486
|
+
raise
|
|
487
|
+
|
|
488
|
+
logger.warning(
|
|
489
|
+
"multidb execution returned no output to write (write_to_file=False, concatenate=False)"
|
|
490
|
+
)
|
|
491
|
+
return None
|
|
261
492
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
493
|
+
async def _execute_multidb_flow(
|
|
494
|
+
self,
|
|
495
|
+
sql_client: Optional[BaseSQLClient],
|
|
496
|
+
sql_query: str,
|
|
497
|
+
workflow_args: Dict[str, Any],
|
|
498
|
+
output_suffix: str,
|
|
499
|
+
typename: str,
|
|
500
|
+
write_to_file: bool,
|
|
501
|
+
concatenate: bool,
|
|
502
|
+
return_dataframe: bool,
|
|
503
|
+
parquet_output: Optional[ParquetOutput],
|
|
504
|
+
) -> Optional[Union[ActivityStatistics, "pd.DataFrame"]]:
|
|
505
|
+
"""Execute multi-database flow with proper error handling and result finalization."""
|
|
506
|
+
# Get effective SQL client
|
|
507
|
+
effective_sql_client = sql_client
|
|
508
|
+
if effective_sql_client is None:
|
|
509
|
+
state = cast(
|
|
510
|
+
BaseSQLMetadataExtractionActivitiesState,
|
|
511
|
+
await self._get_state(workflow_args),
|
|
266
512
|
)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
513
|
+
effective_sql_client = state.sql_client
|
|
514
|
+
|
|
515
|
+
if not effective_sql_client:
|
|
516
|
+
logger.error("SQL client not initialized for multidb execution")
|
|
517
|
+
raise ValueError("SQL client not initialized")
|
|
518
|
+
|
|
519
|
+
# Resolve databases to iterate
|
|
520
|
+
database_names = await get_database_names(
|
|
521
|
+
effective_sql_client, workflow_args, self.fetch_database_sql
|
|
522
|
+
)
|
|
523
|
+
if not database_names:
|
|
524
|
+
logger.warning("No databases found to process")
|
|
525
|
+
return None
|
|
526
|
+
|
|
527
|
+
# Validate client
|
|
528
|
+
if not effective_sql_client.engine:
|
|
529
|
+
logger.error("SQL client engine not initialized")
|
|
530
|
+
raise ValueError("SQL client engine not initialized")
|
|
531
|
+
|
|
532
|
+
successful_databases: List[str] = []
|
|
533
|
+
failed_databases: List[str] = []
|
|
534
|
+
dataframe_list: List[
|
|
535
|
+
Union[AsyncIterator["pd.DataFrame"], Iterator["pd.DataFrame"]]
|
|
536
|
+
] = []
|
|
537
|
+
|
|
538
|
+
# Iterate databases and execute (consolidated single-db processing)
|
|
539
|
+
for database_name in database_names or []:
|
|
540
|
+
try:
|
|
541
|
+
# Setup connection for this database
|
|
542
|
+
await self._setup_database_connection(
|
|
543
|
+
effective_sql_client, database_name
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Prepare query for this database
|
|
547
|
+
prepared_query = self._prepare_database_query(
|
|
548
|
+
sql_query,
|
|
549
|
+
database_name,
|
|
550
|
+
workflow_args,
|
|
551
|
+
typename,
|
|
552
|
+
use_posix_regex=True,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# Execute using helper method
|
|
556
|
+
success, batched_iter = await self._execute_single_db(
|
|
557
|
+
effective_sql_client.engine,
|
|
558
|
+
prepared_query,
|
|
559
|
+
parquet_output,
|
|
560
|
+
write_to_file,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
if success:
|
|
564
|
+
logger.info(f"Successfully processed database: {database_name}")
|
|
565
|
+
else:
|
|
566
|
+
logger.warning(
|
|
567
|
+
f"Failed to execute query for database: {database_name}"
|
|
568
|
+
)
|
|
569
|
+
except Exception as e: # noqa: BLE001
|
|
570
|
+
logger.warning(
|
|
571
|
+
f"Failed to process database '{database_name}': {str(e)}. Skipping to next database."
|
|
572
|
+
)
|
|
573
|
+
success, batched_iter = False, None
|
|
574
|
+
|
|
575
|
+
if success:
|
|
576
|
+
successful_databases.append(database_name)
|
|
577
|
+
if not write_to_file and batched_iter:
|
|
578
|
+
dataframe_list.append(batched_iter)
|
|
579
|
+
else:
|
|
580
|
+
failed_databases.append(database_name)
|
|
581
|
+
|
|
582
|
+
# Log results
|
|
583
|
+
logger.info(
|
|
584
|
+
f"Successfully processed {len(successful_databases)} databases: {successful_databases}"
|
|
585
|
+
)
|
|
586
|
+
if failed_databases:
|
|
587
|
+
logger.warning(
|
|
588
|
+
f"Failed to process {len(failed_databases)} databases: {failed_databases}"
|
|
270
589
|
)
|
|
271
590
|
|
|
272
|
-
|
|
273
|
-
|
|
591
|
+
# Finalize results
|
|
592
|
+
return await self._finalize_multidb_results(
|
|
593
|
+
write_to_file,
|
|
594
|
+
concatenate,
|
|
595
|
+
return_dataframe,
|
|
596
|
+
parquet_output,
|
|
597
|
+
dataframe_list,
|
|
598
|
+
workflow_args,
|
|
599
|
+
output_suffix,
|
|
600
|
+
typename,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
async def _execute_single_db(
|
|
604
|
+
self,
|
|
605
|
+
sql_engine: Any,
|
|
606
|
+
prepared_query: Optional[str],
|
|
607
|
+
parquet_output: Optional[ParquetOutput],
|
|
608
|
+
write_to_file: bool,
|
|
609
|
+
) -> Tuple[
|
|
610
|
+
bool, Optional[Union[AsyncIterator["pd.DataFrame"], Iterator["pd.DataFrame"]]]
|
|
611
|
+
]: # type: ignore
|
|
612
|
+
if not prepared_query:
|
|
613
|
+
logger.error("Prepared query is None, cannot execute")
|
|
614
|
+
return False, None
|
|
615
|
+
|
|
616
|
+
try:
|
|
617
|
+
sql_input = SQLQueryInput(engine=sql_engine, query=prepared_query)
|
|
618
|
+
batched_iter = await sql_input.get_batched_dataframe()
|
|
619
|
+
|
|
620
|
+
if write_to_file and parquet_output:
|
|
621
|
+
# Wrap iterator into a proper (async)generator for type safety
|
|
622
|
+
if hasattr(batched_iter, "__anext__"):
|
|
623
|
+
|
|
624
|
+
async def _to_async_gen(
|
|
625
|
+
it: AsyncIterator["pd.DataFrame"],
|
|
626
|
+
) -> AsyncGenerator["pd.DataFrame", None]:
|
|
627
|
+
async for item in it:
|
|
628
|
+
yield item
|
|
629
|
+
|
|
630
|
+
wrapped: AsyncGenerator["pd.DataFrame", None] = _to_async_gen( # type: ignore
|
|
631
|
+
batched_iter # type: ignore
|
|
632
|
+
)
|
|
633
|
+
await parquet_output.write_batched_dataframe(wrapped)
|
|
634
|
+
else:
|
|
635
|
+
|
|
636
|
+
def _to_gen(
|
|
637
|
+
it: Iterator["pd.DataFrame"],
|
|
638
|
+
) -> Generator["pd.DataFrame", None, None]:
|
|
639
|
+
for item in it:
|
|
640
|
+
yield item
|
|
641
|
+
|
|
642
|
+
wrapped_sync: Generator["pd.DataFrame", None, None] = _to_gen( # type: ignore
|
|
643
|
+
batched_iter # type: ignore
|
|
644
|
+
)
|
|
645
|
+
await parquet_output.write_batched_dataframe(wrapped_sync)
|
|
646
|
+
return True, None
|
|
647
|
+
|
|
648
|
+
return True, batched_iter
|
|
274
649
|
except Exception as e:
|
|
275
650
|
logger.error(
|
|
276
651
|
f"Error during query execution or output writing: {e}", exc_info=True
|
|
@@ -164,6 +164,7 @@ class BaseApplication:
|
|
|
164
164
|
self,
|
|
165
165
|
workflow_class,
|
|
166
166
|
ui_enabled: bool = True,
|
|
167
|
+
has_configmap: bool = False,
|
|
167
168
|
):
|
|
168
169
|
"""
|
|
169
170
|
Optionally set up a server for the application. (No-op by default)
|
|
@@ -176,6 +177,7 @@ class BaseApplication:
|
|
|
176
177
|
workflow_client=self.workflow_client,
|
|
177
178
|
ui_enabled=ui_enabled,
|
|
178
179
|
handler=self.handler_class(client=self.client_class()),
|
|
180
|
+
has_configmap=has_configmap,
|
|
179
181
|
)
|
|
180
182
|
|
|
181
183
|
if self.event_subscriptions:
|
|
@@ -161,12 +161,14 @@ class BaseSQLMetadataExtractionApplication(BaseApplication):
|
|
|
161
161
|
workflow_class: Type[
|
|
162
162
|
BaseSQLMetadataExtractionWorkflow
|
|
163
163
|
] = BaseSQLMetadataExtractionWorkflow,
|
|
164
|
+
has_configmap: bool = False,
|
|
164
165
|
) -> Any:
|
|
165
166
|
"""
|
|
166
167
|
Set up the FastAPI server for the SQL metadata extraction application.
|
|
167
168
|
|
|
168
169
|
Args:
|
|
169
170
|
workflow_class (Type): Workflow class to register with the server. Defaults to BaseSQLMetadataExtractionWorkflow.
|
|
171
|
+
has_configmap (bool): Whether the application has a configmap. Defaults to False.
|
|
170
172
|
|
|
171
173
|
Returns:
|
|
172
174
|
Any: None
|
|
@@ -178,6 +180,7 @@ class BaseSQLMetadataExtractionApplication(BaseApplication):
|
|
|
178
180
|
self.server = APIServer(
|
|
179
181
|
handler=self.handler_class(sql_client=self.client_class()),
|
|
180
182
|
workflow_client=self.workflow_client,
|
|
183
|
+
has_configmap=has_configmap,
|
|
181
184
|
)
|
|
182
185
|
|
|
183
186
|
# register the workflow on the application server
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pydantic models for database client configurations.
|
|
3
|
+
This module provides Pydantic models for database connection configurations,
|
|
4
|
+
ensuring type safety and validation for database client settings.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DatabaseConfig(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Pydantic model for database connection configuration.
|
|
15
|
+
This model defines the structure for database connection configurations,
|
|
16
|
+
including connection templates, required parameters, defaults, and additional
|
|
17
|
+
connection parameters.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
template: str = Field(
|
|
21
|
+
...,
|
|
22
|
+
description="SQLAlchemy connection string template with placeholders for connection parameters",
|
|
23
|
+
)
|
|
24
|
+
required: List[str] = Field(
|
|
25
|
+
default=[],
|
|
26
|
+
description="List of required connection parameters that must be provided",
|
|
27
|
+
)
|
|
28
|
+
defaults: Optional[Dict[str, Any]] = Field(
|
|
29
|
+
default=None,
|
|
30
|
+
description="Default connection parameters to be added to the connection string",
|
|
31
|
+
)
|
|
32
|
+
parameters: Optional[List[str]] = Field(
|
|
33
|
+
default=None,
|
|
34
|
+
description="List of additional connection parameter names that can be dynamically added from credentials",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
class Config:
|
|
38
|
+
"""Pydantic configuration for the DatabaseConfig model."""
|
|
39
|
+
|
|
40
|
+
extra = "forbid" # Prevent additional fields
|
|
41
|
+
validate_assignment = True # Validate on assignment
|
|
42
|
+
use_enum_values = True # Use enum values instead of enum objects
|
application_sdk/clients/sql.py
CHANGED
|
@@ -7,13 +7,14 @@ database operations, supporting batch processing and server-side cursors.
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
-
from typing import Any, Dict, List
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
11
|
from urllib.parse import quote_plus
|
|
12
12
|
|
|
13
13
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
|
14
14
|
from temporalio import activity
|
|
15
15
|
|
|
16
16
|
from application_sdk.clients import ClientInterface
|
|
17
|
+
from application_sdk.clients.models import DatabaseConfig
|
|
17
18
|
from application_sdk.common.aws_utils import (
|
|
18
19
|
generate_aws_rds_token_with_iam_role,
|
|
19
20
|
generate_aws_rds_token_with_iam_user,
|
|
@@ -48,7 +49,7 @@ class BaseSQLClient(ClientInterface):
|
|
|
48
49
|
credentials: Dict[str, Any] = {}
|
|
49
50
|
resolved_credentials: Dict[str, Any] = {}
|
|
50
51
|
use_server_side_cursor: bool = USE_SERVER_SIDE_CURSOR
|
|
51
|
-
DB_CONFIG:
|
|
52
|
+
DB_CONFIG: Optional[DatabaseConfig] = None
|
|
52
53
|
|
|
53
54
|
def __init__(
|
|
54
55
|
self,
|
|
@@ -262,7 +263,9 @@ class BaseSQLClient(ClientInterface):
|
|
|
262
263
|
Returns:
|
|
263
264
|
str: The updated URL with the dialect.
|
|
264
265
|
"""
|
|
265
|
-
|
|
266
|
+
if not self.DB_CONFIG:
|
|
267
|
+
raise ValueError("DB_CONFIG is not configured for this SQL client.")
|
|
268
|
+
installed_dialect = self.DB_CONFIG.template.split("://")[0]
|
|
266
269
|
url_dialect = sqlalchemy_url.split("://")[0]
|
|
267
270
|
if installed_dialect != url_dialect:
|
|
268
271
|
sqlalchemy_url = sqlalchemy_url.replace(url_dialect, installed_dialect)
|
|
@@ -281,6 +284,9 @@ class BaseSQLClient(ClientInterface):
|
|
|
281
284
|
Raises:
|
|
282
285
|
ValueError: If required connection parameters are missing.
|
|
283
286
|
"""
|
|
287
|
+
if not self.DB_CONFIG:
|
|
288
|
+
raise ValueError("DB_CONFIG is not configured for this SQL client.")
|
|
289
|
+
|
|
284
290
|
extra = parse_credentials_extra(self.credentials)
|
|
285
291
|
|
|
286
292
|
# TODO: Uncomment this when the native deployment is ready
|
|
@@ -293,7 +299,7 @@ class BaseSQLClient(ClientInterface):
|
|
|
293
299
|
|
|
294
300
|
# Prepare parameters
|
|
295
301
|
param_values = {}
|
|
296
|
-
for param in self.DB_CONFIG
|
|
302
|
+
for param in self.DB_CONFIG.required:
|
|
297
303
|
if param == "password":
|
|
298
304
|
param_values[param] = auth_token
|
|
299
305
|
else:
|
|
@@ -303,21 +309,19 @@ class BaseSQLClient(ClientInterface):
|
|
|
303
309
|
param_values[param] = value
|
|
304
310
|
|
|
305
311
|
# Fill in base template
|
|
306
|
-
conn_str = self.DB_CONFIG
|
|
312
|
+
conn_str = self.DB_CONFIG.template.format(**param_values)
|
|
307
313
|
|
|
308
314
|
# Append defaults if not already in the template
|
|
309
|
-
if self.DB_CONFIG.
|
|
310
|
-
conn_str = self.add_connection_params(conn_str, self.DB_CONFIG
|
|
315
|
+
if self.DB_CONFIG.defaults:
|
|
316
|
+
conn_str = self.add_connection_params(conn_str, self.DB_CONFIG.defaults)
|
|
311
317
|
|
|
312
|
-
if self.DB_CONFIG.
|
|
313
|
-
parameter_keys = self.DB_CONFIG
|
|
314
|
-
|
|
318
|
+
if self.DB_CONFIG.parameters:
|
|
319
|
+
parameter_keys = self.DB_CONFIG.parameters
|
|
320
|
+
parameter_values = {
|
|
315
321
|
key: self.credentials.get(key) or extra.get(key)
|
|
316
322
|
for key in parameter_keys
|
|
317
323
|
}
|
|
318
|
-
conn_str = self.add_connection_params(
|
|
319
|
-
conn_str, self.DB_CONFIG["parameters"]
|
|
320
|
-
)
|
|
324
|
+
conn_str = self.add_connection_params(conn_str, parameter_values)
|
|
321
325
|
|
|
322
326
|
return conn_str
|
|
323
327
|
|