kailash 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. kailash/__init__.py +3 -3
  2. kailash/api/custom_nodes_secure.py +3 -3
  3. kailash/api/gateway.py +1 -1
  4. kailash/api/studio.py +1 -1
  5. kailash/api/workflow_api.py +2 -2
  6. kailash/core/resilience/bulkhead.py +475 -0
  7. kailash/core/resilience/circuit_breaker.py +92 -10
  8. kailash/core/resilience/health_monitor.py +578 -0
  9. kailash/edge/discovery.py +86 -0
  10. kailash/mcp_server/__init__.py +309 -33
  11. kailash/mcp_server/advanced_features.py +1022 -0
  12. kailash/mcp_server/ai_registry_server.py +27 -2
  13. kailash/mcp_server/auth.py +789 -0
  14. kailash/mcp_server/client.py +645 -378
  15. kailash/mcp_server/discovery.py +1593 -0
  16. kailash/mcp_server/errors.py +673 -0
  17. kailash/mcp_server/oauth.py +1727 -0
  18. kailash/mcp_server/protocol.py +1126 -0
  19. kailash/mcp_server/registry_integration.py +587 -0
  20. kailash/mcp_server/server.py +1228 -96
  21. kailash/mcp_server/transports.py +1169 -0
  22. kailash/mcp_server/utils/__init__.py +6 -1
  23. kailash/mcp_server/utils/cache.py +250 -7
  24. kailash/middleware/auth/auth_manager.py +3 -3
  25. kailash/middleware/communication/api_gateway.py +1 -1
  26. kailash/middleware/communication/realtime.py +1 -1
  27. kailash/middleware/mcp/enhanced_server.py +1 -1
  28. kailash/nodes/__init__.py +2 -0
  29. kailash/nodes/admin/audit_log.py +6 -6
  30. kailash/nodes/admin/permission_check.py +8 -8
  31. kailash/nodes/admin/role_management.py +32 -28
  32. kailash/nodes/admin/schema.sql +6 -1
  33. kailash/nodes/admin/schema_manager.py +13 -13
  34. kailash/nodes/admin/security_event.py +15 -15
  35. kailash/nodes/admin/tenant_isolation.py +3 -3
  36. kailash/nodes/admin/transaction_utils.py +3 -3
  37. kailash/nodes/admin/user_management.py +21 -21
  38. kailash/nodes/ai/a2a.py +11 -11
  39. kailash/nodes/ai/ai_providers.py +9 -12
  40. kailash/nodes/ai/embedding_generator.py +13 -14
  41. kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
  42. kailash/nodes/ai/iterative_llm_agent.py +2 -2
  43. kailash/nodes/ai/llm_agent.py +210 -33
  44. kailash/nodes/ai/self_organizing.py +2 -2
  45. kailash/nodes/alerts/discord.py +4 -4
  46. kailash/nodes/api/graphql.py +6 -6
  47. kailash/nodes/api/http.py +10 -10
  48. kailash/nodes/api/rate_limiting.py +4 -4
  49. kailash/nodes/api/rest.py +15 -15
  50. kailash/nodes/auth/mfa.py +3 -3
  51. kailash/nodes/auth/risk_assessment.py +2 -2
  52. kailash/nodes/auth/session_management.py +5 -5
  53. kailash/nodes/auth/sso.py +143 -0
  54. kailash/nodes/base.py +8 -2
  55. kailash/nodes/base_async.py +16 -2
  56. kailash/nodes/base_with_acl.py +2 -2
  57. kailash/nodes/cache/__init__.py +9 -0
  58. kailash/nodes/cache/cache.py +1172 -0
  59. kailash/nodes/cache/cache_invalidation.py +874 -0
  60. kailash/nodes/cache/redis_pool_manager.py +595 -0
  61. kailash/nodes/code/async_python.py +2 -1
  62. kailash/nodes/code/python.py +194 -30
  63. kailash/nodes/compliance/data_retention.py +6 -6
  64. kailash/nodes/compliance/gdpr.py +5 -5
  65. kailash/nodes/data/__init__.py +10 -0
  66. kailash/nodes/data/async_sql.py +1956 -129
  67. kailash/nodes/data/optimistic_locking.py +906 -0
  68. kailash/nodes/data/readers.py +8 -8
  69. kailash/nodes/data/redis.py +378 -0
  70. kailash/nodes/data/sql.py +314 -3
  71. kailash/nodes/data/streaming.py +21 -0
  72. kailash/nodes/enterprise/__init__.py +8 -0
  73. kailash/nodes/enterprise/audit_logger.py +285 -0
  74. kailash/nodes/enterprise/batch_processor.py +22 -3
  75. kailash/nodes/enterprise/data_lineage.py +1 -1
  76. kailash/nodes/enterprise/mcp_executor.py +205 -0
  77. kailash/nodes/enterprise/service_discovery.py +150 -0
  78. kailash/nodes/enterprise/tenant_assignment.py +108 -0
  79. kailash/nodes/logic/async_operations.py +2 -2
  80. kailash/nodes/logic/convergence.py +1 -1
  81. kailash/nodes/logic/operations.py +1 -1
  82. kailash/nodes/monitoring/__init__.py +11 -1
  83. kailash/nodes/monitoring/health_check.py +456 -0
  84. kailash/nodes/monitoring/log_processor.py +817 -0
  85. kailash/nodes/monitoring/metrics_collector.py +627 -0
  86. kailash/nodes/monitoring/performance_benchmark.py +137 -11
  87. kailash/nodes/rag/advanced.py +7 -7
  88. kailash/nodes/rag/agentic.py +49 -2
  89. kailash/nodes/rag/conversational.py +3 -3
  90. kailash/nodes/rag/evaluation.py +3 -3
  91. kailash/nodes/rag/federated.py +3 -3
  92. kailash/nodes/rag/graph.py +3 -3
  93. kailash/nodes/rag/multimodal.py +3 -3
  94. kailash/nodes/rag/optimized.py +5 -5
  95. kailash/nodes/rag/privacy.py +3 -3
  96. kailash/nodes/rag/query_processing.py +6 -6
  97. kailash/nodes/rag/realtime.py +1 -1
  98. kailash/nodes/rag/registry.py +1 -1
  99. kailash/nodes/rag/router.py +1 -1
  100. kailash/nodes/rag/similarity.py +7 -7
  101. kailash/nodes/rag/strategies.py +4 -4
  102. kailash/nodes/security/abac_evaluator.py +6 -6
  103. kailash/nodes/security/behavior_analysis.py +5 -5
  104. kailash/nodes/security/credential_manager.py +1 -1
  105. kailash/nodes/security/rotating_credentials.py +11 -11
  106. kailash/nodes/security/threat_detection.py +8 -8
  107. kailash/nodes/testing/credential_testing.py +2 -2
  108. kailash/nodes/transform/processors.py +5 -5
  109. kailash/runtime/local.py +163 -9
  110. kailash/runtime/parameter_injection.py +425 -0
  111. kailash/runtime/parameter_injector.py +657 -0
  112. kailash/runtime/testing.py +2 -2
  113. kailash/testing/fixtures.py +2 -2
  114. kailash/workflow/builder.py +99 -14
  115. kailash/workflow/builder_improvements.py +207 -0
  116. kailash/workflow/input_handling.py +170 -0
  117. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/METADATA +22 -9
  118. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/RECORD +122 -95
  119. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/WHEEL +0 -0
  120. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/entry_points.txt +0 -0
  121. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/licenses/LICENSE +0 -0
  122. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,8 @@ Key Features:
27
27
  import asyncio
28
28
  import json
29
29
  import os
30
+ import random
31
+ import re
30
32
  from abc import ABC, abstractmethod
31
33
  from dataclasses import dataclass
32
34
  from datetime import date, datetime
@@ -34,10 +36,37 @@ from decimal import Decimal
34
36
  from enum import Enum
35
37
  from typing import Any, AsyncIterator, Optional, Union
36
38
 
39
+ import yaml
40
+
37
41
  from kailash.nodes.base import NodeParameter, register_node
38
42
  from kailash.nodes.base_async import AsyncNode
39
43
  from kailash.sdk_exceptions import NodeExecutionError, NodeValidationError
40
44
 
45
+ # Import optimistic locking for version control
46
+ try:
47
+ from kailash.nodes.data.optimistic_locking import (
48
+ ConflictResolution,
49
+ LockStatus,
50
+ OptimisticLockingNode,
51
+ )
52
+
53
+ OPTIMISTIC_LOCKING_AVAILABLE = True
54
+ except ImportError:
55
+ OPTIMISTIC_LOCKING_AVAILABLE = False
56
+
57
+ # Define minimal enums if not available
58
+ class ConflictResolution:
59
+ FAIL_FAST = "fail_fast"
60
+ RETRY = "retry"
61
+ MERGE = "merge"
62
+ LAST_WRITER_WINS = "last_writer_wins"
63
+
64
+ class LockStatus:
65
+ SUCCESS = "success"
66
+ VERSION_CONFLICT = "version_conflict"
67
+ RECORD_NOT_FOUND = "record_not_found"
68
+ RETRY_EXHAUSTED = "retry_exhausted"
69
+
41
70
 
42
71
  class DatabaseType(Enum):
43
72
  """Supported database types."""
@@ -47,6 +76,124 @@ class DatabaseType(Enum):
47
76
  SQLITE = "sqlite"
48
77
 
49
78
 
79
+ class QueryValidator:
80
+ """Validates SQL queries for common security issues."""
81
+
82
+ # Dangerous SQL patterns that could indicate injection attempts
83
+ DANGEROUS_PATTERNS = [
84
+ # Multiple statements
85
+ r";\s*(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|GRANT|REVOKE)",
86
+ # Comments that might hide malicious code
87
+ r"--.*$",
88
+ r"/\*.*\*/",
89
+ # Union-based injection
90
+ r"\bUNION\b.*\bSELECT\b",
91
+ # Time-based blind injection
92
+ r"\b(SLEEP|WAITFOR|PG_SLEEP)\b",
93
+ # Out-of-band injection
94
+ r"\b(LOAD_FILE|INTO\s+OUTFILE|INTO\s+DUMPFILE)\b",
95
+ # System command execution
96
+ r"\b(XP_CMDSHELL|EXEC\s+MASTER)",
97
+ ]
98
+
99
+ # Patterns that should only appear in admin queries
100
+ ADMIN_ONLY_PATTERNS = [
101
+ r"\b(CREATE|ALTER|DROP)\s+(?:\w+\s+)*(TABLE|INDEX|VIEW|PROCEDURE|FUNCTION|TRIGGER)",
102
+ r"\b(GRANT|REVOKE)\b",
103
+ r"\bTRUNCATE\b",
104
+ ]
105
+
106
+ @classmethod
107
+ def validate_query(cls, query: str, allow_admin: bool = False) -> None:
108
+ """Validate a SQL query for security issues.
109
+
110
+ Args:
111
+ query: The SQL query to validate
112
+ allow_admin: Whether to allow administrative commands
113
+
114
+ Raises:
115
+ NodeValidationError: If the query contains dangerous patterns
116
+ """
117
+ query_upper = query.upper()
118
+
119
+ # Check for dangerous patterns
120
+ for pattern in cls.DANGEROUS_PATTERNS:
121
+ if re.search(pattern, query, re.IGNORECASE | re.MULTILINE):
122
+ raise NodeValidationError(
123
+ f"Query contains potentially dangerous pattern: {pattern}"
124
+ )
125
+
126
+ # Check for admin-only patterns if not allowed
127
+ if not allow_admin:
128
+ for pattern in cls.ADMIN_ONLY_PATTERNS:
129
+ if re.search(pattern, query, re.IGNORECASE):
130
+ raise NodeValidationError(
131
+ f"Query contains administrative command that is not allowed: {pattern}"
132
+ )
133
+
134
+ @classmethod
135
+ def validate_identifier(cls, identifier: str) -> None:
136
+ """Validate a database identifier (table/column name).
137
+
138
+ Args:
139
+ identifier: The identifier to validate
140
+
141
+ Raises:
142
+ NodeValidationError: If the identifier is invalid
143
+ """
144
+ # Allow alphanumeric, underscore, and dot (for schema.table)
145
+ if not re.match(
146
+ r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?$", identifier
147
+ ):
148
+ raise NodeValidationError(
149
+ f"Invalid identifier: {identifier}. "
150
+ "Identifiers must start with letter/underscore and contain only letters, numbers, underscores."
151
+ )
152
+
153
+ @classmethod
154
+ def sanitize_string_literal(cls, value: str) -> str:
155
+ """Sanitize a string value for SQL by escaping quotes.
156
+
157
+ Args:
158
+ value: The string value to sanitize
159
+
160
+ Returns:
161
+ Escaped string safe for SQL
162
+ """
163
+ # This is a basic implementation - real escaping should be done by the driver
164
+ return value.replace("'", "''").replace("\\", "\\\\")
165
+
166
+ @classmethod
167
+ def validate_connection_string(cls, connection_string: str) -> None:
168
+ """Validate a database connection string.
169
+
170
+ Args:
171
+ connection_string: The connection string to validate
172
+
173
+ Raises:
174
+ NodeValidationError: If the connection string appears malicious
175
+ """
176
+ # Check for suspicious patterns in connection strings
177
+ suspicious_patterns = [
178
+ # SQL injection attempts
179
+ r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)",
180
+ # Command execution attempts
181
+ r';.*\bhost\s*=\s*[\'"]?\|',
182
+ r';.*\bhost\s*=\s*[\'"]?`',
183
+ r"\$\(", # Command substitution
184
+ r"`", # Backticks
185
+ # File access attempts
186
+ r'sslcert\s*=\s*[\'"]?(/etc/passwd|/etc/shadow)',
187
+ r'sslkey\s*=\s*[\'"]?(/etc/passwd|/etc/shadow)',
188
+ ]
189
+
190
+ for pattern in suspicious_patterns:
191
+ if re.search(pattern, connection_string, re.IGNORECASE):
192
+ raise NodeValidationError(
193
+ "Connection string contains suspicious pattern"
194
+ )
195
+
196
+
50
197
  class FetchMode(Enum):
51
198
  """Result fetch modes."""
52
199
 
@@ -56,6 +203,72 @@ class FetchMode(Enum):
56
203
  ITERATOR = "iterator" # Return async iterator
57
204
 
58
205
 
206
+ @dataclass
207
+ class RetryConfig:
208
+ """Configuration for retry logic."""
209
+
210
+ max_retries: int = 3
211
+ initial_delay: float = 1.0
212
+ max_delay: float = 60.0
213
+ exponential_base: float = 2.0
214
+ jitter: bool = True
215
+
216
+ # Retryable error patterns (database-specific)
217
+ retryable_errors: list[str] = None
218
+
219
+ def __post_init__(self):
220
+ """Initialize default retryable errors."""
221
+ if self.retryable_errors is None:
222
+ self.retryable_errors = [
223
+ # PostgreSQL
224
+ "connection_refused",
225
+ "connection_reset",
226
+ "connection reset", # Handle different cases
227
+ "connection_aborted",
228
+ "could not connect",
229
+ "server closed the connection",
230
+ "terminating connection",
231
+ "connectionreseterror",
232
+ "connectionrefusederror",
233
+ "brokenpipeerror",
234
+ # MySQL
235
+ "lost connection to mysql server",
236
+ "mysql server has gone away",
237
+ "can't connect to mysql server",
238
+ # SQLite
239
+ "database is locked",
240
+ "disk i/o error",
241
+ # General
242
+ "timeout",
243
+ "timed out",
244
+ "pool is closed",
245
+ # DNS/Network errors
246
+ "nodename nor servname provided",
247
+ "name or service not known",
248
+ "gaierror",
249
+ "getaddrinfo failed",
250
+ "temporary failure in name resolution",
251
+ ]
252
+
253
+ def should_retry(self, error: Exception) -> bool:
254
+ """Check if an error is retryable."""
255
+ error_str = str(error).lower()
256
+ return any(pattern.lower() in error_str for pattern in self.retryable_errors)
257
+
258
+ def get_delay(self, attempt: int) -> float:
259
+ """Calculate delay for a retry attempt."""
260
+ delay = min(
261
+ self.initial_delay * (self.exponential_base**attempt), self.max_delay
262
+ )
263
+
264
+ if self.jitter:
265
+ # Add random jitter (±25%)
266
+ jitter_amount = delay * 0.25
267
+ delay += random.uniform(-jitter_amount, jitter_amount)
268
+
269
+ return max(0, delay) # Ensure non-negative
270
+
271
+
59
272
  @dataclass
60
273
  class DatabaseConfig:
61
274
  """Database connection configuration."""
@@ -96,19 +309,43 @@ class DatabaseAdapter(ABC):
96
309
  """Convert database-specific types to JSON-serializable types."""
97
310
  converted = {}
98
311
  for key, value in row.items():
99
- if isinstance(value, Decimal):
100
- # Convert Decimal to float for JSON serialization
101
- converted[key] = float(value)
102
- elif isinstance(value, datetime):
103
- # Convert datetime to ISO format string
104
- converted[key] = value.isoformat()
105
- elif isinstance(value, date):
106
- # Convert date to ISO format string
107
- converted[key] = value.isoformat()
108
- else:
109
- converted[key] = value
312
+ converted[key] = self._serialize_value(value)
110
313
  return converted
111
314
 
315
+ def _serialize_value(self, value: Any) -> Any:
316
+ """Convert database-specific types to JSON-serializable types."""
317
+ if value is None:
318
+ return None
319
+ elif isinstance(value, bool):
320
+ # Handle bool before int (bool is subclass of int in Python)
321
+ return value
322
+ elif isinstance(value, (int, float)):
323
+ # Return numeric types as-is
324
+ return value
325
+ elif isinstance(value, str):
326
+ # Return strings as-is
327
+ return value
328
+ elif isinstance(value, bytes):
329
+ import base64
330
+
331
+ result = base64.b64encode(value).decode("utf-8")
332
+ return result
333
+ elif isinstance(value, Decimal):
334
+ return float(value)
335
+ elif isinstance(value, datetime):
336
+ return value.isoformat()
337
+ elif isinstance(value, date):
338
+ return value.isoformat()
339
+ elif hasattr(value, "total_seconds"): # timedelta
340
+ return value.total_seconds()
341
+ elif hasattr(value, "hex"): # UUID
342
+ return str(value)
343
+ elif isinstance(value, (list, tuple)):
344
+ return [self._serialize_value(item) for item in value]
345
+ elif isinstance(value, dict):
346
+ return {k: self._serialize_value(v) for k, v in value.items()}
347
+ return value
348
+
112
349
  @abstractmethod
113
350
  async def connect(self) -> None:
114
351
  """Establish connection pool."""
@@ -126,8 +363,9 @@ class DatabaseAdapter(ABC):
126
363
  params: Optional[Union[tuple, dict]] = None,
127
364
  fetch_mode: FetchMode = FetchMode.ALL,
128
365
  fetch_size: Optional[int] = None,
366
+ transaction: Optional[Any] = None,
129
367
  ) -> Any:
130
- """Execute query and return results."""
368
+ """Execute query and return results, optionally within a transaction."""
131
369
  pass
132
370
 
133
371
  @abstractmethod
@@ -192,24 +430,57 @@ class PostgreSQLAdapter(DatabaseAdapter):
192
430
  params: Optional[Union[tuple, dict]] = None,
193
431
  fetch_mode: FetchMode = FetchMode.ALL,
194
432
  fetch_size: Optional[int] = None,
433
+ transaction: Optional[Any] = None,
195
434
  ) -> Any:
196
435
  """Execute query and return results."""
197
- async with self._pool.acquire() as conn:
198
- # Convert dict params to positional for asyncpg
199
- if isinstance(params, dict):
200
- # Simple parameter substitution for named params
201
- # In production, use a proper SQL parser
202
- query_params = []
203
- for i, (key, value) in enumerate(params.items(), 1):
204
- query = query.replace(f":{key}", f"${i}")
205
- query_params.append(value)
206
- params = query_params
207
-
208
- # Ensure params is a list/tuple for asyncpg
209
- if params is None:
210
- params = []
211
- elif not isinstance(params, (list, tuple)):
212
- params = [params]
436
+ # Convert dict params to positional for asyncpg
437
+ if isinstance(params, dict):
438
+ # Simple parameter substitution for named params
439
+ # In production, use a proper SQL parser
440
+ import json
441
+
442
+ query_params = []
443
+ for i, (key, value) in enumerate(params.items(), 1):
444
+ query = query.replace(f":{key}", f"${i}")
445
+ # For PostgreSQL, lists should remain as lists for array operations
446
+ # Only convert dicts to JSON strings
447
+ if isinstance(value, dict):
448
+ value = json.dumps(value)
449
+ query_params.append(value)
450
+ params = query_params
451
+
452
+ # Ensure params is a list/tuple for asyncpg
453
+ if params is None:
454
+ params = []
455
+ elif not isinstance(params, (list, tuple)):
456
+ params = [params]
457
+
458
+ # Execute query on appropriate connection
459
+ if transaction:
460
+ # Use transaction connection
461
+ conn, tx = transaction
462
+
463
+ # For UPDATE/DELETE queries without RETURNING, use execute() to get affected rows
464
+ query_upper = query.upper()
465
+ if (
466
+ (
467
+ "UPDATE" in query_upper
468
+ or "DELETE" in query_upper
469
+ or "INSERT" in query_upper
470
+ )
471
+ and "RETURNING" not in query_upper
472
+ and fetch_mode == FetchMode.ALL
473
+ ):
474
+ result = await conn.execute(query, *params)
475
+ # asyncpg returns a string like "UPDATE 1", extract the count
476
+ if isinstance(result, str):
477
+ parts = result.split()
478
+ if len(parts) >= 2 and parts[1].isdigit():
479
+ rows_affected = int(parts[1])
480
+ else:
481
+ rows_affected = 0
482
+ return [{"rows_affected": rows_affected}]
483
+ return []
213
484
 
214
485
  if fetch_mode == FetchMode.ONE:
215
486
  row = await conn.fetchrow(query, *params)
@@ -224,26 +495,78 @@ class PostgreSQLAdapter(DatabaseAdapter):
224
495
  return [self._convert_row(dict(row)) for row in rows[:fetch_size]]
225
496
  elif fetch_mode == FetchMode.ITERATOR:
226
497
  raise NotImplementedError("Iterator mode not yet implemented")
498
+ else:
499
+ # Use pool connection
500
+ async with self._pool.acquire() as conn:
501
+ # For UPDATE/DELETE queries without RETURNING, use execute() to get affected rows
502
+ query_upper = query.upper()
503
+ if (
504
+ (
505
+ "UPDATE" in query_upper
506
+ or "DELETE" in query_upper
507
+ or "INSERT" in query_upper
508
+ )
509
+ and "RETURNING" not in query_upper
510
+ and fetch_mode == FetchMode.ALL
511
+ ):
512
+ result = await conn.execute(query, *params)
513
+ # asyncpg returns a string like "UPDATE 1", extract the count
514
+ if isinstance(result, str):
515
+ parts = result.split()
516
+ if len(parts) >= 2 and parts[1].isdigit():
517
+ rows_affected = int(parts[1])
518
+ else:
519
+ rows_affected = 0
520
+ return [{"rows_affected": rows_affected}]
521
+ return []
522
+
523
+ if fetch_mode == FetchMode.ONE:
524
+ row = await conn.fetchrow(query, *params)
525
+ return self._convert_row(dict(row)) if row else None
526
+ elif fetch_mode == FetchMode.ALL:
527
+ rows = await conn.fetch(query, *params)
528
+ return [self._convert_row(dict(row)) for row in rows]
529
+ elif fetch_mode == FetchMode.MANY:
530
+ if not fetch_size:
531
+ raise ValueError("fetch_size required for MANY mode")
532
+ rows = await conn.fetch(query, *params)
533
+ return [self._convert_row(dict(row)) for row in rows[:fetch_size]]
534
+ elif fetch_mode == FetchMode.ITERATOR:
535
+ raise NotImplementedError("Iterator mode not yet implemented")
227
536
 
228
537
  async def execute_many(
229
- self, query: str, params_list: list[Union[tuple, dict]]
538
+ self,
539
+ query: str,
540
+ params_list: list[Union[tuple, dict]],
541
+ transaction: Optional[Any] = None,
230
542
  ) -> None:
231
543
  """Execute query multiple times with different parameters."""
232
- async with self._pool.acquire() as conn:
233
- # Convert all dict params to tuples
234
- converted_params = []
235
- for params in params_list:
236
- if isinstance(params, dict):
237
- query_params = []
238
- for i, (key, value) in enumerate(params.items(), 1):
239
- if i == 1: # Only replace on first iteration
240
- query = query.replace(f":{key}", f"${i}")
241
- query_params.append(value)
242
- converted_params.append(query_params)
243
- else:
244
- converted_params.append(params)
544
+ # Convert all dict params to tuples
245
545
 
246
- await conn.executemany(query, converted_params)
546
+ converted_params = []
547
+ query_converted = query
548
+ for params in params_list:
549
+ if isinstance(params, dict):
550
+ query_params = []
551
+ for i, (key, value) in enumerate(params.items(), 1):
552
+ if converted_params == []: # Only replace on first iteration
553
+ query_converted = query_converted.replace(f":{key}", f"${i}")
554
+ # Serialize complex objects to JSON strings for PostgreSQL
555
+ if isinstance(value, (dict, list)):
556
+ value = json.dumps(value)
557
+ query_params.append(value)
558
+ converted_params.append(query_params)
559
+ else:
560
+ converted_params.append(params)
561
+
562
+ if transaction:
563
+ # Use transaction connection
564
+ conn, tx = transaction
565
+ await conn.executemany(query_converted, converted_params)
566
+ else:
567
+ # Use pool connection
568
+ async with self._pool.acquire() as conn:
569
+ await conn.executemany(query_converted, converted_params)
247
570
 
248
571
  async def begin_transaction(self) -> Any:
249
572
  """Begin a transaction."""
@@ -300,9 +623,12 @@ class MySQLAdapter(DatabaseAdapter):
300
623
  params: Optional[Union[tuple, dict]] = None,
301
624
  fetch_mode: FetchMode = FetchMode.ALL,
302
625
  fetch_size: Optional[int] = None,
626
+ transaction: Optional[Any] = None,
303
627
  ) -> Any:
304
628
  """Execute query and return results."""
305
- async with self._pool.acquire() as conn:
629
+ # Use transaction connection if provided, otherwise get from pool
630
+ if transaction:
631
+ conn = transaction
306
632
  async with conn.cursor() as cursor:
307
633
  await cursor.execute(query, params)
308
634
 
@@ -330,15 +656,56 @@ class MySQLAdapter(DatabaseAdapter):
330
656
  self._convert_row(dict(zip(columns, row))) for row in rows
331
657
  ]
332
658
  return []
659
+ else:
660
+ async with self._pool.acquire() as conn:
661
+ async with conn.cursor() as cursor:
662
+ await cursor.execute(query, params)
663
+
664
+ if fetch_mode == FetchMode.ONE:
665
+ row = await cursor.fetchone()
666
+ if row and cursor.description:
667
+ columns = [desc[0] for desc in cursor.description]
668
+ return self._convert_row(dict(zip(columns, row)))
669
+ return None
670
+ elif fetch_mode == FetchMode.ALL:
671
+ rows = await cursor.fetchall()
672
+ if rows and cursor.description:
673
+ columns = [desc[0] for desc in cursor.description]
674
+ return [
675
+ self._convert_row(dict(zip(columns, row)))
676
+ for row in rows
677
+ ]
678
+ return []
679
+ elif fetch_mode == FetchMode.MANY:
680
+ if not fetch_size:
681
+ raise ValueError("fetch_size required for MANY mode")
682
+ rows = await cursor.fetchmany(fetch_size)
683
+ if rows and cursor.description:
684
+ columns = [desc[0] for desc in cursor.description]
685
+ return [
686
+ self._convert_row(dict(zip(columns, row)))
687
+ for row in rows
688
+ ]
689
+ return []
333
690
 
334
691
  async def execute_many(
335
- self, query: str, params_list: list[Union[tuple, dict]]
692
+ self,
693
+ query: str,
694
+ params_list: list[Union[tuple, dict]],
695
+ transaction: Optional[Any] = None,
336
696
  ) -> None:
337
697
  """Execute query multiple times with different parameters."""
338
- async with self._pool.acquire() as conn:
339
- async with conn.cursor() as cursor:
698
+ if transaction:
699
+ # Use transaction connection
700
+ async with transaction.cursor() as cursor:
340
701
  await cursor.executemany(query, params_list)
341
- await conn.commit()
702
+ # Don't commit here - let transaction handling do it
703
+ else:
704
+ # Use pool connection
705
+ async with self._pool.acquire() as conn:
706
+ async with conn.cursor() as cursor:
707
+ await cursor.executemany(query, params_list)
708
+ await conn.commit()
342
709
 
343
710
  async def begin_transaction(self) -> Any:
344
711
  """Begin a transaction."""
@@ -385,10 +752,12 @@ class SQLiteAdapter(DatabaseAdapter):
385
752
  params: Optional[Union[tuple, dict]] = None,
386
753
  fetch_mode: FetchMode = FetchMode.ALL,
387
754
  fetch_size: Optional[int] = None,
755
+ transaction: Optional[Any] = None,
388
756
  ) -> Any:
389
757
  """Execute query and return results."""
390
- async with self._aiosqlite.connect(self._db_path) as db:
391
- db.row_factory = self._aiosqlite.Row
758
+ if transaction:
759
+ # Use existing transaction connection
760
+ db = transaction
392
761
  cursor = await db.execute(query, params or [])
393
762
 
394
763
  if fetch_mode == FetchMode.ONE:
@@ -402,16 +771,42 @@ class SQLiteAdapter(DatabaseAdapter):
402
771
  raise ValueError("fetch_size required for MANY mode")
403
772
  rows = await cursor.fetchmany(fetch_size)
404
773
  return [self._convert_row(dict(row)) for row in rows]
774
+ else:
775
+ # Create new connection for non-transactional queries
776
+ async with self._aiosqlite.connect(self._db_path) as db:
777
+ db.row_factory = self._aiosqlite.Row
778
+ cursor = await db.execute(query, params or [])
405
779
 
406
- await db.commit()
780
+ if fetch_mode == FetchMode.ONE:
781
+ row = await cursor.fetchone()
782
+ return self._convert_row(dict(row)) if row else None
783
+ elif fetch_mode == FetchMode.ALL:
784
+ rows = await cursor.fetchall()
785
+ return [self._convert_row(dict(row)) for row in rows]
786
+ elif fetch_mode == FetchMode.MANY:
787
+ if not fetch_size:
788
+ raise ValueError("fetch_size required for MANY mode")
789
+ rows = await cursor.fetchmany(fetch_size)
790
+ return [self._convert_row(dict(row)) for row in rows]
791
+
792
+ await db.commit()
407
793
 
408
794
  async def execute_many(
409
- self, query: str, params_list: list[Union[tuple, dict]]
795
+ self,
796
+ query: str,
797
+ params_list: list[Union[tuple, dict]],
798
+ transaction: Optional[Any] = None,
410
799
  ) -> None:
411
800
  """Execute query multiple times with different parameters."""
412
- async with self._aiosqlite.connect(self._db_path) as db:
413
- await db.executemany(query, params_list)
414
- await db.commit()
801
+ if transaction:
802
+ # Use existing transaction connection
803
+ await transaction.executemany(query, params_list)
804
+ # Don't commit here - let transaction handling do it
805
+ else:
806
+ # Create new connection for non-transactional queries
807
+ async with self._aiosqlite.connect(self._db_path) as db:
808
+ await db.executemany(query, params_list)
809
+ await db.commit()
415
810
 
416
811
  async def begin_transaction(self) -> Any:
417
812
  """Begin a transaction."""
@@ -431,6 +826,150 @@ class SQLiteAdapter(DatabaseAdapter):
431
826
  await transaction.close()
432
827
 
433
828
 
829
+ class DatabaseConfigManager:
830
+ """Manager for database configurations from YAML files."""
831
+
832
+ def __init__(self, config_path: Optional[str] = None):
833
+ """Initialize with configuration file path.
834
+
835
+ Args:
836
+ config_path: Path to YAML configuration file. If not provided,
837
+ looks for 'database.yaml' in current directory.
838
+ """
839
+ self.config_path = config_path or "database.yaml"
840
+ self._config: Optional[dict[str, Any]] = None
841
+ self._config_cache: dict[str, tuple[str, dict[str, Any]]] = {}
842
+
843
+ def _load_config(self) -> dict[str, Any]:
844
+ """Load configuration from YAML file."""
845
+ if self._config is not None:
846
+ return self._config
847
+
848
+ if not os.path.exists(self.config_path):
849
+ # No config file, return empty config
850
+ self._config = {}
851
+ return self._config
852
+
853
+ try:
854
+ with open(self.config_path, "r") as f:
855
+ self._config = yaml.safe_load(f) or {}
856
+ return self._config
857
+ except yaml.YAMLError as e:
858
+ raise NodeValidationError(f"Invalid YAML in configuration file: {e}")
859
+ except Exception as e:
860
+ raise NodeExecutionError(f"Failed to load configuration file: {e}")
861
+
862
+ def get_database_config(self, connection_name: str) -> tuple[str, dict[str, Any]]:
863
+ """Get database configuration by connection name.
864
+
865
+ Args:
866
+ connection_name: Name of the database connection from config
867
+
868
+ Returns:
869
+ Tuple of (connection_string, additional_config)
870
+
871
+ Raises:
872
+ NodeExecutionError: If connection not found
873
+ """
874
+ # Check cache first
875
+ if connection_name in self._config_cache:
876
+ return self._config_cache[connection_name]
877
+
878
+ config = self._load_config()
879
+ databases = config.get("databases", {})
880
+
881
+ if connection_name in databases:
882
+ db_config = databases[connection_name].copy()
883
+ connection_string = db_config.pop(
884
+ "connection_string", db_config.pop("url", None)
885
+ )
886
+
887
+ if not connection_string:
888
+ raise NodeExecutionError(
889
+ f"No 'connection_string' or 'url' specified for database '{connection_name}'"
890
+ )
891
+
892
+ # Handle environment variable substitution
893
+ connection_string = self._substitute_env_vars(connection_string)
894
+
895
+ # Process other config values
896
+ for key, value in db_config.items():
897
+ if isinstance(value, str):
898
+ db_config[key] = self._substitute_env_vars(value)
899
+
900
+ # Cache the result
901
+ self._config_cache[connection_name] = (connection_string, db_config)
902
+ return connection_string, db_config
903
+
904
+ # Try default connection
905
+ if "default" in databases:
906
+ return self.get_database_config("default")
907
+
908
+ # No configuration found
909
+ available = list(databases.keys()) if databases else []
910
+ raise NodeExecutionError(
911
+ f"Database connection '{connection_name}' not found in configuration. "
912
+ f"Available connections: {available}"
913
+ )
914
+
915
+ def _substitute_env_vars(self, value: str) -> str:
916
+ """Substitute environment variables in configuration values.
917
+
918
+ Supports:
919
+ - ${VAR_NAME} - Full substitution
920
+ - $VAR_NAME - Simple substitution
921
+ """
922
+ if not isinstance(value, str):
923
+ return value
924
+
925
+ # Handle ${VAR_NAME} format
926
+ if value.startswith("${") and value.endswith("}"):
927
+ env_var = value[2:-1]
928
+ env_value = os.getenv(env_var)
929
+ if env_value is None:
930
+ raise NodeExecutionError(f"Environment variable '{env_var}' not found")
931
+ return env_value
932
+
933
+ # Handle $VAR_NAME and ${VAR_NAME} formats in connection strings
934
+ import re
935
+
936
+ # Pattern to match both $VAR_NAME and ${VAR_NAME}
937
+ pattern = r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}|\$([A-Za-z_][A-Za-z0-9_]*)"
938
+
939
+ def replace_var(match):
940
+ # Group 1 is for ${VAR_NAME}, group 2 is for $VAR_NAME
941
+ var_name = match.group(1) or match.group(2)
942
+ var_value = os.getenv(var_name)
943
+ if var_value is None:
944
+ raise NodeExecutionError(f"Environment variable '{var_name}' not found")
945
+ return var_value
946
+
947
+ return re.sub(pattern, replace_var, value)
948
+
949
+ def list_connections(self) -> list[str]:
950
+ """List all available database connections."""
951
+ config = self._load_config()
952
+ databases = config.get("databases", {})
953
+ return list(databases.keys())
954
+
955
+ def validate_config(self) -> None:
956
+ """Validate the configuration file."""
957
+ config = self._load_config()
958
+ databases = config.get("databases", {})
959
+
960
+ for name, db_config in databases.items():
961
+ if not isinstance(db_config, dict):
962
+ raise NodeValidationError(
963
+ f"Database '{name}' configuration must be a dictionary"
964
+ )
965
+
966
+ # Must have connection string
967
+ if "connection_string" not in db_config and "url" not in db_config:
968
+ raise NodeValidationError(
969
+ f"Database '{name}' must have 'connection_string' or 'url'"
970
+ )
971
+
972
+
434
973
  @register_node()
435
974
  class AsyncSQLDatabaseNode(AsyncNode):
436
975
  """Asynchronous SQL database node for high-concurrency database operations.
@@ -454,30 +993,157 @@ class AsyncSQLDatabaseNode(AsyncNode):
454
993
  pool_size: Initial connection pool size
455
994
  max_pool_size: Maximum connection pool size
456
995
  timeout: Query timeout in seconds
996
+ transaction_mode: Transaction handling mode ('auto', 'manual', 'none')
997
+ share_pool: Whether to share connection pool across instances (default: True)
998
+
999
+ Transaction Modes:
1000
+ - 'auto' (default): Each query runs in its own transaction, automatically
1001
+ committed on success or rolled back on error
1002
+ - 'manual': Transactions must be explicitly managed using begin_transaction(),
1003
+ commit(), and rollback() methods
1004
+ - 'none': No transaction wrapping, queries execute immediately
457
1005
 
458
- Example:
1006
+ Example (auto transaction):
459
1007
  >>> node = AsyncSQLDatabaseNode(
460
- ... name="fetch_users",
1008
+ ... name="update_users",
1009
+ ... database_type="postgresql",
1010
+ ... host="localhost",
1011
+ ... database="myapp",
1012
+ ... user="dbuser",
1013
+ ... password="dbpass"
1014
+ ... )
1015
+ >>> # This will automatically rollback on error
1016
+ >>> await node.async_run(query="INSERT INTO users VALUES (1, 'test')")
1017
+ >>> await node.async_run(query="INVALID SQL") # Previous insert rolled back
1018
+
1019
+ Example (manual transaction):
1020
+ >>> node = AsyncSQLDatabaseNode(
1021
+ ... name="transfer_funds",
461
1022
  ... database_type="postgresql",
462
1023
  ... host="localhost",
463
1024
  ... database="myapp",
464
1025
  ... user="dbuser",
465
1026
  ... password="dbpass",
466
- ... query="SELECT * FROM users WHERE active = :active",
467
- ... params={"active": True},
468
- ... fetch_mode="all"
1027
+ ... transaction_mode="manual"
469
1028
  ... )
470
- >>> result = await node.async_run()
471
- >>> users = result["data"]
1029
+ >>> await node.begin_transaction()
1030
+ >>> try:
1031
+ ... await node.async_run(query="UPDATE accounts SET balance = balance - 100 WHERE id = 1")
1032
+ ... await node.async_run(query="UPDATE accounts SET balance = balance + 100 WHERE id = 2")
1033
+ ... await node.commit()
1034
+ >>> except Exception:
1035
+ ... await node.rollback()
1036
+ ... raise
472
1037
  """
473
1038
 
1039
+ # Class-level pool storage for sharing across instances
1040
+ _shared_pools: dict[str, tuple[DatabaseAdapter, int]] = {}
1041
+ _pool_lock: Optional[asyncio.Lock] = None
1042
+
1043
+ @classmethod
1044
+ def _get_pool_lock(cls) -> asyncio.Lock:
1045
+ """Get or create pool lock for the current event loop."""
1046
+ # Check if we have a lock and if it's for the current loop
1047
+ try:
1048
+ loop = asyncio.get_running_loop()
1049
+ except RuntimeError:
1050
+ # No running loop, create a new lock
1051
+ cls._pool_lock = asyncio.Lock()
1052
+ return cls._pool_lock
1053
+
1054
+ # Check if existing lock is for current loop
1055
+ if cls._pool_lock is None:
1056
+ cls._pool_lock = asyncio.Lock()
1057
+ cls._pool_lock_loop_id = id(loop)
1058
+ else:
1059
+ # Verify the lock is for the current event loop
1060
+ # Just create a new lock if we're in a different loop
1061
+ # The simplest approach is to store the loop ID with the lock
1062
+ if not hasattr(cls, "_pool_lock_loop_id"):
1063
+ cls._pool_lock_loop_id = id(loop)
1064
+ elif cls._pool_lock_loop_id != id(loop):
1065
+ # Different event loop, clear everything
1066
+ cls._pool_lock = asyncio.Lock()
1067
+ cls._pool_lock_loop_id = id(loop)
1068
+ cls._shared_pools.clear()
1069
+
1070
+ return cls._pool_lock
1071
+
474
1072
  def __init__(self, **config):
475
1073
  self._adapter: Optional[DatabaseAdapter] = None
476
1074
  self._connected = False
477
1075
  # Extract access control manager before passing to parent
478
1076
  self.access_control_manager = config.pop("access_control_manager", None)
1077
+
1078
+ # Transaction state management
1079
+ self._active_transaction = None
1080
+ self._transaction_connection = None
1081
+ self._transaction_mode = config.get("transaction_mode", "auto")
1082
+
1083
+ # Pool sharing configuration
1084
+ self._share_pool = config.get("share_pool", True)
1085
+ self._pool_key = None
1086
+
1087
+ # Security configuration
1088
+ self._validate_queries = config.get("validate_queries", True)
1089
+ self._allow_admin = config.get("allow_admin", False)
1090
+
1091
+ # Retry configuration
1092
+ retry_config = config.get("retry_config")
1093
+ if retry_config:
1094
+ if isinstance(retry_config, dict):
1095
+ self._retry_config = RetryConfig(**retry_config)
1096
+ else:
1097
+ self._retry_config = retry_config
1098
+ else:
1099
+ # Build from individual parameters
1100
+ self._retry_config = RetryConfig(
1101
+ max_retries=config.get("max_retries", 3),
1102
+ initial_delay=config.get("retry_delay", 1.0),
1103
+ )
1104
+
1105
+ # Optimistic locking configuration
1106
+ self._enable_optimistic_locking = config.get("enable_optimistic_locking", False)
1107
+ self._version_field = config.get("version_field", "version")
1108
+ self._conflict_resolution = config.get("conflict_resolution", "fail_fast")
1109
+ self._version_retry_attempts = config.get("version_retry_attempts", 3)
1110
+
479
1111
  super().__init__(**config)
480
1112
 
1113
+ def _reinitialize_from_config(self):
1114
+ """Re-initialize instance variables from config after config file loading."""
1115
+ # Update transaction mode
1116
+ self._transaction_mode = self.config.get("transaction_mode", "auto")
1117
+
1118
+ # Update pool sharing configuration
1119
+ self._share_pool = self.config.get("share_pool", True)
1120
+
1121
+ # Update security configuration
1122
+ self._validate_queries = self.config.get("validate_queries", True)
1123
+ self._allow_admin = self.config.get("allow_admin", False)
1124
+
1125
+ # Update retry configuration
1126
+ retry_config = self.config.get("retry_config")
1127
+ if retry_config:
1128
+ if isinstance(retry_config, dict):
1129
+ self._retry_config = RetryConfig(**retry_config)
1130
+ else:
1131
+ self._retry_config = retry_config
1132
+ else:
1133
+ # Build from individual parameters
1134
+ self._retry_config = RetryConfig(
1135
+ max_retries=self.config.get("max_retries", 3),
1136
+ initial_delay=self.config.get("retry_delay", 1.0),
1137
+ )
1138
+
1139
+ # Update optimistic locking configuration
1140
+ self._enable_optimistic_locking = self.config.get(
1141
+ "enable_optimistic_locking", False
1142
+ )
1143
+ self._version_field = self.config.get("version_field", "version")
1144
+ self._conflict_resolution = self.config.get("conflict_resolution", "fail_fast")
1145
+ self._version_retry_attempts = self.config.get("version_retry_attempts", 3)
1146
+
481
1147
  def get_parameters(self) -> dict[str, NodeParameter]:
482
1148
  """Define the parameters this node accepts."""
483
1149
  params = [
@@ -494,6 +1160,18 @@ class AsyncSQLDatabaseNode(AsyncNode):
494
1160
  required=False,
495
1161
  description="Full database connection string (overrides individual params)",
496
1162
  ),
1163
+ NodeParameter(
1164
+ name="connection_name",
1165
+ type=str,
1166
+ required=False,
1167
+ description="Name of database connection from config file",
1168
+ ),
1169
+ NodeParameter(
1170
+ name="config_file",
1171
+ type=str,
1172
+ required=False,
1173
+ description="Path to YAML configuration file (default: database.yaml)",
1174
+ ),
497
1175
  NodeParameter(
498
1176
  name="host", type=str, required=False, description="Database host"
499
1177
  ),
@@ -564,6 +1242,89 @@ class AsyncSQLDatabaseNode(AsyncNode):
564
1242
  required=False,
565
1243
  description="User context for access control",
566
1244
  ),
1245
+ NodeParameter(
1246
+ name="transaction_mode",
1247
+ type=str,
1248
+ required=False,
1249
+ default="auto",
1250
+ description="Transaction mode: 'auto' (default), 'manual', or 'none'",
1251
+ ),
1252
+ NodeParameter(
1253
+ name="share_pool",
1254
+ type=bool,
1255
+ required=False,
1256
+ default=True,
1257
+ description="Whether to share connection pool across instances with same config",
1258
+ ),
1259
+ NodeParameter(
1260
+ name="validate_queries",
1261
+ type=bool,
1262
+ required=False,
1263
+ default=True,
1264
+ description="Whether to validate queries for SQL injection attempts",
1265
+ ),
1266
+ NodeParameter(
1267
+ name="allow_admin",
1268
+ type=bool,
1269
+ required=False,
1270
+ default=False,
1271
+ description="Whether to allow administrative SQL commands (CREATE, DROP, etc.)",
1272
+ ),
1273
+ NodeParameter(
1274
+ name="retry_config",
1275
+ type=Any,
1276
+ required=False,
1277
+ description="Retry configuration dict or RetryConfig object",
1278
+ ),
1279
+ NodeParameter(
1280
+ name="max_retries",
1281
+ type=int,
1282
+ required=False,
1283
+ default=3,
1284
+ description="Maximum number of retry attempts for transient failures",
1285
+ ),
1286
+ NodeParameter(
1287
+ name="retry_delay",
1288
+ type=float,
1289
+ required=False,
1290
+ default=1.0,
1291
+ description="Initial retry delay in seconds",
1292
+ ),
1293
+ NodeParameter(
1294
+ name="enable_optimistic_locking",
1295
+ type=bool,
1296
+ required=False,
1297
+ default=False,
1298
+ description="Enable optimistic locking for version control",
1299
+ ),
1300
+ NodeParameter(
1301
+ name="version_field",
1302
+ type=str,
1303
+ required=False,
1304
+ default="version",
1305
+ description="Column name for version tracking",
1306
+ ),
1307
+ NodeParameter(
1308
+ name="conflict_resolution",
1309
+ type=str,
1310
+ required=False,
1311
+ default="fail_fast",
1312
+ description="How to handle version conflicts: fail_fast, retry, last_writer_wins",
1313
+ ),
1314
+ NodeParameter(
1315
+ name="version_retry_attempts",
1316
+ type=int,
1317
+ required=False,
1318
+ default=3,
1319
+ description="Maximum retries for version conflicts",
1320
+ ),
1321
+ NodeParameter(
1322
+ name="result_format",
1323
+ type=str,
1324
+ required=False,
1325
+ default="dict",
1326
+ description="Result format: 'dict' (default), 'list', or 'dataframe'",
1327
+ ),
567
1328
  ]
568
1329
 
569
1330
  # Convert list to dict as required by base class
@@ -573,6 +1334,39 @@ class AsyncSQLDatabaseNode(AsyncNode):
573
1334
  """Validate node configuration."""
574
1335
  super()._validate_config()
575
1336
 
1337
+ # Handle config file loading
1338
+ connection_name = self.config.get("connection_name")
1339
+ config_file = self.config.get("config_file")
1340
+
1341
+ if connection_name:
1342
+ # Load from config file
1343
+ config_manager = DatabaseConfigManager(config_file)
1344
+ try:
1345
+ conn_string, db_config = config_manager.get_database_config(
1346
+ connection_name
1347
+ )
1348
+ # Update config with values from file
1349
+ self.config["connection_string"] = conn_string
1350
+ # Merge additional config
1351
+ # Config file values should override defaults but not explicit params
1352
+ for key, value in db_config.items():
1353
+ # Check if this was explicitly provided by user
1354
+ param_info = self.get_parameters().get(key)
1355
+ if param_info and key in self.config:
1356
+ # If it equals the default, it wasn't explicitly set
1357
+ if self.config[key] == param_info.default:
1358
+ self.config[key] = value
1359
+ else:
1360
+ # Not a parameter or not in config yet
1361
+ self.config[key] = value
1362
+ except Exception as e:
1363
+ raise NodeValidationError(
1364
+ f"Failed to load config '{connection_name}': {e}"
1365
+ )
1366
+
1367
+ # Re-initialize instance variables with updated config
1368
+ self._reinitialize_from_config()
1369
+
576
1370
  # Validate database type
577
1371
  db_type = self.config.get("database_type", "").lower()
578
1372
  if db_type not in ["postgresql", "mysql", "sqlite"]:
@@ -582,7 +1376,18 @@ class AsyncSQLDatabaseNode(AsyncNode):
582
1376
  )
583
1377
 
584
1378
  # Validate connection parameters
585
- if not self.config.get("connection_string"):
1379
+ connection_string = self.config.get("connection_string")
1380
+ if connection_string:
1381
+ # Validate connection string for security
1382
+ if self._validate_queries:
1383
+ try:
1384
+ QueryValidator.validate_connection_string(connection_string)
1385
+ except NodeValidationError:
1386
+ raise NodeValidationError(
1387
+ "Connection string failed security validation. "
1388
+ "Set validate_queries=False to bypass (not recommended)."
1389
+ )
1390
+ else:
586
1391
  if db_type != "sqlite":
587
1392
  if not self.config.get("host") or not self.config.get("database"):
588
1393
  raise NodeValidationError(
@@ -603,38 +1408,116 @@ class AsyncSQLDatabaseNode(AsyncNode):
603
1408
  if fetch_mode == "many" and not self.config.get("fetch_size"):
604
1409
  raise NodeValidationError("fetch_size required when fetch_mode is 'many'")
605
1410
 
1411
+ # Validate initial query if provided
1412
+ if self.config.get("query") and self._validate_queries:
1413
+ try:
1414
+ QueryValidator.validate_query(
1415
+ self.config["query"], allow_admin=self._allow_admin
1416
+ )
1417
+ except NodeValidationError as e:
1418
+ raise NodeValidationError(
1419
+ f"Initial query validation failed: {e}. "
1420
+ "Set validate_queries=False to bypass (not recommended)."
1421
+ )
1422
+
1423
+ def _generate_pool_key(self) -> str:
1424
+ """Generate a unique key for connection pool sharing."""
1425
+ # Create a unique key based on connection parameters
1426
+ key_parts = [
1427
+ self.config.get("database_type", ""),
1428
+ self.config.get("connection_string", "")
1429
+ or (
1430
+ f"{self.config.get('host', '')}:"
1431
+ f"{self.config.get('port', '')}:"
1432
+ f"{self.config.get('database', '')}:"
1433
+ f"{self.config.get('user', '')}"
1434
+ ),
1435
+ str(self.config.get("pool_size", 10)),
1436
+ str(self.config.get("max_pool_size", 20)),
1437
+ ]
1438
+ return "|".join(key_parts)
1439
+
606
1440
  async def _get_adapter(self) -> DatabaseAdapter:
607
- """Get or create database adapter."""
1441
+ """Get or create database adapter with optional pool sharing."""
608
1442
  if not self._adapter:
609
- db_type = DatabaseType(self.config["database_type"].lower())
610
- db_config = DatabaseConfig(
611
- type=db_type,
612
- host=self.config.get("host"),
613
- port=self.config.get("port"),
614
- database=self.config.get("database"),
615
- user=self.config.get("user"),
616
- password=self.config.get("password"),
617
- connection_string=self.config.get("connection_string"),
618
- pool_size=self.config.get("pool_size", 10),
619
- max_pool_size=self.config.get("max_pool_size", 20),
620
- command_timeout=self.config.get("timeout", 60.0),
621
- )
622
-
623
- if db_type == DatabaseType.POSTGRESQL:
624
- self._adapter = PostgreSQLAdapter(db_config)
625
- elif db_type == DatabaseType.MYSQL:
626
- self._adapter = MySQLAdapter(db_config)
627
- elif db_type == DatabaseType.SQLITE:
628
- self._adapter = SQLiteAdapter(db_config)
1443
+ if self._share_pool:
1444
+ # Use shared pool if available
1445
+ async with self._get_pool_lock():
1446
+ self._pool_key = self._generate_pool_key()
1447
+
1448
+ if self._pool_key in self._shared_pools:
1449
+ # Reuse existing pool
1450
+ adapter, ref_count = self._shared_pools[self._pool_key]
1451
+ self._shared_pools[self._pool_key] = (adapter, ref_count + 1)
1452
+ self._adapter = adapter
1453
+ self._connected = True
1454
+ return self._adapter
1455
+
1456
+ # Create new shared pool
1457
+ self._adapter = await self._create_adapter()
1458
+ self._shared_pools[self._pool_key] = (self._adapter, 1)
629
1459
  else:
630
- raise NodeExecutionError(f"Unsupported database type: {db_type}")
631
-
632
- if not self._connected:
633
- await self._adapter.connect()
634
- self._connected = True
1460
+ # Create dedicated pool
1461
+ self._adapter = await self._create_adapter()
635
1462
 
636
1463
  return self._adapter
637
1464
 
1465
+ async def _create_adapter(self) -> DatabaseAdapter:
1466
+ """Create a new database adapter with retry logic for initial connection."""
1467
+ db_type = DatabaseType(self.config["database_type"].lower())
1468
+ db_config = DatabaseConfig(
1469
+ type=db_type,
1470
+ host=self.config.get("host"),
1471
+ port=self.config.get("port"),
1472
+ database=self.config.get("database"),
1473
+ user=self.config.get("user"),
1474
+ password=self.config.get("password"),
1475
+ connection_string=self.config.get("connection_string"),
1476
+ pool_size=self.config.get("pool_size", 10),
1477
+ max_pool_size=self.config.get("max_pool_size", 20),
1478
+ command_timeout=self.config.get("timeout", 60.0),
1479
+ )
1480
+
1481
+ if db_type == DatabaseType.POSTGRESQL:
1482
+ adapter = PostgreSQLAdapter(db_config)
1483
+ elif db_type == DatabaseType.MYSQL:
1484
+ adapter = MySQLAdapter(db_config)
1485
+ elif db_type == DatabaseType.SQLITE:
1486
+ adapter = SQLiteAdapter(db_config)
1487
+ else:
1488
+ raise NodeExecutionError(f"Unsupported database type: {db_type}")
1489
+
1490
+ # Retry connection with exponential backoff
1491
+ last_error = None
1492
+ for attempt in range(self._retry_config.max_retries):
1493
+ try:
1494
+ await adapter.connect()
1495
+ self._connected = True
1496
+ return adapter
1497
+ except Exception as e:
1498
+ last_error = e
1499
+
1500
+ # Check if error is retryable
1501
+ if not self._retry_config.should_retry(e):
1502
+ raise
1503
+
1504
+ # Check if we have more attempts
1505
+ if attempt >= self._retry_config.max_retries - 1:
1506
+ raise NodeExecutionError(
1507
+ f"Failed to connect after {self._retry_config.max_retries} attempts: {e}"
1508
+ )
1509
+
1510
+ # Calculate delay
1511
+ delay = self._retry_config.get_delay(attempt)
1512
+
1513
+ # Wait before retry
1514
+ await asyncio.sleep(delay)
1515
+
1516
+ # Should not reach here, but just in case
1517
+ raise NodeExecutionError(
1518
+ f"Failed to connect after {self._retry_config.max_retries} attempts: {last_error}"
1519
+ )
1520
+
638
1521
  async def async_run(self, **inputs) -> dict[str, Any]:
639
1522
  """Execute database query asynchronously with optional access control."""
640
1523
  try:
@@ -645,11 +1528,33 @@ class AsyncSQLDatabaseNode(AsyncNode):
645
1528
  inputs.get("fetch_mode", self.config.get("fetch_mode", "all")).lower()
646
1529
  )
647
1530
  fetch_size = inputs.get("fetch_size", self.config.get("fetch_size"))
1531
+ result_format = inputs.get(
1532
+ "result_format", self.config.get("result_format", "dict")
1533
+ )
648
1534
  user_context = inputs.get("user_context")
649
1535
 
650
1536
  if not query:
651
1537
  raise NodeExecutionError("No query provided")
652
1538
 
1539
+ # Handle parameter style conversion
1540
+ if params is not None:
1541
+ if isinstance(params, (list, tuple)):
1542
+ # Convert positional parameters to named parameters
1543
+ query, params = self._convert_to_named_parameters(query, params)
1544
+ elif not isinstance(params, dict):
1545
+ # Single parameter - wrap in list and convert
1546
+ query, params = self._convert_to_named_parameters(query, [params])
1547
+
1548
+ # Validate query for security
1549
+ if self._validate_queries:
1550
+ try:
1551
+ QueryValidator.validate_query(query, allow_admin=self._allow_admin)
1552
+ except NodeValidationError as e:
1553
+ raise NodeExecutionError(
1554
+ f"Query validation failed: {e}. "
1555
+ "Set validate_queries=False to bypass (not recommended)."
1556
+ )
1557
+
653
1558
  # Check access control if enabled
654
1559
  if self.access_control_manager and user_context:
655
1560
  from kailash.access_control import NodePermission
@@ -660,28 +1565,269 @@ class AsyncSQLDatabaseNode(AsyncNode):
660
1565
  if not decision.allowed:
661
1566
  raise NodeExecutionError(f"Access denied: {decision.reason}")
662
1567
 
663
- # Get adapter and execute query
1568
+ # Get adapter
664
1569
  adapter = await self._get_adapter()
665
1570
 
666
1571
  # Execute query with retry logic
667
- max_retries = 3
668
- retry_delay = 1.0
1572
+ result = await self._execute_with_retry(
1573
+ adapter=adapter,
1574
+ query=query,
1575
+ params=params,
1576
+ fetch_mode=fetch_mode,
1577
+ fetch_size=fetch_size,
1578
+ user_context=user_context,
1579
+ )
1580
+
1581
+ # Format results based on requested format
1582
+ formatted_data = self._format_results(result, result_format)
669
1583
 
670
- for attempt in range(max_retries):
1584
+ # For DataFrame, we need special handling for row count
1585
+ row_count = 0
1586
+ if result_format == "dataframe":
671
1587
  try:
672
- result = await adapter.execute(
673
- query=query,
674
- params=params,
675
- fetch_mode=fetch_mode,
676
- fetch_size=fetch_size,
1588
+ row_count = len(formatted_data)
1589
+ except:
1590
+ # If pandas isn't available, formatted_data is still a list
1591
+ row_count = (
1592
+ len(result)
1593
+ if isinstance(result, list)
1594
+ else (1 if result else 0)
677
1595
  )
1596
+ else:
1597
+ row_count = (
1598
+ len(result) if isinstance(result, list) else (1 if result else 0)
1599
+ )
678
1600
 
679
- # Apply data masking if access control is enabled
680
- if (
681
- self.access_control_manager
682
- and user_context
683
- and isinstance(result, list)
684
- ):
1601
+ # Extract column names if available
1602
+ columns = []
1603
+ if result and isinstance(result, list) and result:
1604
+ if isinstance(result[0], dict):
1605
+ columns = list(result[0].keys())
1606
+
1607
+ # Handle DataFrame serialization for JSON compatibility
1608
+ if result_format == "dataframe":
1609
+ try:
1610
+ import pandas as pd
1611
+
1612
+ if isinstance(formatted_data, pd.DataFrame):
1613
+ # Convert DataFrame to JSON-compatible format
1614
+ serializable_data = {
1615
+ "dataframe": formatted_data.to_dict("records"),
1616
+ "columns": formatted_data.columns.tolist(),
1617
+ "index": formatted_data.index.tolist(),
1618
+ "_type": "dataframe",
1619
+ }
1620
+ else:
1621
+ # pandas not available, use regular data
1622
+ serializable_data = formatted_data
1623
+ except ImportError:
1624
+ serializable_data = formatted_data
1625
+ else:
1626
+ serializable_data = formatted_data
1627
+
1628
+ result_dict = {
1629
+ "result": {
1630
+ "data": serializable_data,
1631
+ "row_count": row_count,
1632
+ "query": query,
1633
+ "database_type": self.config["database_type"],
1634
+ "format": result_format,
1635
+ }
1636
+ }
1637
+
1638
+ # Add columns info for list format
1639
+ if result_format == "list" and columns:
1640
+ result_dict["result"]["columns"] = columns
1641
+
1642
+ return result_dict
1643
+
1644
+ except NodeExecutionError:
1645
+ # Re-raise our own errors
1646
+ raise
1647
+ except Exception as e:
1648
+ # Wrap other errors
1649
+ raise NodeExecutionError(f"Database query failed: {str(e)}")
1650
+
1651
+ async def process(self, inputs: dict[str, Any]) -> dict[str, Any]:
1652
+ """Async process method for middleware compatibility."""
1653
+ return await self.async_run(**inputs)
1654
+
1655
+ async def execute_many_async(
1656
+ self, query: str, params_list: list[dict[str, Any]]
1657
+ ) -> dict[str, Any]:
1658
+ """Execute the same query multiple times with different parameters.
1659
+
1660
+ This is useful for bulk inserts, updates, or deletes. The operation
1661
+ runs in a single transaction (in auto or manual mode) for better
1662
+ performance and atomicity.
1663
+
1664
+ Args:
1665
+ query: SQL query to execute multiple times
1666
+ params_list: List of parameter dictionaries
1667
+
1668
+ Returns:
1669
+ dict: Result with affected row count
1670
+
1671
+ Example:
1672
+ >>> params_list = [
1673
+ ... {"name": "Alice", "age": 30},
1674
+ ... {"name": "Bob", "age": 25},
1675
+ ... {"name": "Charlie", "age": 35},
1676
+ ... ]
1677
+ >>> result = await node.execute_many_async(
1678
+ ... query="INSERT INTO users (name, age) VALUES (:name, :age)",
1679
+ ... params_list=params_list
1680
+ ... )
1681
+ >>> print(result["result"]["affected_rows"]) # 3
1682
+ """
1683
+ if not params_list:
1684
+ return {
1685
+ "result": {
1686
+ "affected_rows": 0,
1687
+ "query": query,
1688
+ "database_type": self.config["database_type"],
1689
+ }
1690
+ }
1691
+
1692
+ # Validate query if security is enabled
1693
+ if self._validate_queries:
1694
+ try:
1695
+ QueryValidator.validate_query(query, allow_admin=self._allow_admin)
1696
+ except NodeValidationError as e:
1697
+ raise NodeExecutionError(
1698
+ f"Query validation failed: {e}. "
1699
+ "Set validate_queries=False to bypass (not recommended)."
1700
+ )
1701
+
1702
+ try:
1703
+ # Get adapter
1704
+ adapter = await self._get_adapter()
1705
+
1706
+ # Execute batch with retry logic
1707
+ affected_rows = await self._execute_many_with_retry(
1708
+ adapter=adapter,
1709
+ query=query,
1710
+ params_list=params_list,
1711
+ )
1712
+
1713
+ return {
1714
+ "result": {
1715
+ "affected_rows": affected_rows,
1716
+ "batch_size": len(params_list),
1717
+ "query": query,
1718
+ "database_type": self.config["database_type"],
1719
+ }
1720
+ }
1721
+
1722
+ except NodeExecutionError:
1723
+ raise
1724
+ except Exception as e:
1725
+ raise NodeExecutionError(f"Batch operation failed: {str(e)}")
1726
+
1727
+ async def begin_transaction(self):
1728
+ """Begin a manual transaction.
1729
+
1730
+ Returns:
1731
+ Transaction context that can be used for manual control
1732
+
1733
+ Raises:
1734
+ NodeExecutionError: If transaction already active or mode is 'auto'
1735
+ """
1736
+ if self._transaction_mode != "manual":
1737
+ raise NodeExecutionError(
1738
+ "begin_transaction() can only be called in 'manual' transaction mode"
1739
+ )
1740
+
1741
+ if self._active_transaction:
1742
+ raise NodeExecutionError("Transaction already active")
1743
+
1744
+ adapter = await self._get_adapter()
1745
+ self._active_transaction = await adapter.begin_transaction()
1746
+ return self._active_transaction
1747
+
1748
+ async def commit(self):
1749
+ """Commit the active transaction.
1750
+
1751
+ Raises:
1752
+ NodeExecutionError: If no active transaction or mode is not 'manual'
1753
+ """
1754
+ if self._transaction_mode != "manual":
1755
+ raise NodeExecutionError(
1756
+ "commit() can only be called in 'manual' transaction mode"
1757
+ )
1758
+
1759
+ if not self._active_transaction:
1760
+ raise NodeExecutionError("No active transaction to commit")
1761
+
1762
+ adapter = await self._get_adapter()
1763
+ try:
1764
+ await adapter.commit_transaction(self._active_transaction)
1765
+ finally:
1766
+ # Always clear transaction, even on error
1767
+ self._active_transaction = None
1768
+
1769
+ async def rollback(self):
1770
+ """Rollback the active transaction.
1771
+
1772
+ Raises:
1773
+ NodeExecutionError: If no active transaction or mode is not 'manual'
1774
+ """
1775
+ if self._transaction_mode != "manual":
1776
+ raise NodeExecutionError(
1777
+ "rollback() can only be called in 'manual' transaction mode"
1778
+ )
1779
+
1780
+ if not self._active_transaction:
1781
+ raise NodeExecutionError("No active transaction to rollback")
1782
+
1783
+ adapter = await self._get_adapter()
1784
+ try:
1785
+ await adapter.rollback_transaction(self._active_transaction)
1786
+ finally:
1787
+ # Always clear transaction, even on error
1788
+ self._active_transaction = None
1789
+
1790
+ async def _execute_with_retry(
1791
+ self,
1792
+ adapter: DatabaseAdapter,
1793
+ query: str,
1794
+ params: Any,
1795
+ fetch_mode: FetchMode,
1796
+ fetch_size: Optional[int],
1797
+ user_context: Any = None,
1798
+ ) -> Any:
1799
+ """Execute query with retry logic for transient failures.
1800
+
1801
+ Args:
1802
+ adapter: Database adapter
1803
+ query: SQL query
1804
+ params: Query parameters
1805
+ fetch_mode: How to fetch results
1806
+ fetch_size: Number of rows for 'many' mode
1807
+ user_context: User context for access control
1808
+
1809
+ Returns:
1810
+ Query results
1811
+
1812
+ Raises:
1813
+ NodeExecutionError: After all retry attempts are exhausted
1814
+ """
1815
+ last_error = None
1816
+
1817
+ for attempt in range(self._retry_config.max_retries):
1818
+ try:
1819
+ # Execute query with transaction
1820
+ result = await self._execute_with_transaction(
1821
+ adapter=adapter,
1822
+ query=query,
1823
+ params=params,
1824
+ fetch_mode=fetch_mode,
1825
+ fetch_size=fetch_size,
1826
+ )
1827
+
1828
+ # Apply data masking if access control is enabled
1829
+ if self.access_control_manager and user_context:
1830
+ if isinstance(result, list):
685
1831
  masked_result = []
686
1832
  for row in result:
687
1833
  masked_row = self.access_control_manager.apply_data_masking(
@@ -689,45 +1835,726 @@ class AsyncSQLDatabaseNode(AsyncNode):
689
1835
  )
690
1836
  masked_result.append(masked_row)
691
1837
  result = masked_result
692
- elif (
693
- self.access_control_manager
694
- and user_context
695
- and isinstance(result, dict)
696
- ):
1838
+ elif isinstance(result, dict):
697
1839
  result = self.access_control_manager.apply_data_masking(
698
1840
  user_context, self.metadata.name, result
699
1841
  )
700
1842
 
701
- return {
702
- "result": {
703
- "data": result,
704
- "row_count": (
705
- len(result)
706
- if isinstance(result, list)
707
- else (1 if result else 0)
708
- ),
709
- "query": query,
710
- "database_type": self.config["database_type"],
1843
+ return result
1844
+
1845
+ except Exception as e:
1846
+ last_error = e
1847
+
1848
+ # Check if error is retryable
1849
+ if not self._retry_config.should_retry(e):
1850
+ raise
1851
+
1852
+ # Check if we have more attempts
1853
+ if attempt >= self._retry_config.max_retries - 1:
1854
+ raise
1855
+
1856
+ # Calculate delay
1857
+ delay = self._retry_config.get_delay(attempt)
1858
+
1859
+ # Log retry attempt (if logging is available)
1860
+ try:
1861
+ self.logger.warning(
1862
+ f"Query failed (attempt {attempt + 1}/{self._retry_config.max_retries}): {e}. "
1863
+ f"Retrying in {delay:.2f} seconds..."
1864
+ )
1865
+ except AttributeError:
1866
+ # No logger available
1867
+ pass
1868
+
1869
+ # Wait before retry
1870
+ await asyncio.sleep(delay)
1871
+
1872
+ # For connection errors, try to reconnect
1873
+ if "pool is closed" in str(e).lower() or "connection" in str(e).lower():
1874
+ try:
1875
+ # Clear existing adapter to force reconnection
1876
+ if self._share_pool and self._pool_key:
1877
+ # Remove from shared pools to force recreation
1878
+ async with self._get_pool_lock():
1879
+ if self._pool_key in self._shared_pools:
1880
+ _, ref_count = self._shared_pools[self._pool_key]
1881
+ if ref_count <= 1:
1882
+ del self._shared_pools[self._pool_key]
1883
+ else:
1884
+ # This shouldn't happen with a closed pool
1885
+ del self._shared_pools[self._pool_key]
1886
+
1887
+ self._adapter = None
1888
+ self._connected = False
1889
+ adapter = await self._get_adapter()
1890
+ except Exception:
1891
+ # If reconnection fails, continue with retry loop
1892
+ pass
1893
+
1894
+ # All retries exhausted
1895
+ raise NodeExecutionError(
1896
+ f"Query failed after {self._retry_config.max_retries} attempts: {last_error}"
1897
+ )
1898
+
1899
+ async def _execute_many_with_retry(
1900
+ self, adapter: DatabaseAdapter, query: str, params_list: list[dict[str, Any]]
1901
+ ) -> int:
1902
+ """Execute batch operation with retry logic.
1903
+
1904
+ Args:
1905
+ adapter: Database adapter
1906
+ query: SQL query to execute
1907
+ params_list: List of parameter dictionaries
1908
+
1909
+ Returns:
1910
+ Number of affected rows
1911
+
1912
+ Raises:
1913
+ NodeExecutionError: After all retry attempts are exhausted
1914
+ """
1915
+ last_error = None
1916
+
1917
+ for attempt in range(self._retry_config.max_retries):
1918
+ try:
1919
+ # Execute batch with transaction
1920
+ return await self._execute_many_with_transaction(
1921
+ adapter=adapter,
1922
+ query=query,
1923
+ params_list=params_list,
1924
+ )
1925
+
1926
+ except Exception as e:
1927
+ last_error = e
1928
+
1929
+ # Check if error is retryable
1930
+ if not self._retry_config.should_retry(e):
1931
+ raise
1932
+
1933
+ # Check if we have more attempts
1934
+ if attempt >= self._retry_config.max_retries - 1:
1935
+ raise
1936
+
1937
+ # Calculate delay
1938
+ delay = self._retry_config.get_delay(attempt)
1939
+
1940
+ # Wait before retry
1941
+ await asyncio.sleep(delay)
1942
+
1943
+ # For connection errors, try to reconnect
1944
+ if "pool is closed" in str(e).lower() or "connection" in str(e).lower():
1945
+ try:
1946
+ # Clear existing adapter to force reconnection
1947
+ if self._share_pool and self._pool_key:
1948
+ # Remove from shared pools to force recreation
1949
+ async with self._get_pool_lock():
1950
+ if self._pool_key in self._shared_pools:
1951
+ _, ref_count = self._shared_pools[self._pool_key]
1952
+ if ref_count <= 1:
1953
+ del self._shared_pools[self._pool_key]
1954
+ else:
1955
+ # This shouldn't happen with a closed pool
1956
+ del self._shared_pools[self._pool_key]
1957
+
1958
+ self._adapter = None
1959
+ self._connected = False
1960
+ adapter = await self._get_adapter()
1961
+ except Exception:
1962
+ # If reconnection fails, continue with retry loop
1963
+ pass
1964
+
1965
+ # All retries exhausted
1966
+ raise NodeExecutionError(
1967
+ f"Batch operation failed after {self._retry_config.max_retries} attempts: {last_error}"
1968
+ )
1969
+
1970
+ async def _execute_many_with_transaction(
1971
+ self, adapter: DatabaseAdapter, query: str, params_list: list[dict[str, Any]]
1972
+ ) -> int:
1973
+ """Execute batch operation with automatic transaction management.
1974
+
1975
+ Args:
1976
+ adapter: Database adapter
1977
+ query: SQL query to execute
1978
+ params_list: List of parameter dictionaries
1979
+
1980
+ Returns:
1981
+ Number of affected rows (estimated)
1982
+
1983
+ Raises:
1984
+ Exception: Re-raises any execution errors after rollback
1985
+ """
1986
+ if self._active_transaction:
1987
+ # Use existing transaction (manual mode)
1988
+ await adapter.execute_many(query, params_list, self._active_transaction)
1989
+ # Most adapters don't return row count from execute_many
1990
+ return len(params_list)
1991
+ elif self._transaction_mode == "auto":
1992
+ # Auto-transaction mode
1993
+ transaction = await adapter.begin_transaction()
1994
+ try:
1995
+ await adapter.execute_many(query, params_list, transaction)
1996
+ await adapter.commit_transaction(transaction)
1997
+ return len(params_list)
1998
+ except Exception:
1999
+ await adapter.rollback_transaction(transaction)
2000
+ raise
2001
+ else:
2002
+ # No transaction mode
2003
+ await adapter.execute_many(query, params_list)
2004
+ return len(params_list)
2005
+
2006
+ async def _execute_with_transaction(
2007
+ self,
2008
+ adapter: DatabaseAdapter,
2009
+ query: str,
2010
+ params: Any,
2011
+ fetch_mode: FetchMode,
2012
+ fetch_size: Optional[int],
2013
+ ) -> Any:
2014
+ """Execute query with automatic transaction management.
2015
+
2016
+ Args:
2017
+ adapter: Database adapter
2018
+ query: SQL query
2019
+ params: Query parameters
2020
+ fetch_mode: How to fetch results
2021
+ fetch_size: Number of rows for 'many' mode
2022
+
2023
+ Returns:
2024
+ Query results
2025
+
2026
+ Raises:
2027
+ Exception: Re-raises any execution errors after rollback
2028
+ """
2029
+ if self._active_transaction:
2030
+ # Use existing transaction (manual mode)
2031
+ return await adapter.execute(
2032
+ query=query,
2033
+ params=params,
2034
+ fetch_mode=fetch_mode,
2035
+ fetch_size=fetch_size,
2036
+ transaction=self._active_transaction,
2037
+ )
2038
+ elif self._transaction_mode == "auto":
2039
+ # Auto-transaction mode
2040
+ transaction = await adapter.begin_transaction()
2041
+ try:
2042
+ result = await adapter.execute(
2043
+ query=query,
2044
+ params=params,
2045
+ fetch_mode=fetch_mode,
2046
+ fetch_size=fetch_size,
2047
+ transaction=transaction,
2048
+ )
2049
+ await adapter.commit_transaction(transaction)
2050
+ return result
2051
+ except Exception:
2052
+ await adapter.rollback_transaction(transaction)
2053
+ raise
2054
+ else:
2055
+ # No transaction mode
2056
+ return await adapter.execute(
2057
+ query=query,
2058
+ params=params,
2059
+ fetch_mode=fetch_mode,
2060
+ fetch_size=fetch_size,
2061
+ )
2062
+
2063
+ @classmethod
2064
+ async def get_pool_metrics(cls) -> dict[str, Any]:
2065
+ """Get metrics for all shared connection pools.
2066
+
2067
+ Returns:
2068
+ dict: Pool metrics including pool count, connections per pool, etc.
2069
+ """
2070
+ async with cls._get_pool_lock():
2071
+ metrics = {"total_pools": len(cls._shared_pools), "pools": []}
2072
+
2073
+ for pool_key, (adapter, ref_count) in cls._shared_pools.items():
2074
+ pool_info = {
2075
+ "key": pool_key,
2076
+ "reference_count": ref_count,
2077
+ "type": adapter.__class__.__name__,
2078
+ }
2079
+
2080
+ # Try to get pool-specific metrics if available
2081
+ if hasattr(adapter, "_pool") and adapter._pool:
2082
+ pool = adapter._pool
2083
+ if hasattr(pool, "size"):
2084
+ pool_info["pool_size"] = pool.size()
2085
+ if hasattr(pool, "_holders"):
2086
+ pool_info["active_connections"] = len(
2087
+ [h for h in pool._holders if h._in_use]
2088
+ )
2089
+ elif hasattr(pool, "size") and hasattr(pool, "freesize"):
2090
+ pool_info["active_connections"] = pool.size - pool.freesize
2091
+
2092
+ metrics["pools"].append(pool_info)
2093
+
2094
+ return metrics
2095
+
2096
+ @classmethod
2097
+ async def clear_shared_pools(cls) -> None:
2098
+ """Clear all shared connection pools. Use with caution!"""
2099
+ async with cls._get_pool_lock():
2100
+ for pool_key, (adapter, _) in list(cls._shared_pools.items()):
2101
+ try:
2102
+ await adapter.disconnect()
2103
+ except Exception:
2104
+ pass # Best effort
2105
+ cls._shared_pools.clear()
2106
+
2107
+ def get_pool_info(self) -> dict[str, Any]:
2108
+ """Get information about this instance's connection pool.
2109
+
2110
+ Returns:
2111
+ dict: Pool information including shared status and metrics
2112
+ """
2113
+ info = {
2114
+ "shared": self._share_pool,
2115
+ "pool_key": self._pool_key,
2116
+ "connected": self._connected,
2117
+ }
2118
+
2119
+ if self._adapter and hasattr(self._adapter, "_pool") and self._adapter._pool:
2120
+ pool = self._adapter._pool
2121
+ if hasattr(pool, "size"):
2122
+ info["pool_size"] = pool.size()
2123
+ if hasattr(pool, "_holders"):
2124
+ info["active_connections"] = len(
2125
+ [h for h in pool._holders if h._in_use]
2126
+ )
2127
+ elif hasattr(pool, "size") and hasattr(pool, "freesize"):
2128
+ info["active_connections"] = pool.size - pool.freesize
2129
+
2130
+ return info
2131
+
2132
+ async def execute_with_version_check(
2133
+ self,
2134
+ query: str,
2135
+ params: dict[str, Any],
2136
+ expected_version: Optional[int] = None,
2137
+ record_id: Optional[Any] = None,
2138
+ table_name: Optional[str] = None,
2139
+ ) -> dict[str, Any]:
2140
+ """Execute a query with optimistic locking version check.
2141
+
2142
+ Args:
2143
+ query: SQL query to execute (UPDATE or DELETE)
2144
+ params: Query parameters
2145
+ expected_version: Expected version number for conflict detection
2146
+ record_id: ID of the record being updated (for retry)
2147
+ table_name: Table name (for retry to re-read current version)
2148
+
2149
+ Returns:
2150
+ dict: Result with version information and conflict status
2151
+
2152
+ Raises:
2153
+ NodeExecutionError: On version conflict or database error
2154
+ """
2155
+ if not self._enable_optimistic_locking:
2156
+ # Just execute normally if optimistic locking is disabled
2157
+ result = await self.execute_async(query=query, params=params)
2158
+ return {
2159
+ "result": result,
2160
+ "version_checked": False,
2161
+ "status": LockStatus.SUCCESS,
2162
+ }
2163
+
2164
+ # Add version check to the query
2165
+ if expected_version is not None:
2166
+ # Ensure version field is in params
2167
+ if "expected_version" in query:
2168
+ # Query already uses :expected_version, just ensure it's set
2169
+ params["expected_version"] = expected_version
2170
+ else:
2171
+ # Use standard version field
2172
+ params[self._version_field] = expected_version
2173
+
2174
+ # For UPDATE queries, also add version increment
2175
+ if "UPDATE" in query.upper() and "SET" in query.upper():
2176
+ # Find SET clause and add version increment
2177
+ set_match = re.search(r"(SET\s+)(.+?)(\s+WHERE)", query, re.IGNORECASE)
2178
+ if set_match:
2179
+ set_clause = set_match.group(2)
2180
+ # Add version increment if not already present
2181
+ if self._version_field not in set_clause:
2182
+ new_set_clause = f"{set_clause}, {self._version_field} = {self._version_field} + 1"
2183
+ query = (
2184
+ query[: set_match.start(2)]
2185
+ + new_set_clause
2186
+ + query[set_match.end(2) :]
2187
+ )
2188
+
2189
+ # Modify query to include version check in WHERE clause (only if not already present)
2190
+ # Check for version condition in WHERE clause specifically, not just anywhere in query
2191
+ where_clause_pattern = (
2192
+ r"WHERE\s+.*?" + re.escape(self._version_field) + r"\s*="
2193
+ )
2194
+ has_version_check_in_where = (
2195
+ re.search(where_clause_pattern, query, re.IGNORECASE) is not None
2196
+ or ":expected_version" in query
2197
+ )
2198
+ if not has_version_check_in_where:
2199
+ if "WHERE" in query.upper():
2200
+ query += f" AND {self._version_field} = :{self._version_field}"
2201
+ else:
2202
+ query += f" WHERE {self._version_field} = :{self._version_field}"
2203
+
2204
+ # Try to execute with version check
2205
+ retry_count = 0
2206
+ for attempt in range(self._version_retry_attempts):
2207
+ try:
2208
+ result = await self.execute_async(query=query, params=params)
2209
+
2210
+ # Check if any rows were affected
2211
+ rows_affected = 0
2212
+ rows_affected_found = False
2213
+ if isinstance(result.get("result"), dict):
2214
+ # Check if we have data array with rows_affected
2215
+ data = result["result"].get("data", [])
2216
+ if data and isinstance(data, list) and len(data) > 0:
2217
+ if isinstance(data[0], dict) and "rows_affected" in data[0]:
2218
+ rows_affected = data[0]["rows_affected"]
2219
+ rows_affected_found = True
2220
+
2221
+ # Only check direct keys if we haven't found rows_affected in data
2222
+ if not rows_affected_found:
2223
+ rows_affected = (
2224
+ result["result"].get("rows_affected", 0)
2225
+ or result["result"].get("rowcount", 0)
2226
+ or result["result"].get("affected_rows", 0)
2227
+ or result["result"].get("row_count", 0)
2228
+ )
2229
+
2230
+ if rows_affected == 0 and expected_version is not None:
2231
+ # Version conflict detected
2232
+ if self._conflict_resolution == "fail_fast":
2233
+ raise NodeExecutionError(
2234
+ f"Version conflict: expected version {expected_version} not found"
2235
+ )
2236
+ elif (
2237
+ self._conflict_resolution == "retry"
2238
+ and record_id
2239
+ and table_name
2240
+ ):
2241
+ # Read current version
2242
+ current = await self.execute_async(
2243
+ query=f"SELECT {self._version_field} FROM {table_name} WHERE id = :id",
2244
+ params={"id": record_id},
2245
+ )
2246
+
2247
+ if current["result"]["data"]:
2248
+ current_version = current["result"]["data"][0][
2249
+ self._version_field
2250
+ ]
2251
+ params[self._version_field] = current_version
2252
+ # Update expected version for next attempt
2253
+ expected_version = current_version
2254
+ retry_count += 1
2255
+ continue
2256
+ else:
2257
+ return {
2258
+ "result": None,
2259
+ "status": LockStatus.RECORD_NOT_FOUND,
2260
+ "version_checked": True,
2261
+ "retry_count": retry_count,
2262
+ }
2263
+ elif self._conflict_resolution == "last_writer_wins":
2264
+ # Remove version check and try again
2265
+ params_no_version = params.copy()
2266
+ params_no_version.pop(self._version_field, None)
2267
+ query_no_version = query.replace(
2268
+ f" AND {self._version_field} = :{self._version_field}", ""
2269
+ )
2270
+ result = await self.execute_async(
2271
+ query=query_no_version, params=params_no_version
2272
+ )
2273
+ return {
2274
+ "result": result,
2275
+ "status": LockStatus.SUCCESS,
2276
+ "version_checked": False,
2277
+ "conflict_resolved": "last_writer_wins",
2278
+ "retry_count": retry_count,
711
2279
  }
2280
+
2281
+ # Success - increment version for UPDATE queries
2282
+ if "UPDATE" in query.upper() and rows_affected > 0:
2283
+ # The query should have incremented the version
2284
+ new_version = (
2285
+ (expected_version or 0) + 1
2286
+ if expected_version is not None
2287
+ else None
2288
+ )
2289
+ return {
2290
+ "result": result,
2291
+ "status": LockStatus.SUCCESS,
2292
+ "version_checked": True,
2293
+ "new_version": new_version,
2294
+ "rows_affected": rows_affected,
2295
+ "retry_count": retry_count,
2296
+ }
2297
+ else:
2298
+ return {
2299
+ "result": result,
2300
+ "status": LockStatus.SUCCESS,
2301
+ "version_checked": True,
2302
+ "rows_affected": rows_affected,
2303
+ "retry_count": retry_count,
712
2304
  }
713
2305
 
714
- except Exception as e:
715
- if attempt < max_retries - 1:
716
- await asyncio.sleep(retry_delay * (2**attempt))
717
- continue
2306
+ except NodeExecutionError:
2307
+ if attempt >= self._version_retry_attempts - 1:
718
2308
  raise
2309
+ await asyncio.sleep(0.1 * (attempt + 1)) # Exponential backoff
719
2310
 
720
- except Exception as e:
721
- raise NodeExecutionError(f"Database query failed: {str(e)}")
2311
+ return {
2312
+ "result": None,
2313
+ "status": LockStatus.RETRY_EXHAUSTED,
2314
+ "version_checked": True,
2315
+ "retry_count": self._version_retry_attempts,
2316
+ }
722
2317
 
723
- async def process(self, inputs: dict[str, Any]) -> dict[str, Any]:
724
- """Async process method for middleware compatibility."""
725
- return await self.async_run(**inputs)
2318
+ async def read_with_version(
2319
+ self,
2320
+ query: str,
2321
+ params: Optional[dict[str, Any]] = None,
2322
+ ) -> dict[str, Any]:
2323
+ """Execute a SELECT query and extract version information.
2324
+
2325
+ Args:
2326
+ query: SELECT query to execute
2327
+ params: Query parameters
2328
+
2329
+ Returns:
2330
+ dict: Result with version information included
2331
+ """
2332
+ result = await self.execute_async(query=query, params=params)
2333
+
2334
+ if self._enable_optimistic_locking and result.get("result", {}).get("data"):
2335
+ # Extract version from results
2336
+ data = result["result"]["data"]
2337
+ if isinstance(data, list) and len(data) > 0:
2338
+ # Single record
2339
+ if len(data) == 1 and self._version_field in data[0]:
2340
+ return {
2341
+ "result": result,
2342
+ "version": data[0][self._version_field],
2343
+ "record": data[0],
2344
+ }
2345
+ # Multiple records - include version in each
2346
+ else:
2347
+ versions = []
2348
+ for record in data:
2349
+ if self._version_field in record:
2350
+ versions.append(record[self._version_field])
2351
+ return {
2352
+ "result": result,
2353
+ "versions": versions,
2354
+ "records": data,
2355
+ }
2356
+
2357
+ return result
2358
+
2359
+ def build_versioned_update_query(
2360
+ self,
2361
+ table_name: str,
2362
+ update_fields: dict[str, Any],
2363
+ where_clause: str,
2364
+ increment_version: bool = True,
2365
+ ) -> str:
2366
+ """Build an UPDATE query with version increment.
2367
+
2368
+ Args:
2369
+ table_name: Name of the table to update
2370
+ update_fields: Fields to update (excluding version)
2371
+ where_clause: WHERE clause (without WHERE keyword)
2372
+ increment_version: Whether to increment the version field
2373
+
2374
+ Returns:
2375
+ str: UPDATE query with version handling
2376
+ """
2377
+ if not self._enable_optimistic_locking:
2378
+ # Build normal update query
2379
+ set_parts = [f"{field} = :{field}" for field in update_fields]
2380
+ return (
2381
+ f"UPDATE {table_name} SET {', '.join(set_parts)} WHERE {where_clause}"
2382
+ )
2383
+
2384
+ # Build versioned update query
2385
+ set_parts = [f"{field} = :{field}" for field in update_fields]
2386
+
2387
+ if increment_version:
2388
+ set_parts.append(f"{self._version_field} = {self._version_field} + 1")
2389
+
2390
+ return f"UPDATE {table_name} SET {', '.join(set_parts)} WHERE {where_clause}"
2391
+
2392
+ def _convert_to_named_parameters(
2393
+ self, query: str, parameters: list
2394
+ ) -> tuple[str, dict]:
2395
+ """Convert positional parameters to named parameters for various SQL dialects.
2396
+
2397
+ This method handles conversion from different SQL parameter styles to a
2398
+ consistent named parameter format that works with async database drivers.
2399
+
2400
+ Args:
2401
+ query: SQL query with positional placeholders (?, $1, %s)
2402
+ parameters: List of parameter values
2403
+
2404
+ Returns:
2405
+ Tuple of (modified_query, parameter_dict)
2406
+
2407
+ Examples:
2408
+ >>> # SQLite style
2409
+ >>> query = "SELECT * FROM users WHERE age > ? AND active = ?"
2410
+ >>> params = [25, True]
2411
+ >>> new_query, param_dict = node._convert_to_named_parameters(query, params)
2412
+ >>> # Returns: ("SELECT * FROM users WHERE age > :p0 AND active = :p1",
2413
+ >>> # {"p0": 25, "p1": True})
2414
+
2415
+ >>> # PostgreSQL style
2416
+ >>> query = "UPDATE users SET name = $1 WHERE id = $2"
2417
+ >>> params = ["John", 123]
2418
+ >>> new_query, param_dict = node._convert_to_named_parameters(query, params)
2419
+ >>> # Returns: ("UPDATE users SET name = :p0 WHERE id = :p1",
2420
+ >>> # {"p0": "John", "p1": 123})
2421
+ """
2422
+ # Create parameter dictionary
2423
+ param_dict = {}
2424
+ for i, value in enumerate(parameters):
2425
+ param_dict[f"p{i}"] = value
2426
+
2427
+ # Replace different placeholder formats with named parameters
2428
+ modified_query = query
2429
+
2430
+ # Handle SQLite-style ? placeholders
2431
+ placeholder_count = 0
2432
+
2433
+ def replace_question_mark(match):
2434
+ nonlocal placeholder_count
2435
+ replacement = f":p{placeholder_count}"
2436
+ placeholder_count += 1
2437
+ return replacement
2438
+
2439
+ modified_query = re.sub(r"\?", replace_question_mark, modified_query)
2440
+
2441
+ # Handle PostgreSQL-style $1, $2, etc. placeholders
2442
+ def replace_postgres_placeholder(match):
2443
+ index = int(match.group(1)) - 1 # PostgreSQL uses 1-based indexing
2444
+ return f":p{index}"
2445
+
2446
+ modified_query = re.sub(
2447
+ r"\$(\d+)", replace_postgres_placeholder, modified_query
2448
+ )
2449
+
2450
+ # Handle MySQL-style %s placeholders
2451
+ placeholder_count = 0
2452
+
2453
+ def replace_mysql_placeholder(match):
2454
+ nonlocal placeholder_count
2455
+ replacement = f":p{placeholder_count}"
2456
+ placeholder_count += 1
2457
+ return replacement
2458
+
2459
+ modified_query = re.sub(r"%s", replace_mysql_placeholder, modified_query)
2460
+
2461
+ return modified_query, param_dict
2462
+
2463
+ def _format_results(self, data: list[dict], result_format: str) -> Any:
2464
+ """Format query results according to specified format.
2465
+
2466
+ Args:
2467
+ data: List of dictionaries from database query
2468
+ result_format: Desired output format ('dict', 'list', 'dataframe')
2469
+
2470
+ Returns:
2471
+ Formatted results
2472
+
2473
+ Formats:
2474
+ - 'dict': List of dictionaries (default) - column names as keys
2475
+ - 'list': List of lists - values only, no column names
2476
+ - 'dataframe': Pandas DataFrame (if pandas is available)
2477
+ """
2478
+ if not data:
2479
+ # Return empty structure based on format
2480
+ if result_format == "dataframe":
2481
+ try:
2482
+ import pandas as pd
2483
+
2484
+ return pd.DataFrame()
2485
+ except ImportError:
2486
+ # Fall back to dict if pandas not available
2487
+ return []
2488
+ elif result_format == "list":
2489
+ return []
2490
+ else:
2491
+ return []
2492
+
2493
+ if result_format == "dict":
2494
+ # Already in dict format from adapters
2495
+ return data
2496
+
2497
+ elif result_format == "list":
2498
+ # Convert to list of lists (values only)
2499
+ if data:
2500
+ # Get column order from first row
2501
+ columns = list(data[0].keys())
2502
+ return [[row.get(col) for col in columns] for row in data]
2503
+ return []
2504
+
2505
+ elif result_format == "dataframe":
2506
+ # Convert to pandas DataFrame if available
2507
+ try:
2508
+ import pandas as pd
2509
+
2510
+ return pd.DataFrame(data)
2511
+ except ImportError:
2512
+ # Log warning and fall back to dict format
2513
+ if hasattr(self, "logger"):
2514
+ self.logger.warning(
2515
+ "Pandas not installed. Install with: pip install pandas. "
2516
+ "Falling back to dict format."
2517
+ )
2518
+ return data
2519
+
2520
+ else:
2521
+ # Unknown format - default to dict with warning
2522
+ if hasattr(self, "logger"):
2523
+ self.logger.warning(
2524
+ f"Unknown result_format '{result_format}', defaulting to 'dict'"
2525
+ )
2526
+ return data
726
2527
 
727
2528
  async def cleanup(self):
728
2529
  """Clean up database connections."""
2530
+ # Rollback any active transaction
2531
+ if self._active_transaction and self._adapter:
2532
+ try:
2533
+ await self._adapter.rollback_transaction(self._active_transaction)
2534
+ except Exception:
2535
+ pass # Best effort cleanup
2536
+ self._active_transaction = None
2537
+
729
2538
  if self._adapter and self._connected:
730
- await self._adapter.disconnect()
2539
+ if self._share_pool and self._pool_key:
2540
+ # Decrement reference count for shared pool
2541
+ async with self._get_pool_lock():
2542
+ if self._pool_key in self._shared_pools:
2543
+ adapter, ref_count = self._shared_pools[self._pool_key]
2544
+ if ref_count > 1:
2545
+ # Others still using the pool
2546
+ self._shared_pools[self._pool_key] = (
2547
+ adapter,
2548
+ ref_count - 1,
2549
+ )
2550
+ else:
2551
+ # Last reference, close the pool
2552
+ del self._shared_pools[self._pool_key]
2553
+ await adapter.disconnect()
2554
+ else:
2555
+ # Dedicated pool, close directly
2556
+ await self._adapter.disconnect()
2557
+
731
2558
  self._connected = False
732
2559
  self._adapter = None
733
2560