kailash 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kailash/__init__.py +1 -1
  2. kailash/access_control/__init__.py +1 -1
  3. kailash/core/actors/adaptive_pool_controller.py +630 -0
  4. kailash/core/actors/connection_actor.py +3 -3
  5. kailash/core/ml/__init__.py +1 -0
  6. kailash/core/ml/query_patterns.py +544 -0
  7. kailash/core/monitoring/__init__.py +19 -0
  8. kailash/core/monitoring/connection_metrics.py +488 -0
  9. kailash/core/optimization/__init__.py +1 -0
  10. kailash/core/resilience/__init__.py +17 -0
  11. kailash/core/resilience/circuit_breaker.py +382 -0
  12. kailash/gateway/api.py +7 -5
  13. kailash/gateway/enhanced_gateway.py +1 -1
  14. kailash/middleware/auth/access_control.py +11 -11
  15. kailash/middleware/communication/ai_chat.py +7 -7
  16. kailash/middleware/communication/api_gateway.py +5 -15
  17. kailash/middleware/gateway/checkpoint_manager.py +45 -8
  18. kailash/middleware/gateway/event_store.py +66 -26
  19. kailash/middleware/mcp/enhanced_server.py +2 -2
  20. kailash/nodes/admin/permission_check.py +110 -30
  21. kailash/nodes/admin/schema.sql +387 -0
  22. kailash/nodes/admin/tenant_isolation.py +249 -0
  23. kailash/nodes/admin/transaction_utils.py +244 -0
  24. kailash/nodes/admin/user_management.py +37 -9
  25. kailash/nodes/ai/ai_providers.py +55 -3
  26. kailash/nodes/ai/llm_agent.py +115 -13
  27. kailash/nodes/data/query_pipeline.py +641 -0
  28. kailash/nodes/data/query_router.py +895 -0
  29. kailash/nodes/data/sql.py +24 -0
  30. kailash/nodes/data/workflow_connection_pool.py +451 -23
  31. kailash/nodes/monitoring/__init__.py +3 -5
  32. kailash/nodes/monitoring/connection_dashboard.py +822 -0
  33. kailash/nodes/rag/__init__.py +1 -3
  34. kailash/resources/registry.py +6 -0
  35. kailash/runtime/async_local.py +7 -0
  36. kailash/utils/export.py +152 -0
  37. kailash/workflow/builder.py +42 -0
  38. kailash/workflow/graph.py +86 -17
  39. kailash/workflow/templates.py +4 -9
  40. {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/METADATA +14 -1
  41. {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/RECORD +45 -31
  42. {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/WHEEL +0 -0
  43. {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/entry_points.txt +0 -0
  44. {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/licenses/LICENSE +0 -0
  45. {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,895 @@
1
+ """Intelligent query routing for optimal database connection utilization.
2
+
3
+ This module implements a query router that analyzes queries and routes them
4
+ to the most appropriate connection based on query type, connection health,
5
+ and historical performance data.
6
+ """
7
+
8
+ import asyncio
9
+ import hashlib
10
+ import logging
11
+ import re
12
+ import time
13
+ from collections import defaultdict, deque
14
+ from dataclasses import dataclass
15
+ from datetime import datetime, timedelta
16
+ from enum import Enum
17
+ from typing import Any, Dict, List, Optional, Set, Tuple
18
+
19
+ from kailash.nodes.base import NodeParameter, register_node
20
+ from kailash.nodes.base_async import AsyncNode
21
+ from kailash.sdk_exceptions import NodeExecutionError
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class QueryType(Enum):
27
+ """Query classification types."""
28
+
29
+ READ_SIMPLE = "read_simple" # Single table SELECT
30
+ READ_COMPLEX = "read_complex" # JOINs, aggregations
31
+ WRITE_SIMPLE = "write_simple" # Single row INSERT/UPDATE/DELETE
32
+ WRITE_BULK = "write_bulk" # Multi-row operations
33
+ DDL = "ddl" # Schema modifications
34
+ TRANSACTION = "transaction" # Explicit transaction blocks
35
+ UNKNOWN = "unknown" # Unclassified queries
36
+
37
+
38
+ @dataclass
39
+ class QueryFingerprint:
40
+ """Normalized query representation for caching and pattern matching."""
41
+
42
+ template: str # Query with parameters replaced
43
+ query_type: QueryType # Classification
44
+ tables: Set[str] # Tables involved
45
+ is_read_only: bool # True for SELECT queries
46
+ complexity_score: float # Estimated complexity (0-1)
47
+
48
+
49
+ @dataclass
50
+ class ConnectionInfo:
51
+ """Information about an available connection."""
52
+
53
+ connection_id: str
54
+ health_score: float # 0-100
55
+ current_load: int # Active queries
56
+ capabilities: Set[str] # e.g., {"read", "write", "ddl"}
57
+ avg_latency_ms: float # Recent average latency
58
+ last_used: datetime # For LRU routing
59
+
60
+
61
+ @dataclass
62
+ class RoutingDecision:
63
+ """Result of routing decision."""
64
+
65
+ connection_id: str
66
+ decision_factors: Dict[str, Any] # Why this connection was chosen
67
+ alternatives: List[str] # Other viable connections
68
+ confidence: float # 0-1 confidence in decision
69
+
70
+
71
+ class QueryClassifier:
72
+ """Classifies SQL queries for routing decisions."""
73
+
74
+ # Regex patterns for query classification
75
+ SELECT_SIMPLE = re.compile(
76
+ r"^\s*SELECT\s+.*?\s+FROM\s+(\w+)(?:\s+WHERE|\s*;?\s*$)",
77
+ re.IGNORECASE | re.DOTALL,
78
+ )
79
+ SELECT_COMPLEX = re.compile(
80
+ r"^\s*SELECT\s+.*?\s+FROM\s+.*?(?:JOIN|GROUP\s+BY|HAVING|UNION|INTERSECT|EXCEPT)",
81
+ re.IGNORECASE | re.DOTALL,
82
+ )
83
+ INSERT_PATTERN = re.compile(r"^\s*INSERT\s+INTO", re.IGNORECASE)
84
+ UPDATE_PATTERN = re.compile(r"^\s*UPDATE\s+", re.IGNORECASE)
85
+ DELETE_PATTERN = re.compile(r"^\s*DELETE\s+FROM", re.IGNORECASE)
86
+ DDL_PATTERN = re.compile(r"^\s*(?:CREATE|ALTER|DROP|TRUNCATE)\s+", re.IGNORECASE)
87
+ TRANSACTION_PATTERN = re.compile(
88
+ r"^\s*(?:BEGIN|START\s+TRANSACTION|COMMIT|ROLLBACK)", re.IGNORECASE
89
+ )
90
+ BULK_PATTERN = re.compile(
91
+ r"(?:VALUES\s*\([^)]+\)(?:\s*,\s*\([^)]+\)){2,}|COPY\s+|BULK\s+INSERT)",
92
+ re.IGNORECASE,
93
+ )
94
+
95
+ def __init__(self):
96
+ self.classification_cache = {}
97
+ self.max_cache_size = 10000
98
+
99
+ def classify(self, query: str) -> QueryType:
100
+ """Classify a SQL query into one of the defined types."""
101
+ # Check cache first
102
+ query_hash = hashlib.md5(query.encode()).hexdigest()
103
+ if query_hash in self.classification_cache:
104
+ return self.classification_cache[query_hash]
105
+
106
+ # Clean the query
107
+ cleaned_query = self._clean_query(query)
108
+
109
+ # Classification logic
110
+ query_type = self._classify_query(cleaned_query)
111
+
112
+ # Cache the result
113
+ if len(self.classification_cache) >= self.max_cache_size:
114
+ # Simple LRU: remove oldest entries
115
+ oldest_keys = list(self.classification_cache.keys())[:1000]
116
+ for key in oldest_keys:
117
+ del self.classification_cache[key]
118
+
119
+ self.classification_cache[query_hash] = query_type
120
+ return query_type
121
+
122
+ def _clean_query(self, query: str) -> str:
123
+ """Remove comments and normalize whitespace."""
124
+ # Remove single-line comments
125
+ query = re.sub(r"--[^\n]*", "", query)
126
+ # Remove multi-line comments
127
+ query = re.sub(r"/\*.*?\*/", "", query, flags=re.DOTALL)
128
+ # Normalize whitespace
129
+ query = " ".join(query.split())
130
+ return query.strip()
131
+
132
+ def _classify_query(self, query: str) -> QueryType:
133
+ """Perform the actual classification."""
134
+ # Check for transaction commands
135
+ if self.TRANSACTION_PATTERN.match(query):
136
+ return QueryType.TRANSACTION
137
+
138
+ # Check for DDL
139
+ if self.DDL_PATTERN.match(query):
140
+ return QueryType.DDL
141
+
142
+ # Check for bulk operations
143
+ if self.BULK_PATTERN.search(query):
144
+ return QueryType.WRITE_BULK
145
+
146
+ # Check for complex SELECT
147
+ if self.SELECT_COMPLEX.search(query):
148
+ return QueryType.READ_COMPLEX
149
+
150
+ # Check for simple SELECT
151
+ if self.SELECT_SIMPLE.match(query):
152
+ return QueryType.READ_SIMPLE
153
+
154
+ # Check for INSERT/UPDATE/DELETE
155
+ if (
156
+ self.INSERT_PATTERN.match(query)
157
+ or self.UPDATE_PATTERN.match(query)
158
+ or self.DELETE_PATTERN.match(query)
159
+ ):
160
+ return QueryType.WRITE_SIMPLE
161
+
162
+ return QueryType.UNKNOWN
163
+
164
+ def fingerprint(
165
+ self, query: str, parameters: Optional[List[Any]] = None
166
+ ) -> QueryFingerprint:
167
+ """Create a normalized fingerprint of the query."""
168
+ cleaned_query = self._clean_query(query)
169
+ query_type = self.classify(query)
170
+
171
+ # Extract tables
172
+ tables = self._extract_tables(cleaned_query)
173
+
174
+ # Normalize parameters
175
+ template = self._create_template(cleaned_query, parameters)
176
+
177
+ # Calculate complexity
178
+ complexity = self._calculate_complexity(cleaned_query, query_type)
179
+
180
+ # Determine if read-only
181
+ is_read_only = query_type in [QueryType.READ_SIMPLE, QueryType.READ_COMPLEX]
182
+
183
+ return QueryFingerprint(
184
+ template=template,
185
+ query_type=query_type,
186
+ tables=tables,
187
+ is_read_only=is_read_only,
188
+ complexity_score=complexity,
189
+ )
190
+
191
+ def _extract_tables(self, query: str) -> Set[str]:
192
+ """Extract table names from query."""
193
+ tables = set()
194
+
195
+ # FROM clause
196
+ from_matches = re.findall(r"FROM\s+(\w+)", query, re.IGNORECASE)
197
+ tables.update(from_matches)
198
+
199
+ # JOIN clauses
200
+ join_matches = re.findall(r"JOIN\s+(\w+)", query, re.IGNORECASE)
201
+ tables.update(join_matches)
202
+
203
+ # INSERT INTO
204
+ insert_matches = re.findall(r"INSERT\s+INTO\s+(\w+)", query, re.IGNORECASE)
205
+ tables.update(insert_matches)
206
+
207
+ # UPDATE
208
+ update_matches = re.findall(r"UPDATE\s+(\w+)", query, re.IGNORECASE)
209
+ tables.update(update_matches)
210
+
211
+ # DELETE FROM
212
+ delete_matches = re.findall(r"DELETE\s+FROM\s+(\w+)", query, re.IGNORECASE)
213
+ tables.update(delete_matches)
214
+
215
+ return tables
216
+
217
+ def _create_template(self, query: str, parameters: Optional[List[Any]]) -> str:
218
+ """Create query template with normalized parameters."""
219
+ template = query
220
+
221
+ # Replace string literals
222
+ template = re.sub(r"'[^']*'", "?", template)
223
+ template = re.sub(r'"[^"]*"', "?", template)
224
+
225
+ # Replace numbers
226
+ template = re.sub(r"\b\d+\.?\d*\b", "?", template)
227
+
228
+ # Replace parameter placeholders
229
+ template = re.sub(r"%s|\$\d+|\?", "?", template)
230
+
231
+ return template
232
+
233
+ def _calculate_complexity(self, query: str, query_type: QueryType) -> float:
234
+ """Calculate query complexity score (0-1)."""
235
+ score = 0.0
236
+
237
+ # Base scores by type
238
+ base_scores = {
239
+ QueryType.READ_SIMPLE: 0.1,
240
+ QueryType.READ_COMPLEX: 0.5,
241
+ QueryType.WRITE_SIMPLE: 0.2,
242
+ QueryType.WRITE_BULK: 0.6,
243
+ QueryType.DDL: 0.8,
244
+ QueryType.TRANSACTION: 0.3,
245
+ QueryType.UNKNOWN: 0.5,
246
+ }
247
+ score = base_scores.get(query_type, 0.5)
248
+
249
+ # Adjust for query features
250
+ if re.search(r"\bJOIN\b", query, re.IGNORECASE):
251
+ score += 0.1 * len(re.findall(r"\bJOIN\b", query, re.IGNORECASE))
252
+
253
+ if re.search(r"\bGROUP\s+BY\b", query, re.IGNORECASE):
254
+ score += 0.15
255
+
256
+ if re.search(r"\bORDER\s+BY\b", query, re.IGNORECASE):
257
+ score += 0.05
258
+
259
+ if re.search(r"\bDISTINCT\b", query, re.IGNORECASE):
260
+ score += 0.1
261
+
262
+ # Subqueries
263
+ if query.count("SELECT") > 1:
264
+ score += 0.2 * (query.count("SELECT") - 1)
265
+
266
+ return min(score, 1.0)
267
+
268
+
269
+ class PreparedStatementCache:
270
+ """LRU cache for prepared statements with connection affinity."""
271
+
272
+ def __init__(self, max_size: int = 1000):
273
+ self.max_size = max_size
274
+ self.cache: Dict[str, Dict[str, Any]] = {} # fingerprint -> statement info
275
+ self.usage_order = deque() # For LRU eviction
276
+ self.usage_stats = defaultdict(int) # Track usage frequency
277
+
278
+ def get(self, fingerprint: str, connection_id: str) -> Optional[Dict[str, Any]]:
279
+ """Get cached statement if available for connection."""
280
+ if fingerprint in self.cache:
281
+ entry = self.cache[fingerprint]
282
+ if connection_id in entry.get("connections", {}):
283
+ # Update usage
284
+ self.usage_stats[fingerprint] += 1
285
+ self._update_usage_order(fingerprint)
286
+ return entry
287
+ return None
288
+
289
+ def put(self, fingerprint: str, connection_id: str, statement_info: Dict[str, Any]):
290
+ """Cache a prepared statement."""
291
+ if fingerprint not in self.cache:
292
+ # Check if we need to evict
293
+ if len(self.cache) >= self.max_size:
294
+ self._evict_lru()
295
+
296
+ self.cache[fingerprint] = {
297
+ "connections": {},
298
+ "created_at": datetime.now(),
299
+ "last_used": datetime.now(),
300
+ }
301
+
302
+ # Add connection-specific info
303
+ self.cache[fingerprint]["connections"][connection_id] = statement_info
304
+ self.cache[fingerprint]["last_used"] = datetime.now()
305
+ self._update_usage_order(fingerprint)
306
+
307
+ def invalidate(self, tables: Optional[Set[str]] = None):
308
+ """Invalidate cached statements for specific tables or all."""
309
+ if tables is None:
310
+ # Clear entire cache
311
+ self.cache.clear()
312
+ self.usage_order.clear()
313
+ self.usage_stats.clear()
314
+ else:
315
+ # Invalidate statements touching specified tables
316
+ to_remove = []
317
+ for fingerprint, entry in self.cache.items():
318
+ if "tables" in entry and entry["tables"].intersection(tables):
319
+ to_remove.append(fingerprint)
320
+
321
+ for fingerprint in to_remove:
322
+ del self.cache[fingerprint]
323
+ self.usage_stats.pop(fingerprint, None)
324
+
325
+ def _update_usage_order(self, fingerprint: str):
326
+ """Update LRU order."""
327
+ if fingerprint in self.usage_order:
328
+ self.usage_order.remove(fingerprint)
329
+ self.usage_order.append(fingerprint)
330
+
331
+ def _evict_lru(self):
332
+ """Evict least recently used entry."""
333
+ if self.usage_order:
334
+ victim = self.usage_order.popleft()
335
+ del self.cache[victim]
336
+ self.usage_stats.pop(victim, None)
337
+
338
+ def get_stats(self) -> Dict[str, Any]:
339
+ """Get cache statistics."""
340
+ total_entries = len(self.cache)
341
+ total_usage = sum(self.usage_stats.values())
342
+
343
+ return {
344
+ "total_entries": total_entries,
345
+ "total_usage": total_usage,
346
+ "hit_rate": total_usage / (total_usage + 1) if total_entries > 0 else 0,
347
+ "avg_usage_per_entry": (
348
+ total_usage / total_entries if total_entries > 0 else 0
349
+ ),
350
+ "cache_size_bytes": sum(len(str(v)) for v in self.cache.values()),
351
+ }
352
+
353
+
354
+ class RoutingDecisionEngine:
355
+ """Makes intelligent routing decisions based on multiple factors."""
356
+
357
+ def __init__(
358
+ self, health_threshold: float = 50.0, enable_read_write_split: bool = True
359
+ ):
360
+ self.health_threshold = health_threshold
361
+ self.enable_read_write_split = enable_read_write_split
362
+ self.routing_history = deque(maxlen=1000) # Recent routing decisions
363
+ self.connection_affinity = {} # Track query -> connection affinity
364
+
365
+ def select_connection(
366
+ self,
367
+ query_fingerprint: QueryFingerprint,
368
+ available_connections: List[ConnectionInfo],
369
+ transaction_context: Optional[str] = None,
370
+ ) -> RoutingDecision:
371
+ """Select the optimal connection for the query."""
372
+
373
+ # Filter healthy connections
374
+ healthy_connections = [
375
+ c for c in available_connections if c.health_score >= self.health_threshold
376
+ ]
377
+
378
+ if not healthy_connections:
379
+ # Fall back to any available connection
380
+ healthy_connections = available_connections
381
+
382
+ if not healthy_connections:
383
+ raise NodeExecutionError("No available connections for routing")
384
+
385
+ # If in transaction, must use same connection
386
+ if transaction_context:
387
+ for conn in healthy_connections:
388
+ if conn.connection_id == transaction_context:
389
+ return RoutingDecision(
390
+ connection_id=conn.connection_id,
391
+ decision_factors={"reason": "transaction_affinity"},
392
+ alternatives=[],
393
+ confidence=1.0,
394
+ )
395
+ raise NodeExecutionError(
396
+ f"Transaction connection {transaction_context} not available"
397
+ )
398
+
399
+ # Apply routing strategy
400
+ if self.enable_read_write_split and query_fingerprint.is_read_only:
401
+ selected = self._route_read_query(query_fingerprint, healthy_connections)
402
+ else:
403
+ selected = self._route_write_query(query_fingerprint, healthy_connections)
404
+
405
+ # Record decision
406
+ self.routing_history.append(
407
+ {
408
+ "timestamp": datetime.now(),
409
+ "query_type": query_fingerprint.query_type,
410
+ "connection": selected.connection_id,
411
+ "confidence": selected.confidence,
412
+ }
413
+ )
414
+
415
+ return selected
416
+
417
+ def _route_read_query(
418
+ self, fingerprint: QueryFingerprint, connections: List[ConnectionInfo]
419
+ ) -> RoutingDecision:
420
+ """Route read queries with load balancing."""
421
+ # Filter connections that support reads
422
+ read_connections = [c for c in connections if "read" in c.capabilities]
423
+
424
+ if not read_connections:
425
+ read_connections = connections
426
+
427
+ # Score each connection
428
+ scores = []
429
+ for conn in read_connections:
430
+ score = self._calculate_connection_score(conn, fingerprint)
431
+ scores.append((conn, score))
432
+
433
+ # Sort by score (descending)
434
+ scores.sort(key=lambda x: x[1], reverse=True)
435
+
436
+ # Select best connection
437
+ best_conn, best_score = scores[0]
438
+
439
+ # Calculate confidence based on score distribution
440
+ confidence = self._calculate_confidence(scores)
441
+
442
+ return RoutingDecision(
443
+ connection_id=best_conn.connection_id,
444
+ decision_factors={
445
+ "strategy": "load_balanced_read",
446
+ "score": best_score,
447
+ "health": best_conn.health_score,
448
+ "load": best_conn.current_load,
449
+ },
450
+ alternatives=[s[0].connection_id for s in scores[1:3]],
451
+ confidence=confidence,
452
+ )
453
+
454
+ def _route_write_query(
455
+ self, fingerprint: QueryFingerprint, connections: List[ConnectionInfo]
456
+ ) -> RoutingDecision:
457
+ """Route write queries to primary connections."""
458
+ # Filter connections that support writes
459
+ write_connections = [c for c in connections if "write" in c.capabilities]
460
+
461
+ if not write_connections:
462
+ write_connections = connections
463
+
464
+ # For writes, prefer the healthiest primary connection
465
+ write_connections.sort(
466
+ key=lambda c: (c.health_score, -c.current_load), reverse=True
467
+ )
468
+
469
+ best_conn = write_connections[0]
470
+
471
+ return RoutingDecision(
472
+ connection_id=best_conn.connection_id,
473
+ decision_factors={
474
+ "strategy": "primary_write",
475
+ "health": best_conn.health_score,
476
+ "load": best_conn.current_load,
477
+ },
478
+ alternatives=[c.connection_id for c in write_connections[1:3]],
479
+ confidence=0.9 if best_conn.health_score > 80 else 0.7,
480
+ )
481
+
482
+ def _calculate_connection_score(
483
+ self, conn: ConnectionInfo, fingerprint: QueryFingerprint
484
+ ) -> float:
485
+ """Calculate a score for connection suitability."""
486
+ score = 0.0
487
+
488
+ # Health score (40% weight)
489
+ score += (conn.health_score / 100) * 0.4
490
+
491
+ # Load score (30% weight) - inverse relationship
492
+ max_load = 10 # Assume max 10 concurrent queries
493
+ load_score = 1.0 - (min(conn.current_load, max_load) / max_load)
494
+ score += load_score * 0.3
495
+
496
+ # Latency score (20% weight) - inverse relationship
497
+ max_latency = 100 # 100ms threshold
498
+ latency_score = 1.0 - (min(conn.avg_latency_ms, max_latency) / max_latency)
499
+ score += latency_score * 0.2
500
+
501
+ # Affinity score (10% weight)
502
+ query_key = fingerprint.template
503
+ if query_key in self.connection_affinity:
504
+ if self.connection_affinity[query_key] == conn.connection_id:
505
+ score += 0.1
506
+
507
+ return score
508
+
509
+ def _calculate_confidence(
510
+ self, scores: List[Tuple[ConnectionInfo, float]]
511
+ ) -> float:
512
+ """Calculate confidence in routing decision."""
513
+ if len(scores) < 2:
514
+ return 0.5
515
+
516
+ best_score = scores[0][1]
517
+ second_score = scores[1][1]
518
+
519
+ # Confidence based on score separation
520
+ score_diff = best_score - second_score
521
+
522
+ if score_diff > 0.3:
523
+ return 0.95
524
+ elif score_diff > 0.2:
525
+ return 0.85
526
+ elif score_diff > 0.1:
527
+ return 0.75
528
+ else:
529
+ return 0.65
530
+
531
+
532
+ @register_node()
533
+ class QueryRouterNode(AsyncNode):
534
+ """
535
+ Intelligent query routing for optimal database performance.
536
+
537
+ This node analyzes SQL queries and routes them to the most appropriate
538
+ connection from a WorkflowConnectionPool based on:
539
+ - Query type (read/write)
540
+ - Connection health and load
541
+ - Historical performance
542
+ - Prepared statement cache
543
+
544
+ Parameters:
545
+ connection_pool (str): Name of the WorkflowConnectionPool node
546
+ enable_read_write_split (bool): Enable read/write splitting
547
+ cache_size (int): Size of prepared statement cache
548
+ pattern_learning (bool): Enable pattern-based optimization
549
+ health_threshold (float): Minimum health score for routing
550
+
551
+ Example:
552
+ >>> router = QueryRouterNode(
553
+ ... name="smart_router",
554
+ ... connection_pool="db_pool",
555
+ ... enable_read_write_split=True,
556
+ ... cache_size=1000
557
+ ... )
558
+ >>>
559
+ >>> # Query is automatically routed to optimal connection
560
+ >>> result = await router.process({
561
+ ... "query": "SELECT * FROM orders WHERE status = ?",
562
+ ... "parameters": ["pending"]
563
+ ... })
564
+ """
565
+
566
+ def __init__(self, **config):
567
+ super().__init__(**config)
568
+
569
+ # Configuration
570
+ self.connection_pool_name = config.get("connection_pool")
571
+ if not self.connection_pool_name:
572
+ raise ValueError("connection_pool parameter is required")
573
+
574
+ self.enable_read_write_split = config.get("enable_read_write_split", True)
575
+ self.cache_size = config.get("cache_size", 1000)
576
+ self.pattern_learning = config.get("pattern_learning", True)
577
+ self.health_threshold = config.get("health_threshold", 50.0)
578
+
579
+ # Components
580
+ self.classifier = QueryClassifier()
581
+ self.statement_cache = PreparedStatementCache(max_size=self.cache_size)
582
+ self.routing_engine = RoutingDecisionEngine(
583
+ health_threshold=self.health_threshold,
584
+ enable_read_write_split=self.enable_read_write_split,
585
+ )
586
+
587
+ # Metrics
588
+ self.metrics = {
589
+ "queries_routed": 0,
590
+ "cache_hits": 0,
591
+ "cache_misses": 0,
592
+ "routing_errors": 0,
593
+ "avg_routing_time_ms": 0.0,
594
+ }
595
+
596
+ # Transaction tracking
597
+ self.active_transactions = {} # session_id -> connection_id
598
+
599
+ # Direct pool reference
600
+ self._connection_pool = None
601
+
602
+ def set_connection_pool(self, pool):
603
+ """Set the connection pool directly.
604
+
605
+ Args:
606
+ pool: Connection pool instance
607
+ """
608
+ self._connection_pool = pool
609
+
610
+ @classmethod
611
+ def get_parameters(cls) -> Dict[str, NodeParameter]:
612
+ """Define node parameters."""
613
+ return {
614
+ "connection_pool": NodeParameter(
615
+ name="connection_pool",
616
+ type=str,
617
+ description="Name of the WorkflowConnectionPool node to use",
618
+ required=True,
619
+ ),
620
+ "enable_read_write_split": NodeParameter(
621
+ name="enable_read_write_split",
622
+ type=bool,
623
+ description="Enable routing reads to replica connections",
624
+ required=False,
625
+ default=True,
626
+ ),
627
+ "cache_size": NodeParameter(
628
+ name="cache_size",
629
+ type=int,
630
+ description="Maximum number of prepared statements to cache",
631
+ required=False,
632
+ default=1000,
633
+ ),
634
+ "pattern_learning": NodeParameter(
635
+ name="pattern_learning",
636
+ type=bool,
637
+ description="Enable pattern-based query optimization",
638
+ required=False,
639
+ default=True,
640
+ ),
641
+ "health_threshold": NodeParameter(
642
+ name="health_threshold",
643
+ type=float,
644
+ description="Minimum health score for connection routing (0-100)",
645
+ required=False,
646
+ default=50.0,
647
+ ),
648
+ }
649
+
650
+ async def process(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
651
+ """Route and execute a query."""
652
+ start_time = time.time()
653
+
654
+ try:
655
+ # Extract query and parameters
656
+ query = input_data.get("query")
657
+ if not query:
658
+ raise ValueError("'query' is required in input_data")
659
+
660
+ parameters = input_data.get("parameters", [])
661
+ session_id = input_data.get("session_id")
662
+
663
+ # Get connection pool
664
+ pool_node = self._connection_pool
665
+ if not pool_node:
666
+ # Try to get from runtime
667
+ if hasattr(self, "runtime") and hasattr(self.runtime, "get_node"):
668
+ pool_node = self.runtime.get_node(self.connection_pool_name)
669
+ elif hasattr(self, "runtime") and hasattr(
670
+ self.runtime, "resource_registry"
671
+ ):
672
+ pool_node = self.runtime.resource_registry.get(
673
+ self.connection_pool_name
674
+ )
675
+ elif hasattr(self, "context") and hasattr(
676
+ self.context, "resource_registry"
677
+ ):
678
+ pool_node = self.context.resource_registry.get(
679
+ self.connection_pool_name
680
+ )
681
+
682
+ if not pool_node:
683
+ raise NodeExecutionError(
684
+ f"Connection pool '{self.connection_pool_name}' not found"
685
+ )
686
+
687
+ # Classify and fingerprint query
688
+ fingerprint = self.classifier.fingerprint(query, parameters)
689
+
690
+ # Get available connections
691
+ pool_status = await pool_node.process({"operation": "get_status"})
692
+ available_connections = self._parse_pool_status(pool_status)
693
+
694
+ # Check for active transaction
695
+ transaction_context = None
696
+ if session_id and session_id in self.active_transactions:
697
+ transaction_context = self.active_transactions[session_id]
698
+
699
+ # Handle transaction commands
700
+ if fingerprint.query_type == QueryType.TRANSACTION:
701
+ return await self._handle_transaction_command(
702
+ query, session_id, pool_node, available_connections
703
+ )
704
+
705
+ # Make routing decision
706
+ decision = self.routing_engine.select_connection(
707
+ fingerprint, available_connections, transaction_context
708
+ )
709
+
710
+ # Check cache for prepared statement
711
+ cache_key = fingerprint.template
712
+ cached_statement = self.statement_cache.get(
713
+ cache_key, decision.connection_id
714
+ )
715
+
716
+ if cached_statement:
717
+ self.metrics["cache_hits"] += 1
718
+ else:
719
+ self.metrics["cache_misses"] += 1
720
+
721
+ # Execute query on selected connection
722
+ result = await self._execute_on_connection(
723
+ pool_node,
724
+ decision.connection_id,
725
+ query,
726
+ parameters,
727
+ fingerprint,
728
+ cached_statement,
729
+ )
730
+
731
+ # Update metrics
732
+ self.metrics["queries_routed"] += 1
733
+ routing_time = (time.time() - start_time) * 1000
734
+ self.metrics["avg_routing_time_ms"] = (
735
+ self.metrics["avg_routing_time_ms"]
736
+ * (self.metrics["queries_routed"] - 1)
737
+ + routing_time
738
+ ) / self.metrics["queries_routed"]
739
+
740
+ # Add routing metadata to result
741
+ result["routing_metadata"] = {
742
+ "connection_id": decision.connection_id,
743
+ "query_type": fingerprint.query_type.value,
744
+ "complexity_score": fingerprint.complexity_score,
745
+ "routing_confidence": decision.confidence,
746
+ "cache_hit": cached_statement is not None,
747
+ "routing_time_ms": routing_time,
748
+ }
749
+
750
+ return result
751
+
752
+ except Exception as e:
753
+ self.metrics["routing_errors"] += 1
754
+ logger.error(f"Query routing error: {str(e)}")
755
+ raise NodeExecutionError(f"Query routing failed: {str(e)}")
756
+
757
+ def _parse_pool_status(self, pool_status: Dict[str, Any]) -> List[ConnectionInfo]:
758
+ """Parse pool status into ConnectionInfo objects."""
759
+ connections = []
760
+
761
+ for conn_id, conn_data in pool_status.get("connections", {}).items():
762
+ connections.append(
763
+ ConnectionInfo(
764
+ connection_id=conn_id,
765
+ health_score=conn_data.get("health_score", 0),
766
+ current_load=conn_data.get("active_queries", 0),
767
+ capabilities=set(conn_data.get("capabilities", ["read", "write"])),
768
+ avg_latency_ms=conn_data.get("avg_latency_ms", 0),
769
+ last_used=datetime.fromisoformat(
770
+ conn_data.get("last_used", datetime.now().isoformat())
771
+ ),
772
+ )
773
+ )
774
+
775
+ return connections
776
+
777
+ async def _handle_transaction_command(
778
+ self,
779
+ query: str,
780
+ session_id: Optional[str],
781
+ pool_node: Any,
782
+ connections: List[ConnectionInfo],
783
+ ) -> Dict[str, Any]:
784
+ """Handle transaction control commands."""
785
+ query_upper = query.upper().strip()
786
+
787
+ if query_upper.startswith(("BEGIN", "START TRANSACTION")):
788
+ if not session_id:
789
+ raise ValueError("session_id required for transactions")
790
+
791
+ # Select a connection for the transaction
792
+ write_connections = [c for c in connections if "write" in c.capabilities]
793
+ if not write_connections:
794
+ raise NodeExecutionError("No write-capable connections available")
795
+
796
+ # Use healthiest connection
797
+ best_conn = max(write_connections, key=lambda c: c.health_score)
798
+ self.active_transactions[session_id] = best_conn.connection_id
799
+
800
+ # Execute BEGIN on selected connection
801
+ result = await pool_node.process(
802
+ {
803
+ "operation": "execute",
804
+ "connection_id": best_conn.connection_id,
805
+ "query": query,
806
+ }
807
+ )
808
+
809
+ result["transaction_started"] = True
810
+ result["connection_id"] = best_conn.connection_id
811
+ return result
812
+
813
+ elif query_upper.startswith(("COMMIT", "ROLLBACK")):
814
+ if not session_id or session_id not in self.active_transactions:
815
+ raise ValueError("No active transaction for session")
816
+
817
+ conn_id = self.active_transactions[session_id]
818
+
819
+ # Execute on transaction connection
820
+ result = await pool_node.process(
821
+ {"operation": "execute", "connection_id": conn_id, "query": query}
822
+ )
823
+
824
+ # Clear transaction state
825
+ del self.active_transactions[session_id]
826
+
827
+ result["transaction_ended"] = True
828
+ return result
829
+
830
+ else:
831
+ raise ValueError(f"Unknown transaction command: {query}")
832
+
833
+ async def _execute_on_connection(
834
+ self,
835
+ pool_node: Any,
836
+ connection_id: str,
837
+ query: str,
838
+ parameters: List[Any],
839
+ fingerprint: QueryFingerprint,
840
+ cached_statement: Optional[Dict[str, Any]],
841
+ ) -> Dict[str, Any]:
842
+ """Execute query on selected connection."""
843
+
844
+ # Build execution request
845
+ execution_request = {
846
+ "operation": "execute",
847
+ "connection_id": connection_id,
848
+ "query": query,
849
+ "parameters": parameters,
850
+ }
851
+
852
+ # Add caching hint if available
853
+ if cached_statement:
854
+ execution_request["use_prepared"] = True
855
+ execution_request["statement_name"] = cached_statement.get("statement_name")
856
+
857
+ # Execute query
858
+ result = await pool_node.process(execution_request)
859
+
860
+ # Cache prepared statement info if not cached
861
+ if not cached_statement and result.get("prepared_statement_name"):
862
+ self.statement_cache.put(
863
+ fingerprint.template,
864
+ connection_id,
865
+ {
866
+ "statement_name": result["prepared_statement_name"],
867
+ "tables": list(fingerprint.tables),
868
+ "created_at": datetime.now(),
869
+ },
870
+ )
871
+
872
+ return result
873
+
874
+ async def get_metrics(self) -> Dict[str, Any]:
875
+ """Get router metrics and statistics."""
876
+ cache_stats = self.statement_cache.get_stats()
877
+
878
+ return {
879
+ "router_metrics": self.metrics,
880
+ "cache_stats": cache_stats,
881
+ "active_transactions": len(self.active_transactions),
882
+ "routing_history": {
883
+ "total_decisions": len(self.routing_engine.routing_history),
884
+ "recent_decisions": list(self.routing_engine.routing_history)[-10:],
885
+ },
886
+ }
887
+
888
+ async def invalidate_cache(self, tables: Optional[List[str]] = None):
889
+ """Invalidate prepared statement cache."""
890
+ if tables:
891
+ self.statement_cache.invalidate(set(tables))
892
+ else:
893
+ self.statement_cache.invalidate()
894
+
895
+ return {"invalidated": True, "tables": tables}