rgwfuncs 0.0.109__py3-none-any.whl → 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rgwfuncs/__init__.py +1 -1
- rgwfuncs/df_lib.py +318 -137
- {rgwfuncs-0.0.109.dist-info → rgwfuncs-0.0.111.dist-info}/METADATA +1 -1
- rgwfuncs-0.0.111.dist-info/RECORD +11 -0
- rgwfuncs-0.0.109.dist-info/RECORD +0 -11
- {rgwfuncs-0.0.109.dist-info → rgwfuncs-0.0.111.dist-info}/WHEEL +0 -0
- {rgwfuncs-0.0.109.dist-info → rgwfuncs-0.0.111.dist-info}/entry_points.txt +0 -0
- {rgwfuncs-0.0.109.dist-info → rgwfuncs-0.0.111.dist-info}/licenses/LICENSE +0 -0
- {rgwfuncs-0.0.109.dist-info → rgwfuncs-0.0.111.dist-info}/top_level.txt +0 -0
rgwfuncs/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# This file is automatically generated
|
2
2
|
# Dynamically importing functions from modules
|
3
3
|
|
4
|
-
from .df_lib import append_columns, append_percentile_classification_column, append_ranged_classification_column, append_ranged_date_classification_column, append_rows, append_xgb_labels, append_xgb_logistic_regression_predictions, append_xgb_regression_predictions, bag_union_join, bottom_n_unique_values, cascade_sort, delete_rows, drop_duplicates, drop_duplicates_retain_first, drop_duplicates_retain_last, filter_dataframe, filter_indian_mobiles, first_n_rows, from_raw_data, insert_dataframe_in_sqlite_database, last_n_rows, left_join, limit_dataframe, load_data_from_path, load_data_from_query, load_data_from_sqlite_path, load_fresh_data_or_pull_from_cache, mask_against_dataframe, mask_against_dataframe_converse, numeric_clean, order_columns, print_correlation, print_dataframe, print_memory_usage, print_n_frequency_cascading, print_n_frequency_linear, rename_columns, retain_columns, right_join, send_data_to_email, send_data_to_slack, send_dataframe_via_telegram, sync_dataframe_to_sqlite_database, top_n_unique_values, union_join, update_rows
|
4
|
+
from .df_lib import append_columns, append_percentile_classification_column, append_ranged_classification_column, append_ranged_date_classification_column, append_rows, append_xgb_labels, append_xgb_logistic_regression_predictions, append_xgb_regression_predictions, bag_union_join, bottom_n_unique_values, cascade_sort, delete_rows, drop_duplicates, drop_duplicates_retain_first, drop_duplicates_retain_last, filter_dataframe, filter_indian_mobiles, first_n_rows, from_raw_data, insert_dataframe_in_sqlite_database, last_n_rows, left_join, limit_dataframe, load_data_from_aws_athena_query, load_data_from_big_query, load_data_from_path, load_data_from_query, load_data_from_sqlite_path, load_fresh_data_or_pull_from_cache, mask_against_dataframe, mask_against_dataframe_converse, numeric_clean, order_columns, print_correlation, print_dataframe, print_memory_usage, print_n_frequency_cascading, print_n_frequency_linear, rename_columns, retain_columns, right_join, send_data_to_email, send_data_to_slack, send_dataframe_via_telegram, sync_dataframe_to_sqlite_database, top_n_unique_values, union_join, update_rows
|
5
5
|
from .interactive_shell_lib import interactive_shell
|
6
6
|
from .docs_lib import docs
|
7
7
|
from .str_lib import heading, send_telegram_message, sub_heading, title
|
rgwfuncs/df_lib.py
CHANGED
@@ -310,98 +310,150 @@ def drop_duplicates_retain_last(
|
|
310
310
|
return df.drop_duplicates(subset=columns_list, keep='last')
|
311
311
|
|
312
312
|
|
313
|
-
def load_data_from_query(
|
313
|
+
def load_data_from_query(
|
314
|
+
query: str,
|
315
|
+
preset: Optional[str] = None,
|
316
|
+
db_type: Optional[str] = None,
|
317
|
+
host: Optional[str] = None,
|
318
|
+
user: Optional[str] = None,
|
319
|
+
password: Optional[str] = None,
|
320
|
+
database: Optional[str] = None
|
321
|
+
) -> pd.DataFrame:
|
314
322
|
"""
|
315
|
-
Load data from a database query into a DataFrame
|
323
|
+
Load data from a database query into a DataFrame using either a preset or provided credentials.
|
316
324
|
|
317
325
|
Parameters:
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
326
|
+
query (str): The SQL query to execute.
|
327
|
+
preset (Optional[str]): The name of the database preset in the .rgwfuncsrc file.
|
328
|
+
db_type (Optional[str]): The database type ('mssql', 'mysql', or 'clickhouse').
|
329
|
+
host (Optional[str]): Database host.
|
330
|
+
user (Optional[str]): Database username.
|
331
|
+
password (Optional[str]): Database password.
|
332
|
+
database (Optional[str]): Database name.
|
324
333
|
|
325
334
|
Returns:
|
326
|
-
|
335
|
+
pd.DataFrame: DataFrame containing the query result.
|
327
336
|
|
328
337
|
Raises:
|
329
|
-
FileNotFoundError: If no '.rgwfuncsrc' file is found
|
330
|
-
ValueError: If
|
338
|
+
FileNotFoundError: If no '.rgwfuncsrc' file is found when using preset.
|
339
|
+
ValueError: If both preset and direct credentials are provided, neither is provided,
|
340
|
+
required credentials are missing, or db_type is invalid.
|
341
|
+
RuntimeError: If the preset is not found or necessary preset details are missing.
|
331
342
|
"""
|
332
|
-
|
333
|
-
|
334
|
-
"""Get configuration either from a path, direct dictionary, or by searching upwards."""
|
335
|
-
def get_config_from_file(config_path: str) -> dict:
|
336
|
-
"""Load configuration from a JSON file."""
|
337
|
-
with open(config_path, 'r') as file:
|
338
|
-
return json.load(file)
|
339
|
-
|
343
|
+
def get_config() -> dict:
|
344
|
+
"""Get configuration by searching for '.rgwfuncsrc' in current directory and upwards."""
|
340
345
|
def find_config_file() -> str:
|
341
|
-
"""Search for '.rgwfuncsrc' in current directory and upwards."""
|
342
346
|
current_dir = os.getcwd()
|
343
347
|
while True:
|
344
348
|
config_path = os.path.join(current_dir, '.rgwfuncsrc')
|
345
349
|
if os.path.isfile(config_path):
|
346
350
|
return config_path
|
347
351
|
parent_dir = os.path.dirname(current_dir)
|
348
|
-
if parent_dir == current_dir:
|
352
|
+
if parent_dir == current_dir:
|
349
353
|
raise FileNotFoundError("No '.rgwfuncsrc' file found in current or parent directories")
|
350
354
|
current_dir = parent_dir
|
351
355
|
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
356
|
+
config_path = find_config_file()
|
357
|
+
with open(config_path, 'r', encoding='utf-8') as file:
|
358
|
+
content = file.read()
|
359
|
+
if not content.strip():
|
360
|
+
raise ValueError(f"Config file {config_path} is empty")
|
361
|
+
try:
|
362
|
+
return json.loads(content)
|
363
|
+
except json.JSONDecodeError as e:
|
364
|
+
raise ValueError(f"Invalid JSON in config file {config_path}: {e}")
|
365
|
+
|
366
|
+
def get_db_preset(config: dict, preset_name: str) -> dict:
|
367
|
+
"""Retrieve the database preset from the configuration."""
|
368
|
+
db_presets = config.get('db_presets', [])
|
369
|
+
for preset in db_presets:
|
370
|
+
if preset.get('name') == preset_name:
|
371
|
+
return preset
|
372
|
+
raise RuntimeError(f"Database preset '{preset_name}' not found in the configuration file")
|
373
|
+
|
374
|
+
def validate_credentials(
|
375
|
+
db_type: str,
|
376
|
+
host: Optional[str] = None,
|
377
|
+
user: Optional[str] = None,
|
378
|
+
password: Optional[str] = None,
|
379
|
+
database: Optional[str] = None
|
380
|
+
) -> dict:
|
381
|
+
"""Validate credentials and return a credentials dictionary."""
|
382
|
+
required_fields = {
|
383
|
+
'mssql': ['host', 'user', 'password'],
|
384
|
+
'mysql': ['host', 'user', 'password'],
|
385
|
+
'clickhouse': ['host', 'user', 'password', 'database']
|
386
|
+
}
|
387
|
+
if db_type not in required_fields:
|
388
|
+
raise ValueError(f"Unsupported db_type: {db_type}")
|
389
|
+
|
390
|
+
credentials = {
|
391
|
+
'db_type': db_type,
|
392
|
+
'host': host,
|
393
|
+
'user': user,
|
394
|
+
'password': password,
|
395
|
+
'database': database
|
396
|
+
}
|
397
|
+
missing = [field for field in required_fields[db_type] if credentials[field] is None]
|
398
|
+
if missing:
|
399
|
+
raise ValueError(f"Missing required credentials for {db_type}: {missing}")
|
400
|
+
return credentials
|
401
|
+
|
402
|
+
# Validate input parameters
|
403
|
+
all_credentials = [host, user, password, database]
|
404
|
+
if preset and any(all_credentials):
|
405
|
+
raise ValueError("Cannot specify both preset and direct credentials")
|
406
|
+
if not preset and not db_type:
|
407
|
+
raise ValueError("Either preset or db_type with credentials must be provided")
|
408
|
+
|
409
|
+
# Get credentials
|
410
|
+
if preset:
|
411
|
+
config_dict = get_config()
|
412
|
+
credentials = get_db_preset(config_dict, preset)
|
413
|
+
db_type = credentials.get('db_type')
|
414
|
+
if not db_type:
|
415
|
+
raise ValueError(f"Preset '{preset}' does not specify db_type")
|
416
|
+
else:
|
417
|
+
credentials = validate_credentials(
|
418
|
+
db_type=db_type,
|
419
|
+
host=host,
|
420
|
+
user=user,
|
421
|
+
password=password,
|
422
|
+
database=database
|
423
|
+
)
|
371
424
|
|
425
|
+
# Query functions
|
426
|
+
def query_mssql(credentials: dict, query: str) -> pd.DataFrame:
|
427
|
+
server = credentials['host']
|
428
|
+
user = credentials['user']
|
429
|
+
password = credentials['password']
|
430
|
+
database = credentials.get('database', '')
|
372
431
|
with pymssql.connect(server=server, user=user, password=password, database=database) as conn:
|
373
432
|
with conn.cursor() as cursor:
|
374
433
|
cursor.execute(query)
|
375
434
|
rows = cursor.fetchall()
|
376
435
|
columns = [desc[0] for desc in cursor.description]
|
377
|
-
|
378
436
|
return pd.DataFrame(rows, columns=columns)
|
379
437
|
|
380
|
-
def query_mysql(
|
381
|
-
host =
|
382
|
-
user =
|
383
|
-
password =
|
384
|
-
database =
|
385
|
-
|
438
|
+
def query_mysql(credentials: dict, query: str) -> pd.DataFrame:
|
439
|
+
host = credentials['host']
|
440
|
+
user = credentials['user']
|
441
|
+
password = credentials['password']
|
442
|
+
database = credentials.get('database', '')
|
386
443
|
with mysql.connector.connect(host=host, user=user, password=password, database=database) as conn:
|
387
444
|
with conn.cursor() as cursor:
|
388
445
|
cursor.execute(query)
|
389
446
|
rows = cursor.fetchall()
|
390
|
-
columns =
|
391
|
-
if cursor.description else [])
|
392
|
-
|
447
|
+
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
393
448
|
return pd.DataFrame(rows, columns=columns)
|
394
449
|
|
395
|
-
def query_clickhouse(
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
database = db_preset['database']
|
401
|
-
|
450
|
+
def query_clickhouse(credentials: dict, query: str) -> pd.DataFrame:
|
451
|
+
host = credentials['host']
|
452
|
+
user = credentials['user']
|
453
|
+
password = credentials['password']
|
454
|
+
database = credentials['database']
|
402
455
|
max_retries = 5
|
403
456
|
retry_delay = 5
|
404
|
-
|
405
457
|
for attempt in range(max_retries):
|
406
458
|
try:
|
407
459
|
client = clickhouse_connect.get_client(
|
@@ -411,99 +463,228 @@ def load_data_from_query(db_preset_name: str, query: str, config: Optional[Union
|
|
411
463
|
columns = data.column_names
|
412
464
|
return pd.DataFrame(rows, columns=columns)
|
413
465
|
except Exception as e:
|
414
|
-
print(f"Attempt {attempt + 1} failed: {e}")
|
415
466
|
if attempt < max_retries - 1:
|
416
|
-
print(f"Retrying in {retry_delay} seconds...")
|
417
467
|
time.sleep(retry_delay)
|
418
468
|
else:
|
419
|
-
raise ConnectionError(
|
420
|
-
"All attempts to connect to ClickHouse failed.")
|
421
|
-
|
422
|
-
def query_google_big_query(db_preset: Dict[str, Any], query: str) -> pd.DataFrame:
|
423
|
-
json_file_path = db_preset['json_file_path']
|
424
|
-
project_id = db_preset['project_id']
|
469
|
+
raise ConnectionError("All attempts to connect to ClickHouse failed.")
|
425
470
|
|
426
|
-
|
427
|
-
|
428
|
-
|
471
|
+
# Execute query based on db_type
|
472
|
+
if db_type == 'mssql':
|
473
|
+
return query_mssql(credentials, query)
|
474
|
+
elif db_type == 'mysql':
|
475
|
+
return query_mysql(credentials, query)
|
476
|
+
elif db_type == 'clickhouse':
|
477
|
+
return query_clickhouse(credentials, query)
|
478
|
+
else:
|
479
|
+
raise ValueError(f"Unsupported db_type: {db_type}")
|
429
480
|
|
430
|
-
query_job = client.query(query)
|
431
|
-
results = query_job.result()
|
432
|
-
rows = [list(row.values()) for row in results]
|
433
|
-
columns = [field.name for field in results.schema]
|
434
481
|
|
435
|
-
|
482
|
+
def load_data_from_big_query(
|
483
|
+
query: str,
|
484
|
+
json_file_path: Optional[str] = None,
|
485
|
+
project_id: Optional[str] = None,
|
486
|
+
preset: Optional[str] = None
|
487
|
+
) -> pd.DataFrame:
|
488
|
+
"""
|
489
|
+
Load data from a Google Big Query query into a DataFrame.
|
436
490
|
|
437
|
-
|
491
|
+
Parameters:
|
492
|
+
query (str): The SQL query to execute.
|
493
|
+
json_file_path (Optional[str]): Path to the Google Cloud service account JSON file.
|
494
|
+
project_id (Optional[str]): Google Cloud project ID.
|
495
|
+
preset (Optional[str]): The name of the Big Query preset in the .rgwfuncsrc file.
|
438
496
|
|
439
|
-
|
440
|
-
|
441
|
-
QueryString=query,
|
442
|
-
QueryExecutionContext={"Database": database},
|
443
|
-
ResultConfiguration={"OutputLocation": output_bucket}
|
444
|
-
)
|
445
|
-
return response["QueryExecutionId"]
|
497
|
+
Returns:
|
498
|
+
pd.DataFrame: DataFrame containing the query result.
|
446
499
|
|
447
|
-
|
500
|
+
Raises:
|
501
|
+
ValueError: If both preset and direct credentials are provided, neither is provided,
|
502
|
+
or required credentials are missing.
|
503
|
+
FileNotFoundError: If no '.rgwfuncsrc' file is found when using preset.
|
504
|
+
RuntimeError: If the preset is not found or necessary preset details are missing.
|
505
|
+
"""
|
506
|
+
def get_config() -> dict:
|
507
|
+
"""Get configuration by searching for '.rgwfuncsrc' in current directory and upwards."""
|
508
|
+
def find_config_file() -> str:
|
509
|
+
current_dir = os.getcwd()
|
448
510
|
while True:
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
raise
|
455
|
-
|
456
|
-
|
457
|
-
def download_athena_query_results(athena_client, query_execution_id: str) -> pd.DataFrame:
|
458
|
-
paginator = athena_client.get_paginator("get_query_results")
|
459
|
-
result_pages = paginator.paginate(QueryExecutionId=query_execution_id)
|
460
|
-
rows = []
|
461
|
-
columns = []
|
462
|
-
for page in result_pages:
|
463
|
-
if not columns:
|
464
|
-
columns = [col["Name"] for col in page["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
|
465
|
-
rows.extend(page["ResultSet"]["Rows"])
|
466
|
-
|
467
|
-
data = [[col.get("VarCharValue", None) for col in row["Data"]] for row in rows[1:]]
|
468
|
-
return pd.DataFrame(data, columns=columns)
|
469
|
-
|
470
|
-
aws_region = db_preset['aws_region']
|
471
|
-
database = db_preset['database']
|
472
|
-
output_bucket = db_preset['output_bucket']
|
473
|
-
|
474
|
-
athena_client = boto3.client(
|
475
|
-
'athena',
|
476
|
-
region_name=aws_region,
|
477
|
-
aws_access_key_id=db_preset['aws_access_key'],
|
478
|
-
aws_secret_access_key=db_preset['aws_secret_key']
|
479
|
-
)
|
511
|
+
config_path = os.path.join(current_dir, '.rgwfuncsrc')
|
512
|
+
if os.path.isfile(config_path):
|
513
|
+
return config_path
|
514
|
+
parent_dir = os.path.dirname(current_dir)
|
515
|
+
if parent_dir == current_dir:
|
516
|
+
raise FileNotFoundError("No '.rgwfuncsrc' file found in current or parent directories")
|
517
|
+
current_dir = parent_dir
|
480
518
|
|
481
|
-
|
482
|
-
|
483
|
-
|
519
|
+
config_path = find_config_file()
|
520
|
+
with open(config_path, 'r', encoding='utf-8') as file:
|
521
|
+
content = file.read()
|
522
|
+
if not content.strip():
|
523
|
+
raise ValueError(f"Config file {config_path} is empty")
|
524
|
+
try:
|
525
|
+
return json.loads(content)
|
526
|
+
except json.JSONDecodeError as e:
|
527
|
+
raise ValueError(f"Invalid JSON in config file {config_path}: {e}")
|
528
|
+
|
529
|
+
def get_db_preset(config: dict, preset_name: str) -> dict:
|
530
|
+
"""Retrieve the database preset from the configuration."""
|
531
|
+
db_presets = config.get('db_presets', [])
|
532
|
+
for preset in db_presets:
|
533
|
+
if preset.get('name') == preset_name:
|
534
|
+
return preset
|
535
|
+
raise RuntimeError(f"Database preset '{preset_name}' not found in the configuration file")
|
536
|
+
|
537
|
+
# Validate inputs
|
538
|
+
if preset and (json_file_path or project_id):
|
539
|
+
raise ValueError("Cannot specify both preset and json_file_path/project_id")
|
540
|
+
if not preset and (bool(json_file_path) != bool(project_id)):
|
541
|
+
raise ValueError("Both json_file_path and project_id must be provided if preset is not used")
|
542
|
+
if not preset and not json_file_path and not project_id:
|
543
|
+
raise ValueError("Either preset or both json_file_path and project_id must be provided")
|
544
|
+
|
545
|
+
# Get credentials
|
546
|
+
if preset:
|
547
|
+
config_dict = get_config()
|
548
|
+
credentials = get_db_preset(config_dict, preset)
|
549
|
+
if credentials.get('db_type') != 'google_big_query':
|
550
|
+
raise ValueError(f"Preset '{preset}' is not for google_big_query")
|
551
|
+
json_file_path = credentials.get('json_file_path')
|
552
|
+
project_id = credentials.get('project_id')
|
553
|
+
if not json_file_path or not project_id:
|
554
|
+
raise ValueError(f"Missing json_file_path or project_id in preset '{preset}'")
|
555
|
+
|
556
|
+
# Execute query
|
557
|
+
credentials_obj = service_account.Credentials.from_service_account_file(json_file_path)
|
558
|
+
client = bigquery.Client(credentials=credentials_obj, project=project_id)
|
559
|
+
query_job = client.query(query)
|
560
|
+
results = query_job.result()
|
561
|
+
rows = [list(row.values()) for row in results]
|
562
|
+
columns = [field.name for field in results.schema]
|
563
|
+
return pd.DataFrame(rows, columns=columns)
|
564
|
+
|
565
|
+
|
566
|
+
def load_data_from_aws_athena_query(
|
567
|
+
query: str,
|
568
|
+
aws_region: Optional[str] = None,
|
569
|
+
database: Optional[str] = None,
|
570
|
+
output_bucket: Optional[str] = None,
|
571
|
+
aws_access_key: Optional[str] = None,
|
572
|
+
aws_secret_key: Optional[str] = None,
|
573
|
+
preset: Optional[str] = None
|
574
|
+
) -> pd.DataFrame:
|
575
|
+
"""
|
576
|
+
Load data from an AWS Athena query into a DataFrame.
|
484
577
|
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
(
|
489
|
-
|
490
|
-
|
491
|
-
|
578
|
+
Parameters:
|
579
|
+
query (str): The SQL query to execute.
|
580
|
+
aws_region (Optional[str]): AWS region for Athena.
|
581
|
+
database (Optional[str]): Athena database name.
|
582
|
+
output_bucket (Optional[str]): S3 bucket for query results.
|
583
|
+
aws_access_key (Optional[str]): AWS access key ID.
|
584
|
+
aws_secret_key (Optional[str]): AWS secret access key.
|
585
|
+
preset (Optional[str]): The name of the Athena preset in the .rgwfuncsrc file.
|
492
586
|
|
493
|
-
|
587
|
+
Returns:
|
588
|
+
pd.DataFrame: DataFrame containing the query result.
|
494
589
|
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
590
|
+
Raises:
|
591
|
+
ValueError: If both preset and direct credentials are provided, neither is provided,
|
592
|
+
or required credentials are missing.
|
593
|
+
FileNotFoundError: If no '.rgwfuncsrc' file is found when using preset.
|
594
|
+
RuntimeError: If the preset is not found or necessary preset details are missing.
|
595
|
+
"""
|
596
|
+
def get_config() -> dict:
|
597
|
+
"""Get configuration by searching for '.rgwfuncsrc' in current directory and upwards."""
|
598
|
+
def find_config_file() -> str:
|
599
|
+
current_dir = os.getcwd()
|
600
|
+
while True:
|
601
|
+
config_path = os.path.join(current_dir, '.rgwfuncsrc')
|
602
|
+
if os.path.isfile(config_path):
|
603
|
+
return config_path
|
604
|
+
parent_dir = os.path.dirname(current_dir)
|
605
|
+
if parent_dir == current_dir:
|
606
|
+
raise FileNotFoundError("No '.rgwfuncsrc' file found in current or parent directories")
|
607
|
+
current_dir = parent_dir
|
608
|
+
|
609
|
+
config_path = find_config_file()
|
610
|
+
with open(config_path, 'r', encoding='utf-8') as file:
|
611
|
+
content = file.read()
|
612
|
+
if not content.strip():
|
613
|
+
raise ValueError(f"Config file {config_path} is empty")
|
614
|
+
try:
|
615
|
+
return json.loads(content)
|
616
|
+
except json.JSONDecodeError as e:
|
617
|
+
raise ValueError(f"Invalid JSON in config file {config_path}: {e}")
|
618
|
+
|
619
|
+
def get_db_preset(config: dict, preset_name: str) -> dict:
|
620
|
+
"""Retrieve the database preset from the configuration."""
|
621
|
+
db_presets = config.get('db_presets', [])
|
622
|
+
for preset in db_presets:
|
623
|
+
if preset.get('name') == preset_name:
|
624
|
+
return preset
|
625
|
+
raise RuntimeError(f"Database preset '{preset_name}' not found in the configuration file")
|
626
|
+
|
627
|
+
# Validate inputs
|
628
|
+
if preset and any([aws_region, database, output_bucket, aws_access_key, aws_secret_key]):
|
629
|
+
raise ValueError("Cannot specify both preset and direct Athena credentials")
|
630
|
+
required = [aws_region, database, output_bucket, aws_access_key, aws_secret_key]
|
631
|
+
if not preset and not all(required):
|
632
|
+
raise ValueError("All Athena credentials (aws_region, database, output_bucket, aws_access_key, aws_secret_key) must be provided if preset is not used")
|
633
|
+
|
634
|
+
# Get credentials
|
635
|
+
if preset:
|
636
|
+
config_dict = get_config()
|
637
|
+
credentials = get_db_preset(config_dict, preset)
|
638
|
+
if credentials.get('db_type') != 'aws_athena':
|
639
|
+
raise ValueError(f"Preset '{preset}' is not for aws_athena")
|
640
|
+
aws_region = credentials.get('aws_region')
|
641
|
+
database = credentials.get('database')
|
642
|
+
output_bucket = credentials.get('output_bucket')
|
643
|
+
aws_access_key = credentials.get('aws_access_key')
|
644
|
+
aws_secret_key = credentials.get('aws_secret_key')
|
645
|
+
if not all([aws_region, database, output_bucket, aws_access_key, aws_secret_key]):
|
646
|
+
raise ValueError(f"Missing required Athena credentials in preset '{preset}'")
|
647
|
+
|
648
|
+
# Execute query
|
649
|
+
def execute_athena_query(athena_client, query: str, database: str, output_bucket: str) -> str:
|
650
|
+
response = athena_client.start_query_execution(
|
651
|
+
QueryString=query,
|
652
|
+
QueryExecutionContext={"Database": database},
|
653
|
+
ResultConfiguration={"OutputLocation": output_bucket}
|
654
|
+
)
|
655
|
+
return response["QueryExecutionId"]
|
656
|
+
|
657
|
+
def wait_for_athena_query_to_complete(athena_client, query_execution_id: str):
|
658
|
+
while True:
|
659
|
+
response = athena_client.get_query_execution(QueryExecutionId=query_execution_id)
|
660
|
+
state = response["QueryExecution"]["Status"]["State"]
|
661
|
+
if state == "SUCCEEDED":
|
662
|
+
break
|
663
|
+
elif state in ("FAILED", "CANCELLED"):
|
664
|
+
raise Exception(f"Query failed with state: {state}")
|
665
|
+
time.sleep(1)
|
666
|
+
|
667
|
+
def download_athena_query_results(athena_client, query_execution_id: str) -> pd.DataFrame:
|
668
|
+
paginator = athena_client.get_paginator("get_query_results")
|
669
|
+
result_pages = paginator.paginate(QueryExecutionId=query_execution_id)
|
670
|
+
rows = []
|
671
|
+
columns = []
|
672
|
+
for page in result_pages:
|
673
|
+
if not columns:
|
674
|
+
columns = [col["Name"] for col in page["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
|
675
|
+
rows.extend(page["ResultSet"]["Rows"])
|
676
|
+
data = [[col.get("VarCharValue", None) for col in row["Data"]] for row in rows[1:]]
|
677
|
+
return pd.DataFrame(data, columns=columns)
|
678
|
+
|
679
|
+
athena_client = boto3.client(
|
680
|
+
'athena',
|
681
|
+
region_name=aws_region,
|
682
|
+
aws_access_key_id=aws_access_key,
|
683
|
+
aws_secret_access_key=aws_secret_key
|
684
|
+
)
|
685
|
+
query_execution_id = execute_athena_query(athena_client, query, database, output_bucket)
|
686
|
+
wait_for_athena_query_to_complete(athena_client, query_execution_id)
|
687
|
+
return download_athena_query_results(athena_client, query_execution_id)
|
507
688
|
|
508
689
|
|
509
690
|
def load_data_from_path(file_path: str) -> pd.DataFrame:
|
@@ -0,0 +1,11 @@
|
|
1
|
+
rgwfuncs/__init__.py,sha256=WafmLtqJRnQ7LWU7Son0inje75tF-m6qHJqrmCgiM84,1354
|
2
|
+
rgwfuncs/df_lib.py,sha256=lcKrMj8IpxWxnWkBFekWBQd9-Ed2nXTsHo57gTeATjE,85153
|
3
|
+
rgwfuncs/docs_lib.py,sha256=i63NzX-V8cGhikYdtkRGAEe2VcuwpXxDUyTRa9xI7l8,1972
|
4
|
+
rgwfuncs/interactive_shell_lib.py,sha256=YeJBW9YgH5Nv77ONdOyIKFgtf0ItXStdlKGN9GGf8bU,4228
|
5
|
+
rgwfuncs/str_lib.py,sha256=rfzRd7bOc0KQ7UxUN-J6yxY23pHEUHqerVR1u-TyIAY,10690
|
6
|
+
rgwfuncs-0.0.111.dist-info/licenses/LICENSE,sha256=jLvt20gcUZYB8UOvyBvyKQ1qhYYhD__qP7ZDx2lPFkU,1062
|
7
|
+
rgwfuncs-0.0.111.dist-info/METADATA,sha256=Krg_iCOkr1AD2LJai679k3oLih0ljcr15Lt4was6H48,42972
|
8
|
+
rgwfuncs-0.0.111.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
9
|
+
rgwfuncs-0.0.111.dist-info/entry_points.txt,sha256=j-c5IOPIQ0252EaOV6j6STio56sbXl2C4ym_fQ0lXx0,43
|
10
|
+
rgwfuncs-0.0.111.dist-info/top_level.txt,sha256=aGuVIzWsKiV1f2gCb6mynx0zx5ma0B1EwPGFKVEMTi4,9
|
11
|
+
rgwfuncs-0.0.111.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
rgwfuncs/__init__.py,sha256=XIbMHeEimBLNcX2PK7uQGquG-eAYXsetVEJrlBZIH5U,1295
|
2
|
-
rgwfuncs/df_lib.py,sha256=3foomRHlunCpf_accTcqfdgwDmSPIVkVCnlWJ4ag4XQ,76947
|
3
|
-
rgwfuncs/docs_lib.py,sha256=i63NzX-V8cGhikYdtkRGAEe2VcuwpXxDUyTRa9xI7l8,1972
|
4
|
-
rgwfuncs/interactive_shell_lib.py,sha256=YeJBW9YgH5Nv77ONdOyIKFgtf0ItXStdlKGN9GGf8bU,4228
|
5
|
-
rgwfuncs/str_lib.py,sha256=rfzRd7bOc0KQ7UxUN-J6yxY23pHEUHqerVR1u-TyIAY,10690
|
6
|
-
rgwfuncs-0.0.109.dist-info/licenses/LICENSE,sha256=jLvt20gcUZYB8UOvyBvyKQ1qhYYhD__qP7ZDx2lPFkU,1062
|
7
|
-
rgwfuncs-0.0.109.dist-info/METADATA,sha256=8x4cokNH7a30Xk_Bym-UHPLPNXopZ9NmbBEtm9_sWKE,42972
|
8
|
-
rgwfuncs-0.0.109.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
9
|
-
rgwfuncs-0.0.109.dist-info/entry_points.txt,sha256=j-c5IOPIQ0252EaOV6j6STio56sbXl2C4ym_fQ0lXx0,43
|
10
|
-
rgwfuncs-0.0.109.dist-info/top_level.txt,sha256=aGuVIzWsKiV1f2gCb6mynx0zx5ma0B1EwPGFKVEMTi4,9
|
11
|
-
rgwfuncs-0.0.109.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|