kailash 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kailash/__init__.py +5 -5
  2. kailash/channels/__init__.py +2 -1
  3. kailash/channels/mcp_channel.py +23 -4
  4. kailash/cli/validate_imports.py +202 -0
  5. kailash/core/resilience/bulkhead.py +15 -5
  6. kailash/core/resilience/circuit_breaker.py +4 -1
  7. kailash/core/resilience/health_monitor.py +312 -84
  8. kailash/edge/migration/edge_migration_service.py +384 -0
  9. kailash/mcp_server/protocol.py +26 -0
  10. kailash/mcp_server/server.py +1081 -8
  11. kailash/mcp_server/subscriptions.py +1560 -0
  12. kailash/mcp_server/transports.py +305 -0
  13. kailash/middleware/gateway/event_store.py +1 -0
  14. kailash/nodes/base.py +77 -1
  15. kailash/nodes/code/python.py +44 -3
  16. kailash/nodes/data/async_sql.py +42 -20
  17. kailash/nodes/edge/edge_migration_node.py +16 -12
  18. kailash/nodes/governance.py +410 -0
  19. kailash/nodes/rag/registry.py +1 -1
  20. kailash/nodes/transaction/distributed_transaction_manager.py +48 -1
  21. kailash/nodes/transaction/saga_state_storage.py +2 -1
  22. kailash/nodes/validation.py +8 -8
  23. kailash/runtime/local.py +30 -0
  24. kailash/runtime/validation/__init__.py +7 -15
  25. kailash/runtime/validation/import_validator.py +446 -0
  26. kailash/runtime/validation/suggestion_engine.py +5 -5
  27. kailash/utils/data_paths.py +74 -0
  28. kailash/workflow/builder.py +183 -4
  29. kailash/workflow/mermaid_visualizer.py +3 -1
  30. kailash/workflow/templates.py +6 -6
  31. kailash/workflow/validation.py +134 -3
  32. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/METADATA +20 -17
  33. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/RECORD +37 -31
  34. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/WHEEL +0 -0
  35. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/entry_points.txt +0 -0
  36. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/licenses/LICENSE +0 -0
  37. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/top_level.txt +0 -0
@@ -1084,6 +1084,310 @@ class WebSocketTransport(BaseTransport):
1084
1084
  self.websocket = None
1085
1085
 
1086
1086
 
1087
+ class WebSocketServerTransport(BaseTransport):
1088
+ """WebSocket server transport for accepting MCP connections."""
1089
+
1090
+ def __init__(
1091
+ self,
1092
+ host: str = "0.0.0.0",
1093
+ port: int = 3001,
1094
+ message_handler: Optional[
1095
+ Callable[[Dict[str, Any], str], Dict[str, Any]]
1096
+ ] = None,
1097
+ ping_interval: float = 20.0,
1098
+ ping_timeout: float = 20.0,
1099
+ max_message_size: int = 10 * 1024 * 1024, # 10MB
1100
+ **kwargs,
1101
+ ):
1102
+ """Initialize WebSocket server transport.
1103
+
1104
+ Args:
1105
+ host: Host to bind to
1106
+ port: Port to listen on
1107
+ message_handler: Handler for incoming messages
1108
+ ping_interval: Ping interval in seconds
1109
+ ping_timeout: Ping timeout in seconds
1110
+ max_message_size: Maximum message size in bytes
1111
+ **kwargs: Base transport arguments
1112
+ """
1113
+ super().__init__("websocket_server", **kwargs)
1114
+
1115
+ self.host = host
1116
+ self.port = port
1117
+ self.message_handler = message_handler
1118
+ self.ping_interval = ping_interval
1119
+ self.ping_timeout = ping_timeout
1120
+ self.max_message_size = max_message_size
1121
+
1122
+ # Server state
1123
+ self.server: Optional[websockets.WebSocketServer] = None
1124
+ self._clients: Dict[str, Any] = {} # websockets.WebSocketServerProtocol
1125
+ self._client_sessions: Dict[str, Dict[str, Any]] = {}
1126
+ self._server_task: Optional[asyncio.Task] = None
1127
+
1128
+ async def connect(self) -> None:
1129
+ """Start the WebSocket server."""
1130
+ if self._connected:
1131
+ return
1132
+
1133
+ try:
1134
+ # Create handler that works with new websockets API
1135
+ async def connection_handler(websocket):
1136
+ # Get path from the websocket's request path
1137
+ path = websocket.path if hasattr(websocket, "path") else "/"
1138
+ await self.handle_client(websocket, path)
1139
+
1140
+ # Start WebSocket server
1141
+ self.server = await websockets.serve(
1142
+ connection_handler,
1143
+ self.host,
1144
+ self.port,
1145
+ ping_interval=self.ping_interval,
1146
+ ping_timeout=self.ping_timeout,
1147
+ max_size=self.max_message_size,
1148
+ )
1149
+
1150
+ self._connected = True
1151
+ self._update_metrics("connections_total")
1152
+
1153
+ logger.info(f"WebSocket server listening on {self.host}:{self.port}")
1154
+
1155
+ except Exception as e:
1156
+ self._update_metrics("connections_failed")
1157
+ raise TransportError(
1158
+ f"Failed to start WebSocket server: {e}",
1159
+ transport_type="websocket_server",
1160
+ )
1161
+
1162
+ async def disconnect(self) -> None:
1163
+ """Stop the WebSocket server."""
1164
+ if not self._connected:
1165
+ return
1166
+
1167
+ self._connected = False
1168
+
1169
+ # Close all client connections
1170
+ clients = list(self._clients.values())
1171
+ for client in clients:
1172
+ await client.close()
1173
+
1174
+ # Stop server
1175
+ if self.server:
1176
+ self.server.close()
1177
+ await self.server.wait_closed()
1178
+ self.server = None
1179
+
1180
+ # Clear client tracking
1181
+ self._clients.clear()
1182
+ self._client_sessions.clear()
1183
+
1184
+ logger.info("WebSocket server stopped")
1185
+
1186
+ async def send_message(
1187
+ self, message: Dict[str, Any], client_id: Optional[str] = None
1188
+ ) -> None:
1189
+ """Send message to specific client or broadcast to all.
1190
+
1191
+ Args:
1192
+ message: Message to send
1193
+ client_id: Target client ID (None for broadcast)
1194
+ """
1195
+ if not self._connected:
1196
+ raise TransportError(
1197
+ "Transport not connected", transport_type="websocket_server"
1198
+ )
1199
+
1200
+ message_data = json.dumps(message)
1201
+
1202
+ try:
1203
+ if client_id:
1204
+ # Send to specific client
1205
+ if client_id in self._clients:
1206
+ await self._clients[client_id].send(message_data)
1207
+ self._update_metrics("messages_sent")
1208
+ self._update_metrics("bytes_sent", len(message_data))
1209
+ else:
1210
+ raise TransportError(
1211
+ f"Client {client_id} not found",
1212
+ transport_type="websocket_server",
1213
+ )
1214
+ else:
1215
+ # Broadcast to all clients
1216
+ if self._clients:
1217
+ await asyncio.gather(
1218
+ *[
1219
+ client.send(message_data)
1220
+ for client in self._clients.values()
1221
+ ],
1222
+ return_exceptions=True,
1223
+ )
1224
+ self._update_metrics("messages_sent", len(self._clients))
1225
+ self._update_metrics(
1226
+ "bytes_sent", len(message_data) * len(self._clients)
1227
+ )
1228
+
1229
+ except Exception as e:
1230
+ self._update_metrics("errors_total")
1231
+ raise TransportError(
1232
+ f"Failed to send message: {e}", transport_type="websocket_server"
1233
+ )
1234
+
1235
+ async def receive_message(self) -> Dict[str, Any]:
1236
+ """Not implemented for server transport."""
1237
+ raise NotImplementedError(
1238
+ "Server transport doesn't support receive_message. "
1239
+ "Messages are handled via handle_client callback."
1240
+ )
1241
+
1242
+ async def handle_client(self, websocket, path: str):
1243
+ """Handle a client connection.
1244
+
1245
+ Args:
1246
+ websocket: WebSocket connection
1247
+ path: Request path
1248
+ """
1249
+ client_id = str(uuid.uuid4())
1250
+ self._clients[client_id] = websocket
1251
+ self._client_sessions[client_id] = {
1252
+ "connected_at": time.time(),
1253
+ "path": path,
1254
+ "remote_address": websocket.remote_address,
1255
+ }
1256
+
1257
+ logger.info(f"Client {client_id} connected from {websocket.remote_address}")
1258
+ self._update_metrics("connections_total")
1259
+
1260
+ try:
1261
+ async for message in websocket:
1262
+ try:
1263
+ # Parse message
1264
+ request = json.loads(message)
1265
+
1266
+ # Update metrics
1267
+ self._update_metrics("messages_received")
1268
+ self._update_metrics("bytes_received", len(message))
1269
+
1270
+ # Handle message
1271
+ if self.message_handler:
1272
+ response = await self._handle_message_safely(request, client_id)
1273
+ else:
1274
+ response = {
1275
+ "jsonrpc": "2.0",
1276
+ "error": {
1277
+ "code": -32601,
1278
+ "message": "No message handler configured",
1279
+ },
1280
+ "id": request.get("id"),
1281
+ }
1282
+
1283
+ # Send response
1284
+ await websocket.send(json.dumps(response))
1285
+ self._update_metrics("messages_sent")
1286
+ self._update_metrics("bytes_sent", len(json.dumps(response)))
1287
+
1288
+ except json.JSONDecodeError as e:
1289
+ logger.error(f"Invalid JSON from client {client_id}: {e}")
1290
+ self._update_metrics("errors_total")
1291
+
1292
+ error_response = {
1293
+ "jsonrpc": "2.0",
1294
+ "error": {
1295
+ "code": -32700,
1296
+ "message": "Parse error: Invalid JSON",
1297
+ },
1298
+ "id": None,
1299
+ }
1300
+ await websocket.send(json.dumps(error_response))
1301
+
1302
+ except Exception as e:
1303
+ logger.error(f"Error handling message from client {client_id}: {e}")
1304
+ self._update_metrics("errors_total")
1305
+
1306
+ except websockets.exceptions.ConnectionClosed:
1307
+ logger.info(f"Client {client_id} disconnected")
1308
+ except Exception as e:
1309
+ logger.error(f"Error in client handler for {client_id}: {e}")
1310
+ finally:
1311
+ # Clean up client
1312
+ del self._clients[client_id]
1313
+ del self._client_sessions[client_id]
1314
+
1315
+ async def _handle_message_safely(
1316
+ self, request: Dict[str, Any], client_id: str
1317
+ ) -> Dict[str, Any]:
1318
+ """Handle message with error handling.
1319
+
1320
+ Args:
1321
+ request: JSON-RPC request
1322
+ client_id: Client identifier
1323
+
1324
+ Returns:
1325
+ JSON-RPC response
1326
+ """
1327
+ try:
1328
+ if asyncio.iscoroutinefunction(self.message_handler):
1329
+ return await self.message_handler(request, client_id)
1330
+ else:
1331
+ return self.message_handler(request, client_id)
1332
+ except Exception as e:
1333
+ logger.error(f"Message handler error: {e}")
1334
+ return {
1335
+ "jsonrpc": "2.0",
1336
+ "error": {
1337
+ "code": -32603,
1338
+ "message": f"Internal error: {str(e)}",
1339
+ },
1340
+ "id": request.get("id"),
1341
+ }
1342
+
1343
+ def get_client_info(self, client_id: str) -> Optional[Dict[str, Any]]:
1344
+ """Get information about a connected client.
1345
+
1346
+ Args:
1347
+ client_id: Client identifier
1348
+
1349
+ Returns:
1350
+ Client information or None
1351
+ """
1352
+ if client_id not in self._client_sessions:
1353
+ return None
1354
+
1355
+ session = self._client_sessions[client_id]
1356
+ return {
1357
+ "client_id": client_id,
1358
+ "connected_at": session["connected_at"],
1359
+ "connection_duration": time.time() - session["connected_at"],
1360
+ "path": session["path"],
1361
+ "remote_address": session["remote_address"],
1362
+ }
1363
+
1364
+ def list_clients(self) -> List[Dict[str, Any]]:
1365
+ """List all connected clients.
1366
+
1367
+ Returns:
1368
+ List of client information
1369
+ """
1370
+ return [self.get_client_info(client_id) for client_id in self._client_sessions]
1371
+
1372
+ async def close_client(
1373
+ self, client_id: str, code: int = 1000, reason: str = ""
1374
+ ) -> bool:
1375
+ """Close a specific client connection.
1376
+
1377
+ Args:
1378
+ client_id: Client to disconnect
1379
+ code: WebSocket close code
1380
+ reason: Close reason
1381
+
1382
+ Returns:
1383
+ True if client was closed
1384
+ """
1385
+ if client_id in self._clients:
1386
+ await self._clients[client_id].close(code, reason)
1387
+ return True
1388
+ return False
1389
+
1390
+
1087
1391
  class TransportManager:
1088
1392
  """Manager for MCP transport instances."""
1089
1393
 
@@ -1095,6 +1399,7 @@ class TransportManager:
1095
1399
  "sse": SSETransport,
1096
1400
  "streamable_http": StreamableHTTPTransport,
1097
1401
  "websocket": WebSocketTransport,
1402
+ "websocket_server": WebSocketServerTransport,
1098
1403
  }
1099
1404
 
1100
1405
  def register_transport_factory(self, transport_type: str, factory: Callable):
@@ -125,6 +125,7 @@ class EventStore:
125
125
  self._flush_task = asyncio.create_task(self._flush_loop())
126
126
  except RuntimeError:
127
127
  # If no event loop is running, defer task creation
128
+ # Don't create the coroutine here as it will never be awaited
128
129
  self._flush_task = None
129
130
 
130
131
  async def _ensure_flush_task(self):
kailash/nodes/base.py CHANGED
@@ -255,6 +255,54 @@ class Node(ABC):
255
255
  f"Failed to initialize node '{self.id}': {e}"
256
256
  ) from e
257
257
 
258
+ def get_workflow_context(self, key: str, default: Any = None) -> Any:
259
+ """Get a value from the workflow context.
260
+
261
+ This method allows nodes to retrieve shared state from the workflow
262
+ execution context. The workflow context is managed by the runtime
263
+ and provides a way for nodes to share data within a single workflow
264
+ execution.
265
+
266
+ Args:
267
+ key: The key to retrieve from the workflow context
268
+ default: Default value to return if key is not found
269
+
270
+ Returns:
271
+ The value from the workflow context, or default if not found
272
+
273
+ Example:
274
+ >>> # In a transaction node
275
+ >>> connection = self.get_workflow_context('transaction_connection')
276
+ >>> if connection:
277
+ >>> # Use the shared connection
278
+ >>> result = await connection.execute(query)
279
+ """
280
+ if not hasattr(self, "_workflow_context"):
281
+ self._workflow_context = {}
282
+ return self._workflow_context.get(key, default)
283
+
284
+ def set_workflow_context(self, key: str, value: Any) -> None:
285
+ """Set a value in the workflow context.
286
+
287
+ This method allows nodes to store shared state in the workflow
288
+ execution context. Other nodes in the same workflow execution
289
+ can retrieve this data using get_workflow_context().
290
+
291
+ Args:
292
+ key: The key to store the value under
293
+ value: The value to store in the workflow context
294
+
295
+ Example:
296
+ >>> # In a transaction scope node
297
+ >>> connection = await self.get_connection()
298
+ >>> transaction = await connection.begin()
299
+ >>> self.set_workflow_context('transaction_connection', connection)
300
+ >>> self.set_workflow_context('active_transaction', transaction)
301
+ """
302
+ if not hasattr(self, "_workflow_context"):
303
+ self._workflow_context = {}
304
+ self._workflow_context[key] = value
305
+
258
306
  @abstractmethod
259
307
  def get_parameters(self) -> dict[str, NodeParameter]:
260
308
  """Define the parameters this node accepts.
@@ -467,9 +515,23 @@ class Node(ABC):
467
515
  # Skip type checking for Any type
468
516
  if param_def.type is Any:
469
517
  continue
518
+ # Skip validation for template expressions like ${variable_name}
519
+ if isinstance(value, str) and self._is_template_expression(value):
520
+ continue
470
521
  if not isinstance(value, param_def.type):
471
522
  try:
472
- self.config[param_name] = param_def.type(value)
523
+ # Special handling for datetime conversion from ISO strings
524
+ if param_def.type.__name__ == "datetime" and isinstance(
525
+ value, str
526
+ ):
527
+ from datetime import datetime
528
+
529
+ # Try to parse ISO format string
530
+ self.config[param_name] = datetime.fromisoformat(
531
+ value.replace("Z", "+00:00")
532
+ )
533
+ else:
534
+ self.config[param_name] = param_def.type(value)
473
535
  except (ValueError, TypeError) as e:
474
536
  raise NodeConfigurationError(
475
537
  f"Configuration parameter '{param_name}' must be of type "
@@ -477,6 +539,20 @@ class Node(ABC):
477
539
  f"Conversion failed: {e}"
478
540
  ) from e
479
541
 
542
+ def _is_template_expression(self, value: str) -> bool:
543
+ """Check if a string value is a template expression like ${variable_name}.
544
+
545
+ Args:
546
+ value: String value to check
547
+
548
+ Returns:
549
+ True if the value is a template expression, False otherwise
550
+ """
551
+ import re
552
+
553
+ # Match template expressions like ${variable_name} or ${node.output}
554
+ return bool(re.match(r"^\$\{[^}]+\}$", value))
555
+
480
556
  def _get_cached_parameters(self) -> dict[str, NodeParameter]:
481
557
  """Get cached parameter definitions.
482
558
 
@@ -389,7 +389,9 @@ class CodeExecutor:
389
389
  f"Error position: {' ' * (e.offset - 1) if e.offset else ''}^"
390
390
  )
391
391
 
392
- def execute_code(self, code: str, inputs: dict[str, Any]) -> dict[str, Any]:
392
+ def execute_code(
393
+ self, code: str, inputs: dict[str, Any], node_instance=None
394
+ ) -> dict[str, Any]:
393
395
  """Execute Python code with given inputs.
394
396
 
395
397
  Args:
@@ -476,6 +478,43 @@ class CodeExecutor:
476
478
  except ImportError:
477
479
  logger.warning(f"Module {module_name} not available")
478
480
 
481
+ # Add global utility functions to namespace
482
+ try:
483
+ from kailash.utils.data_paths import (
484
+ get_data_path,
485
+ get_input_data_path,
486
+ get_output_data_path,
487
+ )
488
+
489
+ namespace["get_input_data_path"] = get_input_data_path
490
+ namespace["get_output_data_path"] = get_output_data_path
491
+ namespace["get_data_path"] = get_data_path
492
+ except ImportError:
493
+ logger.warning(
494
+ "Could not import data path utilities - functions will not be available in PythonCodeNode execution"
495
+ )
496
+
497
+ # Add workflow context functions if node instance is available
498
+ if node_instance and hasattr(node_instance, "get_workflow_context"):
499
+ # Bind the actual node methods
500
+ namespace["get_workflow_context"] = node_instance.get_workflow_context
501
+ namespace["set_workflow_context"] = node_instance.set_workflow_context
502
+ else:
503
+ # Add placeholder functions that warn about unavailability
504
+ def _get_workflow_context(key: str, default=None):
505
+ logger.warning(
506
+ "get_workflow_context() is not available in PythonCodeNode execution context. Node instance not provided."
507
+ )
508
+ return default
509
+
510
+ def _set_workflow_context(key: str, value):
511
+ logger.warning(
512
+ "set_workflow_context() is not available in PythonCodeNode execution context. Node instance not provided."
513
+ )
514
+
515
+ namespace["get_workflow_context"] = _get_workflow_context
516
+ namespace["set_workflow_context"] = _set_workflow_context
517
+
479
518
  # Add sanitized inputs
480
519
  namespace.update(sanitized_inputs)
481
520
 
@@ -1222,7 +1261,9 @@ class PythonCodeNode(Node):
1222
1261
  try:
1223
1262
  if self.code:
1224
1263
  # Execute code string
1225
- outputs = self.executor.execute_code(self.code, kwargs)
1264
+ outputs = self.executor.execute_code(
1265
+ self.code, kwargs, node_instance=self
1266
+ )
1226
1267
  # Return 'result' variable if it exists, otherwise all outputs
1227
1268
  if "result" in outputs:
1228
1269
  return {"result": outputs["result"]}
@@ -1454,7 +1495,7 @@ class PythonCodeNode(Node):
1454
1495
  """
1455
1496
  # Execute directly based on execution type
1456
1497
  if self.code:
1457
- outputs = self.executor.execute_code(self.code, inputs)
1498
+ outputs = self.executor.execute_code(self.code, inputs, node_instance=self)
1458
1499
  return outputs.get("result", outputs)
1459
1500
  elif self.function:
1460
1501
  wrapper = FunctionWrapper(self.function, self.executor)
@@ -489,6 +489,7 @@ class PostgreSQLAdapter(DatabaseAdapter):
489
489
  or "DELETE" in query_upper
490
490
  or "INSERT" in query_upper
491
491
  )
492
+ and "SELECT" not in query_upper
492
493
  and "RETURNING" not in query_upper
493
494
  and fetch_mode == FetchMode.ALL
494
495
  ):
@@ -527,6 +528,7 @@ class PostgreSQLAdapter(DatabaseAdapter):
527
528
  or "DELETE" in query_upper
528
529
  or "INSERT" in query_upper
529
530
  )
531
+ and "SELECT" not in query_upper
530
532
  and "RETURNING" not in query_upper
531
533
  and fetch_mode == FetchMode.ALL
532
534
  ):
@@ -2617,33 +2619,53 @@ class AsyncSQLDatabaseNode(AsyncNode):
2617
2619
 
2618
2620
  async def cleanup(self):
2619
2621
  """Clean up database connections."""
2622
+ try:
2623
+ # Check if we have a running event loop
2624
+ loop = asyncio.get_running_loop()
2625
+ if loop.is_closed():
2626
+ # Event loop is closing, skip cleanup
2627
+ return
2628
+ except RuntimeError:
2629
+ # No event loop, skip cleanup
2630
+ return
2631
+
2620
2632
  # Rollback any active transaction
2621
2633
  if self._active_transaction and self._adapter:
2622
2634
  try:
2623
- await self._adapter.rollback_transaction(self._active_transaction)
2624
- except Exception:
2635
+ await asyncio.wait_for(
2636
+ self._adapter.rollback_transaction(self._active_transaction),
2637
+ timeout=1.0,
2638
+ )
2639
+ except (Exception, asyncio.TimeoutError):
2625
2640
  pass # Best effort cleanup
2626
2641
  self._active_transaction = None
2627
2642
 
2628
2643
  if self._adapter and self._connected:
2629
- if self._share_pool and self._pool_key:
2630
- # Decrement reference count for shared pool
2631
- async with self._get_pool_lock():
2632
- if self._pool_key in self._shared_pools:
2633
- adapter, ref_count = self._shared_pools[self._pool_key]
2634
- if ref_count > 1:
2635
- # Others still using the pool
2636
- self._shared_pools[self._pool_key] = (
2637
- adapter,
2638
- ref_count - 1,
2639
- )
2640
- else:
2641
- # Last reference, close the pool
2642
- del self._shared_pools[self._pool_key]
2643
- await adapter.disconnect()
2644
- else:
2645
- # Dedicated pool, close directly
2646
- await self._adapter.disconnect()
2644
+ try:
2645
+ if self._share_pool and self._pool_key:
2646
+ # Decrement reference count for shared pool with timeout
2647
+ async with await asyncio.wait_for(
2648
+ self._get_pool_lock(), timeout=1.0
2649
+ ):
2650
+ if self._pool_key in self._shared_pools:
2651
+ adapter, ref_count = self._shared_pools[self._pool_key]
2652
+ if ref_count > 1:
2653
+ # Others still using the pool
2654
+ self._shared_pools[self._pool_key] = (
2655
+ adapter,
2656
+ ref_count - 1,
2657
+ )
2658
+ else:
2659
+ # Last reference, close the pool
2660
+ del self._shared_pools[self._pool_key]
2661
+ await asyncio.wait_for(
2662
+ adapter.disconnect(), timeout=1.0
2663
+ )
2664
+ else:
2665
+ # Dedicated pool, close directly
2666
+ await asyncio.wait_for(self._adapter.disconnect(), timeout=1.0)
2667
+ except (Exception, asyncio.TimeoutError):
2668
+ pass # Best effort cleanup
2647
2669
 
2648
2670
  self._connected = False
2649
2671
  self._adapter = None
@@ -8,6 +8,7 @@ import asyncio
8
8
  from datetime import datetime
9
9
  from typing import Any, Dict, List, Optional
10
10
 
11
+ from kailash.edge.migration.edge_migration_service import EdgeMigrationService
11
12
  from kailash.edge.migration.edge_migrator import (
12
13
  EdgeMigrator,
13
14
  MigrationPhase,
@@ -59,18 +60,21 @@ class EdgeMigrationNode(AsyncNode):
59
60
  """Initialize edge migration node."""
60
61
  super().__init__(**kwargs)
61
62
 
62
- # Extract configuration
63
- checkpoint_interval = kwargs.get("checkpoint_interval", 60)
64
- sync_batch_size = kwargs.get("sync_batch_size", 1000)
65
- bandwidth_limit_mbps = kwargs.get("bandwidth_limit_mbps")
66
- enable_compression = kwargs.get("enable_compression", True)
67
-
68
- # Initialize migrator
69
- self.migrator = EdgeMigrator(
70
- checkpoint_interval=checkpoint_interval,
71
- sync_batch_size=sync_batch_size,
72
- bandwidth_limit_mbps=bandwidth_limit_mbps,
73
- enable_compression=enable_compression,
63
+ # Extract node-specific configuration
64
+ self.node_config = {
65
+ "checkpoint_interval": kwargs.get("checkpoint_interval", 60),
66
+ "sync_batch_size": kwargs.get("sync_batch_size", 1000),
67
+ "bandwidth_limit_mbps": kwargs.get("bandwidth_limit_mbps"),
68
+ "enable_compression": kwargs.get("enable_compression", True),
69
+ }
70
+
71
+ # Get reference to shared migration service
72
+ self.migration_service = EdgeMigrationService(self.node_config)
73
+
74
+ # Get migrator instance from shared service with node-specific config
75
+ self.node_id = f"edge_migration_node_{id(self)}"
76
+ self.migrator = self.migration_service.get_migrator_for_node(
77
+ self.node_id, self.node_config
74
78
  )
75
79
 
76
80
  self._migrator_started = False