nui-python-shared-utils 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,623 @@
1
+ """
2
+ Refactored Database client using BaseClient for DRY code patterns.
3
+ """
4
+
5
+ import os
6
+ import time
7
+ import logging
8
+ from typing import Dict, List, Optional, Any
9
+ from contextlib import contextmanager
10
+ import pymysql
11
+
12
+ from .base_client import BaseClient, ServiceHealthMixin
13
+ from .utils import handle_client_errors, safe_close_connection
14
+ from .secrets_helper import get_database_credentials
15
+
16
+ # Optional PostgreSQL support
17
+ try:
18
+ import psycopg2
19
+ import psycopg2.extras
20
+ HAS_POSTGRESQL = True
21
+ except ImportError:
22
+ HAS_POSTGRESQL = False
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+ # Global connection pool for reuse across invocations
27
+ _connection_pool = {}
28
+
29
+
30
+ def _clean_expired_connections(pool_key: str, pool_recycle: int) -> None:
31
+ """
32
+ Clean expired connections from the pool.
33
+
34
+ Args:
35
+ pool_key: Pool identifier
36
+ pool_recycle: Maximum connection age in seconds (0 disables recycling)
37
+ """
38
+ if not pool_recycle or pool_recycle <= 0:
39
+ return
40
+
41
+ if pool_key not in _connection_pool:
42
+ return
43
+
44
+ current_time = time.time()
45
+ pool_entries = _connection_pool[pool_key]
46
+
47
+ # Find expired connections
48
+ expired_connections = []
49
+ fresh_connections = []
50
+
51
+ for entry in pool_entries:
52
+ age = current_time - entry["timestamp"]
53
+ if age > pool_recycle:
54
+ expired_connections.append(entry["connection"])
55
+ else:
56
+ fresh_connections.append(entry)
57
+
58
+ # Close expired connections
59
+ for conn in expired_connections:
60
+ safe_close_connection(conn)
61
+
62
+ # Update pool with only fresh connections
63
+ _connection_pool[pool_key] = fresh_connections
64
+
65
+ if expired_connections:
66
+ log.debug(
67
+ f"Cleaned {len(expired_connections)} expired connections from pool {pool_key}"
68
+ )
69
+
70
+
71
+ def get_pool_stats() -> Dict[str, Any]:
72
+ """
73
+ Get current connection pool statistics.
74
+
75
+ Returns:
76
+ Dictionary with pool status for monitoring
77
+ """
78
+ stats = {"total_pools": len(_connection_pool), "pools": {}}
79
+ current_time = time.time()
80
+
81
+ for pool_key, connection_entries in _connection_pool.items():
82
+ pool_stats = {
83
+ "active_connections": len(connection_entries),
84
+ "healthy_connections": 0,
85
+ "aged_connections": 0,
86
+ }
87
+
88
+ # Test health of pooled connections
89
+ healthy = 0
90
+ aged = 0
91
+ for entry in connection_entries:
92
+ conn = entry["connection"]
93
+ timestamp = entry["timestamp"]
94
+ age = current_time - timestamp
95
+
96
+ try:
97
+ conn.ping(reconnect=False)
98
+ healthy += 1
99
+ except Exception:
100
+ pass # Connection is unhealthy
101
+
102
+ # Count connections older than 1 hour as aged
103
+ if age > 3600:
104
+ aged += 1
105
+
106
+ pool_stats["healthy_connections"] = healthy
107
+ pool_stats["aged_connections"] = aged
108
+ stats["pools"][pool_key] = pool_stats
109
+
110
+ return stats
111
+
112
+
113
+ class DatabaseClient(BaseClient, ServiceHealthMixin):
114
+ """
115
+ Refactored Database client with connection pooling and standardized patterns.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ secret_name: Optional[str] = None,
121
+ use_pool: bool = True,
122
+ pool_size: int = 5,
123
+ pool_recycle: int = 3600,
124
+ credentials: Optional[Dict[str, Any]] = None,
125
+ **kwargs
126
+ ):
127
+ """
128
+ Initialize database client.
129
+
130
+ Args:
131
+ secret_name: Override secret name
132
+ use_pool: Enable connection pooling
133
+ pool_size: Maximum pooled connections
134
+ pool_recycle: Recycle connections after seconds
135
+ credentials: Direct credentials dict (keys: host, port, username, password, database),
136
+ bypasses Secrets Manager
137
+ **kwargs: Additional configuration
138
+ """
139
+ self.use_pool = use_pool
140
+ self.pool_size = pool_size
141
+ self.pool_recycle = pool_recycle
142
+
143
+ super().__init__(secret_name=secret_name, credentials=credentials, **kwargs)
144
+
145
+ # Build pool key for connection management
146
+ self._pool_key = f"{self.credentials['host']}:{self.credentials['port']}"
147
+
148
+ def _get_default_config_prefix(self) -> str:
149
+ """Return configuration prefix for database."""
150
+ return "db"
151
+
152
+ def _get_default_secret_name(self) -> str:
153
+ """Return default secret name for DB credentials."""
154
+ return "database-credentials"
155
+
156
+ def _create_service_client(self) -> None:
157
+ """Database client doesn't have a single service client - uses connections."""
158
+ return None
159
+
160
+ def _resolve_credentials_from_env(self) -> Optional[Dict[str, Any]]:
161
+ """Resolve MySQL credentials from environment variables.
162
+
163
+ Requires both DB_HOST and DB_PASSWORD to trigger.
164
+ """
165
+ host = os.environ.get("DB_HOST")
166
+ password = os.environ.get("DB_PASSWORD")
167
+ if not host or not password:
168
+ return None
169
+ port_str = os.environ.get("DB_PORT", "3306")
170
+ try:
171
+ port = int(port_str)
172
+ except ValueError:
173
+ raise ValueError(f"DB_PORT must be an integer, got: {port_str!r}")
174
+ return {
175
+ "host": host,
176
+ "port": port,
177
+ "username": os.environ.get("DB_USERNAME", "root"),
178
+ "password": password,
179
+ "database": os.environ.get("DB_DATABASE", "app"),
180
+ }
181
+
182
+ def _fetch_credentials_from_sm(self, secret_name: Optional[str]) -> Dict[str, Any]:
183
+ """Override to use get_database_credentials for normalized field names."""
184
+ return get_database_credentials(self._resolve_secret_name(secret_name))
185
+
186
+ def _clean_expired_connections(self, pool_key: str) -> None:
187
+ """
188
+ Clean expired connections from the pool.
189
+
190
+ Args:
191
+ pool_key: Pool identifier
192
+ """
193
+ if not self.pool_recycle or self.pool_recycle <= 0:
194
+ return
195
+
196
+ if pool_key not in _connection_pool:
197
+ return
198
+
199
+ current_time = time.time()
200
+ pool_entries = _connection_pool[pool_key]
201
+
202
+ # Filter out expired connections
203
+ active_entries = []
204
+ expired_count = 0
205
+
206
+ for entry in pool_entries:
207
+ age = current_time - entry["timestamp"]
208
+ if age >= self.pool_recycle:
209
+ safe_close_connection(entry["connection"])
210
+ expired_count += 1
211
+ else:
212
+ active_entries.append(entry)
213
+
214
+ _connection_pool[pool_key] = active_entries
215
+
216
+ if expired_count > 0:
217
+ log.debug(f"Cleaned {expired_count} expired connections from pool {pool_key}")
218
+
219
+ @contextmanager
220
+ def get_connection(self, database: Optional[str] = None):
221
+ """
222
+ Context manager for database connections with pooling support.
223
+
224
+ Args:
225
+ database: Override default database
226
+
227
+ Yields:
228
+ Database connection object
229
+ """
230
+ connection = None
231
+ pool_key = None
232
+
233
+ try:
234
+ # Use pooling if enabled and for default database only
235
+ if self.use_pool and not database:
236
+ pool_key = f"{self._pool_key}_{self.credentials.get('database', 'app')}"
237
+ current_time = time.time()
238
+
239
+ # Try to get from pool first
240
+ if pool_key in _connection_pool:
241
+ pool_entries = _connection_pool[pool_key]
242
+
243
+ # Look for a fresh, healthy connection
244
+ while pool_entries:
245
+ entry = pool_entries.pop()
246
+ conn = entry["connection"]
247
+ timestamp = entry["timestamp"]
248
+ age = current_time - timestamp
249
+
250
+ # Check if connection has exceeded recycle time
251
+ if self.pool_recycle and self.pool_recycle > 0 and age >= self.pool_recycle:
252
+ safe_close_connection(conn)
253
+ log.debug(f"Recycled expired connection (age: {age:.1f}s) for {pool_key}")
254
+ continue
255
+
256
+ # Test if connection is still alive
257
+ try:
258
+ conn.ping(reconnect=False)
259
+ connection = conn
260
+ log.debug(f"Reused pooled connection (age: {age:.1f}s) for {pool_key}")
261
+ break
262
+ except Exception as e:
263
+ safe_close_connection(conn)
264
+ log.debug(f"Closed dead pooled connection for {pool_key}: {e}")
265
+ continue
266
+
267
+ # Create new connection if no pooled connection available
268
+ if connection is None:
269
+ connection = pymysql.connect(
270
+ host=self.credentials["host"],
271
+ port=self.credentials.get("port", 3306),
272
+ user=self.credentials["username"],
273
+ password=self.credentials["password"],
274
+ database=database or self.credentials.get("database", "app"),
275
+ charset="utf8mb4",
276
+ cursorclass=pymysql.cursors.DictCursor,
277
+ connect_timeout=10,
278
+ read_timeout=30,
279
+ )
280
+ if self.use_pool and not database:
281
+ log.debug(f"Created new pooled connection for {pool_key}")
282
+
283
+ yield connection
284
+
285
+ finally:
286
+ if connection:
287
+ # Return to pool if pooling enabled and healthy
288
+ if self.use_pool and not database and pool_key:
289
+ try:
290
+ # Test connection health before returning to pool
291
+ connection.ping(reconnect=False)
292
+
293
+ # Initialize pool for this key if needed
294
+ if pool_key not in _connection_pool:
295
+ _connection_pool[pool_key] = []
296
+
297
+ # Clean up expired connections before adding new one
298
+ self._clean_expired_connections(pool_key)
299
+
300
+ # Add back to pool if under limit
301
+ if len(_connection_pool[pool_key]) < self.pool_size:
302
+ entry = {"connection": connection, "timestamp": time.time()}
303
+ _connection_pool[pool_key].append(entry)
304
+ log.debug(
305
+ f"Returned connection to pool {pool_key} "
306
+ f"(pool size: {len(_connection_pool[pool_key])})"
307
+ )
308
+ else:
309
+ # Pool full, close connection
310
+ safe_close_connection(connection)
311
+ log.debug(f"Pool {pool_key} full, closed connection")
312
+ except Exception as e:
313
+ # Connection unhealthy, close it
314
+ safe_close_connection(connection)
315
+ log.debug(f"Connection unhealthy, closed instead of pooling: {e}")
316
+ connection = None
317
+ else:
318
+ # Not using pooling, close immediately
319
+ safe_close_connection(connection)
320
+
321
+ @handle_client_errors(default_return=[])
322
+ def query(
323
+ self,
324
+ sql: str,
325
+ params: Optional[tuple] = None,
326
+ database: Optional[str] = None
327
+ ) -> List[Dict]:
328
+ """
329
+ Execute a SELECT query with error handling.
330
+
331
+ Args:
332
+ sql: SQL query with %s placeholders
333
+ params: Query parameters
334
+ database: Override default database
335
+
336
+ Returns:
337
+ List of result rows as dictionaries
338
+ """
339
+ def _query_operation():
340
+ with self.get_connection(database) as conn:
341
+ with conn.cursor() as cursor:
342
+ cursor.execute(sql, params)
343
+ return cursor.fetchall()
344
+
345
+ return self._execute_with_error_handling(
346
+ "query",
347
+ _query_operation,
348
+ sql=sql[:100], # First 100 chars for safety
349
+ database=database
350
+ )
351
+
352
+ @handle_client_errors(reraise=True)
353
+ def execute(
354
+ self,
355
+ sql: str,
356
+ params: Optional[tuple] = None,
357
+ database: Optional[str] = None
358
+ ) -> int:
359
+ """
360
+ Execute an INSERT, UPDATE, or DELETE query with error handling.
361
+
362
+ Args:
363
+ sql: SQL query with %s placeholders
364
+ params: Query parameters
365
+ database: Override default database
366
+
367
+ Returns:
368
+ Number of affected rows
369
+ """
370
+ def _execute_operation():
371
+ with self.get_connection(database) as conn:
372
+ with conn.cursor() as cursor:
373
+ cursor.execute(sql, params)
374
+ conn.commit()
375
+ return cursor.rowcount
376
+
377
+ return self._execute_with_error_handling(
378
+ "execute",
379
+ _execute_operation,
380
+ sql=sql[:100],
381
+ database=database
382
+ )
383
+
384
+ @handle_client_errors(reraise=True)
385
+ def bulk_insert(
386
+ self,
387
+ table: str,
388
+ records: List[Dict],
389
+ database: Optional[str] = None,
390
+ batch_size: int = 1000,
391
+ ignore_duplicates: bool = False,
392
+ ) -> int:
393
+ """
394
+ Bulk insert records with error handling.
395
+
396
+ Args:
397
+ table: Table name
398
+ records: List of dictionaries to insert
399
+ database: Override default database
400
+ batch_size: Records per batch
401
+ ignore_duplicates: Use INSERT IGNORE
402
+
403
+ Returns:
404
+ Total number of inserted rows
405
+ """
406
+ if not records:
407
+ return 0
408
+
409
+ def _bulk_insert_operation():
410
+ # Prepare query
411
+ columns = list(records[0].keys())
412
+ placeholders = ", ".join(["%s"] * len(columns))
413
+ columns_str = ", ".join(f"`{col}`" for col in columns)
414
+
415
+ insert_cmd = "INSERT IGNORE" if ignore_duplicates else "INSERT"
416
+ sql = f"{insert_cmd} INTO `{table}` ({columns_str}) VALUES ({placeholders})"
417
+
418
+ total_inserted = 0
419
+
420
+ with self.get_connection(database) as conn:
421
+ with conn.cursor() as cursor:
422
+ # Process in batches
423
+ for i in range(0, len(records), batch_size):
424
+ batch = records[i : i + batch_size]
425
+ values = [tuple(record.get(col) for col in columns) for record in batch]
426
+
427
+ cursor.executemany(sql, values)
428
+ total_inserted += cursor.rowcount
429
+
430
+ conn.commit()
431
+
432
+ log.info(f"Bulk inserted {total_inserted} rows into {table}")
433
+ return total_inserted
434
+
435
+ return self._execute_with_error_handling(
436
+ "bulk_insert",
437
+ _bulk_insert_operation,
438
+ table=table,
439
+ record_count=len(records),
440
+ database=database
441
+ )
442
+
443
+ def _perform_health_check(self):
444
+ """Perform database health check."""
445
+ try:
446
+ with self.get_connection() as conn:
447
+ with conn.cursor() as cursor:
448
+ cursor.execute("SELECT 1")
449
+ result = cursor.fetchone()
450
+ if not result or result.get("1") != 1:
451
+ raise Exception("Database health check query failed")
452
+ except Exception as e:
453
+ raise Exception(f"Database health check failed: {e}")
454
+
455
+ def get_connection_info(self) -> Dict:
456
+ """
457
+ Get database connection information.
458
+
459
+ Returns:
460
+ Dictionary with connection details
461
+ """
462
+ return {
463
+ "host": self.credentials["host"],
464
+ "port": self.credentials["port"],
465
+ "database": self.credentials["database"],
466
+ "username": self.credentials["username"],
467
+ "pool_enabled": self.use_pool,
468
+ "pool_size": self.pool_size,
469
+ "pool_recycle_seconds": self.pool_recycle,
470
+ }
471
+
472
+
473
+ class PostgreSQLClient(BaseClient, ServiceHealthMixin):
474
+ """
475
+ PostgreSQL client with connection management.
476
+ """
477
+
478
+ def __init__(
479
+ self,
480
+ secret_name: Optional[str] = None,
481
+ use_auth_credentials: bool = True,
482
+ credentials: Optional[Dict[str, Any]] = None,
483
+ **kwargs
484
+ ):
485
+ """
486
+ Initialize PostgreSQL client.
487
+
488
+ Args:
489
+ secret_name: Override secret name
490
+ use_auth_credentials: Use auth-specific credentials
491
+ credentials: Direct credentials dict (keys: host, port, username, password, database),
492
+ bypasses Secrets Manager
493
+ **kwargs: Additional configuration
494
+ """
495
+ if not HAS_POSTGRESQL:
496
+ raise ImportError("psycopg2 is not installed. Install with: pip install psycopg2-binary")
497
+
498
+ self.use_auth_credentials = use_auth_credentials
499
+ super().__init__(secret_name=secret_name, credentials=credentials, **kwargs)
500
+
501
+ def _get_default_config_prefix(self) -> str:
502
+ """Return configuration prefix for PostgreSQL."""
503
+ return "db"
504
+
505
+ def _get_default_secret_name(self) -> str:
506
+ """Return default secret name for PostgreSQL credentials."""
507
+ return "database-credentials"
508
+
509
+ def _create_service_client(self) -> None:
510
+ """PostgreSQL client doesn't have a single service client - uses connections."""
511
+ return None
512
+
513
+ def _resolve_credentials_from_env(self) -> Optional[Dict[str, Any]]:
514
+ """Resolve PostgreSQL credentials from environment variables.
515
+
516
+ Requires both DB_HOST and DB_PASSWORD to trigger.
517
+ """
518
+ host = os.environ.get("DB_HOST")
519
+ password = os.environ.get("DB_PASSWORD")
520
+ if not host or not password:
521
+ return None
522
+ port_str = os.environ.get("DB_PORT", "5432")
523
+ try:
524
+ port = int(port_str)
525
+ except ValueError:
526
+ raise ValueError(f"DB_PORT must be an integer, got: {port_str!r}")
527
+ return {
528
+ "host": host,
529
+ "port": port,
530
+ "username": os.environ.get("DB_USERNAME", "postgres"),
531
+ "password": password,
532
+ "database": os.environ.get("DB_DATABASE", "postgres"),
533
+ }
534
+
535
+ def _fetch_credentials_from_sm(self, secret_name: Optional[str]) -> Dict[str, Any]:
536
+ """Override with auth-specific credential handling."""
537
+ from .secrets_helper import get_secret
538
+
539
+ resolved_secret_name = self._resolve_secret_name(secret_name)
540
+ raw_creds = get_secret(resolved_secret_name)
541
+
542
+ # Use auth-specific credentials if available and requested
543
+ if self.use_auth_credentials and "auth_host" in raw_creds:
544
+ return {
545
+ "host": raw_creds["auth_host"],
546
+ "port": int(raw_creds.get("auth_port", 5432)),
547
+ "username": raw_creds.get("auth_username"),
548
+ "password": raw_creds.get("auth_password"),
549
+ "database": raw_creds.get("auth_database", "auth-service-db"),
550
+ }
551
+ else:
552
+ return get_database_credentials(resolved_secret_name)
553
+
554
+ @contextmanager
555
+ def get_connection(self, database: Optional[str] = None):
556
+ """
557
+ Context manager for PostgreSQL connections.
558
+
559
+ Args:
560
+ database: Override default database
561
+
562
+ Yields:
563
+ psycopg2 connection object
564
+ """
565
+ connection = None
566
+ try:
567
+ connect_params = {
568
+ "host": self.credentials["host"],
569
+ "port": self.credentials.get("port", 5432),
570
+ "user": self.credentials["username"],
571
+ "password": self.credentials["password"],
572
+ "database": database or self.credentials.get("database", "postgres"),
573
+ "connect_timeout": 5,
574
+ }
575
+
576
+ connection = psycopg2.connect(**connect_params)
577
+ yield connection
578
+ finally:
579
+ if connection:
580
+ connection.close()
581
+
582
+ @handle_client_errors(default_return=[])
583
+ def query(
584
+ self,
585
+ sql: str,
586
+ params: Optional[tuple] = None,
587
+ database: Optional[str] = None
588
+ ) -> List[Dict]:
589
+ """
590
+ Execute PostgreSQL SELECT query.
591
+
592
+ Args:
593
+ sql: SQL query with %s placeholders
594
+ params: Query parameters
595
+ database: Override default database
596
+
597
+ Returns:
598
+ List of result rows as dictionaries
599
+ """
600
+ def _query_operation():
601
+ with self.get_connection(database) as conn:
602
+ with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
603
+ cursor.execute(sql, params)
604
+ return [dict(row) for row in cursor.fetchall()]
605
+
606
+ return self._execute_with_error_handling(
607
+ "query",
608
+ _query_operation,
609
+ sql=sql[:100],
610
+ database=database
611
+ )
612
+
613
+ def _perform_health_check(self):
614
+ """Perform PostgreSQL health check."""
615
+ try:
616
+ with self.get_connection() as conn:
617
+ with conn.cursor() as cursor:
618
+ cursor.execute("SELECT 1")
619
+ result = cursor.fetchone()
620
+ if not result or result[0] != 1:
621
+ raise Exception("PostgreSQL health check query failed")
622
+ except Exception as e:
623
+ raise Exception(f"PostgreSQL health check failed: {e}")