kailash 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +1 -1
- kailash/access_control/__init__.py +1 -1
- kailash/core/actors/adaptive_pool_controller.py +630 -0
- kailash/core/actors/connection_actor.py +3 -3
- kailash/core/ml/__init__.py +1 -0
- kailash/core/ml/query_patterns.py +544 -0
- kailash/core/monitoring/__init__.py +19 -0
- kailash/core/monitoring/connection_metrics.py +488 -0
- kailash/core/optimization/__init__.py +1 -0
- kailash/core/resilience/__init__.py +17 -0
- kailash/core/resilience/circuit_breaker.py +382 -0
- kailash/gateway/api.py +7 -5
- kailash/gateway/enhanced_gateway.py +1 -1
- kailash/middleware/auth/access_control.py +11 -11
- kailash/middleware/communication/ai_chat.py +7 -7
- kailash/middleware/communication/api_gateway.py +5 -15
- kailash/middleware/gateway/checkpoint_manager.py +45 -8
- kailash/middleware/gateway/event_store.py +66 -26
- kailash/middleware/mcp/enhanced_server.py +2 -2
- kailash/nodes/admin/permission_check.py +110 -30
- kailash/nodes/admin/schema.sql +387 -0
- kailash/nodes/admin/tenant_isolation.py +249 -0
- kailash/nodes/admin/transaction_utils.py +244 -0
- kailash/nodes/admin/user_management.py +37 -9
- kailash/nodes/ai/ai_providers.py +55 -3
- kailash/nodes/ai/llm_agent.py +115 -13
- kailash/nodes/data/query_pipeline.py +641 -0
- kailash/nodes/data/query_router.py +895 -0
- kailash/nodes/data/sql.py +24 -0
- kailash/nodes/data/workflow_connection_pool.py +451 -23
- kailash/nodes/monitoring/__init__.py +3 -5
- kailash/nodes/monitoring/connection_dashboard.py +822 -0
- kailash/nodes/rag/__init__.py +1 -3
- kailash/resources/registry.py +6 -0
- kailash/runtime/async_local.py +7 -0
- kailash/utils/export.py +152 -0
- kailash/workflow/builder.py +42 -0
- kailash/workflow/graph.py +86 -17
- kailash/workflow/templates.py +4 -9
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/METADATA +14 -1
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/RECORD +45 -31
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/WHEEL +0 -0
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/entry_points.txt +0 -0
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,895 @@
|
|
1
|
+
"""Intelligent query routing for optimal database connection utilization.
|
2
|
+
|
3
|
+
This module implements a query router that analyzes queries and routes them
|
4
|
+
to the most appropriate connection based on query type, connection health,
|
5
|
+
and historical performance data.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import asyncio
|
9
|
+
import hashlib
|
10
|
+
import logging
|
11
|
+
import re
|
12
|
+
import time
|
13
|
+
from collections import defaultdict, deque
|
14
|
+
from dataclasses import dataclass
|
15
|
+
from datetime import datetime, timedelta
|
16
|
+
from enum import Enum
|
17
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
18
|
+
|
19
|
+
from kailash.nodes.base import NodeParameter, register_node
|
20
|
+
from kailash.nodes.base_async import AsyncNode
|
21
|
+
from kailash.sdk_exceptions import NodeExecutionError
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class QueryType(Enum):
|
27
|
+
"""Query classification types."""
|
28
|
+
|
29
|
+
READ_SIMPLE = "read_simple" # Single table SELECT
|
30
|
+
READ_COMPLEX = "read_complex" # JOINs, aggregations
|
31
|
+
WRITE_SIMPLE = "write_simple" # Single row INSERT/UPDATE/DELETE
|
32
|
+
WRITE_BULK = "write_bulk" # Multi-row operations
|
33
|
+
DDL = "ddl" # Schema modifications
|
34
|
+
TRANSACTION = "transaction" # Explicit transaction blocks
|
35
|
+
UNKNOWN = "unknown" # Unclassified queries
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class QueryFingerprint:
|
40
|
+
"""Normalized query representation for caching and pattern matching."""
|
41
|
+
|
42
|
+
template: str # Query with parameters replaced
|
43
|
+
query_type: QueryType # Classification
|
44
|
+
tables: Set[str] # Tables involved
|
45
|
+
is_read_only: bool # True for SELECT queries
|
46
|
+
complexity_score: float # Estimated complexity (0-1)
|
47
|
+
|
48
|
+
|
49
|
+
@dataclass
|
50
|
+
class ConnectionInfo:
|
51
|
+
"""Information about an available connection."""
|
52
|
+
|
53
|
+
connection_id: str
|
54
|
+
health_score: float # 0-100
|
55
|
+
current_load: int # Active queries
|
56
|
+
capabilities: Set[str] # e.g., {"read", "write", "ddl"}
|
57
|
+
avg_latency_ms: float # Recent average latency
|
58
|
+
last_used: datetime # For LRU routing
|
59
|
+
|
60
|
+
|
61
|
+
@dataclass
|
62
|
+
class RoutingDecision:
|
63
|
+
"""Result of routing decision."""
|
64
|
+
|
65
|
+
connection_id: str
|
66
|
+
decision_factors: Dict[str, Any] # Why this connection was chosen
|
67
|
+
alternatives: List[str] # Other viable connections
|
68
|
+
confidence: float # 0-1 confidence in decision
|
69
|
+
|
70
|
+
|
71
|
+
class QueryClassifier:
|
72
|
+
"""Classifies SQL queries for routing decisions."""
|
73
|
+
|
74
|
+
# Regex patterns for query classification
|
75
|
+
SELECT_SIMPLE = re.compile(
|
76
|
+
r"^\s*SELECT\s+.*?\s+FROM\s+(\w+)(?:\s+WHERE|\s*;?\s*$)",
|
77
|
+
re.IGNORECASE | re.DOTALL,
|
78
|
+
)
|
79
|
+
SELECT_COMPLEX = re.compile(
|
80
|
+
r"^\s*SELECT\s+.*?\s+FROM\s+.*?(?:JOIN|GROUP\s+BY|HAVING|UNION|INTERSECT|EXCEPT)",
|
81
|
+
re.IGNORECASE | re.DOTALL,
|
82
|
+
)
|
83
|
+
INSERT_PATTERN = re.compile(r"^\s*INSERT\s+INTO", re.IGNORECASE)
|
84
|
+
UPDATE_PATTERN = re.compile(r"^\s*UPDATE\s+", re.IGNORECASE)
|
85
|
+
DELETE_PATTERN = re.compile(r"^\s*DELETE\s+FROM", re.IGNORECASE)
|
86
|
+
DDL_PATTERN = re.compile(r"^\s*(?:CREATE|ALTER|DROP|TRUNCATE)\s+", re.IGNORECASE)
|
87
|
+
TRANSACTION_PATTERN = re.compile(
|
88
|
+
r"^\s*(?:BEGIN|START\s+TRANSACTION|COMMIT|ROLLBACK)", re.IGNORECASE
|
89
|
+
)
|
90
|
+
BULK_PATTERN = re.compile(
|
91
|
+
r"(?:VALUES\s*\([^)]+\)(?:\s*,\s*\([^)]+\)){2,}|COPY\s+|BULK\s+INSERT)",
|
92
|
+
re.IGNORECASE,
|
93
|
+
)
|
94
|
+
|
95
|
+
def __init__(self):
|
96
|
+
self.classification_cache = {}
|
97
|
+
self.max_cache_size = 10000
|
98
|
+
|
99
|
+
def classify(self, query: str) -> QueryType:
|
100
|
+
"""Classify a SQL query into one of the defined types."""
|
101
|
+
# Check cache first
|
102
|
+
query_hash = hashlib.md5(query.encode()).hexdigest()
|
103
|
+
if query_hash in self.classification_cache:
|
104
|
+
return self.classification_cache[query_hash]
|
105
|
+
|
106
|
+
# Clean the query
|
107
|
+
cleaned_query = self._clean_query(query)
|
108
|
+
|
109
|
+
# Classification logic
|
110
|
+
query_type = self._classify_query(cleaned_query)
|
111
|
+
|
112
|
+
# Cache the result
|
113
|
+
if len(self.classification_cache) >= self.max_cache_size:
|
114
|
+
# Simple LRU: remove oldest entries
|
115
|
+
oldest_keys = list(self.classification_cache.keys())[:1000]
|
116
|
+
for key in oldest_keys:
|
117
|
+
del self.classification_cache[key]
|
118
|
+
|
119
|
+
self.classification_cache[query_hash] = query_type
|
120
|
+
return query_type
|
121
|
+
|
122
|
+
def _clean_query(self, query: str) -> str:
|
123
|
+
"""Remove comments and normalize whitespace."""
|
124
|
+
# Remove single-line comments
|
125
|
+
query = re.sub(r"--[^\n]*", "", query)
|
126
|
+
# Remove multi-line comments
|
127
|
+
query = re.sub(r"/\*.*?\*/", "", query, flags=re.DOTALL)
|
128
|
+
# Normalize whitespace
|
129
|
+
query = " ".join(query.split())
|
130
|
+
return query.strip()
|
131
|
+
|
132
|
+
def _classify_query(self, query: str) -> QueryType:
|
133
|
+
"""Perform the actual classification."""
|
134
|
+
# Check for transaction commands
|
135
|
+
if self.TRANSACTION_PATTERN.match(query):
|
136
|
+
return QueryType.TRANSACTION
|
137
|
+
|
138
|
+
# Check for DDL
|
139
|
+
if self.DDL_PATTERN.match(query):
|
140
|
+
return QueryType.DDL
|
141
|
+
|
142
|
+
# Check for bulk operations
|
143
|
+
if self.BULK_PATTERN.search(query):
|
144
|
+
return QueryType.WRITE_BULK
|
145
|
+
|
146
|
+
# Check for complex SELECT
|
147
|
+
if self.SELECT_COMPLEX.search(query):
|
148
|
+
return QueryType.READ_COMPLEX
|
149
|
+
|
150
|
+
# Check for simple SELECT
|
151
|
+
if self.SELECT_SIMPLE.match(query):
|
152
|
+
return QueryType.READ_SIMPLE
|
153
|
+
|
154
|
+
# Check for INSERT/UPDATE/DELETE
|
155
|
+
if (
|
156
|
+
self.INSERT_PATTERN.match(query)
|
157
|
+
or self.UPDATE_PATTERN.match(query)
|
158
|
+
or self.DELETE_PATTERN.match(query)
|
159
|
+
):
|
160
|
+
return QueryType.WRITE_SIMPLE
|
161
|
+
|
162
|
+
return QueryType.UNKNOWN
|
163
|
+
|
164
|
+
def fingerprint(
|
165
|
+
self, query: str, parameters: Optional[List[Any]] = None
|
166
|
+
) -> QueryFingerprint:
|
167
|
+
"""Create a normalized fingerprint of the query."""
|
168
|
+
cleaned_query = self._clean_query(query)
|
169
|
+
query_type = self.classify(query)
|
170
|
+
|
171
|
+
# Extract tables
|
172
|
+
tables = self._extract_tables(cleaned_query)
|
173
|
+
|
174
|
+
# Normalize parameters
|
175
|
+
template = self._create_template(cleaned_query, parameters)
|
176
|
+
|
177
|
+
# Calculate complexity
|
178
|
+
complexity = self._calculate_complexity(cleaned_query, query_type)
|
179
|
+
|
180
|
+
# Determine if read-only
|
181
|
+
is_read_only = query_type in [QueryType.READ_SIMPLE, QueryType.READ_COMPLEX]
|
182
|
+
|
183
|
+
return QueryFingerprint(
|
184
|
+
template=template,
|
185
|
+
query_type=query_type,
|
186
|
+
tables=tables,
|
187
|
+
is_read_only=is_read_only,
|
188
|
+
complexity_score=complexity,
|
189
|
+
)
|
190
|
+
|
191
|
+
def _extract_tables(self, query: str) -> Set[str]:
|
192
|
+
"""Extract table names from query."""
|
193
|
+
tables = set()
|
194
|
+
|
195
|
+
# FROM clause
|
196
|
+
from_matches = re.findall(r"FROM\s+(\w+)", query, re.IGNORECASE)
|
197
|
+
tables.update(from_matches)
|
198
|
+
|
199
|
+
# JOIN clauses
|
200
|
+
join_matches = re.findall(r"JOIN\s+(\w+)", query, re.IGNORECASE)
|
201
|
+
tables.update(join_matches)
|
202
|
+
|
203
|
+
# INSERT INTO
|
204
|
+
insert_matches = re.findall(r"INSERT\s+INTO\s+(\w+)", query, re.IGNORECASE)
|
205
|
+
tables.update(insert_matches)
|
206
|
+
|
207
|
+
# UPDATE
|
208
|
+
update_matches = re.findall(r"UPDATE\s+(\w+)", query, re.IGNORECASE)
|
209
|
+
tables.update(update_matches)
|
210
|
+
|
211
|
+
# DELETE FROM
|
212
|
+
delete_matches = re.findall(r"DELETE\s+FROM\s+(\w+)", query, re.IGNORECASE)
|
213
|
+
tables.update(delete_matches)
|
214
|
+
|
215
|
+
return tables
|
216
|
+
|
217
|
+
def _create_template(self, query: str, parameters: Optional[List[Any]]) -> str:
|
218
|
+
"""Create query template with normalized parameters."""
|
219
|
+
template = query
|
220
|
+
|
221
|
+
# Replace string literals
|
222
|
+
template = re.sub(r"'[^']*'", "?", template)
|
223
|
+
template = re.sub(r'"[^"]*"', "?", template)
|
224
|
+
|
225
|
+
# Replace numbers
|
226
|
+
template = re.sub(r"\b\d+\.?\d*\b", "?", template)
|
227
|
+
|
228
|
+
# Replace parameter placeholders
|
229
|
+
template = re.sub(r"%s|\$\d+|\?", "?", template)
|
230
|
+
|
231
|
+
return template
|
232
|
+
|
233
|
+
def _calculate_complexity(self, query: str, query_type: QueryType) -> float:
|
234
|
+
"""Calculate query complexity score (0-1)."""
|
235
|
+
score = 0.0
|
236
|
+
|
237
|
+
# Base scores by type
|
238
|
+
base_scores = {
|
239
|
+
QueryType.READ_SIMPLE: 0.1,
|
240
|
+
QueryType.READ_COMPLEX: 0.5,
|
241
|
+
QueryType.WRITE_SIMPLE: 0.2,
|
242
|
+
QueryType.WRITE_BULK: 0.6,
|
243
|
+
QueryType.DDL: 0.8,
|
244
|
+
QueryType.TRANSACTION: 0.3,
|
245
|
+
QueryType.UNKNOWN: 0.5,
|
246
|
+
}
|
247
|
+
score = base_scores.get(query_type, 0.5)
|
248
|
+
|
249
|
+
# Adjust for query features
|
250
|
+
if re.search(r"\bJOIN\b", query, re.IGNORECASE):
|
251
|
+
score += 0.1 * len(re.findall(r"\bJOIN\b", query, re.IGNORECASE))
|
252
|
+
|
253
|
+
if re.search(r"\bGROUP\s+BY\b", query, re.IGNORECASE):
|
254
|
+
score += 0.15
|
255
|
+
|
256
|
+
if re.search(r"\bORDER\s+BY\b", query, re.IGNORECASE):
|
257
|
+
score += 0.05
|
258
|
+
|
259
|
+
if re.search(r"\bDISTINCT\b", query, re.IGNORECASE):
|
260
|
+
score += 0.1
|
261
|
+
|
262
|
+
# Subqueries
|
263
|
+
if query.count("SELECT") > 1:
|
264
|
+
score += 0.2 * (query.count("SELECT") - 1)
|
265
|
+
|
266
|
+
return min(score, 1.0)
|
267
|
+
|
268
|
+
|
269
|
+
class PreparedStatementCache:
|
270
|
+
"""LRU cache for prepared statements with connection affinity."""
|
271
|
+
|
272
|
+
def __init__(self, max_size: int = 1000):
|
273
|
+
self.max_size = max_size
|
274
|
+
self.cache: Dict[str, Dict[str, Any]] = {} # fingerprint -> statement info
|
275
|
+
self.usage_order = deque() # For LRU eviction
|
276
|
+
self.usage_stats = defaultdict(int) # Track usage frequency
|
277
|
+
|
278
|
+
def get(self, fingerprint: str, connection_id: str) -> Optional[Dict[str, Any]]:
|
279
|
+
"""Get cached statement if available for connection."""
|
280
|
+
if fingerprint in self.cache:
|
281
|
+
entry = self.cache[fingerprint]
|
282
|
+
if connection_id in entry.get("connections", {}):
|
283
|
+
# Update usage
|
284
|
+
self.usage_stats[fingerprint] += 1
|
285
|
+
self._update_usage_order(fingerprint)
|
286
|
+
return entry
|
287
|
+
return None
|
288
|
+
|
289
|
+
def put(self, fingerprint: str, connection_id: str, statement_info: Dict[str, Any]):
|
290
|
+
"""Cache a prepared statement."""
|
291
|
+
if fingerprint not in self.cache:
|
292
|
+
# Check if we need to evict
|
293
|
+
if len(self.cache) >= self.max_size:
|
294
|
+
self._evict_lru()
|
295
|
+
|
296
|
+
self.cache[fingerprint] = {
|
297
|
+
"connections": {},
|
298
|
+
"created_at": datetime.now(),
|
299
|
+
"last_used": datetime.now(),
|
300
|
+
}
|
301
|
+
|
302
|
+
# Add connection-specific info
|
303
|
+
self.cache[fingerprint]["connections"][connection_id] = statement_info
|
304
|
+
self.cache[fingerprint]["last_used"] = datetime.now()
|
305
|
+
self._update_usage_order(fingerprint)
|
306
|
+
|
307
|
+
def invalidate(self, tables: Optional[Set[str]] = None):
|
308
|
+
"""Invalidate cached statements for specific tables or all."""
|
309
|
+
if tables is None:
|
310
|
+
# Clear entire cache
|
311
|
+
self.cache.clear()
|
312
|
+
self.usage_order.clear()
|
313
|
+
self.usage_stats.clear()
|
314
|
+
else:
|
315
|
+
# Invalidate statements touching specified tables
|
316
|
+
to_remove = []
|
317
|
+
for fingerprint, entry in self.cache.items():
|
318
|
+
if "tables" in entry and entry["tables"].intersection(tables):
|
319
|
+
to_remove.append(fingerprint)
|
320
|
+
|
321
|
+
for fingerprint in to_remove:
|
322
|
+
del self.cache[fingerprint]
|
323
|
+
self.usage_stats.pop(fingerprint, None)
|
324
|
+
|
325
|
+
def _update_usage_order(self, fingerprint: str):
|
326
|
+
"""Update LRU order."""
|
327
|
+
if fingerprint in self.usage_order:
|
328
|
+
self.usage_order.remove(fingerprint)
|
329
|
+
self.usage_order.append(fingerprint)
|
330
|
+
|
331
|
+
def _evict_lru(self):
|
332
|
+
"""Evict least recently used entry."""
|
333
|
+
if self.usage_order:
|
334
|
+
victim = self.usage_order.popleft()
|
335
|
+
del self.cache[victim]
|
336
|
+
self.usage_stats.pop(victim, None)
|
337
|
+
|
338
|
+
def get_stats(self) -> Dict[str, Any]:
|
339
|
+
"""Get cache statistics."""
|
340
|
+
total_entries = len(self.cache)
|
341
|
+
total_usage = sum(self.usage_stats.values())
|
342
|
+
|
343
|
+
return {
|
344
|
+
"total_entries": total_entries,
|
345
|
+
"total_usage": total_usage,
|
346
|
+
"hit_rate": total_usage / (total_usage + 1) if total_entries > 0 else 0,
|
347
|
+
"avg_usage_per_entry": (
|
348
|
+
total_usage / total_entries if total_entries > 0 else 0
|
349
|
+
),
|
350
|
+
"cache_size_bytes": sum(len(str(v)) for v in self.cache.values()),
|
351
|
+
}
|
352
|
+
|
353
|
+
|
354
|
+
class RoutingDecisionEngine:
|
355
|
+
"""Makes intelligent routing decisions based on multiple factors."""
|
356
|
+
|
357
|
+
def __init__(
|
358
|
+
self, health_threshold: float = 50.0, enable_read_write_split: bool = True
|
359
|
+
):
|
360
|
+
self.health_threshold = health_threshold
|
361
|
+
self.enable_read_write_split = enable_read_write_split
|
362
|
+
self.routing_history = deque(maxlen=1000) # Recent routing decisions
|
363
|
+
self.connection_affinity = {} # Track query -> connection affinity
|
364
|
+
|
365
|
+
def select_connection(
|
366
|
+
self,
|
367
|
+
query_fingerprint: QueryFingerprint,
|
368
|
+
available_connections: List[ConnectionInfo],
|
369
|
+
transaction_context: Optional[str] = None,
|
370
|
+
) -> RoutingDecision:
|
371
|
+
"""Select the optimal connection for the query."""
|
372
|
+
|
373
|
+
# Filter healthy connections
|
374
|
+
healthy_connections = [
|
375
|
+
c for c in available_connections if c.health_score >= self.health_threshold
|
376
|
+
]
|
377
|
+
|
378
|
+
if not healthy_connections:
|
379
|
+
# Fall back to any available connection
|
380
|
+
healthy_connections = available_connections
|
381
|
+
|
382
|
+
if not healthy_connections:
|
383
|
+
raise NodeExecutionError("No available connections for routing")
|
384
|
+
|
385
|
+
# If in transaction, must use same connection
|
386
|
+
if transaction_context:
|
387
|
+
for conn in healthy_connections:
|
388
|
+
if conn.connection_id == transaction_context:
|
389
|
+
return RoutingDecision(
|
390
|
+
connection_id=conn.connection_id,
|
391
|
+
decision_factors={"reason": "transaction_affinity"},
|
392
|
+
alternatives=[],
|
393
|
+
confidence=1.0,
|
394
|
+
)
|
395
|
+
raise NodeExecutionError(
|
396
|
+
f"Transaction connection {transaction_context} not available"
|
397
|
+
)
|
398
|
+
|
399
|
+
# Apply routing strategy
|
400
|
+
if self.enable_read_write_split and query_fingerprint.is_read_only:
|
401
|
+
selected = self._route_read_query(query_fingerprint, healthy_connections)
|
402
|
+
else:
|
403
|
+
selected = self._route_write_query(query_fingerprint, healthy_connections)
|
404
|
+
|
405
|
+
# Record decision
|
406
|
+
self.routing_history.append(
|
407
|
+
{
|
408
|
+
"timestamp": datetime.now(),
|
409
|
+
"query_type": query_fingerprint.query_type,
|
410
|
+
"connection": selected.connection_id,
|
411
|
+
"confidence": selected.confidence,
|
412
|
+
}
|
413
|
+
)
|
414
|
+
|
415
|
+
return selected
|
416
|
+
|
417
|
+
def _route_read_query(
|
418
|
+
self, fingerprint: QueryFingerprint, connections: List[ConnectionInfo]
|
419
|
+
) -> RoutingDecision:
|
420
|
+
"""Route read queries with load balancing."""
|
421
|
+
# Filter connections that support reads
|
422
|
+
read_connections = [c for c in connections if "read" in c.capabilities]
|
423
|
+
|
424
|
+
if not read_connections:
|
425
|
+
read_connections = connections
|
426
|
+
|
427
|
+
# Score each connection
|
428
|
+
scores = []
|
429
|
+
for conn in read_connections:
|
430
|
+
score = self._calculate_connection_score(conn, fingerprint)
|
431
|
+
scores.append((conn, score))
|
432
|
+
|
433
|
+
# Sort by score (descending)
|
434
|
+
scores.sort(key=lambda x: x[1], reverse=True)
|
435
|
+
|
436
|
+
# Select best connection
|
437
|
+
best_conn, best_score = scores[0]
|
438
|
+
|
439
|
+
# Calculate confidence based on score distribution
|
440
|
+
confidence = self._calculate_confidence(scores)
|
441
|
+
|
442
|
+
return RoutingDecision(
|
443
|
+
connection_id=best_conn.connection_id,
|
444
|
+
decision_factors={
|
445
|
+
"strategy": "load_balanced_read",
|
446
|
+
"score": best_score,
|
447
|
+
"health": best_conn.health_score,
|
448
|
+
"load": best_conn.current_load,
|
449
|
+
},
|
450
|
+
alternatives=[s[0].connection_id for s in scores[1:3]],
|
451
|
+
confidence=confidence,
|
452
|
+
)
|
453
|
+
|
454
|
+
def _route_write_query(
|
455
|
+
self, fingerprint: QueryFingerprint, connections: List[ConnectionInfo]
|
456
|
+
) -> RoutingDecision:
|
457
|
+
"""Route write queries to primary connections."""
|
458
|
+
# Filter connections that support writes
|
459
|
+
write_connections = [c for c in connections if "write" in c.capabilities]
|
460
|
+
|
461
|
+
if not write_connections:
|
462
|
+
write_connections = connections
|
463
|
+
|
464
|
+
# For writes, prefer the healthiest primary connection
|
465
|
+
write_connections.sort(
|
466
|
+
key=lambda c: (c.health_score, -c.current_load), reverse=True
|
467
|
+
)
|
468
|
+
|
469
|
+
best_conn = write_connections[0]
|
470
|
+
|
471
|
+
return RoutingDecision(
|
472
|
+
connection_id=best_conn.connection_id,
|
473
|
+
decision_factors={
|
474
|
+
"strategy": "primary_write",
|
475
|
+
"health": best_conn.health_score,
|
476
|
+
"load": best_conn.current_load,
|
477
|
+
},
|
478
|
+
alternatives=[c.connection_id for c in write_connections[1:3]],
|
479
|
+
confidence=0.9 if best_conn.health_score > 80 else 0.7,
|
480
|
+
)
|
481
|
+
|
482
|
+
def _calculate_connection_score(
|
483
|
+
self, conn: ConnectionInfo, fingerprint: QueryFingerprint
|
484
|
+
) -> float:
|
485
|
+
"""Calculate a score for connection suitability."""
|
486
|
+
score = 0.0
|
487
|
+
|
488
|
+
# Health score (40% weight)
|
489
|
+
score += (conn.health_score / 100) * 0.4
|
490
|
+
|
491
|
+
# Load score (30% weight) - inverse relationship
|
492
|
+
max_load = 10 # Assume max 10 concurrent queries
|
493
|
+
load_score = 1.0 - (min(conn.current_load, max_load) / max_load)
|
494
|
+
score += load_score * 0.3
|
495
|
+
|
496
|
+
# Latency score (20% weight) - inverse relationship
|
497
|
+
max_latency = 100 # 100ms threshold
|
498
|
+
latency_score = 1.0 - (min(conn.avg_latency_ms, max_latency) / max_latency)
|
499
|
+
score += latency_score * 0.2
|
500
|
+
|
501
|
+
# Affinity score (10% weight)
|
502
|
+
query_key = fingerprint.template
|
503
|
+
if query_key in self.connection_affinity:
|
504
|
+
if self.connection_affinity[query_key] == conn.connection_id:
|
505
|
+
score += 0.1
|
506
|
+
|
507
|
+
return score
|
508
|
+
|
509
|
+
def _calculate_confidence(
|
510
|
+
self, scores: List[Tuple[ConnectionInfo, float]]
|
511
|
+
) -> float:
|
512
|
+
"""Calculate confidence in routing decision."""
|
513
|
+
if len(scores) < 2:
|
514
|
+
return 0.5
|
515
|
+
|
516
|
+
best_score = scores[0][1]
|
517
|
+
second_score = scores[1][1]
|
518
|
+
|
519
|
+
# Confidence based on score separation
|
520
|
+
score_diff = best_score - second_score
|
521
|
+
|
522
|
+
if score_diff > 0.3:
|
523
|
+
return 0.95
|
524
|
+
elif score_diff > 0.2:
|
525
|
+
return 0.85
|
526
|
+
elif score_diff > 0.1:
|
527
|
+
return 0.75
|
528
|
+
else:
|
529
|
+
return 0.65
|
530
|
+
|
531
|
+
|
532
|
+
@register_node()
|
533
|
+
class QueryRouterNode(AsyncNode):
|
534
|
+
"""
|
535
|
+
Intelligent query routing for optimal database performance.
|
536
|
+
|
537
|
+
This node analyzes SQL queries and routes them to the most appropriate
|
538
|
+
connection from a WorkflowConnectionPool based on:
|
539
|
+
- Query type (read/write)
|
540
|
+
- Connection health and load
|
541
|
+
- Historical performance
|
542
|
+
- Prepared statement cache
|
543
|
+
|
544
|
+
Parameters:
|
545
|
+
connection_pool (str): Name of the WorkflowConnectionPool node
|
546
|
+
enable_read_write_split (bool): Enable read/write splitting
|
547
|
+
cache_size (int): Size of prepared statement cache
|
548
|
+
pattern_learning (bool): Enable pattern-based optimization
|
549
|
+
health_threshold (float): Minimum health score for routing
|
550
|
+
|
551
|
+
Example:
|
552
|
+
>>> router = QueryRouterNode(
|
553
|
+
... name="smart_router",
|
554
|
+
... connection_pool="db_pool",
|
555
|
+
... enable_read_write_split=True,
|
556
|
+
... cache_size=1000
|
557
|
+
... )
|
558
|
+
>>>
|
559
|
+
>>> # Query is automatically routed to optimal connection
|
560
|
+
>>> result = await router.process({
|
561
|
+
... "query": "SELECT * FROM orders WHERE status = ?",
|
562
|
+
... "parameters": ["pending"]
|
563
|
+
... })
|
564
|
+
"""
|
565
|
+
|
566
|
+
def __init__(self, **config):
|
567
|
+
super().__init__(**config)
|
568
|
+
|
569
|
+
# Configuration
|
570
|
+
self.connection_pool_name = config.get("connection_pool")
|
571
|
+
if not self.connection_pool_name:
|
572
|
+
raise ValueError("connection_pool parameter is required")
|
573
|
+
|
574
|
+
self.enable_read_write_split = config.get("enable_read_write_split", True)
|
575
|
+
self.cache_size = config.get("cache_size", 1000)
|
576
|
+
self.pattern_learning = config.get("pattern_learning", True)
|
577
|
+
self.health_threshold = config.get("health_threshold", 50.0)
|
578
|
+
|
579
|
+
# Components
|
580
|
+
self.classifier = QueryClassifier()
|
581
|
+
self.statement_cache = PreparedStatementCache(max_size=self.cache_size)
|
582
|
+
self.routing_engine = RoutingDecisionEngine(
|
583
|
+
health_threshold=self.health_threshold,
|
584
|
+
enable_read_write_split=self.enable_read_write_split,
|
585
|
+
)
|
586
|
+
|
587
|
+
# Metrics
|
588
|
+
self.metrics = {
|
589
|
+
"queries_routed": 0,
|
590
|
+
"cache_hits": 0,
|
591
|
+
"cache_misses": 0,
|
592
|
+
"routing_errors": 0,
|
593
|
+
"avg_routing_time_ms": 0.0,
|
594
|
+
}
|
595
|
+
|
596
|
+
# Transaction tracking
|
597
|
+
self.active_transactions = {} # session_id -> connection_id
|
598
|
+
|
599
|
+
# Direct pool reference
|
600
|
+
self._connection_pool = None
|
601
|
+
|
602
|
+
def set_connection_pool(self, pool):
|
603
|
+
"""Set the connection pool directly.
|
604
|
+
|
605
|
+
Args:
|
606
|
+
pool: Connection pool instance
|
607
|
+
"""
|
608
|
+
self._connection_pool = pool
|
609
|
+
|
610
|
+
@classmethod
|
611
|
+
def get_parameters(cls) -> Dict[str, NodeParameter]:
|
612
|
+
"""Define node parameters."""
|
613
|
+
return {
|
614
|
+
"connection_pool": NodeParameter(
|
615
|
+
name="connection_pool",
|
616
|
+
type=str,
|
617
|
+
description="Name of the WorkflowConnectionPool node to use",
|
618
|
+
required=True,
|
619
|
+
),
|
620
|
+
"enable_read_write_split": NodeParameter(
|
621
|
+
name="enable_read_write_split",
|
622
|
+
type=bool,
|
623
|
+
description="Enable routing reads to replica connections",
|
624
|
+
required=False,
|
625
|
+
default=True,
|
626
|
+
),
|
627
|
+
"cache_size": NodeParameter(
|
628
|
+
name="cache_size",
|
629
|
+
type=int,
|
630
|
+
description="Maximum number of prepared statements to cache",
|
631
|
+
required=False,
|
632
|
+
default=1000,
|
633
|
+
),
|
634
|
+
"pattern_learning": NodeParameter(
|
635
|
+
name="pattern_learning",
|
636
|
+
type=bool,
|
637
|
+
description="Enable pattern-based query optimization",
|
638
|
+
required=False,
|
639
|
+
default=True,
|
640
|
+
),
|
641
|
+
"health_threshold": NodeParameter(
|
642
|
+
name="health_threshold",
|
643
|
+
type=float,
|
644
|
+
description="Minimum health score for connection routing (0-100)",
|
645
|
+
required=False,
|
646
|
+
default=50.0,
|
647
|
+
),
|
648
|
+
}
|
649
|
+
|
650
|
+
async def process(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
651
|
+
"""Route and execute a query."""
|
652
|
+
start_time = time.time()
|
653
|
+
|
654
|
+
try:
|
655
|
+
# Extract query and parameters
|
656
|
+
query = input_data.get("query")
|
657
|
+
if not query:
|
658
|
+
raise ValueError("'query' is required in input_data")
|
659
|
+
|
660
|
+
parameters = input_data.get("parameters", [])
|
661
|
+
session_id = input_data.get("session_id")
|
662
|
+
|
663
|
+
# Get connection pool
|
664
|
+
pool_node = self._connection_pool
|
665
|
+
if not pool_node:
|
666
|
+
# Try to get from runtime
|
667
|
+
if hasattr(self, "runtime") and hasattr(self.runtime, "get_node"):
|
668
|
+
pool_node = self.runtime.get_node(self.connection_pool_name)
|
669
|
+
elif hasattr(self, "runtime") and hasattr(
|
670
|
+
self.runtime, "resource_registry"
|
671
|
+
):
|
672
|
+
pool_node = self.runtime.resource_registry.get(
|
673
|
+
self.connection_pool_name
|
674
|
+
)
|
675
|
+
elif hasattr(self, "context") and hasattr(
|
676
|
+
self.context, "resource_registry"
|
677
|
+
):
|
678
|
+
pool_node = self.context.resource_registry.get(
|
679
|
+
self.connection_pool_name
|
680
|
+
)
|
681
|
+
|
682
|
+
if not pool_node:
|
683
|
+
raise NodeExecutionError(
|
684
|
+
f"Connection pool '{self.connection_pool_name}' not found"
|
685
|
+
)
|
686
|
+
|
687
|
+
# Classify and fingerprint query
|
688
|
+
fingerprint = self.classifier.fingerprint(query, parameters)
|
689
|
+
|
690
|
+
# Get available connections
|
691
|
+
pool_status = await pool_node.process({"operation": "get_status"})
|
692
|
+
available_connections = self._parse_pool_status(pool_status)
|
693
|
+
|
694
|
+
# Check for active transaction
|
695
|
+
transaction_context = None
|
696
|
+
if session_id and session_id in self.active_transactions:
|
697
|
+
transaction_context = self.active_transactions[session_id]
|
698
|
+
|
699
|
+
# Handle transaction commands
|
700
|
+
if fingerprint.query_type == QueryType.TRANSACTION:
|
701
|
+
return await self._handle_transaction_command(
|
702
|
+
query, session_id, pool_node, available_connections
|
703
|
+
)
|
704
|
+
|
705
|
+
# Make routing decision
|
706
|
+
decision = self.routing_engine.select_connection(
|
707
|
+
fingerprint, available_connections, transaction_context
|
708
|
+
)
|
709
|
+
|
710
|
+
# Check cache for prepared statement
|
711
|
+
cache_key = fingerprint.template
|
712
|
+
cached_statement = self.statement_cache.get(
|
713
|
+
cache_key, decision.connection_id
|
714
|
+
)
|
715
|
+
|
716
|
+
if cached_statement:
|
717
|
+
self.metrics["cache_hits"] += 1
|
718
|
+
else:
|
719
|
+
self.metrics["cache_misses"] += 1
|
720
|
+
|
721
|
+
# Execute query on selected connection
|
722
|
+
result = await self._execute_on_connection(
|
723
|
+
pool_node,
|
724
|
+
decision.connection_id,
|
725
|
+
query,
|
726
|
+
parameters,
|
727
|
+
fingerprint,
|
728
|
+
cached_statement,
|
729
|
+
)
|
730
|
+
|
731
|
+
# Update metrics
|
732
|
+
self.metrics["queries_routed"] += 1
|
733
|
+
routing_time = (time.time() - start_time) * 1000
|
734
|
+
self.metrics["avg_routing_time_ms"] = (
|
735
|
+
self.metrics["avg_routing_time_ms"]
|
736
|
+
* (self.metrics["queries_routed"] - 1)
|
737
|
+
+ routing_time
|
738
|
+
) / self.metrics["queries_routed"]
|
739
|
+
|
740
|
+
# Add routing metadata to result
|
741
|
+
result["routing_metadata"] = {
|
742
|
+
"connection_id": decision.connection_id,
|
743
|
+
"query_type": fingerprint.query_type.value,
|
744
|
+
"complexity_score": fingerprint.complexity_score,
|
745
|
+
"routing_confidence": decision.confidence,
|
746
|
+
"cache_hit": cached_statement is not None,
|
747
|
+
"routing_time_ms": routing_time,
|
748
|
+
}
|
749
|
+
|
750
|
+
return result
|
751
|
+
|
752
|
+
except Exception as e:
|
753
|
+
self.metrics["routing_errors"] += 1
|
754
|
+
logger.error(f"Query routing error: {str(e)}")
|
755
|
+
raise NodeExecutionError(f"Query routing failed: {str(e)}")
|
756
|
+
|
757
|
+
def _parse_pool_status(self, pool_status: Dict[str, Any]) -> List[ConnectionInfo]:
|
758
|
+
"""Parse pool status into ConnectionInfo objects."""
|
759
|
+
connections = []
|
760
|
+
|
761
|
+
for conn_id, conn_data in pool_status.get("connections", {}).items():
|
762
|
+
connections.append(
|
763
|
+
ConnectionInfo(
|
764
|
+
connection_id=conn_id,
|
765
|
+
health_score=conn_data.get("health_score", 0),
|
766
|
+
current_load=conn_data.get("active_queries", 0),
|
767
|
+
capabilities=set(conn_data.get("capabilities", ["read", "write"])),
|
768
|
+
avg_latency_ms=conn_data.get("avg_latency_ms", 0),
|
769
|
+
last_used=datetime.fromisoformat(
|
770
|
+
conn_data.get("last_used", datetime.now().isoformat())
|
771
|
+
),
|
772
|
+
)
|
773
|
+
)
|
774
|
+
|
775
|
+
return connections
|
776
|
+
|
777
|
+
async def _handle_transaction_command(
|
778
|
+
self,
|
779
|
+
query: str,
|
780
|
+
session_id: Optional[str],
|
781
|
+
pool_node: Any,
|
782
|
+
connections: List[ConnectionInfo],
|
783
|
+
) -> Dict[str, Any]:
|
784
|
+
"""Handle transaction control commands."""
|
785
|
+
query_upper = query.upper().strip()
|
786
|
+
|
787
|
+
if query_upper.startswith(("BEGIN", "START TRANSACTION")):
|
788
|
+
if not session_id:
|
789
|
+
raise ValueError("session_id required for transactions")
|
790
|
+
|
791
|
+
# Select a connection for the transaction
|
792
|
+
write_connections = [c for c in connections if "write" in c.capabilities]
|
793
|
+
if not write_connections:
|
794
|
+
raise NodeExecutionError("No write-capable connections available")
|
795
|
+
|
796
|
+
# Use healthiest connection
|
797
|
+
best_conn = max(write_connections, key=lambda c: c.health_score)
|
798
|
+
self.active_transactions[session_id] = best_conn.connection_id
|
799
|
+
|
800
|
+
# Execute BEGIN on selected connection
|
801
|
+
result = await pool_node.process(
|
802
|
+
{
|
803
|
+
"operation": "execute",
|
804
|
+
"connection_id": best_conn.connection_id,
|
805
|
+
"query": query,
|
806
|
+
}
|
807
|
+
)
|
808
|
+
|
809
|
+
result["transaction_started"] = True
|
810
|
+
result["connection_id"] = best_conn.connection_id
|
811
|
+
return result
|
812
|
+
|
813
|
+
elif query_upper.startswith(("COMMIT", "ROLLBACK")):
|
814
|
+
if not session_id or session_id not in self.active_transactions:
|
815
|
+
raise ValueError("No active transaction for session")
|
816
|
+
|
817
|
+
conn_id = self.active_transactions[session_id]
|
818
|
+
|
819
|
+
# Execute on transaction connection
|
820
|
+
result = await pool_node.process(
|
821
|
+
{"operation": "execute", "connection_id": conn_id, "query": query}
|
822
|
+
)
|
823
|
+
|
824
|
+
# Clear transaction state
|
825
|
+
del self.active_transactions[session_id]
|
826
|
+
|
827
|
+
result["transaction_ended"] = True
|
828
|
+
return result
|
829
|
+
|
830
|
+
else:
|
831
|
+
raise ValueError(f"Unknown transaction command: {query}")
|
832
|
+
|
833
|
+
async def _execute_on_connection(
|
834
|
+
self,
|
835
|
+
pool_node: Any,
|
836
|
+
connection_id: str,
|
837
|
+
query: str,
|
838
|
+
parameters: List[Any],
|
839
|
+
fingerprint: QueryFingerprint,
|
840
|
+
cached_statement: Optional[Dict[str, Any]],
|
841
|
+
) -> Dict[str, Any]:
|
842
|
+
"""Execute query on selected connection."""
|
843
|
+
|
844
|
+
# Build execution request
|
845
|
+
execution_request = {
|
846
|
+
"operation": "execute",
|
847
|
+
"connection_id": connection_id,
|
848
|
+
"query": query,
|
849
|
+
"parameters": parameters,
|
850
|
+
}
|
851
|
+
|
852
|
+
# Add caching hint if available
|
853
|
+
if cached_statement:
|
854
|
+
execution_request["use_prepared"] = True
|
855
|
+
execution_request["statement_name"] = cached_statement.get("statement_name")
|
856
|
+
|
857
|
+
# Execute query
|
858
|
+
result = await pool_node.process(execution_request)
|
859
|
+
|
860
|
+
# Cache prepared statement info if not cached
|
861
|
+
if not cached_statement and result.get("prepared_statement_name"):
|
862
|
+
self.statement_cache.put(
|
863
|
+
fingerprint.template,
|
864
|
+
connection_id,
|
865
|
+
{
|
866
|
+
"statement_name": result["prepared_statement_name"],
|
867
|
+
"tables": list(fingerprint.tables),
|
868
|
+
"created_at": datetime.now(),
|
869
|
+
},
|
870
|
+
)
|
871
|
+
|
872
|
+
return result
|
873
|
+
|
874
|
+
async def get_metrics(self) -> Dict[str, Any]:
|
875
|
+
"""Get router metrics and statistics."""
|
876
|
+
cache_stats = self.statement_cache.get_stats()
|
877
|
+
|
878
|
+
return {
|
879
|
+
"router_metrics": self.metrics,
|
880
|
+
"cache_stats": cache_stats,
|
881
|
+
"active_transactions": len(self.active_transactions),
|
882
|
+
"routing_history": {
|
883
|
+
"total_decisions": len(self.routing_engine.routing_history),
|
884
|
+
"recent_decisions": list(self.routing_engine.routing_history)[-10:],
|
885
|
+
},
|
886
|
+
}
|
887
|
+
|
888
|
+
async def invalidate_cache(self, tables: Optional[List[str]] = None):
|
889
|
+
"""Invalidate prepared statement cache."""
|
890
|
+
if tables:
|
891
|
+
self.statement_cache.invalidate(set(tables))
|
892
|
+
else:
|
893
|
+
self.statement_cache.invalidate()
|
894
|
+
|
895
|
+
return {"invalidated": True, "tables": tables}
|