kailash 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +3 -3
- kailash/api/custom_nodes_secure.py +3 -3
- kailash/api/gateway.py +1 -1
- kailash/api/studio.py +1 -1
- kailash/api/workflow_api.py +2 -2
- kailash/core/resilience/bulkhead.py +475 -0
- kailash/core/resilience/circuit_breaker.py +92 -10
- kailash/core/resilience/health_monitor.py +578 -0
- kailash/edge/discovery.py +86 -0
- kailash/mcp_server/__init__.py +309 -33
- kailash/mcp_server/advanced_features.py +1022 -0
- kailash/mcp_server/ai_registry_server.py +27 -2
- kailash/mcp_server/auth.py +789 -0
- kailash/mcp_server/client.py +645 -378
- kailash/mcp_server/discovery.py +1593 -0
- kailash/mcp_server/errors.py +673 -0
- kailash/mcp_server/oauth.py +1727 -0
- kailash/mcp_server/protocol.py +1126 -0
- kailash/mcp_server/registry_integration.py +587 -0
- kailash/mcp_server/server.py +1228 -96
- kailash/mcp_server/transports.py +1169 -0
- kailash/mcp_server/utils/__init__.py +6 -1
- kailash/mcp_server/utils/cache.py +250 -7
- kailash/middleware/auth/auth_manager.py +3 -3
- kailash/middleware/communication/api_gateway.py +1 -1
- kailash/middleware/communication/realtime.py +1 -1
- kailash/middleware/mcp/enhanced_server.py +1 -1
- kailash/nodes/__init__.py +2 -0
- kailash/nodes/admin/audit_log.py +6 -6
- kailash/nodes/admin/permission_check.py +8 -8
- kailash/nodes/admin/role_management.py +32 -28
- kailash/nodes/admin/schema.sql +6 -1
- kailash/nodes/admin/schema_manager.py +13 -13
- kailash/nodes/admin/security_event.py +15 -15
- kailash/nodes/admin/tenant_isolation.py +3 -3
- kailash/nodes/admin/transaction_utils.py +3 -3
- kailash/nodes/admin/user_management.py +21 -21
- kailash/nodes/ai/a2a.py +11 -11
- kailash/nodes/ai/ai_providers.py +9 -12
- kailash/nodes/ai/embedding_generator.py +13 -14
- kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
- kailash/nodes/ai/iterative_llm_agent.py +2 -2
- kailash/nodes/ai/llm_agent.py +210 -33
- kailash/nodes/ai/self_organizing.py +2 -2
- kailash/nodes/alerts/discord.py +4 -4
- kailash/nodes/api/graphql.py +6 -6
- kailash/nodes/api/http.py +10 -10
- kailash/nodes/api/rate_limiting.py +4 -4
- kailash/nodes/api/rest.py +15 -15
- kailash/nodes/auth/mfa.py +3 -3
- kailash/nodes/auth/risk_assessment.py +2 -2
- kailash/nodes/auth/session_management.py +5 -5
- kailash/nodes/auth/sso.py +143 -0
- kailash/nodes/base.py +8 -2
- kailash/nodes/base_async.py +16 -2
- kailash/nodes/base_with_acl.py +2 -2
- kailash/nodes/cache/__init__.py +9 -0
- kailash/nodes/cache/cache.py +1172 -0
- kailash/nodes/cache/cache_invalidation.py +874 -0
- kailash/nodes/cache/redis_pool_manager.py +595 -0
- kailash/nodes/code/async_python.py +2 -1
- kailash/nodes/code/python.py +194 -30
- kailash/nodes/compliance/data_retention.py +6 -6
- kailash/nodes/compliance/gdpr.py +5 -5
- kailash/nodes/data/__init__.py +10 -0
- kailash/nodes/data/async_sql.py +1956 -129
- kailash/nodes/data/optimistic_locking.py +906 -0
- kailash/nodes/data/readers.py +8 -8
- kailash/nodes/data/redis.py +378 -0
- kailash/nodes/data/sql.py +314 -3
- kailash/nodes/data/streaming.py +21 -0
- kailash/nodes/enterprise/__init__.py +8 -0
- kailash/nodes/enterprise/audit_logger.py +285 -0
- kailash/nodes/enterprise/batch_processor.py +22 -3
- kailash/nodes/enterprise/data_lineage.py +1 -1
- kailash/nodes/enterprise/mcp_executor.py +205 -0
- kailash/nodes/enterprise/service_discovery.py +150 -0
- kailash/nodes/enterprise/tenant_assignment.py +108 -0
- kailash/nodes/logic/async_operations.py +2 -2
- kailash/nodes/logic/convergence.py +1 -1
- kailash/nodes/logic/operations.py +1 -1
- kailash/nodes/monitoring/__init__.py +11 -1
- kailash/nodes/monitoring/health_check.py +456 -0
- kailash/nodes/monitoring/log_processor.py +817 -0
- kailash/nodes/monitoring/metrics_collector.py +627 -0
- kailash/nodes/monitoring/performance_benchmark.py +137 -11
- kailash/nodes/rag/advanced.py +7 -7
- kailash/nodes/rag/agentic.py +49 -2
- kailash/nodes/rag/conversational.py +3 -3
- kailash/nodes/rag/evaluation.py +3 -3
- kailash/nodes/rag/federated.py +3 -3
- kailash/nodes/rag/graph.py +3 -3
- kailash/nodes/rag/multimodal.py +3 -3
- kailash/nodes/rag/optimized.py +5 -5
- kailash/nodes/rag/privacy.py +3 -3
- kailash/nodes/rag/query_processing.py +6 -6
- kailash/nodes/rag/realtime.py +1 -1
- kailash/nodes/rag/registry.py +1 -1
- kailash/nodes/rag/router.py +1 -1
- kailash/nodes/rag/similarity.py +7 -7
- kailash/nodes/rag/strategies.py +4 -4
- kailash/nodes/security/abac_evaluator.py +6 -6
- kailash/nodes/security/behavior_analysis.py +5 -5
- kailash/nodes/security/credential_manager.py +1 -1
- kailash/nodes/security/rotating_credentials.py +11 -11
- kailash/nodes/security/threat_detection.py +8 -8
- kailash/nodes/testing/credential_testing.py +2 -2
- kailash/nodes/transform/processors.py +5 -5
- kailash/runtime/local.py +163 -9
- kailash/runtime/parameter_injection.py +425 -0
- kailash/runtime/parameter_injector.py +657 -0
- kailash/runtime/testing.py +2 -2
- kailash/testing/fixtures.py +2 -2
- kailash/workflow/builder.py +99 -14
- kailash/workflow/builder_improvements.py +207 -0
- kailash/workflow/input_handling.py +170 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/METADATA +22 -9
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/RECORD +122 -95
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/WHEEL +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/entry_points.txt +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/top_level.txt +0 -0
kailash/nodes/data/async_sql.py
CHANGED
@@ -27,6 +27,8 @@ Key Features:
|
|
27
27
|
import asyncio
|
28
28
|
import json
|
29
29
|
import os
|
30
|
+
import random
|
31
|
+
import re
|
30
32
|
from abc import ABC, abstractmethod
|
31
33
|
from dataclasses import dataclass
|
32
34
|
from datetime import date, datetime
|
@@ -34,10 +36,37 @@ from decimal import Decimal
|
|
34
36
|
from enum import Enum
|
35
37
|
from typing import Any, AsyncIterator, Optional, Union
|
36
38
|
|
39
|
+
import yaml
|
40
|
+
|
37
41
|
from kailash.nodes.base import NodeParameter, register_node
|
38
42
|
from kailash.nodes.base_async import AsyncNode
|
39
43
|
from kailash.sdk_exceptions import NodeExecutionError, NodeValidationError
|
40
44
|
|
45
|
+
# Import optimistic locking for version control
|
46
|
+
try:
|
47
|
+
from kailash.nodes.data.optimistic_locking import (
|
48
|
+
ConflictResolution,
|
49
|
+
LockStatus,
|
50
|
+
OptimisticLockingNode,
|
51
|
+
)
|
52
|
+
|
53
|
+
OPTIMISTIC_LOCKING_AVAILABLE = True
|
54
|
+
except ImportError:
|
55
|
+
OPTIMISTIC_LOCKING_AVAILABLE = False
|
56
|
+
|
57
|
+
# Define minimal enums if not available
|
58
|
+
class ConflictResolution:
|
59
|
+
FAIL_FAST = "fail_fast"
|
60
|
+
RETRY = "retry"
|
61
|
+
MERGE = "merge"
|
62
|
+
LAST_WRITER_WINS = "last_writer_wins"
|
63
|
+
|
64
|
+
class LockStatus:
|
65
|
+
SUCCESS = "success"
|
66
|
+
VERSION_CONFLICT = "version_conflict"
|
67
|
+
RECORD_NOT_FOUND = "record_not_found"
|
68
|
+
RETRY_EXHAUSTED = "retry_exhausted"
|
69
|
+
|
41
70
|
|
42
71
|
class DatabaseType(Enum):
|
43
72
|
"""Supported database types."""
|
@@ -47,6 +76,124 @@ class DatabaseType(Enum):
|
|
47
76
|
SQLITE = "sqlite"
|
48
77
|
|
49
78
|
|
79
|
+
class QueryValidator:
|
80
|
+
"""Validates SQL queries for common security issues."""
|
81
|
+
|
82
|
+
# Dangerous SQL patterns that could indicate injection attempts
|
83
|
+
DANGEROUS_PATTERNS = [
|
84
|
+
# Multiple statements
|
85
|
+
r";\s*(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|GRANT|REVOKE)",
|
86
|
+
# Comments that might hide malicious code
|
87
|
+
r"--.*$",
|
88
|
+
r"/\*.*\*/",
|
89
|
+
# Union-based injection
|
90
|
+
r"\bUNION\b.*\bSELECT\b",
|
91
|
+
# Time-based blind injection
|
92
|
+
r"\b(SLEEP|WAITFOR|PG_SLEEP)\b",
|
93
|
+
# Out-of-band injection
|
94
|
+
r"\b(LOAD_FILE|INTO\s+OUTFILE|INTO\s+DUMPFILE)\b",
|
95
|
+
# System command execution
|
96
|
+
r"\b(XP_CMDSHELL|EXEC\s+MASTER)",
|
97
|
+
]
|
98
|
+
|
99
|
+
# Patterns that should only appear in admin queries
|
100
|
+
ADMIN_ONLY_PATTERNS = [
|
101
|
+
r"\b(CREATE|ALTER|DROP)\s+(?:\w+\s+)*(TABLE|INDEX|VIEW|PROCEDURE|FUNCTION|TRIGGER)",
|
102
|
+
r"\b(GRANT|REVOKE)\b",
|
103
|
+
r"\bTRUNCATE\b",
|
104
|
+
]
|
105
|
+
|
106
|
+
@classmethod
|
107
|
+
def validate_query(cls, query: str, allow_admin: bool = False) -> None:
|
108
|
+
"""Validate a SQL query for security issues.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
query: The SQL query to validate
|
112
|
+
allow_admin: Whether to allow administrative commands
|
113
|
+
|
114
|
+
Raises:
|
115
|
+
NodeValidationError: If the query contains dangerous patterns
|
116
|
+
"""
|
117
|
+
query_upper = query.upper()
|
118
|
+
|
119
|
+
# Check for dangerous patterns
|
120
|
+
for pattern in cls.DANGEROUS_PATTERNS:
|
121
|
+
if re.search(pattern, query, re.IGNORECASE | re.MULTILINE):
|
122
|
+
raise NodeValidationError(
|
123
|
+
f"Query contains potentially dangerous pattern: {pattern}"
|
124
|
+
)
|
125
|
+
|
126
|
+
# Check for admin-only patterns if not allowed
|
127
|
+
if not allow_admin:
|
128
|
+
for pattern in cls.ADMIN_ONLY_PATTERNS:
|
129
|
+
if re.search(pattern, query, re.IGNORECASE):
|
130
|
+
raise NodeValidationError(
|
131
|
+
f"Query contains administrative command that is not allowed: {pattern}"
|
132
|
+
)
|
133
|
+
|
134
|
+
@classmethod
|
135
|
+
def validate_identifier(cls, identifier: str) -> None:
|
136
|
+
"""Validate a database identifier (table/column name).
|
137
|
+
|
138
|
+
Args:
|
139
|
+
identifier: The identifier to validate
|
140
|
+
|
141
|
+
Raises:
|
142
|
+
NodeValidationError: If the identifier is invalid
|
143
|
+
"""
|
144
|
+
# Allow alphanumeric, underscore, and dot (for schema.table)
|
145
|
+
if not re.match(
|
146
|
+
r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?$", identifier
|
147
|
+
):
|
148
|
+
raise NodeValidationError(
|
149
|
+
f"Invalid identifier: {identifier}. "
|
150
|
+
"Identifiers must start with letter/underscore and contain only letters, numbers, underscores."
|
151
|
+
)
|
152
|
+
|
153
|
+
@classmethod
|
154
|
+
def sanitize_string_literal(cls, value: str) -> str:
|
155
|
+
"""Sanitize a string value for SQL by escaping quotes.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
value: The string value to sanitize
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
Escaped string safe for SQL
|
162
|
+
"""
|
163
|
+
# This is a basic implementation - real escaping should be done by the driver
|
164
|
+
return value.replace("'", "''").replace("\\", "\\\\")
|
165
|
+
|
166
|
+
@classmethod
|
167
|
+
def validate_connection_string(cls, connection_string: str) -> None:
|
168
|
+
"""Validate a database connection string.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
connection_string: The connection string to validate
|
172
|
+
|
173
|
+
Raises:
|
174
|
+
NodeValidationError: If the connection string appears malicious
|
175
|
+
"""
|
176
|
+
# Check for suspicious patterns in connection strings
|
177
|
+
suspicious_patterns = [
|
178
|
+
# SQL injection attempts
|
179
|
+
r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)",
|
180
|
+
# Command execution attempts
|
181
|
+
r';.*\bhost\s*=\s*[\'"]?\|',
|
182
|
+
r';.*\bhost\s*=\s*[\'"]?`',
|
183
|
+
r"\$\(", # Command substitution
|
184
|
+
r"`", # Backticks
|
185
|
+
# File access attempts
|
186
|
+
r'sslcert\s*=\s*[\'"]?(/etc/passwd|/etc/shadow)',
|
187
|
+
r'sslkey\s*=\s*[\'"]?(/etc/passwd|/etc/shadow)',
|
188
|
+
]
|
189
|
+
|
190
|
+
for pattern in suspicious_patterns:
|
191
|
+
if re.search(pattern, connection_string, re.IGNORECASE):
|
192
|
+
raise NodeValidationError(
|
193
|
+
"Connection string contains suspicious pattern"
|
194
|
+
)
|
195
|
+
|
196
|
+
|
50
197
|
class FetchMode(Enum):
|
51
198
|
"""Result fetch modes."""
|
52
199
|
|
@@ -56,6 +203,72 @@ class FetchMode(Enum):
|
|
56
203
|
ITERATOR = "iterator" # Return async iterator
|
57
204
|
|
58
205
|
|
206
|
+
@dataclass
|
207
|
+
class RetryConfig:
|
208
|
+
"""Configuration for retry logic."""
|
209
|
+
|
210
|
+
max_retries: int = 3
|
211
|
+
initial_delay: float = 1.0
|
212
|
+
max_delay: float = 60.0
|
213
|
+
exponential_base: float = 2.0
|
214
|
+
jitter: bool = True
|
215
|
+
|
216
|
+
# Retryable error patterns (database-specific)
|
217
|
+
retryable_errors: list[str] = None
|
218
|
+
|
219
|
+
def __post_init__(self):
|
220
|
+
"""Initialize default retryable errors."""
|
221
|
+
if self.retryable_errors is None:
|
222
|
+
self.retryable_errors = [
|
223
|
+
# PostgreSQL
|
224
|
+
"connection_refused",
|
225
|
+
"connection_reset",
|
226
|
+
"connection reset", # Handle different cases
|
227
|
+
"connection_aborted",
|
228
|
+
"could not connect",
|
229
|
+
"server closed the connection",
|
230
|
+
"terminating connection",
|
231
|
+
"connectionreseterror",
|
232
|
+
"connectionrefusederror",
|
233
|
+
"brokenpipeerror",
|
234
|
+
# MySQL
|
235
|
+
"lost connection to mysql server",
|
236
|
+
"mysql server has gone away",
|
237
|
+
"can't connect to mysql server",
|
238
|
+
# SQLite
|
239
|
+
"database is locked",
|
240
|
+
"disk i/o error",
|
241
|
+
# General
|
242
|
+
"timeout",
|
243
|
+
"timed out",
|
244
|
+
"pool is closed",
|
245
|
+
# DNS/Network errors
|
246
|
+
"nodename nor servname provided",
|
247
|
+
"name or service not known",
|
248
|
+
"gaierror",
|
249
|
+
"getaddrinfo failed",
|
250
|
+
"temporary failure in name resolution",
|
251
|
+
]
|
252
|
+
|
253
|
+
def should_retry(self, error: Exception) -> bool:
|
254
|
+
"""Check if an error is retryable."""
|
255
|
+
error_str = str(error).lower()
|
256
|
+
return any(pattern.lower() in error_str for pattern in self.retryable_errors)
|
257
|
+
|
258
|
+
def get_delay(self, attempt: int) -> float:
|
259
|
+
"""Calculate delay for a retry attempt."""
|
260
|
+
delay = min(
|
261
|
+
self.initial_delay * (self.exponential_base**attempt), self.max_delay
|
262
|
+
)
|
263
|
+
|
264
|
+
if self.jitter:
|
265
|
+
# Add random jitter (±25%)
|
266
|
+
jitter_amount = delay * 0.25
|
267
|
+
delay += random.uniform(-jitter_amount, jitter_amount)
|
268
|
+
|
269
|
+
return max(0, delay) # Ensure non-negative
|
270
|
+
|
271
|
+
|
59
272
|
@dataclass
|
60
273
|
class DatabaseConfig:
|
61
274
|
"""Database connection configuration."""
|
@@ -96,19 +309,43 @@ class DatabaseAdapter(ABC):
|
|
96
309
|
"""Convert database-specific types to JSON-serializable types."""
|
97
310
|
converted = {}
|
98
311
|
for key, value in row.items():
|
99
|
-
|
100
|
-
# Convert Decimal to float for JSON serialization
|
101
|
-
converted[key] = float(value)
|
102
|
-
elif isinstance(value, datetime):
|
103
|
-
# Convert datetime to ISO format string
|
104
|
-
converted[key] = value.isoformat()
|
105
|
-
elif isinstance(value, date):
|
106
|
-
# Convert date to ISO format string
|
107
|
-
converted[key] = value.isoformat()
|
108
|
-
else:
|
109
|
-
converted[key] = value
|
312
|
+
converted[key] = self._serialize_value(value)
|
110
313
|
return converted
|
111
314
|
|
315
|
+
def _serialize_value(self, value: Any) -> Any:
|
316
|
+
"""Convert database-specific types to JSON-serializable types."""
|
317
|
+
if value is None:
|
318
|
+
return None
|
319
|
+
elif isinstance(value, bool):
|
320
|
+
# Handle bool before int (bool is subclass of int in Python)
|
321
|
+
return value
|
322
|
+
elif isinstance(value, (int, float)):
|
323
|
+
# Return numeric types as-is
|
324
|
+
return value
|
325
|
+
elif isinstance(value, str):
|
326
|
+
# Return strings as-is
|
327
|
+
return value
|
328
|
+
elif isinstance(value, bytes):
|
329
|
+
import base64
|
330
|
+
|
331
|
+
result = base64.b64encode(value).decode("utf-8")
|
332
|
+
return result
|
333
|
+
elif isinstance(value, Decimal):
|
334
|
+
return float(value)
|
335
|
+
elif isinstance(value, datetime):
|
336
|
+
return value.isoformat()
|
337
|
+
elif isinstance(value, date):
|
338
|
+
return value.isoformat()
|
339
|
+
elif hasattr(value, "total_seconds"): # timedelta
|
340
|
+
return value.total_seconds()
|
341
|
+
elif hasattr(value, "hex"): # UUID
|
342
|
+
return str(value)
|
343
|
+
elif isinstance(value, (list, tuple)):
|
344
|
+
return [self._serialize_value(item) for item in value]
|
345
|
+
elif isinstance(value, dict):
|
346
|
+
return {k: self._serialize_value(v) for k, v in value.items()}
|
347
|
+
return value
|
348
|
+
|
112
349
|
@abstractmethod
|
113
350
|
async def connect(self) -> None:
|
114
351
|
"""Establish connection pool."""
|
@@ -126,8 +363,9 @@ class DatabaseAdapter(ABC):
|
|
126
363
|
params: Optional[Union[tuple, dict]] = None,
|
127
364
|
fetch_mode: FetchMode = FetchMode.ALL,
|
128
365
|
fetch_size: Optional[int] = None,
|
366
|
+
transaction: Optional[Any] = None,
|
129
367
|
) -> Any:
|
130
|
-
"""Execute query and return results."""
|
368
|
+
"""Execute query and return results, optionally within a transaction."""
|
131
369
|
pass
|
132
370
|
|
133
371
|
@abstractmethod
|
@@ -192,24 +430,57 @@ class PostgreSQLAdapter(DatabaseAdapter):
|
|
192
430
|
params: Optional[Union[tuple, dict]] = None,
|
193
431
|
fetch_mode: FetchMode = FetchMode.ALL,
|
194
432
|
fetch_size: Optional[int] = None,
|
433
|
+
transaction: Optional[Any] = None,
|
195
434
|
) -> Any:
|
196
435
|
"""Execute query and return results."""
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
436
|
+
# Convert dict params to positional for asyncpg
|
437
|
+
if isinstance(params, dict):
|
438
|
+
# Simple parameter substitution for named params
|
439
|
+
# In production, use a proper SQL parser
|
440
|
+
import json
|
441
|
+
|
442
|
+
query_params = []
|
443
|
+
for i, (key, value) in enumerate(params.items(), 1):
|
444
|
+
query = query.replace(f":{key}", f"${i}")
|
445
|
+
# For PostgreSQL, lists should remain as lists for array operations
|
446
|
+
# Only convert dicts to JSON strings
|
447
|
+
if isinstance(value, dict):
|
448
|
+
value = json.dumps(value)
|
449
|
+
query_params.append(value)
|
450
|
+
params = query_params
|
451
|
+
|
452
|
+
# Ensure params is a list/tuple for asyncpg
|
453
|
+
if params is None:
|
454
|
+
params = []
|
455
|
+
elif not isinstance(params, (list, tuple)):
|
456
|
+
params = [params]
|
457
|
+
|
458
|
+
# Execute query on appropriate connection
|
459
|
+
if transaction:
|
460
|
+
# Use transaction connection
|
461
|
+
conn, tx = transaction
|
462
|
+
|
463
|
+
# For UPDATE/DELETE queries without RETURNING, use execute() to get affected rows
|
464
|
+
query_upper = query.upper()
|
465
|
+
if (
|
466
|
+
(
|
467
|
+
"UPDATE" in query_upper
|
468
|
+
or "DELETE" in query_upper
|
469
|
+
or "INSERT" in query_upper
|
470
|
+
)
|
471
|
+
and "RETURNING" not in query_upper
|
472
|
+
and fetch_mode == FetchMode.ALL
|
473
|
+
):
|
474
|
+
result = await conn.execute(query, *params)
|
475
|
+
# asyncpg returns a string like "UPDATE 1", extract the count
|
476
|
+
if isinstance(result, str):
|
477
|
+
parts = result.split()
|
478
|
+
if len(parts) >= 2 and parts[1].isdigit():
|
479
|
+
rows_affected = int(parts[1])
|
480
|
+
else:
|
481
|
+
rows_affected = 0
|
482
|
+
return [{"rows_affected": rows_affected}]
|
483
|
+
return []
|
213
484
|
|
214
485
|
if fetch_mode == FetchMode.ONE:
|
215
486
|
row = await conn.fetchrow(query, *params)
|
@@ -224,26 +495,78 @@ class PostgreSQLAdapter(DatabaseAdapter):
|
|
224
495
|
return [self._convert_row(dict(row)) for row in rows[:fetch_size]]
|
225
496
|
elif fetch_mode == FetchMode.ITERATOR:
|
226
497
|
raise NotImplementedError("Iterator mode not yet implemented")
|
498
|
+
else:
|
499
|
+
# Use pool connection
|
500
|
+
async with self._pool.acquire() as conn:
|
501
|
+
# For UPDATE/DELETE queries without RETURNING, use execute() to get affected rows
|
502
|
+
query_upper = query.upper()
|
503
|
+
if (
|
504
|
+
(
|
505
|
+
"UPDATE" in query_upper
|
506
|
+
or "DELETE" in query_upper
|
507
|
+
or "INSERT" in query_upper
|
508
|
+
)
|
509
|
+
and "RETURNING" not in query_upper
|
510
|
+
and fetch_mode == FetchMode.ALL
|
511
|
+
):
|
512
|
+
result = await conn.execute(query, *params)
|
513
|
+
# asyncpg returns a string like "UPDATE 1", extract the count
|
514
|
+
if isinstance(result, str):
|
515
|
+
parts = result.split()
|
516
|
+
if len(parts) >= 2 and parts[1].isdigit():
|
517
|
+
rows_affected = int(parts[1])
|
518
|
+
else:
|
519
|
+
rows_affected = 0
|
520
|
+
return [{"rows_affected": rows_affected}]
|
521
|
+
return []
|
522
|
+
|
523
|
+
if fetch_mode == FetchMode.ONE:
|
524
|
+
row = await conn.fetchrow(query, *params)
|
525
|
+
return self._convert_row(dict(row)) if row else None
|
526
|
+
elif fetch_mode == FetchMode.ALL:
|
527
|
+
rows = await conn.fetch(query, *params)
|
528
|
+
return [self._convert_row(dict(row)) for row in rows]
|
529
|
+
elif fetch_mode == FetchMode.MANY:
|
530
|
+
if not fetch_size:
|
531
|
+
raise ValueError("fetch_size required for MANY mode")
|
532
|
+
rows = await conn.fetch(query, *params)
|
533
|
+
return [self._convert_row(dict(row)) for row in rows[:fetch_size]]
|
534
|
+
elif fetch_mode == FetchMode.ITERATOR:
|
535
|
+
raise NotImplementedError("Iterator mode not yet implemented")
|
227
536
|
|
228
537
|
async def execute_many(
|
229
|
-
self,
|
538
|
+
self,
|
539
|
+
query: str,
|
540
|
+
params_list: list[Union[tuple, dict]],
|
541
|
+
transaction: Optional[Any] = None,
|
230
542
|
) -> None:
|
231
543
|
"""Execute query multiple times with different parameters."""
|
232
|
-
|
233
|
-
# Convert all dict params to tuples
|
234
|
-
converted_params = []
|
235
|
-
for params in params_list:
|
236
|
-
if isinstance(params, dict):
|
237
|
-
query_params = []
|
238
|
-
for i, (key, value) in enumerate(params.items(), 1):
|
239
|
-
if i == 1: # Only replace on first iteration
|
240
|
-
query = query.replace(f":{key}", f"${i}")
|
241
|
-
query_params.append(value)
|
242
|
-
converted_params.append(query_params)
|
243
|
-
else:
|
244
|
-
converted_params.append(params)
|
544
|
+
# Convert all dict params to tuples
|
245
545
|
|
246
|
-
|
546
|
+
converted_params = []
|
547
|
+
query_converted = query
|
548
|
+
for params in params_list:
|
549
|
+
if isinstance(params, dict):
|
550
|
+
query_params = []
|
551
|
+
for i, (key, value) in enumerate(params.items(), 1):
|
552
|
+
if converted_params == []: # Only replace on first iteration
|
553
|
+
query_converted = query_converted.replace(f":{key}", f"${i}")
|
554
|
+
# Serialize complex objects to JSON strings for PostgreSQL
|
555
|
+
if isinstance(value, (dict, list)):
|
556
|
+
value = json.dumps(value)
|
557
|
+
query_params.append(value)
|
558
|
+
converted_params.append(query_params)
|
559
|
+
else:
|
560
|
+
converted_params.append(params)
|
561
|
+
|
562
|
+
if transaction:
|
563
|
+
# Use transaction connection
|
564
|
+
conn, tx = transaction
|
565
|
+
await conn.executemany(query_converted, converted_params)
|
566
|
+
else:
|
567
|
+
# Use pool connection
|
568
|
+
async with self._pool.acquire() as conn:
|
569
|
+
await conn.executemany(query_converted, converted_params)
|
247
570
|
|
248
571
|
async def begin_transaction(self) -> Any:
|
249
572
|
"""Begin a transaction."""
|
@@ -300,9 +623,12 @@ class MySQLAdapter(DatabaseAdapter):
|
|
300
623
|
params: Optional[Union[tuple, dict]] = None,
|
301
624
|
fetch_mode: FetchMode = FetchMode.ALL,
|
302
625
|
fetch_size: Optional[int] = None,
|
626
|
+
transaction: Optional[Any] = None,
|
303
627
|
) -> Any:
|
304
628
|
"""Execute query and return results."""
|
305
|
-
|
629
|
+
# Use transaction connection if provided, otherwise get from pool
|
630
|
+
if transaction:
|
631
|
+
conn = transaction
|
306
632
|
async with conn.cursor() as cursor:
|
307
633
|
await cursor.execute(query, params)
|
308
634
|
|
@@ -330,15 +656,56 @@ class MySQLAdapter(DatabaseAdapter):
|
|
330
656
|
self._convert_row(dict(zip(columns, row))) for row in rows
|
331
657
|
]
|
332
658
|
return []
|
659
|
+
else:
|
660
|
+
async with self._pool.acquire() as conn:
|
661
|
+
async with conn.cursor() as cursor:
|
662
|
+
await cursor.execute(query, params)
|
663
|
+
|
664
|
+
if fetch_mode == FetchMode.ONE:
|
665
|
+
row = await cursor.fetchone()
|
666
|
+
if row and cursor.description:
|
667
|
+
columns = [desc[0] for desc in cursor.description]
|
668
|
+
return self._convert_row(dict(zip(columns, row)))
|
669
|
+
return None
|
670
|
+
elif fetch_mode == FetchMode.ALL:
|
671
|
+
rows = await cursor.fetchall()
|
672
|
+
if rows and cursor.description:
|
673
|
+
columns = [desc[0] for desc in cursor.description]
|
674
|
+
return [
|
675
|
+
self._convert_row(dict(zip(columns, row)))
|
676
|
+
for row in rows
|
677
|
+
]
|
678
|
+
return []
|
679
|
+
elif fetch_mode == FetchMode.MANY:
|
680
|
+
if not fetch_size:
|
681
|
+
raise ValueError("fetch_size required for MANY mode")
|
682
|
+
rows = await cursor.fetchmany(fetch_size)
|
683
|
+
if rows and cursor.description:
|
684
|
+
columns = [desc[0] for desc in cursor.description]
|
685
|
+
return [
|
686
|
+
self._convert_row(dict(zip(columns, row)))
|
687
|
+
for row in rows
|
688
|
+
]
|
689
|
+
return []
|
333
690
|
|
334
691
|
async def execute_many(
|
335
|
-
self,
|
692
|
+
self,
|
693
|
+
query: str,
|
694
|
+
params_list: list[Union[tuple, dict]],
|
695
|
+
transaction: Optional[Any] = None,
|
336
696
|
) -> None:
|
337
697
|
"""Execute query multiple times with different parameters."""
|
338
|
-
|
339
|
-
|
698
|
+
if transaction:
|
699
|
+
# Use transaction connection
|
700
|
+
async with transaction.cursor() as cursor:
|
340
701
|
await cursor.executemany(query, params_list)
|
341
|
-
|
702
|
+
# Don't commit here - let transaction handling do it
|
703
|
+
else:
|
704
|
+
# Use pool connection
|
705
|
+
async with self._pool.acquire() as conn:
|
706
|
+
async with conn.cursor() as cursor:
|
707
|
+
await cursor.executemany(query, params_list)
|
708
|
+
await conn.commit()
|
342
709
|
|
343
710
|
async def begin_transaction(self) -> Any:
|
344
711
|
"""Begin a transaction."""
|
@@ -385,10 +752,12 @@ class SQLiteAdapter(DatabaseAdapter):
|
|
385
752
|
params: Optional[Union[tuple, dict]] = None,
|
386
753
|
fetch_mode: FetchMode = FetchMode.ALL,
|
387
754
|
fetch_size: Optional[int] = None,
|
755
|
+
transaction: Optional[Any] = None,
|
388
756
|
) -> Any:
|
389
757
|
"""Execute query and return results."""
|
390
|
-
|
391
|
-
|
758
|
+
if transaction:
|
759
|
+
# Use existing transaction connection
|
760
|
+
db = transaction
|
392
761
|
cursor = await db.execute(query, params or [])
|
393
762
|
|
394
763
|
if fetch_mode == FetchMode.ONE:
|
@@ -402,16 +771,42 @@ class SQLiteAdapter(DatabaseAdapter):
|
|
402
771
|
raise ValueError("fetch_size required for MANY mode")
|
403
772
|
rows = await cursor.fetchmany(fetch_size)
|
404
773
|
return [self._convert_row(dict(row)) for row in rows]
|
774
|
+
else:
|
775
|
+
# Create new connection for non-transactional queries
|
776
|
+
async with self._aiosqlite.connect(self._db_path) as db:
|
777
|
+
db.row_factory = self._aiosqlite.Row
|
778
|
+
cursor = await db.execute(query, params or [])
|
405
779
|
|
406
|
-
|
780
|
+
if fetch_mode == FetchMode.ONE:
|
781
|
+
row = await cursor.fetchone()
|
782
|
+
return self._convert_row(dict(row)) if row else None
|
783
|
+
elif fetch_mode == FetchMode.ALL:
|
784
|
+
rows = await cursor.fetchall()
|
785
|
+
return [self._convert_row(dict(row)) for row in rows]
|
786
|
+
elif fetch_mode == FetchMode.MANY:
|
787
|
+
if not fetch_size:
|
788
|
+
raise ValueError("fetch_size required for MANY mode")
|
789
|
+
rows = await cursor.fetchmany(fetch_size)
|
790
|
+
return [self._convert_row(dict(row)) for row in rows]
|
791
|
+
|
792
|
+
await db.commit()
|
407
793
|
|
408
794
|
async def execute_many(
|
409
|
-
self,
|
795
|
+
self,
|
796
|
+
query: str,
|
797
|
+
params_list: list[Union[tuple, dict]],
|
798
|
+
transaction: Optional[Any] = None,
|
410
799
|
) -> None:
|
411
800
|
"""Execute query multiple times with different parameters."""
|
412
|
-
|
413
|
-
|
414
|
-
await
|
801
|
+
if transaction:
|
802
|
+
# Use existing transaction connection
|
803
|
+
await transaction.executemany(query, params_list)
|
804
|
+
# Don't commit here - let transaction handling do it
|
805
|
+
else:
|
806
|
+
# Create new connection for non-transactional queries
|
807
|
+
async with self._aiosqlite.connect(self._db_path) as db:
|
808
|
+
await db.executemany(query, params_list)
|
809
|
+
await db.commit()
|
415
810
|
|
416
811
|
async def begin_transaction(self) -> Any:
|
417
812
|
"""Begin a transaction."""
|
@@ -431,6 +826,150 @@ class SQLiteAdapter(DatabaseAdapter):
|
|
431
826
|
await transaction.close()
|
432
827
|
|
433
828
|
|
829
|
+
class DatabaseConfigManager:
|
830
|
+
"""Manager for database configurations from YAML files."""
|
831
|
+
|
832
|
+
def __init__(self, config_path: Optional[str] = None):
|
833
|
+
"""Initialize with configuration file path.
|
834
|
+
|
835
|
+
Args:
|
836
|
+
config_path: Path to YAML configuration file. If not provided,
|
837
|
+
looks for 'database.yaml' in current directory.
|
838
|
+
"""
|
839
|
+
self.config_path = config_path or "database.yaml"
|
840
|
+
self._config: Optional[dict[str, Any]] = None
|
841
|
+
self._config_cache: dict[str, tuple[str, dict[str, Any]]] = {}
|
842
|
+
|
843
|
+
def _load_config(self) -> dict[str, Any]:
|
844
|
+
"""Load configuration from YAML file."""
|
845
|
+
if self._config is not None:
|
846
|
+
return self._config
|
847
|
+
|
848
|
+
if not os.path.exists(self.config_path):
|
849
|
+
# No config file, return empty config
|
850
|
+
self._config = {}
|
851
|
+
return self._config
|
852
|
+
|
853
|
+
try:
|
854
|
+
with open(self.config_path, "r") as f:
|
855
|
+
self._config = yaml.safe_load(f) or {}
|
856
|
+
return self._config
|
857
|
+
except yaml.YAMLError as e:
|
858
|
+
raise NodeValidationError(f"Invalid YAML in configuration file: {e}")
|
859
|
+
except Exception as e:
|
860
|
+
raise NodeExecutionError(f"Failed to load configuration file: {e}")
|
861
|
+
|
862
|
+
def get_database_config(self, connection_name: str) -> tuple[str, dict[str, Any]]:
|
863
|
+
"""Get database configuration by connection name.
|
864
|
+
|
865
|
+
Args:
|
866
|
+
connection_name: Name of the database connection from config
|
867
|
+
|
868
|
+
Returns:
|
869
|
+
Tuple of (connection_string, additional_config)
|
870
|
+
|
871
|
+
Raises:
|
872
|
+
NodeExecutionError: If connection not found
|
873
|
+
"""
|
874
|
+
# Check cache first
|
875
|
+
if connection_name in self._config_cache:
|
876
|
+
return self._config_cache[connection_name]
|
877
|
+
|
878
|
+
config = self._load_config()
|
879
|
+
databases = config.get("databases", {})
|
880
|
+
|
881
|
+
if connection_name in databases:
|
882
|
+
db_config = databases[connection_name].copy()
|
883
|
+
connection_string = db_config.pop(
|
884
|
+
"connection_string", db_config.pop("url", None)
|
885
|
+
)
|
886
|
+
|
887
|
+
if not connection_string:
|
888
|
+
raise NodeExecutionError(
|
889
|
+
f"No 'connection_string' or 'url' specified for database '{connection_name}'"
|
890
|
+
)
|
891
|
+
|
892
|
+
# Handle environment variable substitution
|
893
|
+
connection_string = self._substitute_env_vars(connection_string)
|
894
|
+
|
895
|
+
# Process other config values
|
896
|
+
for key, value in db_config.items():
|
897
|
+
if isinstance(value, str):
|
898
|
+
db_config[key] = self._substitute_env_vars(value)
|
899
|
+
|
900
|
+
# Cache the result
|
901
|
+
self._config_cache[connection_name] = (connection_string, db_config)
|
902
|
+
return connection_string, db_config
|
903
|
+
|
904
|
+
# Try default connection
|
905
|
+
if "default" in databases:
|
906
|
+
return self.get_database_config("default")
|
907
|
+
|
908
|
+
# No configuration found
|
909
|
+
available = list(databases.keys()) if databases else []
|
910
|
+
raise NodeExecutionError(
|
911
|
+
f"Database connection '{connection_name}' not found in configuration. "
|
912
|
+
f"Available connections: {available}"
|
913
|
+
)
|
914
|
+
|
915
|
+
def _substitute_env_vars(self, value: str) -> str:
|
916
|
+
"""Substitute environment variables in configuration values.
|
917
|
+
|
918
|
+
Supports:
|
919
|
+
- ${VAR_NAME} - Full substitution
|
920
|
+
- $VAR_NAME - Simple substitution
|
921
|
+
"""
|
922
|
+
if not isinstance(value, str):
|
923
|
+
return value
|
924
|
+
|
925
|
+
# Handle ${VAR_NAME} format
|
926
|
+
if value.startswith("${") and value.endswith("}"):
|
927
|
+
env_var = value[2:-1]
|
928
|
+
env_value = os.getenv(env_var)
|
929
|
+
if env_value is None:
|
930
|
+
raise NodeExecutionError(f"Environment variable '{env_var}' not found")
|
931
|
+
return env_value
|
932
|
+
|
933
|
+
# Handle $VAR_NAME and ${VAR_NAME} formats in connection strings
|
934
|
+
import re
|
935
|
+
|
936
|
+
# Pattern to match both $VAR_NAME and ${VAR_NAME}
|
937
|
+
pattern = r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}|\$([A-Za-z_][A-Za-z0-9_]*)"
|
938
|
+
|
939
|
+
def replace_var(match):
|
940
|
+
# Group 1 is for ${VAR_NAME}, group 2 is for $VAR_NAME
|
941
|
+
var_name = match.group(1) or match.group(2)
|
942
|
+
var_value = os.getenv(var_name)
|
943
|
+
if var_value is None:
|
944
|
+
raise NodeExecutionError(f"Environment variable '{var_name}' not found")
|
945
|
+
return var_value
|
946
|
+
|
947
|
+
return re.sub(pattern, replace_var, value)
|
948
|
+
|
949
|
+
def list_connections(self) -> list[str]:
|
950
|
+
"""List all available database connections."""
|
951
|
+
config = self._load_config()
|
952
|
+
databases = config.get("databases", {})
|
953
|
+
return list(databases.keys())
|
954
|
+
|
955
|
+
def validate_config(self) -> None:
|
956
|
+
"""Validate the configuration file."""
|
957
|
+
config = self._load_config()
|
958
|
+
databases = config.get("databases", {})
|
959
|
+
|
960
|
+
for name, db_config in databases.items():
|
961
|
+
if not isinstance(db_config, dict):
|
962
|
+
raise NodeValidationError(
|
963
|
+
f"Database '{name}' configuration must be a dictionary"
|
964
|
+
)
|
965
|
+
|
966
|
+
# Must have connection string
|
967
|
+
if "connection_string" not in db_config and "url" not in db_config:
|
968
|
+
raise NodeValidationError(
|
969
|
+
f"Database '{name}' must have 'connection_string' or 'url'"
|
970
|
+
)
|
971
|
+
|
972
|
+
|
434
973
|
@register_node()
|
435
974
|
class AsyncSQLDatabaseNode(AsyncNode):
|
436
975
|
"""Asynchronous SQL database node for high-concurrency database operations.
|
@@ -454,30 +993,157 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
454
993
|
pool_size: Initial connection pool size
|
455
994
|
max_pool_size: Maximum connection pool size
|
456
995
|
timeout: Query timeout in seconds
|
996
|
+
transaction_mode: Transaction handling mode ('auto', 'manual', 'none')
|
997
|
+
share_pool: Whether to share connection pool across instances (default: True)
|
998
|
+
|
999
|
+
Transaction Modes:
|
1000
|
+
- 'auto' (default): Each query runs in its own transaction, automatically
|
1001
|
+
committed on success or rolled back on error
|
1002
|
+
- 'manual': Transactions must be explicitly managed using begin_transaction(),
|
1003
|
+
commit(), and rollback() methods
|
1004
|
+
- 'none': No transaction wrapping, queries execute immediately
|
457
1005
|
|
458
|
-
Example:
|
1006
|
+
Example (auto transaction):
|
459
1007
|
>>> node = AsyncSQLDatabaseNode(
|
460
|
-
... name="
|
1008
|
+
... name="update_users",
|
1009
|
+
... database_type="postgresql",
|
1010
|
+
... host="localhost",
|
1011
|
+
... database="myapp",
|
1012
|
+
... user="dbuser",
|
1013
|
+
... password="dbpass"
|
1014
|
+
... )
|
1015
|
+
>>> # This will automatically rollback on error
|
1016
|
+
>>> await node.async_run(query="INSERT INTO users VALUES (1, 'test')")
|
1017
|
+
>>> await node.async_run(query="INVALID SQL") # Previous insert rolled back
|
1018
|
+
|
1019
|
+
Example (manual transaction):
|
1020
|
+
>>> node = AsyncSQLDatabaseNode(
|
1021
|
+
... name="transfer_funds",
|
461
1022
|
... database_type="postgresql",
|
462
1023
|
... host="localhost",
|
463
1024
|
... database="myapp",
|
464
1025
|
... user="dbuser",
|
465
1026
|
... password="dbpass",
|
466
|
-
...
|
467
|
-
... params={"active": True},
|
468
|
-
... fetch_mode="all"
|
1027
|
+
... transaction_mode="manual"
|
469
1028
|
... )
|
470
|
-
>>>
|
471
|
-
>>>
|
1029
|
+
>>> await node.begin_transaction()
|
1030
|
+
>>> try:
|
1031
|
+
... await node.async_run(query="UPDATE accounts SET balance = balance - 100 WHERE id = 1")
|
1032
|
+
... await node.async_run(query="UPDATE accounts SET balance = balance + 100 WHERE id = 2")
|
1033
|
+
... await node.commit()
|
1034
|
+
>>> except Exception:
|
1035
|
+
... await node.rollback()
|
1036
|
+
... raise
|
472
1037
|
"""
|
473
1038
|
|
1039
|
+
# Class-level pool storage for sharing across instances
|
1040
|
+
_shared_pools: dict[str, tuple[DatabaseAdapter, int]] = {}
|
1041
|
+
_pool_lock: Optional[asyncio.Lock] = None
|
1042
|
+
|
1043
|
+
@classmethod
|
1044
|
+
def _get_pool_lock(cls) -> asyncio.Lock:
|
1045
|
+
"""Get or create pool lock for the current event loop."""
|
1046
|
+
# Check if we have a lock and if it's for the current loop
|
1047
|
+
try:
|
1048
|
+
loop = asyncio.get_running_loop()
|
1049
|
+
except RuntimeError:
|
1050
|
+
# No running loop, create a new lock
|
1051
|
+
cls._pool_lock = asyncio.Lock()
|
1052
|
+
return cls._pool_lock
|
1053
|
+
|
1054
|
+
# Check if existing lock is for current loop
|
1055
|
+
if cls._pool_lock is None:
|
1056
|
+
cls._pool_lock = asyncio.Lock()
|
1057
|
+
cls._pool_lock_loop_id = id(loop)
|
1058
|
+
else:
|
1059
|
+
# Verify the lock is for the current event loop
|
1060
|
+
# Just create a new lock if we're in a different loop
|
1061
|
+
# The simplest approach is to store the loop ID with the lock
|
1062
|
+
if not hasattr(cls, "_pool_lock_loop_id"):
|
1063
|
+
cls._pool_lock_loop_id = id(loop)
|
1064
|
+
elif cls._pool_lock_loop_id != id(loop):
|
1065
|
+
# Different event loop, clear everything
|
1066
|
+
cls._pool_lock = asyncio.Lock()
|
1067
|
+
cls._pool_lock_loop_id = id(loop)
|
1068
|
+
cls._shared_pools.clear()
|
1069
|
+
|
1070
|
+
return cls._pool_lock
|
1071
|
+
|
474
1072
|
def __init__(self, **config):
|
475
1073
|
self._adapter: Optional[DatabaseAdapter] = None
|
476
1074
|
self._connected = False
|
477
1075
|
# Extract access control manager before passing to parent
|
478
1076
|
self.access_control_manager = config.pop("access_control_manager", None)
|
1077
|
+
|
1078
|
+
# Transaction state management
|
1079
|
+
self._active_transaction = None
|
1080
|
+
self._transaction_connection = None
|
1081
|
+
self._transaction_mode = config.get("transaction_mode", "auto")
|
1082
|
+
|
1083
|
+
# Pool sharing configuration
|
1084
|
+
self._share_pool = config.get("share_pool", True)
|
1085
|
+
self._pool_key = None
|
1086
|
+
|
1087
|
+
# Security configuration
|
1088
|
+
self._validate_queries = config.get("validate_queries", True)
|
1089
|
+
self._allow_admin = config.get("allow_admin", False)
|
1090
|
+
|
1091
|
+
# Retry configuration
|
1092
|
+
retry_config = config.get("retry_config")
|
1093
|
+
if retry_config:
|
1094
|
+
if isinstance(retry_config, dict):
|
1095
|
+
self._retry_config = RetryConfig(**retry_config)
|
1096
|
+
else:
|
1097
|
+
self._retry_config = retry_config
|
1098
|
+
else:
|
1099
|
+
# Build from individual parameters
|
1100
|
+
self._retry_config = RetryConfig(
|
1101
|
+
max_retries=config.get("max_retries", 3),
|
1102
|
+
initial_delay=config.get("retry_delay", 1.0),
|
1103
|
+
)
|
1104
|
+
|
1105
|
+
# Optimistic locking configuration
|
1106
|
+
self._enable_optimistic_locking = config.get("enable_optimistic_locking", False)
|
1107
|
+
self._version_field = config.get("version_field", "version")
|
1108
|
+
self._conflict_resolution = config.get("conflict_resolution", "fail_fast")
|
1109
|
+
self._version_retry_attempts = config.get("version_retry_attempts", 3)
|
1110
|
+
|
479
1111
|
super().__init__(**config)
|
480
1112
|
|
1113
|
+
def _reinitialize_from_config(self):
|
1114
|
+
"""Re-initialize instance variables from config after config file loading."""
|
1115
|
+
# Update transaction mode
|
1116
|
+
self._transaction_mode = self.config.get("transaction_mode", "auto")
|
1117
|
+
|
1118
|
+
# Update pool sharing configuration
|
1119
|
+
self._share_pool = self.config.get("share_pool", True)
|
1120
|
+
|
1121
|
+
# Update security configuration
|
1122
|
+
self._validate_queries = self.config.get("validate_queries", True)
|
1123
|
+
self._allow_admin = self.config.get("allow_admin", False)
|
1124
|
+
|
1125
|
+
# Update retry configuration
|
1126
|
+
retry_config = self.config.get("retry_config")
|
1127
|
+
if retry_config:
|
1128
|
+
if isinstance(retry_config, dict):
|
1129
|
+
self._retry_config = RetryConfig(**retry_config)
|
1130
|
+
else:
|
1131
|
+
self._retry_config = retry_config
|
1132
|
+
else:
|
1133
|
+
# Build from individual parameters
|
1134
|
+
self._retry_config = RetryConfig(
|
1135
|
+
max_retries=self.config.get("max_retries", 3),
|
1136
|
+
initial_delay=self.config.get("retry_delay", 1.0),
|
1137
|
+
)
|
1138
|
+
|
1139
|
+
# Update optimistic locking configuration
|
1140
|
+
self._enable_optimistic_locking = self.config.get(
|
1141
|
+
"enable_optimistic_locking", False
|
1142
|
+
)
|
1143
|
+
self._version_field = self.config.get("version_field", "version")
|
1144
|
+
self._conflict_resolution = self.config.get("conflict_resolution", "fail_fast")
|
1145
|
+
self._version_retry_attempts = self.config.get("version_retry_attempts", 3)
|
1146
|
+
|
481
1147
|
def get_parameters(self) -> dict[str, NodeParameter]:
|
482
1148
|
"""Define the parameters this node accepts."""
|
483
1149
|
params = [
|
@@ -494,6 +1160,18 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
494
1160
|
required=False,
|
495
1161
|
description="Full database connection string (overrides individual params)",
|
496
1162
|
),
|
1163
|
+
NodeParameter(
|
1164
|
+
name="connection_name",
|
1165
|
+
type=str,
|
1166
|
+
required=False,
|
1167
|
+
description="Name of database connection from config file",
|
1168
|
+
),
|
1169
|
+
NodeParameter(
|
1170
|
+
name="config_file",
|
1171
|
+
type=str,
|
1172
|
+
required=False,
|
1173
|
+
description="Path to YAML configuration file (default: database.yaml)",
|
1174
|
+
),
|
497
1175
|
NodeParameter(
|
498
1176
|
name="host", type=str, required=False, description="Database host"
|
499
1177
|
),
|
@@ -564,6 +1242,89 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
564
1242
|
required=False,
|
565
1243
|
description="User context for access control",
|
566
1244
|
),
|
1245
|
+
NodeParameter(
|
1246
|
+
name="transaction_mode",
|
1247
|
+
type=str,
|
1248
|
+
required=False,
|
1249
|
+
default="auto",
|
1250
|
+
description="Transaction mode: 'auto' (default), 'manual', or 'none'",
|
1251
|
+
),
|
1252
|
+
NodeParameter(
|
1253
|
+
name="share_pool",
|
1254
|
+
type=bool,
|
1255
|
+
required=False,
|
1256
|
+
default=True,
|
1257
|
+
description="Whether to share connection pool across instances with same config",
|
1258
|
+
),
|
1259
|
+
NodeParameter(
|
1260
|
+
name="validate_queries",
|
1261
|
+
type=bool,
|
1262
|
+
required=False,
|
1263
|
+
default=True,
|
1264
|
+
description="Whether to validate queries for SQL injection attempts",
|
1265
|
+
),
|
1266
|
+
NodeParameter(
|
1267
|
+
name="allow_admin",
|
1268
|
+
type=bool,
|
1269
|
+
required=False,
|
1270
|
+
default=False,
|
1271
|
+
description="Whether to allow administrative SQL commands (CREATE, DROP, etc.)",
|
1272
|
+
),
|
1273
|
+
NodeParameter(
|
1274
|
+
name="retry_config",
|
1275
|
+
type=Any,
|
1276
|
+
required=False,
|
1277
|
+
description="Retry configuration dict or RetryConfig object",
|
1278
|
+
),
|
1279
|
+
NodeParameter(
|
1280
|
+
name="max_retries",
|
1281
|
+
type=int,
|
1282
|
+
required=False,
|
1283
|
+
default=3,
|
1284
|
+
description="Maximum number of retry attempts for transient failures",
|
1285
|
+
),
|
1286
|
+
NodeParameter(
|
1287
|
+
name="retry_delay",
|
1288
|
+
type=float,
|
1289
|
+
required=False,
|
1290
|
+
default=1.0,
|
1291
|
+
description="Initial retry delay in seconds",
|
1292
|
+
),
|
1293
|
+
NodeParameter(
|
1294
|
+
name="enable_optimistic_locking",
|
1295
|
+
type=bool,
|
1296
|
+
required=False,
|
1297
|
+
default=False,
|
1298
|
+
description="Enable optimistic locking for version control",
|
1299
|
+
),
|
1300
|
+
NodeParameter(
|
1301
|
+
name="version_field",
|
1302
|
+
type=str,
|
1303
|
+
required=False,
|
1304
|
+
default="version",
|
1305
|
+
description="Column name for version tracking",
|
1306
|
+
),
|
1307
|
+
NodeParameter(
|
1308
|
+
name="conflict_resolution",
|
1309
|
+
type=str,
|
1310
|
+
required=False,
|
1311
|
+
default="fail_fast",
|
1312
|
+
description="How to handle version conflicts: fail_fast, retry, last_writer_wins",
|
1313
|
+
),
|
1314
|
+
NodeParameter(
|
1315
|
+
name="version_retry_attempts",
|
1316
|
+
type=int,
|
1317
|
+
required=False,
|
1318
|
+
default=3,
|
1319
|
+
description="Maximum retries for version conflicts",
|
1320
|
+
),
|
1321
|
+
NodeParameter(
|
1322
|
+
name="result_format",
|
1323
|
+
type=str,
|
1324
|
+
required=False,
|
1325
|
+
default="dict",
|
1326
|
+
description="Result format: 'dict' (default), 'list', or 'dataframe'",
|
1327
|
+
),
|
567
1328
|
]
|
568
1329
|
|
569
1330
|
# Convert list to dict as required by base class
|
@@ -573,6 +1334,39 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
573
1334
|
"""Validate node configuration."""
|
574
1335
|
super()._validate_config()
|
575
1336
|
|
1337
|
+
# Handle config file loading
|
1338
|
+
connection_name = self.config.get("connection_name")
|
1339
|
+
config_file = self.config.get("config_file")
|
1340
|
+
|
1341
|
+
if connection_name:
|
1342
|
+
# Load from config file
|
1343
|
+
config_manager = DatabaseConfigManager(config_file)
|
1344
|
+
try:
|
1345
|
+
conn_string, db_config = config_manager.get_database_config(
|
1346
|
+
connection_name
|
1347
|
+
)
|
1348
|
+
# Update config with values from file
|
1349
|
+
self.config["connection_string"] = conn_string
|
1350
|
+
# Merge additional config
|
1351
|
+
# Config file values should override defaults but not explicit params
|
1352
|
+
for key, value in db_config.items():
|
1353
|
+
# Check if this was explicitly provided by user
|
1354
|
+
param_info = self.get_parameters().get(key)
|
1355
|
+
if param_info and key in self.config:
|
1356
|
+
# If it equals the default, it wasn't explicitly set
|
1357
|
+
if self.config[key] == param_info.default:
|
1358
|
+
self.config[key] = value
|
1359
|
+
else:
|
1360
|
+
# Not a parameter or not in config yet
|
1361
|
+
self.config[key] = value
|
1362
|
+
except Exception as e:
|
1363
|
+
raise NodeValidationError(
|
1364
|
+
f"Failed to load config '{connection_name}': {e}"
|
1365
|
+
)
|
1366
|
+
|
1367
|
+
# Re-initialize instance variables with updated config
|
1368
|
+
self._reinitialize_from_config()
|
1369
|
+
|
576
1370
|
# Validate database type
|
577
1371
|
db_type = self.config.get("database_type", "").lower()
|
578
1372
|
if db_type not in ["postgresql", "mysql", "sqlite"]:
|
@@ -582,7 +1376,18 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
582
1376
|
)
|
583
1377
|
|
584
1378
|
# Validate connection parameters
|
585
|
-
|
1379
|
+
connection_string = self.config.get("connection_string")
|
1380
|
+
if connection_string:
|
1381
|
+
# Validate connection string for security
|
1382
|
+
if self._validate_queries:
|
1383
|
+
try:
|
1384
|
+
QueryValidator.validate_connection_string(connection_string)
|
1385
|
+
except NodeValidationError:
|
1386
|
+
raise NodeValidationError(
|
1387
|
+
"Connection string failed security validation. "
|
1388
|
+
"Set validate_queries=False to bypass (not recommended)."
|
1389
|
+
)
|
1390
|
+
else:
|
586
1391
|
if db_type != "sqlite":
|
587
1392
|
if not self.config.get("host") or not self.config.get("database"):
|
588
1393
|
raise NodeValidationError(
|
@@ -603,38 +1408,116 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
603
1408
|
if fetch_mode == "many" and not self.config.get("fetch_size"):
|
604
1409
|
raise NodeValidationError("fetch_size required when fetch_mode is 'many'")
|
605
1410
|
|
1411
|
+
# Validate initial query if provided
|
1412
|
+
if self.config.get("query") and self._validate_queries:
|
1413
|
+
try:
|
1414
|
+
QueryValidator.validate_query(
|
1415
|
+
self.config["query"], allow_admin=self._allow_admin
|
1416
|
+
)
|
1417
|
+
except NodeValidationError as e:
|
1418
|
+
raise NodeValidationError(
|
1419
|
+
f"Initial query validation failed: {e}. "
|
1420
|
+
"Set validate_queries=False to bypass (not recommended)."
|
1421
|
+
)
|
1422
|
+
|
1423
|
+
def _generate_pool_key(self) -> str:
|
1424
|
+
"""Generate a unique key for connection pool sharing."""
|
1425
|
+
# Create a unique key based on connection parameters
|
1426
|
+
key_parts = [
|
1427
|
+
self.config.get("database_type", ""),
|
1428
|
+
self.config.get("connection_string", "")
|
1429
|
+
or (
|
1430
|
+
f"{self.config.get('host', '')}:"
|
1431
|
+
f"{self.config.get('port', '')}:"
|
1432
|
+
f"{self.config.get('database', '')}:"
|
1433
|
+
f"{self.config.get('user', '')}"
|
1434
|
+
),
|
1435
|
+
str(self.config.get("pool_size", 10)),
|
1436
|
+
str(self.config.get("max_pool_size", 20)),
|
1437
|
+
]
|
1438
|
+
return "|".join(key_parts)
|
1439
|
+
|
606
1440
|
async def _get_adapter(self) -> DatabaseAdapter:
|
607
|
-
"""Get or create database adapter."""
|
1441
|
+
"""Get or create database adapter with optional pool sharing."""
|
608
1442
|
if not self._adapter:
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
elif db_type == DatabaseType.MYSQL:
|
626
|
-
self._adapter = MySQLAdapter(db_config)
|
627
|
-
elif db_type == DatabaseType.SQLITE:
|
628
|
-
self._adapter = SQLiteAdapter(db_config)
|
1443
|
+
if self._share_pool:
|
1444
|
+
# Use shared pool if available
|
1445
|
+
async with self._get_pool_lock():
|
1446
|
+
self._pool_key = self._generate_pool_key()
|
1447
|
+
|
1448
|
+
if self._pool_key in self._shared_pools:
|
1449
|
+
# Reuse existing pool
|
1450
|
+
adapter, ref_count = self._shared_pools[self._pool_key]
|
1451
|
+
self._shared_pools[self._pool_key] = (adapter, ref_count + 1)
|
1452
|
+
self._adapter = adapter
|
1453
|
+
self._connected = True
|
1454
|
+
return self._adapter
|
1455
|
+
|
1456
|
+
# Create new shared pool
|
1457
|
+
self._adapter = await self._create_adapter()
|
1458
|
+
self._shared_pools[self._pool_key] = (self._adapter, 1)
|
629
1459
|
else:
|
630
|
-
|
631
|
-
|
632
|
-
if not self._connected:
|
633
|
-
await self._adapter.connect()
|
634
|
-
self._connected = True
|
1460
|
+
# Create dedicated pool
|
1461
|
+
self._adapter = await self._create_adapter()
|
635
1462
|
|
636
1463
|
return self._adapter
|
637
1464
|
|
1465
|
+
async def _create_adapter(self) -> DatabaseAdapter:
|
1466
|
+
"""Create a new database adapter with retry logic for initial connection."""
|
1467
|
+
db_type = DatabaseType(self.config["database_type"].lower())
|
1468
|
+
db_config = DatabaseConfig(
|
1469
|
+
type=db_type,
|
1470
|
+
host=self.config.get("host"),
|
1471
|
+
port=self.config.get("port"),
|
1472
|
+
database=self.config.get("database"),
|
1473
|
+
user=self.config.get("user"),
|
1474
|
+
password=self.config.get("password"),
|
1475
|
+
connection_string=self.config.get("connection_string"),
|
1476
|
+
pool_size=self.config.get("pool_size", 10),
|
1477
|
+
max_pool_size=self.config.get("max_pool_size", 20),
|
1478
|
+
command_timeout=self.config.get("timeout", 60.0),
|
1479
|
+
)
|
1480
|
+
|
1481
|
+
if db_type == DatabaseType.POSTGRESQL:
|
1482
|
+
adapter = PostgreSQLAdapter(db_config)
|
1483
|
+
elif db_type == DatabaseType.MYSQL:
|
1484
|
+
adapter = MySQLAdapter(db_config)
|
1485
|
+
elif db_type == DatabaseType.SQLITE:
|
1486
|
+
adapter = SQLiteAdapter(db_config)
|
1487
|
+
else:
|
1488
|
+
raise NodeExecutionError(f"Unsupported database type: {db_type}")
|
1489
|
+
|
1490
|
+
# Retry connection with exponential backoff
|
1491
|
+
last_error = None
|
1492
|
+
for attempt in range(self._retry_config.max_retries):
|
1493
|
+
try:
|
1494
|
+
await adapter.connect()
|
1495
|
+
self._connected = True
|
1496
|
+
return adapter
|
1497
|
+
except Exception as e:
|
1498
|
+
last_error = e
|
1499
|
+
|
1500
|
+
# Check if error is retryable
|
1501
|
+
if not self._retry_config.should_retry(e):
|
1502
|
+
raise
|
1503
|
+
|
1504
|
+
# Check if we have more attempts
|
1505
|
+
if attempt >= self._retry_config.max_retries - 1:
|
1506
|
+
raise NodeExecutionError(
|
1507
|
+
f"Failed to connect after {self._retry_config.max_retries} attempts: {e}"
|
1508
|
+
)
|
1509
|
+
|
1510
|
+
# Calculate delay
|
1511
|
+
delay = self._retry_config.get_delay(attempt)
|
1512
|
+
|
1513
|
+
# Wait before retry
|
1514
|
+
await asyncio.sleep(delay)
|
1515
|
+
|
1516
|
+
# Should not reach here, but just in case
|
1517
|
+
raise NodeExecutionError(
|
1518
|
+
f"Failed to connect after {self._retry_config.max_retries} attempts: {last_error}"
|
1519
|
+
)
|
1520
|
+
|
638
1521
|
async def async_run(self, **inputs) -> dict[str, Any]:
|
639
1522
|
"""Execute database query asynchronously with optional access control."""
|
640
1523
|
try:
|
@@ -645,11 +1528,33 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
645
1528
|
inputs.get("fetch_mode", self.config.get("fetch_mode", "all")).lower()
|
646
1529
|
)
|
647
1530
|
fetch_size = inputs.get("fetch_size", self.config.get("fetch_size"))
|
1531
|
+
result_format = inputs.get(
|
1532
|
+
"result_format", self.config.get("result_format", "dict")
|
1533
|
+
)
|
648
1534
|
user_context = inputs.get("user_context")
|
649
1535
|
|
650
1536
|
if not query:
|
651
1537
|
raise NodeExecutionError("No query provided")
|
652
1538
|
|
1539
|
+
# Handle parameter style conversion
|
1540
|
+
if params is not None:
|
1541
|
+
if isinstance(params, (list, tuple)):
|
1542
|
+
# Convert positional parameters to named parameters
|
1543
|
+
query, params = self._convert_to_named_parameters(query, params)
|
1544
|
+
elif not isinstance(params, dict):
|
1545
|
+
# Single parameter - wrap in list and convert
|
1546
|
+
query, params = self._convert_to_named_parameters(query, [params])
|
1547
|
+
|
1548
|
+
# Validate query for security
|
1549
|
+
if self._validate_queries:
|
1550
|
+
try:
|
1551
|
+
QueryValidator.validate_query(query, allow_admin=self._allow_admin)
|
1552
|
+
except NodeValidationError as e:
|
1553
|
+
raise NodeExecutionError(
|
1554
|
+
f"Query validation failed: {e}. "
|
1555
|
+
"Set validate_queries=False to bypass (not recommended)."
|
1556
|
+
)
|
1557
|
+
|
653
1558
|
# Check access control if enabled
|
654
1559
|
if self.access_control_manager and user_context:
|
655
1560
|
from kailash.access_control import NodePermission
|
@@ -660,28 +1565,269 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
660
1565
|
if not decision.allowed:
|
661
1566
|
raise NodeExecutionError(f"Access denied: {decision.reason}")
|
662
1567
|
|
663
|
-
# Get adapter
|
1568
|
+
# Get adapter
|
664
1569
|
adapter = await self._get_adapter()
|
665
1570
|
|
666
1571
|
# Execute query with retry logic
|
667
|
-
|
668
|
-
|
1572
|
+
result = await self._execute_with_retry(
|
1573
|
+
adapter=adapter,
|
1574
|
+
query=query,
|
1575
|
+
params=params,
|
1576
|
+
fetch_mode=fetch_mode,
|
1577
|
+
fetch_size=fetch_size,
|
1578
|
+
user_context=user_context,
|
1579
|
+
)
|
1580
|
+
|
1581
|
+
# Format results based on requested format
|
1582
|
+
formatted_data = self._format_results(result, result_format)
|
669
1583
|
|
670
|
-
for
|
1584
|
+
# For DataFrame, we need special handling for row count
|
1585
|
+
row_count = 0
|
1586
|
+
if result_format == "dataframe":
|
671
1587
|
try:
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
1588
|
+
row_count = len(formatted_data)
|
1589
|
+
except:
|
1590
|
+
# If pandas isn't available, formatted_data is still a list
|
1591
|
+
row_count = (
|
1592
|
+
len(result)
|
1593
|
+
if isinstance(result, list)
|
1594
|
+
else (1 if result else 0)
|
677
1595
|
)
|
1596
|
+
else:
|
1597
|
+
row_count = (
|
1598
|
+
len(result) if isinstance(result, list) else (1 if result else 0)
|
1599
|
+
)
|
678
1600
|
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
1601
|
+
# Extract column names if available
|
1602
|
+
columns = []
|
1603
|
+
if result and isinstance(result, list) and result:
|
1604
|
+
if isinstance(result[0], dict):
|
1605
|
+
columns = list(result[0].keys())
|
1606
|
+
|
1607
|
+
# Handle DataFrame serialization for JSON compatibility
|
1608
|
+
if result_format == "dataframe":
|
1609
|
+
try:
|
1610
|
+
import pandas as pd
|
1611
|
+
|
1612
|
+
if isinstance(formatted_data, pd.DataFrame):
|
1613
|
+
# Convert DataFrame to JSON-compatible format
|
1614
|
+
serializable_data = {
|
1615
|
+
"dataframe": formatted_data.to_dict("records"),
|
1616
|
+
"columns": formatted_data.columns.tolist(),
|
1617
|
+
"index": formatted_data.index.tolist(),
|
1618
|
+
"_type": "dataframe",
|
1619
|
+
}
|
1620
|
+
else:
|
1621
|
+
# pandas not available, use regular data
|
1622
|
+
serializable_data = formatted_data
|
1623
|
+
except ImportError:
|
1624
|
+
serializable_data = formatted_data
|
1625
|
+
else:
|
1626
|
+
serializable_data = formatted_data
|
1627
|
+
|
1628
|
+
result_dict = {
|
1629
|
+
"result": {
|
1630
|
+
"data": serializable_data,
|
1631
|
+
"row_count": row_count,
|
1632
|
+
"query": query,
|
1633
|
+
"database_type": self.config["database_type"],
|
1634
|
+
"format": result_format,
|
1635
|
+
}
|
1636
|
+
}
|
1637
|
+
|
1638
|
+
# Add columns info for list format
|
1639
|
+
if result_format == "list" and columns:
|
1640
|
+
result_dict["result"]["columns"] = columns
|
1641
|
+
|
1642
|
+
return result_dict
|
1643
|
+
|
1644
|
+
except NodeExecutionError:
|
1645
|
+
# Re-raise our own errors
|
1646
|
+
raise
|
1647
|
+
except Exception as e:
|
1648
|
+
# Wrap other errors
|
1649
|
+
raise NodeExecutionError(f"Database query failed: {str(e)}")
|
1650
|
+
|
1651
|
+
async def process(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
1652
|
+
"""Async process method for middleware compatibility."""
|
1653
|
+
return await self.async_run(**inputs)
|
1654
|
+
|
1655
|
+
async def execute_many_async(
|
1656
|
+
self, query: str, params_list: list[dict[str, Any]]
|
1657
|
+
) -> dict[str, Any]:
|
1658
|
+
"""Execute the same query multiple times with different parameters.
|
1659
|
+
|
1660
|
+
This is useful for bulk inserts, updates, or deletes. The operation
|
1661
|
+
runs in a single transaction (in auto or manual mode) for better
|
1662
|
+
performance and atomicity.
|
1663
|
+
|
1664
|
+
Args:
|
1665
|
+
query: SQL query to execute multiple times
|
1666
|
+
params_list: List of parameter dictionaries
|
1667
|
+
|
1668
|
+
Returns:
|
1669
|
+
dict: Result with affected row count
|
1670
|
+
|
1671
|
+
Example:
|
1672
|
+
>>> params_list = [
|
1673
|
+
... {"name": "Alice", "age": 30},
|
1674
|
+
... {"name": "Bob", "age": 25},
|
1675
|
+
... {"name": "Charlie", "age": 35},
|
1676
|
+
... ]
|
1677
|
+
>>> result = await node.execute_many_async(
|
1678
|
+
... query="INSERT INTO users (name, age) VALUES (:name, :age)",
|
1679
|
+
... params_list=params_list
|
1680
|
+
... )
|
1681
|
+
>>> print(result["result"]["affected_rows"]) # 3
|
1682
|
+
"""
|
1683
|
+
if not params_list:
|
1684
|
+
return {
|
1685
|
+
"result": {
|
1686
|
+
"affected_rows": 0,
|
1687
|
+
"query": query,
|
1688
|
+
"database_type": self.config["database_type"],
|
1689
|
+
}
|
1690
|
+
}
|
1691
|
+
|
1692
|
+
# Validate query if security is enabled
|
1693
|
+
if self._validate_queries:
|
1694
|
+
try:
|
1695
|
+
QueryValidator.validate_query(query, allow_admin=self._allow_admin)
|
1696
|
+
except NodeValidationError as e:
|
1697
|
+
raise NodeExecutionError(
|
1698
|
+
f"Query validation failed: {e}. "
|
1699
|
+
"Set validate_queries=False to bypass (not recommended)."
|
1700
|
+
)
|
1701
|
+
|
1702
|
+
try:
|
1703
|
+
# Get adapter
|
1704
|
+
adapter = await self._get_adapter()
|
1705
|
+
|
1706
|
+
# Execute batch with retry logic
|
1707
|
+
affected_rows = await self._execute_many_with_retry(
|
1708
|
+
adapter=adapter,
|
1709
|
+
query=query,
|
1710
|
+
params_list=params_list,
|
1711
|
+
)
|
1712
|
+
|
1713
|
+
return {
|
1714
|
+
"result": {
|
1715
|
+
"affected_rows": affected_rows,
|
1716
|
+
"batch_size": len(params_list),
|
1717
|
+
"query": query,
|
1718
|
+
"database_type": self.config["database_type"],
|
1719
|
+
}
|
1720
|
+
}
|
1721
|
+
|
1722
|
+
except NodeExecutionError:
|
1723
|
+
raise
|
1724
|
+
except Exception as e:
|
1725
|
+
raise NodeExecutionError(f"Batch operation failed: {str(e)}")
|
1726
|
+
|
1727
|
+
async def begin_transaction(self):
|
1728
|
+
"""Begin a manual transaction.
|
1729
|
+
|
1730
|
+
Returns:
|
1731
|
+
Transaction context that can be used for manual control
|
1732
|
+
|
1733
|
+
Raises:
|
1734
|
+
NodeExecutionError: If transaction already active or mode is 'auto'
|
1735
|
+
"""
|
1736
|
+
if self._transaction_mode != "manual":
|
1737
|
+
raise NodeExecutionError(
|
1738
|
+
"begin_transaction() can only be called in 'manual' transaction mode"
|
1739
|
+
)
|
1740
|
+
|
1741
|
+
if self._active_transaction:
|
1742
|
+
raise NodeExecutionError("Transaction already active")
|
1743
|
+
|
1744
|
+
adapter = await self._get_adapter()
|
1745
|
+
self._active_transaction = await adapter.begin_transaction()
|
1746
|
+
return self._active_transaction
|
1747
|
+
|
1748
|
+
async def commit(self):
|
1749
|
+
"""Commit the active transaction.
|
1750
|
+
|
1751
|
+
Raises:
|
1752
|
+
NodeExecutionError: If no active transaction or mode is not 'manual'
|
1753
|
+
"""
|
1754
|
+
if self._transaction_mode != "manual":
|
1755
|
+
raise NodeExecutionError(
|
1756
|
+
"commit() can only be called in 'manual' transaction mode"
|
1757
|
+
)
|
1758
|
+
|
1759
|
+
if not self._active_transaction:
|
1760
|
+
raise NodeExecutionError("No active transaction to commit")
|
1761
|
+
|
1762
|
+
adapter = await self._get_adapter()
|
1763
|
+
try:
|
1764
|
+
await adapter.commit_transaction(self._active_transaction)
|
1765
|
+
finally:
|
1766
|
+
# Always clear transaction, even on error
|
1767
|
+
self._active_transaction = None
|
1768
|
+
|
1769
|
+
async def rollback(self):
|
1770
|
+
"""Rollback the active transaction.
|
1771
|
+
|
1772
|
+
Raises:
|
1773
|
+
NodeExecutionError: If no active transaction or mode is not 'manual'
|
1774
|
+
"""
|
1775
|
+
if self._transaction_mode != "manual":
|
1776
|
+
raise NodeExecutionError(
|
1777
|
+
"rollback() can only be called in 'manual' transaction mode"
|
1778
|
+
)
|
1779
|
+
|
1780
|
+
if not self._active_transaction:
|
1781
|
+
raise NodeExecutionError("No active transaction to rollback")
|
1782
|
+
|
1783
|
+
adapter = await self._get_adapter()
|
1784
|
+
try:
|
1785
|
+
await adapter.rollback_transaction(self._active_transaction)
|
1786
|
+
finally:
|
1787
|
+
# Always clear transaction, even on error
|
1788
|
+
self._active_transaction = None
|
1789
|
+
|
1790
|
+
async def _execute_with_retry(
|
1791
|
+
self,
|
1792
|
+
adapter: DatabaseAdapter,
|
1793
|
+
query: str,
|
1794
|
+
params: Any,
|
1795
|
+
fetch_mode: FetchMode,
|
1796
|
+
fetch_size: Optional[int],
|
1797
|
+
user_context: Any = None,
|
1798
|
+
) -> Any:
|
1799
|
+
"""Execute query with retry logic for transient failures.
|
1800
|
+
|
1801
|
+
Args:
|
1802
|
+
adapter: Database adapter
|
1803
|
+
query: SQL query
|
1804
|
+
params: Query parameters
|
1805
|
+
fetch_mode: How to fetch results
|
1806
|
+
fetch_size: Number of rows for 'many' mode
|
1807
|
+
user_context: User context for access control
|
1808
|
+
|
1809
|
+
Returns:
|
1810
|
+
Query results
|
1811
|
+
|
1812
|
+
Raises:
|
1813
|
+
NodeExecutionError: After all retry attempts are exhausted
|
1814
|
+
"""
|
1815
|
+
last_error = None
|
1816
|
+
|
1817
|
+
for attempt in range(self._retry_config.max_retries):
|
1818
|
+
try:
|
1819
|
+
# Execute query with transaction
|
1820
|
+
result = await self._execute_with_transaction(
|
1821
|
+
adapter=adapter,
|
1822
|
+
query=query,
|
1823
|
+
params=params,
|
1824
|
+
fetch_mode=fetch_mode,
|
1825
|
+
fetch_size=fetch_size,
|
1826
|
+
)
|
1827
|
+
|
1828
|
+
# Apply data masking if access control is enabled
|
1829
|
+
if self.access_control_manager and user_context:
|
1830
|
+
if isinstance(result, list):
|
685
1831
|
masked_result = []
|
686
1832
|
for row in result:
|
687
1833
|
masked_row = self.access_control_manager.apply_data_masking(
|
@@ -689,45 +1835,726 @@ class AsyncSQLDatabaseNode(AsyncNode):
|
|
689
1835
|
)
|
690
1836
|
masked_result.append(masked_row)
|
691
1837
|
result = masked_result
|
692
|
-
elif (
|
693
|
-
self.access_control_manager
|
694
|
-
and user_context
|
695
|
-
and isinstance(result, dict)
|
696
|
-
):
|
1838
|
+
elif isinstance(result, dict):
|
697
1839
|
result = self.access_control_manager.apply_data_masking(
|
698
1840
|
user_context, self.metadata.name, result
|
699
1841
|
)
|
700
1842
|
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
1843
|
+
return result
|
1844
|
+
|
1845
|
+
except Exception as e:
|
1846
|
+
last_error = e
|
1847
|
+
|
1848
|
+
# Check if error is retryable
|
1849
|
+
if not self._retry_config.should_retry(e):
|
1850
|
+
raise
|
1851
|
+
|
1852
|
+
# Check if we have more attempts
|
1853
|
+
if attempt >= self._retry_config.max_retries - 1:
|
1854
|
+
raise
|
1855
|
+
|
1856
|
+
# Calculate delay
|
1857
|
+
delay = self._retry_config.get_delay(attempt)
|
1858
|
+
|
1859
|
+
# Log retry attempt (if logging is available)
|
1860
|
+
try:
|
1861
|
+
self.logger.warning(
|
1862
|
+
f"Query failed (attempt {attempt + 1}/{self._retry_config.max_retries}): {e}. "
|
1863
|
+
f"Retrying in {delay:.2f} seconds..."
|
1864
|
+
)
|
1865
|
+
except AttributeError:
|
1866
|
+
# No logger available
|
1867
|
+
pass
|
1868
|
+
|
1869
|
+
# Wait before retry
|
1870
|
+
await asyncio.sleep(delay)
|
1871
|
+
|
1872
|
+
# For connection errors, try to reconnect
|
1873
|
+
if "pool is closed" in str(e).lower() or "connection" in str(e).lower():
|
1874
|
+
try:
|
1875
|
+
# Clear existing adapter to force reconnection
|
1876
|
+
if self._share_pool and self._pool_key:
|
1877
|
+
# Remove from shared pools to force recreation
|
1878
|
+
async with self._get_pool_lock():
|
1879
|
+
if self._pool_key in self._shared_pools:
|
1880
|
+
_, ref_count = self._shared_pools[self._pool_key]
|
1881
|
+
if ref_count <= 1:
|
1882
|
+
del self._shared_pools[self._pool_key]
|
1883
|
+
else:
|
1884
|
+
# This shouldn't happen with a closed pool
|
1885
|
+
del self._shared_pools[self._pool_key]
|
1886
|
+
|
1887
|
+
self._adapter = None
|
1888
|
+
self._connected = False
|
1889
|
+
adapter = await self._get_adapter()
|
1890
|
+
except Exception:
|
1891
|
+
# If reconnection fails, continue with retry loop
|
1892
|
+
pass
|
1893
|
+
|
1894
|
+
# All retries exhausted
|
1895
|
+
raise NodeExecutionError(
|
1896
|
+
f"Query failed after {self._retry_config.max_retries} attempts: {last_error}"
|
1897
|
+
)
|
1898
|
+
|
1899
|
+
async def _execute_many_with_retry(
|
1900
|
+
self, adapter: DatabaseAdapter, query: str, params_list: list[dict[str, Any]]
|
1901
|
+
) -> int:
|
1902
|
+
"""Execute batch operation with retry logic.
|
1903
|
+
|
1904
|
+
Args:
|
1905
|
+
adapter: Database adapter
|
1906
|
+
query: SQL query to execute
|
1907
|
+
params_list: List of parameter dictionaries
|
1908
|
+
|
1909
|
+
Returns:
|
1910
|
+
Number of affected rows
|
1911
|
+
|
1912
|
+
Raises:
|
1913
|
+
NodeExecutionError: After all retry attempts are exhausted
|
1914
|
+
"""
|
1915
|
+
last_error = None
|
1916
|
+
|
1917
|
+
for attempt in range(self._retry_config.max_retries):
|
1918
|
+
try:
|
1919
|
+
# Execute batch with transaction
|
1920
|
+
return await self._execute_many_with_transaction(
|
1921
|
+
adapter=adapter,
|
1922
|
+
query=query,
|
1923
|
+
params_list=params_list,
|
1924
|
+
)
|
1925
|
+
|
1926
|
+
except Exception as e:
|
1927
|
+
last_error = e
|
1928
|
+
|
1929
|
+
# Check if error is retryable
|
1930
|
+
if not self._retry_config.should_retry(e):
|
1931
|
+
raise
|
1932
|
+
|
1933
|
+
# Check if we have more attempts
|
1934
|
+
if attempt >= self._retry_config.max_retries - 1:
|
1935
|
+
raise
|
1936
|
+
|
1937
|
+
# Calculate delay
|
1938
|
+
delay = self._retry_config.get_delay(attempt)
|
1939
|
+
|
1940
|
+
# Wait before retry
|
1941
|
+
await asyncio.sleep(delay)
|
1942
|
+
|
1943
|
+
# For connection errors, try to reconnect
|
1944
|
+
if "pool is closed" in str(e).lower() or "connection" in str(e).lower():
|
1945
|
+
try:
|
1946
|
+
# Clear existing adapter to force reconnection
|
1947
|
+
if self._share_pool and self._pool_key:
|
1948
|
+
# Remove from shared pools to force recreation
|
1949
|
+
async with self._get_pool_lock():
|
1950
|
+
if self._pool_key in self._shared_pools:
|
1951
|
+
_, ref_count = self._shared_pools[self._pool_key]
|
1952
|
+
if ref_count <= 1:
|
1953
|
+
del self._shared_pools[self._pool_key]
|
1954
|
+
else:
|
1955
|
+
# This shouldn't happen with a closed pool
|
1956
|
+
del self._shared_pools[self._pool_key]
|
1957
|
+
|
1958
|
+
self._adapter = None
|
1959
|
+
self._connected = False
|
1960
|
+
adapter = await self._get_adapter()
|
1961
|
+
except Exception:
|
1962
|
+
# If reconnection fails, continue with retry loop
|
1963
|
+
pass
|
1964
|
+
|
1965
|
+
# All retries exhausted
|
1966
|
+
raise NodeExecutionError(
|
1967
|
+
f"Batch operation failed after {self._retry_config.max_retries} attempts: {last_error}"
|
1968
|
+
)
|
1969
|
+
|
1970
|
+
async def _execute_many_with_transaction(
|
1971
|
+
self, adapter: DatabaseAdapter, query: str, params_list: list[dict[str, Any]]
|
1972
|
+
) -> int:
|
1973
|
+
"""Execute batch operation with automatic transaction management.
|
1974
|
+
|
1975
|
+
Args:
|
1976
|
+
adapter: Database adapter
|
1977
|
+
query: SQL query to execute
|
1978
|
+
params_list: List of parameter dictionaries
|
1979
|
+
|
1980
|
+
Returns:
|
1981
|
+
Number of affected rows (estimated)
|
1982
|
+
|
1983
|
+
Raises:
|
1984
|
+
Exception: Re-raises any execution errors after rollback
|
1985
|
+
"""
|
1986
|
+
if self._active_transaction:
|
1987
|
+
# Use existing transaction (manual mode)
|
1988
|
+
await adapter.execute_many(query, params_list, self._active_transaction)
|
1989
|
+
# Most adapters don't return row count from execute_many
|
1990
|
+
return len(params_list)
|
1991
|
+
elif self._transaction_mode == "auto":
|
1992
|
+
# Auto-transaction mode
|
1993
|
+
transaction = await adapter.begin_transaction()
|
1994
|
+
try:
|
1995
|
+
await adapter.execute_many(query, params_list, transaction)
|
1996
|
+
await adapter.commit_transaction(transaction)
|
1997
|
+
return len(params_list)
|
1998
|
+
except Exception:
|
1999
|
+
await adapter.rollback_transaction(transaction)
|
2000
|
+
raise
|
2001
|
+
else:
|
2002
|
+
# No transaction mode
|
2003
|
+
await adapter.execute_many(query, params_list)
|
2004
|
+
return len(params_list)
|
2005
|
+
|
2006
|
+
async def _execute_with_transaction(
|
2007
|
+
self,
|
2008
|
+
adapter: DatabaseAdapter,
|
2009
|
+
query: str,
|
2010
|
+
params: Any,
|
2011
|
+
fetch_mode: FetchMode,
|
2012
|
+
fetch_size: Optional[int],
|
2013
|
+
) -> Any:
|
2014
|
+
"""Execute query with automatic transaction management.
|
2015
|
+
|
2016
|
+
Args:
|
2017
|
+
adapter: Database adapter
|
2018
|
+
query: SQL query
|
2019
|
+
params: Query parameters
|
2020
|
+
fetch_mode: How to fetch results
|
2021
|
+
fetch_size: Number of rows for 'many' mode
|
2022
|
+
|
2023
|
+
Returns:
|
2024
|
+
Query results
|
2025
|
+
|
2026
|
+
Raises:
|
2027
|
+
Exception: Re-raises any execution errors after rollback
|
2028
|
+
"""
|
2029
|
+
if self._active_transaction:
|
2030
|
+
# Use existing transaction (manual mode)
|
2031
|
+
return await adapter.execute(
|
2032
|
+
query=query,
|
2033
|
+
params=params,
|
2034
|
+
fetch_mode=fetch_mode,
|
2035
|
+
fetch_size=fetch_size,
|
2036
|
+
transaction=self._active_transaction,
|
2037
|
+
)
|
2038
|
+
elif self._transaction_mode == "auto":
|
2039
|
+
# Auto-transaction mode
|
2040
|
+
transaction = await adapter.begin_transaction()
|
2041
|
+
try:
|
2042
|
+
result = await adapter.execute(
|
2043
|
+
query=query,
|
2044
|
+
params=params,
|
2045
|
+
fetch_mode=fetch_mode,
|
2046
|
+
fetch_size=fetch_size,
|
2047
|
+
transaction=transaction,
|
2048
|
+
)
|
2049
|
+
await adapter.commit_transaction(transaction)
|
2050
|
+
return result
|
2051
|
+
except Exception:
|
2052
|
+
await adapter.rollback_transaction(transaction)
|
2053
|
+
raise
|
2054
|
+
else:
|
2055
|
+
# No transaction mode
|
2056
|
+
return await adapter.execute(
|
2057
|
+
query=query,
|
2058
|
+
params=params,
|
2059
|
+
fetch_mode=fetch_mode,
|
2060
|
+
fetch_size=fetch_size,
|
2061
|
+
)
|
2062
|
+
|
2063
|
+
@classmethod
|
2064
|
+
async def get_pool_metrics(cls) -> dict[str, Any]:
|
2065
|
+
"""Get metrics for all shared connection pools.
|
2066
|
+
|
2067
|
+
Returns:
|
2068
|
+
dict: Pool metrics including pool count, connections per pool, etc.
|
2069
|
+
"""
|
2070
|
+
async with cls._get_pool_lock():
|
2071
|
+
metrics = {"total_pools": len(cls._shared_pools), "pools": []}
|
2072
|
+
|
2073
|
+
for pool_key, (adapter, ref_count) in cls._shared_pools.items():
|
2074
|
+
pool_info = {
|
2075
|
+
"key": pool_key,
|
2076
|
+
"reference_count": ref_count,
|
2077
|
+
"type": adapter.__class__.__name__,
|
2078
|
+
}
|
2079
|
+
|
2080
|
+
# Try to get pool-specific metrics if available
|
2081
|
+
if hasattr(adapter, "_pool") and adapter._pool:
|
2082
|
+
pool = adapter._pool
|
2083
|
+
if hasattr(pool, "size"):
|
2084
|
+
pool_info["pool_size"] = pool.size()
|
2085
|
+
if hasattr(pool, "_holders"):
|
2086
|
+
pool_info["active_connections"] = len(
|
2087
|
+
[h for h in pool._holders if h._in_use]
|
2088
|
+
)
|
2089
|
+
elif hasattr(pool, "size") and hasattr(pool, "freesize"):
|
2090
|
+
pool_info["active_connections"] = pool.size - pool.freesize
|
2091
|
+
|
2092
|
+
metrics["pools"].append(pool_info)
|
2093
|
+
|
2094
|
+
return metrics
|
2095
|
+
|
2096
|
+
@classmethod
|
2097
|
+
async def clear_shared_pools(cls) -> None:
|
2098
|
+
"""Clear all shared connection pools. Use with caution!"""
|
2099
|
+
async with cls._get_pool_lock():
|
2100
|
+
for pool_key, (adapter, _) in list(cls._shared_pools.items()):
|
2101
|
+
try:
|
2102
|
+
await adapter.disconnect()
|
2103
|
+
except Exception:
|
2104
|
+
pass # Best effort
|
2105
|
+
cls._shared_pools.clear()
|
2106
|
+
|
2107
|
+
def get_pool_info(self) -> dict[str, Any]:
|
2108
|
+
"""Get information about this instance's connection pool.
|
2109
|
+
|
2110
|
+
Returns:
|
2111
|
+
dict: Pool information including shared status and metrics
|
2112
|
+
"""
|
2113
|
+
info = {
|
2114
|
+
"shared": self._share_pool,
|
2115
|
+
"pool_key": self._pool_key,
|
2116
|
+
"connected": self._connected,
|
2117
|
+
}
|
2118
|
+
|
2119
|
+
if self._adapter and hasattr(self._adapter, "_pool") and self._adapter._pool:
|
2120
|
+
pool = self._adapter._pool
|
2121
|
+
if hasattr(pool, "size"):
|
2122
|
+
info["pool_size"] = pool.size()
|
2123
|
+
if hasattr(pool, "_holders"):
|
2124
|
+
info["active_connections"] = len(
|
2125
|
+
[h for h in pool._holders if h._in_use]
|
2126
|
+
)
|
2127
|
+
elif hasattr(pool, "size") and hasattr(pool, "freesize"):
|
2128
|
+
info["active_connections"] = pool.size - pool.freesize
|
2129
|
+
|
2130
|
+
return info
|
2131
|
+
|
2132
|
+
async def execute_with_version_check(
|
2133
|
+
self,
|
2134
|
+
query: str,
|
2135
|
+
params: dict[str, Any],
|
2136
|
+
expected_version: Optional[int] = None,
|
2137
|
+
record_id: Optional[Any] = None,
|
2138
|
+
table_name: Optional[str] = None,
|
2139
|
+
) -> dict[str, Any]:
|
2140
|
+
"""Execute a query with optimistic locking version check.
|
2141
|
+
|
2142
|
+
Args:
|
2143
|
+
query: SQL query to execute (UPDATE or DELETE)
|
2144
|
+
params: Query parameters
|
2145
|
+
expected_version: Expected version number for conflict detection
|
2146
|
+
record_id: ID of the record being updated (for retry)
|
2147
|
+
table_name: Table name (for retry to re-read current version)
|
2148
|
+
|
2149
|
+
Returns:
|
2150
|
+
dict: Result with version information and conflict status
|
2151
|
+
|
2152
|
+
Raises:
|
2153
|
+
NodeExecutionError: On version conflict or database error
|
2154
|
+
"""
|
2155
|
+
if not self._enable_optimistic_locking:
|
2156
|
+
# Just execute normally if optimistic locking is disabled
|
2157
|
+
result = await self.execute_async(query=query, params=params)
|
2158
|
+
return {
|
2159
|
+
"result": result,
|
2160
|
+
"version_checked": False,
|
2161
|
+
"status": LockStatus.SUCCESS,
|
2162
|
+
}
|
2163
|
+
|
2164
|
+
# Add version check to the query
|
2165
|
+
if expected_version is not None:
|
2166
|
+
# Ensure version field is in params
|
2167
|
+
if "expected_version" in query:
|
2168
|
+
# Query already uses :expected_version, just ensure it's set
|
2169
|
+
params["expected_version"] = expected_version
|
2170
|
+
else:
|
2171
|
+
# Use standard version field
|
2172
|
+
params[self._version_field] = expected_version
|
2173
|
+
|
2174
|
+
# For UPDATE queries, also add version increment
|
2175
|
+
if "UPDATE" in query.upper() and "SET" in query.upper():
|
2176
|
+
# Find SET clause and add version increment
|
2177
|
+
set_match = re.search(r"(SET\s+)(.+?)(\s+WHERE)", query, re.IGNORECASE)
|
2178
|
+
if set_match:
|
2179
|
+
set_clause = set_match.group(2)
|
2180
|
+
# Add version increment if not already present
|
2181
|
+
if self._version_field not in set_clause:
|
2182
|
+
new_set_clause = f"{set_clause}, {self._version_field} = {self._version_field} + 1"
|
2183
|
+
query = (
|
2184
|
+
query[: set_match.start(2)]
|
2185
|
+
+ new_set_clause
|
2186
|
+
+ query[set_match.end(2) :]
|
2187
|
+
)
|
2188
|
+
|
2189
|
+
# Modify query to include version check in WHERE clause (only if not already present)
|
2190
|
+
# Check for version condition in WHERE clause specifically, not just anywhere in query
|
2191
|
+
where_clause_pattern = (
|
2192
|
+
r"WHERE\s+.*?" + re.escape(self._version_field) + r"\s*="
|
2193
|
+
)
|
2194
|
+
has_version_check_in_where = (
|
2195
|
+
re.search(where_clause_pattern, query, re.IGNORECASE) is not None
|
2196
|
+
or ":expected_version" in query
|
2197
|
+
)
|
2198
|
+
if not has_version_check_in_where:
|
2199
|
+
if "WHERE" in query.upper():
|
2200
|
+
query += f" AND {self._version_field} = :{self._version_field}"
|
2201
|
+
else:
|
2202
|
+
query += f" WHERE {self._version_field} = :{self._version_field}"
|
2203
|
+
|
2204
|
+
# Try to execute with version check
|
2205
|
+
retry_count = 0
|
2206
|
+
for attempt in range(self._version_retry_attempts):
|
2207
|
+
try:
|
2208
|
+
result = await self.execute_async(query=query, params=params)
|
2209
|
+
|
2210
|
+
# Check if any rows were affected
|
2211
|
+
rows_affected = 0
|
2212
|
+
rows_affected_found = False
|
2213
|
+
if isinstance(result.get("result"), dict):
|
2214
|
+
# Check if we have data array with rows_affected
|
2215
|
+
data = result["result"].get("data", [])
|
2216
|
+
if data and isinstance(data, list) and len(data) > 0:
|
2217
|
+
if isinstance(data[0], dict) and "rows_affected" in data[0]:
|
2218
|
+
rows_affected = data[0]["rows_affected"]
|
2219
|
+
rows_affected_found = True
|
2220
|
+
|
2221
|
+
# Only check direct keys if we haven't found rows_affected in data
|
2222
|
+
if not rows_affected_found:
|
2223
|
+
rows_affected = (
|
2224
|
+
result["result"].get("rows_affected", 0)
|
2225
|
+
or result["result"].get("rowcount", 0)
|
2226
|
+
or result["result"].get("affected_rows", 0)
|
2227
|
+
or result["result"].get("row_count", 0)
|
2228
|
+
)
|
2229
|
+
|
2230
|
+
if rows_affected == 0 and expected_version is not None:
|
2231
|
+
# Version conflict detected
|
2232
|
+
if self._conflict_resolution == "fail_fast":
|
2233
|
+
raise NodeExecutionError(
|
2234
|
+
f"Version conflict: expected version {expected_version} not found"
|
2235
|
+
)
|
2236
|
+
elif (
|
2237
|
+
self._conflict_resolution == "retry"
|
2238
|
+
and record_id
|
2239
|
+
and table_name
|
2240
|
+
):
|
2241
|
+
# Read current version
|
2242
|
+
current = await self.execute_async(
|
2243
|
+
query=f"SELECT {self._version_field} FROM {table_name} WHERE id = :id",
|
2244
|
+
params={"id": record_id},
|
2245
|
+
)
|
2246
|
+
|
2247
|
+
if current["result"]["data"]:
|
2248
|
+
current_version = current["result"]["data"][0][
|
2249
|
+
self._version_field
|
2250
|
+
]
|
2251
|
+
params[self._version_field] = current_version
|
2252
|
+
# Update expected version for next attempt
|
2253
|
+
expected_version = current_version
|
2254
|
+
retry_count += 1
|
2255
|
+
continue
|
2256
|
+
else:
|
2257
|
+
return {
|
2258
|
+
"result": None,
|
2259
|
+
"status": LockStatus.RECORD_NOT_FOUND,
|
2260
|
+
"version_checked": True,
|
2261
|
+
"retry_count": retry_count,
|
2262
|
+
}
|
2263
|
+
elif self._conflict_resolution == "last_writer_wins":
|
2264
|
+
# Remove version check and try again
|
2265
|
+
params_no_version = params.copy()
|
2266
|
+
params_no_version.pop(self._version_field, None)
|
2267
|
+
query_no_version = query.replace(
|
2268
|
+
f" AND {self._version_field} = :{self._version_field}", ""
|
2269
|
+
)
|
2270
|
+
result = await self.execute_async(
|
2271
|
+
query=query_no_version, params=params_no_version
|
2272
|
+
)
|
2273
|
+
return {
|
2274
|
+
"result": result,
|
2275
|
+
"status": LockStatus.SUCCESS,
|
2276
|
+
"version_checked": False,
|
2277
|
+
"conflict_resolved": "last_writer_wins",
|
2278
|
+
"retry_count": retry_count,
|
711
2279
|
}
|
2280
|
+
|
2281
|
+
# Success - increment version for UPDATE queries
|
2282
|
+
if "UPDATE" in query.upper() and rows_affected > 0:
|
2283
|
+
# The query should have incremented the version
|
2284
|
+
new_version = (
|
2285
|
+
(expected_version or 0) + 1
|
2286
|
+
if expected_version is not None
|
2287
|
+
else None
|
2288
|
+
)
|
2289
|
+
return {
|
2290
|
+
"result": result,
|
2291
|
+
"status": LockStatus.SUCCESS,
|
2292
|
+
"version_checked": True,
|
2293
|
+
"new_version": new_version,
|
2294
|
+
"rows_affected": rows_affected,
|
2295
|
+
"retry_count": retry_count,
|
2296
|
+
}
|
2297
|
+
else:
|
2298
|
+
return {
|
2299
|
+
"result": result,
|
2300
|
+
"status": LockStatus.SUCCESS,
|
2301
|
+
"version_checked": True,
|
2302
|
+
"rows_affected": rows_affected,
|
2303
|
+
"retry_count": retry_count,
|
712
2304
|
}
|
713
2305
|
|
714
|
-
|
715
|
-
|
716
|
-
await asyncio.sleep(retry_delay * (2**attempt))
|
717
|
-
continue
|
2306
|
+
except NodeExecutionError:
|
2307
|
+
if attempt >= self._version_retry_attempts - 1:
|
718
2308
|
raise
|
2309
|
+
await asyncio.sleep(0.1 * (attempt + 1)) # Exponential backoff
|
719
2310
|
|
720
|
-
|
721
|
-
|
2311
|
+
return {
|
2312
|
+
"result": None,
|
2313
|
+
"status": LockStatus.RETRY_EXHAUSTED,
|
2314
|
+
"version_checked": True,
|
2315
|
+
"retry_count": self._version_retry_attempts,
|
2316
|
+
}
|
722
2317
|
|
723
|
-
async def
|
724
|
-
|
725
|
-
|
2318
|
+
async def read_with_version(
|
2319
|
+
self,
|
2320
|
+
query: str,
|
2321
|
+
params: Optional[dict[str, Any]] = None,
|
2322
|
+
) -> dict[str, Any]:
|
2323
|
+
"""Execute a SELECT query and extract version information.
|
2324
|
+
|
2325
|
+
Args:
|
2326
|
+
query: SELECT query to execute
|
2327
|
+
params: Query parameters
|
2328
|
+
|
2329
|
+
Returns:
|
2330
|
+
dict: Result with version information included
|
2331
|
+
"""
|
2332
|
+
result = await self.execute_async(query=query, params=params)
|
2333
|
+
|
2334
|
+
if self._enable_optimistic_locking and result.get("result", {}).get("data"):
|
2335
|
+
# Extract version from results
|
2336
|
+
data = result["result"]["data"]
|
2337
|
+
if isinstance(data, list) and len(data) > 0:
|
2338
|
+
# Single record
|
2339
|
+
if len(data) == 1 and self._version_field in data[0]:
|
2340
|
+
return {
|
2341
|
+
"result": result,
|
2342
|
+
"version": data[0][self._version_field],
|
2343
|
+
"record": data[0],
|
2344
|
+
}
|
2345
|
+
# Multiple records - include version in each
|
2346
|
+
else:
|
2347
|
+
versions = []
|
2348
|
+
for record in data:
|
2349
|
+
if self._version_field in record:
|
2350
|
+
versions.append(record[self._version_field])
|
2351
|
+
return {
|
2352
|
+
"result": result,
|
2353
|
+
"versions": versions,
|
2354
|
+
"records": data,
|
2355
|
+
}
|
2356
|
+
|
2357
|
+
return result
|
2358
|
+
|
2359
|
+
def build_versioned_update_query(
|
2360
|
+
self,
|
2361
|
+
table_name: str,
|
2362
|
+
update_fields: dict[str, Any],
|
2363
|
+
where_clause: str,
|
2364
|
+
increment_version: bool = True,
|
2365
|
+
) -> str:
|
2366
|
+
"""Build an UPDATE query with version increment.
|
2367
|
+
|
2368
|
+
Args:
|
2369
|
+
table_name: Name of the table to update
|
2370
|
+
update_fields: Fields to update (excluding version)
|
2371
|
+
where_clause: WHERE clause (without WHERE keyword)
|
2372
|
+
increment_version: Whether to increment the version field
|
2373
|
+
|
2374
|
+
Returns:
|
2375
|
+
str: UPDATE query with version handling
|
2376
|
+
"""
|
2377
|
+
if not self._enable_optimistic_locking:
|
2378
|
+
# Build normal update query
|
2379
|
+
set_parts = [f"{field} = :{field}" for field in update_fields]
|
2380
|
+
return (
|
2381
|
+
f"UPDATE {table_name} SET {', '.join(set_parts)} WHERE {where_clause}"
|
2382
|
+
)
|
2383
|
+
|
2384
|
+
# Build versioned update query
|
2385
|
+
set_parts = [f"{field} = :{field}" for field in update_fields]
|
2386
|
+
|
2387
|
+
if increment_version:
|
2388
|
+
set_parts.append(f"{self._version_field} = {self._version_field} + 1")
|
2389
|
+
|
2390
|
+
return f"UPDATE {table_name} SET {', '.join(set_parts)} WHERE {where_clause}"
|
2391
|
+
|
2392
|
+
def _convert_to_named_parameters(
|
2393
|
+
self, query: str, parameters: list
|
2394
|
+
) -> tuple[str, dict]:
|
2395
|
+
"""Convert positional parameters to named parameters for various SQL dialects.
|
2396
|
+
|
2397
|
+
This method handles conversion from different SQL parameter styles to a
|
2398
|
+
consistent named parameter format that works with async database drivers.
|
2399
|
+
|
2400
|
+
Args:
|
2401
|
+
query: SQL query with positional placeholders (?, $1, %s)
|
2402
|
+
parameters: List of parameter values
|
2403
|
+
|
2404
|
+
Returns:
|
2405
|
+
Tuple of (modified_query, parameter_dict)
|
2406
|
+
|
2407
|
+
Examples:
|
2408
|
+
>>> # SQLite style
|
2409
|
+
>>> query = "SELECT * FROM users WHERE age > ? AND active = ?"
|
2410
|
+
>>> params = [25, True]
|
2411
|
+
>>> new_query, param_dict = node._convert_to_named_parameters(query, params)
|
2412
|
+
>>> # Returns: ("SELECT * FROM users WHERE age > :p0 AND active = :p1",
|
2413
|
+
>>> # {"p0": 25, "p1": True})
|
2414
|
+
|
2415
|
+
>>> # PostgreSQL style
|
2416
|
+
>>> query = "UPDATE users SET name = $1 WHERE id = $2"
|
2417
|
+
>>> params = ["John", 123]
|
2418
|
+
>>> new_query, param_dict = node._convert_to_named_parameters(query, params)
|
2419
|
+
>>> # Returns: ("UPDATE users SET name = :p0 WHERE id = :p1",
|
2420
|
+
>>> # {"p0": "John", "p1": 123})
|
2421
|
+
"""
|
2422
|
+
# Create parameter dictionary
|
2423
|
+
param_dict = {}
|
2424
|
+
for i, value in enumerate(parameters):
|
2425
|
+
param_dict[f"p{i}"] = value
|
2426
|
+
|
2427
|
+
# Replace different placeholder formats with named parameters
|
2428
|
+
modified_query = query
|
2429
|
+
|
2430
|
+
# Handle SQLite-style ? placeholders
|
2431
|
+
placeholder_count = 0
|
2432
|
+
|
2433
|
+
def replace_question_mark(match):
|
2434
|
+
nonlocal placeholder_count
|
2435
|
+
replacement = f":p{placeholder_count}"
|
2436
|
+
placeholder_count += 1
|
2437
|
+
return replacement
|
2438
|
+
|
2439
|
+
modified_query = re.sub(r"\?", replace_question_mark, modified_query)
|
2440
|
+
|
2441
|
+
# Handle PostgreSQL-style $1, $2, etc. placeholders
|
2442
|
+
def replace_postgres_placeholder(match):
|
2443
|
+
index = int(match.group(1)) - 1 # PostgreSQL uses 1-based indexing
|
2444
|
+
return f":p{index}"
|
2445
|
+
|
2446
|
+
modified_query = re.sub(
|
2447
|
+
r"\$(\d+)", replace_postgres_placeholder, modified_query
|
2448
|
+
)
|
2449
|
+
|
2450
|
+
# Handle MySQL-style %s placeholders
|
2451
|
+
placeholder_count = 0
|
2452
|
+
|
2453
|
+
def replace_mysql_placeholder(match):
|
2454
|
+
nonlocal placeholder_count
|
2455
|
+
replacement = f":p{placeholder_count}"
|
2456
|
+
placeholder_count += 1
|
2457
|
+
return replacement
|
2458
|
+
|
2459
|
+
modified_query = re.sub(r"%s", replace_mysql_placeholder, modified_query)
|
2460
|
+
|
2461
|
+
return modified_query, param_dict
|
2462
|
+
|
2463
|
+
def _format_results(self, data: list[dict], result_format: str) -> Any:
|
2464
|
+
"""Format query results according to specified format.
|
2465
|
+
|
2466
|
+
Args:
|
2467
|
+
data: List of dictionaries from database query
|
2468
|
+
result_format: Desired output format ('dict', 'list', 'dataframe')
|
2469
|
+
|
2470
|
+
Returns:
|
2471
|
+
Formatted results
|
2472
|
+
|
2473
|
+
Formats:
|
2474
|
+
- 'dict': List of dictionaries (default) - column names as keys
|
2475
|
+
- 'list': List of lists - values only, no column names
|
2476
|
+
- 'dataframe': Pandas DataFrame (if pandas is available)
|
2477
|
+
"""
|
2478
|
+
if not data:
|
2479
|
+
# Return empty structure based on format
|
2480
|
+
if result_format == "dataframe":
|
2481
|
+
try:
|
2482
|
+
import pandas as pd
|
2483
|
+
|
2484
|
+
return pd.DataFrame()
|
2485
|
+
except ImportError:
|
2486
|
+
# Fall back to dict if pandas not available
|
2487
|
+
return []
|
2488
|
+
elif result_format == "list":
|
2489
|
+
return []
|
2490
|
+
else:
|
2491
|
+
return []
|
2492
|
+
|
2493
|
+
if result_format == "dict":
|
2494
|
+
# Already in dict format from adapters
|
2495
|
+
return data
|
2496
|
+
|
2497
|
+
elif result_format == "list":
|
2498
|
+
# Convert to list of lists (values only)
|
2499
|
+
if data:
|
2500
|
+
# Get column order from first row
|
2501
|
+
columns = list(data[0].keys())
|
2502
|
+
return [[row.get(col) for col in columns] for row in data]
|
2503
|
+
return []
|
2504
|
+
|
2505
|
+
elif result_format == "dataframe":
|
2506
|
+
# Convert to pandas DataFrame if available
|
2507
|
+
try:
|
2508
|
+
import pandas as pd
|
2509
|
+
|
2510
|
+
return pd.DataFrame(data)
|
2511
|
+
except ImportError:
|
2512
|
+
# Log warning and fall back to dict format
|
2513
|
+
if hasattr(self, "logger"):
|
2514
|
+
self.logger.warning(
|
2515
|
+
"Pandas not installed. Install with: pip install pandas. "
|
2516
|
+
"Falling back to dict format."
|
2517
|
+
)
|
2518
|
+
return data
|
2519
|
+
|
2520
|
+
else:
|
2521
|
+
# Unknown format - default to dict with warning
|
2522
|
+
if hasattr(self, "logger"):
|
2523
|
+
self.logger.warning(
|
2524
|
+
f"Unknown result_format '{result_format}', defaulting to 'dict'"
|
2525
|
+
)
|
2526
|
+
return data
|
726
2527
|
|
727
2528
|
async def cleanup(self):
|
728
2529
|
"""Clean up database connections."""
|
2530
|
+
# Rollback any active transaction
|
2531
|
+
if self._active_transaction and self._adapter:
|
2532
|
+
try:
|
2533
|
+
await self._adapter.rollback_transaction(self._active_transaction)
|
2534
|
+
except Exception:
|
2535
|
+
pass # Best effort cleanup
|
2536
|
+
self._active_transaction = None
|
2537
|
+
|
729
2538
|
if self._adapter and self._connected:
|
730
|
-
|
2539
|
+
if self._share_pool and self._pool_key:
|
2540
|
+
# Decrement reference count for shared pool
|
2541
|
+
async with self._get_pool_lock():
|
2542
|
+
if self._pool_key in self._shared_pools:
|
2543
|
+
adapter, ref_count = self._shared_pools[self._pool_key]
|
2544
|
+
if ref_count > 1:
|
2545
|
+
# Others still using the pool
|
2546
|
+
self._shared_pools[self._pool_key] = (
|
2547
|
+
adapter,
|
2548
|
+
ref_count - 1,
|
2549
|
+
)
|
2550
|
+
else:
|
2551
|
+
# Last reference, close the pool
|
2552
|
+
del self._shared_pools[self._pool_key]
|
2553
|
+
await adapter.disconnect()
|
2554
|
+
else:
|
2555
|
+
# Dedicated pool, close directly
|
2556
|
+
await self._adapter.disconnect()
|
2557
|
+
|
731
2558
|
self._connected = False
|
732
2559
|
self._adapter = None
|
733
2560
|
|