kailash 0.6.5__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. kailash/__init__.py +35 -4
  2. kailash/adapters/__init__.py +5 -0
  3. kailash/adapters/mcp_platform_adapter.py +273 -0
  4. kailash/channels/__init__.py +21 -0
  5. kailash/channels/api_channel.py +409 -0
  6. kailash/channels/base.py +271 -0
  7. kailash/channels/cli_channel.py +661 -0
  8. kailash/channels/event_router.py +496 -0
  9. kailash/channels/mcp_channel.py +648 -0
  10. kailash/channels/session.py +423 -0
  11. kailash/mcp_server/discovery.py +1 -1
  12. kailash/middleware/core/agent_ui.py +5 -0
  13. kailash/middleware/mcp/enhanced_server.py +22 -16
  14. kailash/nexus/__init__.py +21 -0
  15. kailash/nexus/factory.py +413 -0
  16. kailash/nexus/gateway.py +545 -0
  17. kailash/nodes/__init__.py +2 -0
  18. kailash/nodes/ai/iterative_llm_agent.py +988 -17
  19. kailash/nodes/ai/llm_agent.py +29 -9
  20. kailash/nodes/api/__init__.py +2 -2
  21. kailash/nodes/api/monitoring.py +1 -1
  22. kailash/nodes/base_async.py +54 -14
  23. kailash/nodes/code/async_python.py +1 -1
  24. kailash/nodes/data/bulk_operations.py +939 -0
  25. kailash/nodes/data/query_builder.py +373 -0
  26. kailash/nodes/data/query_cache.py +512 -0
  27. kailash/nodes/monitoring/__init__.py +10 -0
  28. kailash/nodes/monitoring/deadlock_detector.py +964 -0
  29. kailash/nodes/monitoring/performance_anomaly.py +1078 -0
  30. kailash/nodes/monitoring/race_condition_detector.py +1151 -0
  31. kailash/nodes/monitoring/transaction_metrics.py +790 -0
  32. kailash/nodes/monitoring/transaction_monitor.py +931 -0
  33. kailash/nodes/system/__init__.py +17 -0
  34. kailash/nodes/system/command_parser.py +820 -0
  35. kailash/nodes/transaction/__init__.py +48 -0
  36. kailash/nodes/transaction/distributed_transaction_manager.py +983 -0
  37. kailash/nodes/transaction/saga_coordinator.py +652 -0
  38. kailash/nodes/transaction/saga_state_storage.py +411 -0
  39. kailash/nodes/transaction/saga_step.py +467 -0
  40. kailash/nodes/transaction/transaction_context.py +756 -0
  41. kailash/nodes/transaction/two_phase_commit.py +978 -0
  42. kailash/nodes/transform/processors.py +17 -1
  43. kailash/nodes/validation/__init__.py +21 -0
  44. kailash/nodes/validation/test_executor.py +532 -0
  45. kailash/nodes/validation/validation_nodes.py +447 -0
  46. kailash/resources/factory.py +1 -1
  47. kailash/runtime/async_local.py +84 -21
  48. kailash/runtime/local.py +21 -2
  49. kailash/runtime/parameter_injector.py +187 -31
  50. kailash/security.py +16 -1
  51. kailash/servers/__init__.py +32 -0
  52. kailash/servers/durable_workflow_server.py +430 -0
  53. kailash/servers/enterprise_workflow_server.py +466 -0
  54. kailash/servers/gateway.py +183 -0
  55. kailash/servers/workflow_server.py +290 -0
  56. kailash/utils/data_validation.py +192 -0
  57. kailash/workflow/builder.py +291 -12
  58. kailash/workflow/validation.py +144 -8
  59. {kailash-0.6.5.dist-info → kailash-0.7.0.dist-info}/METADATA +1 -1
  60. {kailash-0.6.5.dist-info → kailash-0.7.0.dist-info}/RECORD +64 -26
  61. {kailash-0.6.5.dist-info → kailash-0.7.0.dist-info}/WHEEL +0 -0
  62. {kailash-0.6.5.dist-info → kailash-0.7.0.dist-info}/entry_points.txt +0 -0
  63. {kailash-0.6.5.dist-info → kailash-0.7.0.dist-info}/licenses/LICENSE +0 -0
  64. {kailash-0.6.5.dist-info → kailash-0.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,978 @@
1
+ """Two-Phase Commit (2PC) Transaction Coordinator Node.
2
+
3
+ This module implements the Two-Phase Commit protocol for distributed transactions,
4
+ ensuring atomicity across multiple resources. Unlike sagas which use compensation,
5
+ 2PC uses a prepare/commit protocol to achieve ACID properties.
6
+
7
+ The Two-Phase Commit protocol consists of:
8
+ 1. Phase 1 (Prepare): All participants prepare to commit and vote
9
+ 2. Phase 2 (Commit/Abort): Coordinator decides based on votes
10
+
11
+ Examples:
12
+ Basic 2PC transaction:
13
+
14
+ >>> coordinator = TwoPhaseCommitCoordinatorNode(
15
+ ... transaction_name="order_processing",
16
+ ... participants=["database", "payment", "inventory"]
17
+ ... )
18
+ >>> result = coordinator.execute(
19
+ ... operation="begin_transaction",
20
+ ... context={"order_id": "order_123", "amount": 100.00}
21
+ ... )
22
+
23
+ Adding participants:
24
+
25
+ >>> coordinator.execute(
26
+ ... operation="add_participant",
27
+ ... participant_id="audit_service",
28
+ ... endpoint="http://audit:8080/prepare"
29
+ ... )
30
+
31
+ Executing transaction:
32
+
33
+ >>> result = coordinator.execute(operation="execute_transaction")
34
+ # Returns success if all participants commit, failure if any abort
35
+ """
36
+
37
+ import asyncio
38
+ import json
39
+ import logging
40
+ import time
41
+ import uuid
42
+ from datetime import UTC, datetime
43
+ from enum import Enum
44
+ from typing import Any, Dict, List, Optional
45
+
46
+ from kailash.nodes.base import NodeMetadata, NodeParameter, register_node
47
+ from kailash.nodes.base_async import AsyncNode
48
+ from kailash.sdk_exceptions import NodeConfigurationError, NodeExecutionError
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ class TransactionState(Enum):
54
+ """Two-phase commit transaction states."""
55
+
56
+ INIT = "init"
57
+ PREPARING = "preparing"
58
+ PREPARED = "prepared"
59
+ COMMITTING = "committing"
60
+ COMMITTED = "committed"
61
+ ABORTING = "aborting"
62
+ ABORTED = "aborted"
63
+ FAILED = "failed"
64
+ TIMEOUT = "timeout"
65
+
66
+
67
+ class ParticipantVote(Enum):
68
+ """Participant votes in prepare phase."""
69
+
70
+ PREPARED = "prepared"
71
+ ABORT = "abort"
72
+ TIMEOUT = "timeout"
73
+
74
+
75
+ class TwoPhaseCommitParticipant:
76
+ """Represents a participant in the 2PC protocol."""
77
+
78
+ def __init__(
79
+ self,
80
+ participant_id: str,
81
+ endpoint: str,
82
+ timeout: int = 30,
83
+ retry_count: int = 3,
84
+ ):
85
+ self.participant_id = participant_id
86
+ self.endpoint = endpoint
87
+ self.timeout = timeout
88
+ self.retry_count = retry_count
89
+ self.vote: Optional[ParticipantVote] = None
90
+ self.last_contact: Optional[datetime] = None
91
+ self.prepare_time: Optional[datetime] = None
92
+ self.commit_time: Optional[datetime] = None
93
+
94
+ def to_dict(self) -> Dict[str, Any]:
95
+ """Convert participant to dictionary for serialization."""
96
+ return {
97
+ "participant_id": self.participant_id,
98
+ "endpoint": self.endpoint,
99
+ "timeout": self.timeout,
100
+ "retry_count": self.retry_count,
101
+ "vote": self.vote.value if self.vote else None,
102
+ "last_contact": (
103
+ self.last_contact.isoformat() if self.last_contact else None
104
+ ),
105
+ "prepare_time": (
106
+ self.prepare_time.isoformat() if self.prepare_time else None
107
+ ),
108
+ "commit_time": self.commit_time.isoformat() if self.commit_time else None,
109
+ }
110
+
111
+ @classmethod
112
+ def from_dict(cls, data: Dict[str, Any]) -> "TwoPhaseCommitParticipant":
113
+ """Create participant from dictionary."""
114
+ participant = cls(
115
+ participant_id=data["participant_id"],
116
+ endpoint=data["endpoint"],
117
+ timeout=data.get("timeout", 30),
118
+ retry_count=data.get("retry_count", 3),
119
+ )
120
+
121
+ if data.get("vote"):
122
+ participant.vote = ParticipantVote(data["vote"])
123
+ if data.get("last_contact"):
124
+ participant.last_contact = datetime.fromisoformat(data["last_contact"])
125
+ if data.get("prepare_time"):
126
+ participant.prepare_time = datetime.fromisoformat(data["prepare_time"])
127
+ if data.get("commit_time"):
128
+ participant.commit_time = datetime.fromisoformat(data["commit_time"])
129
+
130
+ return participant
131
+
132
+
133
+ @register_node("TwoPhaseCommitCoordinatorNode")
134
+ class TwoPhaseCommitCoordinatorNode(AsyncNode):
135
+ """Node implementing Two-Phase Commit coordinator functionality.
136
+
137
+ This node orchestrates distributed transactions using the 2PC protocol,
138
+ ensuring atomicity across multiple participants. Unlike saga patterns,
139
+ 2PC provides stronger consistency guarantees but requires all participants
140
+ to be available during the transaction.
141
+
142
+ Key Features:
143
+ - Atomic distributed transactions
144
+ - Automatic timeout handling
145
+ - Participant failure detection
146
+ - Transaction state persistence
147
+ - Recovery after coordinator failure
148
+ - Configurable retry policies
149
+
150
+ Operations:
151
+ - begin_transaction: Start new 2PC transaction
152
+ - add_participant: Add participant to transaction
153
+ - execute_transaction: Execute prepare/commit phases
154
+ - abort_transaction: Abort transaction
155
+ - get_status: Get transaction status
156
+ - recover_transaction: Recover from coordinator failure
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ transaction_name: str = None,
162
+ transaction_id: str = None,
163
+ participants: List[str] = None,
164
+ timeout: int = 300,
165
+ prepare_timeout: int = 30,
166
+ commit_timeout: int = 30,
167
+ max_retries: int = 3,
168
+ state_storage: str = "memory",
169
+ storage_config: Dict[str, Any] = None,
170
+ **kwargs,
171
+ ):
172
+ """Initialize Two-Phase Commit coordinator.
173
+
174
+ Args:
175
+ transaction_name: Human-readable transaction name
176
+ transaction_id: Unique transaction identifier
177
+ participants: List of participant identifiers
178
+ timeout: Overall transaction timeout in seconds
179
+ prepare_timeout: Timeout for prepare phase in seconds
180
+ commit_timeout: Timeout for commit phase in seconds
181
+ max_retries: Maximum retry attempts per participant
182
+ state_storage: Storage backend ("memory", "redis", "database")
183
+ storage_config: Configuration for state storage
184
+ **kwargs: Additional node configuration
185
+ """
186
+ # Set node metadata
187
+ metadata = NodeMetadata(
188
+ name=kwargs.get("name", "two_phase_commit_coordinator"),
189
+ description="Coordinates distributed transactions using Two-Phase Commit protocol",
190
+ version="1.0.0",
191
+ tags={"transaction", "2pc", "distributed", "coordinator"},
192
+ )
193
+
194
+ # Initialize AsyncNode
195
+ super().__init__(metadata=metadata, **kwargs)
196
+
197
+ # Transaction configuration
198
+ self.transaction_name = transaction_name or f"2pc_{int(time.time())}"
199
+ self.transaction_id = transaction_id or str(uuid.uuid4())
200
+ self.timeout = timeout
201
+ self.prepare_timeout = prepare_timeout
202
+ self.commit_timeout = commit_timeout
203
+ self.max_retries = max_retries
204
+
205
+ # Transaction state
206
+ self.state = TransactionState.INIT
207
+ self.participants: Dict[str, TwoPhaseCommitParticipant] = {}
208
+ self.context: Dict[str, Any] = {}
209
+ self.started_at: Optional[datetime] = None
210
+ self.prepared_at: Optional[datetime] = None
211
+ self.completed_at: Optional[datetime] = None
212
+ self.error_message: Optional[str] = None
213
+
214
+ # Initialize participants if provided
215
+ if participants:
216
+ for p_id in participants:
217
+ self.participants[p_id] = TwoPhaseCommitParticipant(
218
+ participant_id=p_id,
219
+ endpoint=f"http://{p_id}/2pc", # Default endpoint
220
+ timeout=prepare_timeout,
221
+ )
222
+
223
+ # State persistence
224
+ self.state_storage = state_storage
225
+ self.storage_config = storage_config or {}
226
+ self._storage = None
227
+
228
+ logger.info(f"Initialized 2PC coordinator: {self.transaction_id}")
229
+
230
+ def get_parameters(self) -> Dict[str, NodeParameter]:
231
+ """Get node parameters for validation."""
232
+ return {
233
+ "operation": NodeParameter(
234
+ name="operation",
235
+ type=str,
236
+ required=False,
237
+ description="2PC operation to execute",
238
+ default="begin_transaction",
239
+ ),
240
+ "participant_id": NodeParameter(
241
+ name="participant_id",
242
+ type=str,
243
+ required=False,
244
+ description="Participant ID for add_participant operation",
245
+ ),
246
+ "endpoint": NodeParameter(
247
+ name="endpoint",
248
+ type=str,
249
+ required=False,
250
+ description="Participant endpoint for add_participant operation",
251
+ ),
252
+ "context": NodeParameter(
253
+ name="context",
254
+ type=dict,
255
+ required=False,
256
+ description="Transaction context data",
257
+ ),
258
+ "transaction_id": NodeParameter(
259
+ name="transaction_id",
260
+ type=str,
261
+ required=False,
262
+ description="Transaction ID for recovery operations",
263
+ ),
264
+ }
265
+
266
+ def get_outputs(self) -> Dict[str, NodeParameter]:
267
+ """Get node outputs."""
268
+ return {
269
+ "status": NodeParameter(
270
+ name="status",
271
+ type=str,
272
+ required=True,
273
+ description="Operation status (success, failed, aborted)",
274
+ ),
275
+ "transaction_id": NodeParameter(
276
+ name="transaction_id",
277
+ type=str,
278
+ required=True,
279
+ description="Transaction identifier",
280
+ ),
281
+ "state": NodeParameter(
282
+ name="state",
283
+ type=str,
284
+ required=True,
285
+ description="Current transaction state",
286
+ ),
287
+ "participants": NodeParameter(
288
+ name="participants",
289
+ type=list,
290
+ required=False,
291
+ description="List of transaction participants",
292
+ ),
293
+ "result": NodeParameter(
294
+ name="result",
295
+ type=dict,
296
+ required=False,
297
+ description="Transaction result data",
298
+ ),
299
+ "error": NodeParameter(
300
+ name="error",
301
+ type=str,
302
+ required=False,
303
+ description="Error message if transaction failed",
304
+ ),
305
+ }
306
+
307
+ async def async_run(self, **kwargs) -> Dict[str, Any]:
308
+ """Execute 2PC operation asynchronously."""
309
+ operation = kwargs.get("operation", "begin_transaction")
310
+
311
+ try:
312
+ if operation == "begin_transaction":
313
+ return await self._begin_transaction(**kwargs)
314
+ elif operation == "add_participant":
315
+ return await self._add_participant(**kwargs)
316
+ elif operation == "execute_transaction":
317
+ return await self._execute_transaction()
318
+ elif operation == "abort_transaction":
319
+ return await self._abort_transaction()
320
+ elif operation == "get_status":
321
+ return await self._get_status()
322
+ elif operation == "recover_transaction":
323
+ return await self._recover_transaction(**kwargs)
324
+ else:
325
+ raise NodeExecutionError(f"Unknown 2PC operation: {operation}")
326
+
327
+ except Exception as e:
328
+ logger.error(f"2PC coordinator error: {e}")
329
+ self.error_message = str(e)
330
+ await self._persist_state()
331
+ return {
332
+ "status": "error",
333
+ "transaction_id": self.transaction_id,
334
+ "state": self.state.value,
335
+ "error": str(e),
336
+ }
337
+
338
+ async def _begin_transaction(self, **kwargs) -> Dict[str, Any]:
339
+ """Begin a new 2PC transaction."""
340
+ if self.state != TransactionState.INIT:
341
+ raise NodeExecutionError(
342
+ f"Transaction already in state: {self.state.value}"
343
+ )
344
+
345
+ # Update context
346
+ context = kwargs.get("context", {})
347
+ self.context.update(context)
348
+
349
+ # Set transaction start time
350
+ self.started_at = datetime.now(UTC)
351
+
352
+ logger.info(f"Beginning 2PC transaction: {self.transaction_id}")
353
+
354
+ # Persist initial state
355
+ await self._persist_state()
356
+
357
+ return {
358
+ "status": "success",
359
+ "transaction_id": self.transaction_id,
360
+ "state": self.state.value,
361
+ "participants": list(self.participants.keys()),
362
+ "started_at": self.started_at.isoformat(),
363
+ }
364
+
365
+ async def _add_participant(self, **kwargs) -> Dict[str, Any]:
366
+ """Add a participant to the transaction."""
367
+ participant_id = kwargs.get("participant_id")
368
+ endpoint = kwargs.get("endpoint")
369
+
370
+ if not participant_id:
371
+ raise NodeExecutionError(
372
+ "participant_id is required for add_participant operation"
373
+ )
374
+
375
+ if not endpoint:
376
+ # Generate default endpoint
377
+ endpoint = f"http://{participant_id}/2pc"
378
+
379
+ if participant_id in self.participants:
380
+ logger.warning(f"Participant {participant_id} already exists")
381
+ return {
382
+ "status": "exists",
383
+ "transaction_id": self.transaction_id,
384
+ "participant_id": participant_id,
385
+ }
386
+
387
+ # Create participant
388
+ participant = TwoPhaseCommitParticipant(
389
+ participant_id=participant_id,
390
+ endpoint=endpoint,
391
+ timeout=self.prepare_timeout,
392
+ retry_count=self.max_retries,
393
+ )
394
+
395
+ self.participants[participant_id] = participant
396
+
397
+ logger.info(
398
+ f"Added participant {participant_id} to transaction {self.transaction_id}"
399
+ )
400
+
401
+ # Persist state
402
+ await self._persist_state()
403
+
404
+ return {
405
+ "status": "success",
406
+ "transaction_id": self.transaction_id,
407
+ "participant_id": participant_id,
408
+ "total_participants": len(self.participants),
409
+ }
410
+
411
+ async def _execute_transaction(self) -> Dict[str, Any]:
412
+ """Execute the two-phase commit protocol."""
413
+ if not self.participants:
414
+ raise NodeExecutionError("No participants defined for transaction")
415
+
416
+ try:
417
+ # Phase 1: Prepare
418
+ logger.info(f"Starting prepare phase for transaction {self.transaction_id}")
419
+ self.state = TransactionState.PREPARING
420
+ await self._persist_state()
421
+
422
+ prepare_success = await self._execute_prepare_phase()
423
+
424
+ if not prepare_success:
425
+ # Some participants voted to abort
426
+ logger.warning(
427
+ f"Prepare phase failed for transaction {self.transaction_id}"
428
+ )
429
+ await self._abort_all_participants()
430
+ self.state = TransactionState.ABORTED
431
+ await self._persist_state()
432
+
433
+ return {
434
+ "status": "aborted",
435
+ "transaction_id": self.transaction_id,
436
+ "state": self.state.value,
437
+ "reason": "One or more participants voted to abort",
438
+ }
439
+
440
+ # All participants prepared successfully
441
+ self.state = TransactionState.PREPARED
442
+ self.prepared_at = datetime.now(UTC)
443
+ await self._persist_state()
444
+
445
+ # Phase 2: Commit
446
+ logger.info(f"Starting commit phase for transaction {self.transaction_id}")
447
+ self.state = TransactionState.COMMITTING
448
+ await self._persist_state()
449
+
450
+ commit_success = await self._execute_commit_phase()
451
+
452
+ if commit_success:
453
+ self.state = TransactionState.COMMITTED
454
+ self.completed_at = datetime.now(UTC)
455
+
456
+ logger.info(f"Transaction {self.transaction_id} committed successfully")
457
+
458
+ await self._persist_state()
459
+
460
+ return {
461
+ "status": "success",
462
+ "transaction_id": self.transaction_id,
463
+ "state": self.state.value,
464
+ "participants_committed": len(
465
+ [p for p in self.participants.values() if p.commit_time]
466
+ ),
467
+ "completed_at": self.completed_at.isoformat(),
468
+ }
469
+ else:
470
+ # Commit phase failed - this is a serious problem in 2PC
471
+ self.state = TransactionState.FAILED
472
+ self.error_message = (
473
+ "Commit phase failed - system in inconsistent state"
474
+ )
475
+
476
+ logger.error(
477
+ f"CRITICAL: Commit phase failed for transaction {self.transaction_id}"
478
+ )
479
+
480
+ await self._persist_state()
481
+
482
+ return {
483
+ "status": "failed",
484
+ "transaction_id": self.transaction_id,
485
+ "state": self.state.value,
486
+ "error": self.error_message,
487
+ }
488
+
489
+ except Exception as e:
490
+ logger.error(f"Transaction execution failed: {e}")
491
+ self.state = TransactionState.FAILED
492
+ self.error_message = str(e)
493
+ await self._persist_state()
494
+
495
+ # Try to abort if we're still in prepare phase
496
+ if self.state in [TransactionState.PREPARING, TransactionState.PREPARED]:
497
+ await self._abort_all_participants()
498
+
499
+ return {
500
+ "status": "failed",
501
+ "transaction_id": self.transaction_id,
502
+ "state": self.state.value,
503
+ "error": str(e),
504
+ }
505
+
506
+ async def _execute_prepare_phase(self) -> bool:
507
+ """Execute prepare phase of 2PC protocol."""
508
+ # Send prepare requests to all participants
509
+ prepare_tasks = []
510
+
511
+ for participant in self.participants.values():
512
+ task = asyncio.create_task(self._send_prepare_request(participant))
513
+ prepare_tasks.append(task)
514
+
515
+ # Wait for all prepare responses with timeout
516
+ try:
517
+ await asyncio.wait_for(
518
+ asyncio.gather(*prepare_tasks, return_exceptions=True),
519
+ timeout=self.prepare_timeout,
520
+ )
521
+ except asyncio.TimeoutError:
522
+ logger.error(f"Prepare phase timeout for transaction {self.transaction_id}")
523
+ return False
524
+
525
+ # Check if all participants voted to prepare
526
+ for participant in self.participants.values():
527
+ if participant.vote != ParticipantVote.PREPARED:
528
+ logger.warning(
529
+ f"Participant {participant.participant_id} voted {participant.vote}"
530
+ )
531
+ return False
532
+
533
+ return True
534
+
535
+ async def _execute_commit_phase(self) -> bool:
536
+ """Execute commit phase of 2PC protocol."""
537
+ # Send commit requests to all participants
538
+ commit_tasks = []
539
+
540
+ for participant in self.participants.values():
541
+ task = asyncio.create_task(self._send_commit_request(participant))
542
+ commit_tasks.append(task)
543
+
544
+ # Wait for all commit responses
545
+ try:
546
+ results = await asyncio.wait_for(
547
+ asyncio.gather(*commit_tasks, return_exceptions=True),
548
+ timeout=self.commit_timeout,
549
+ )
550
+
551
+ # Check for any failures
552
+ for i, result in enumerate(results):
553
+ if isinstance(result, Exception):
554
+ participant_id = list(self.participants.keys())[i]
555
+ logger.error(
556
+ f"Commit failed for participant {participant_id}: {result}"
557
+ )
558
+ return False
559
+
560
+ return True
561
+
562
+ except asyncio.TimeoutError:
563
+ logger.error(f"Commit phase timeout for transaction {self.transaction_id}")
564
+ return False
565
+
566
+ async def _send_prepare_request(self, participant: TwoPhaseCommitParticipant):
567
+ """Send prepare request to a participant."""
568
+ try:
569
+ # This is a mock implementation - in real usage, this would
570
+ # make HTTP/gRPC calls to actual participants
571
+ logger.info(f"Sending PREPARE to {participant.participant_id}")
572
+
573
+ # Simulate network call and processing time
574
+ await asyncio.sleep(0.1)
575
+
576
+ # Mock successful prepare vote (in real implementation, this would
577
+ # depend on the participant's actual response)
578
+ participant.vote = ParticipantVote.PREPARED
579
+ participant.prepare_time = datetime.now(UTC)
580
+ participant.last_contact = datetime.now(UTC)
581
+
582
+ logger.info(f"Participant {participant.participant_id} voted PREPARED")
583
+
584
+ except Exception as e:
585
+ logger.error(f"Failed to send prepare to {participant.participant_id}: {e}")
586
+ participant.vote = ParticipantVote.ABORT
587
+ participant.last_contact = datetime.now(UTC)
588
+
589
+ async def _send_commit_request(self, participant: TwoPhaseCommitParticipant):
590
+ """Send commit request to a participant."""
591
+ try:
592
+ # This is a mock implementation
593
+ logger.info(f"Sending COMMIT to {participant.participant_id}")
594
+
595
+ # Simulate commit processing
596
+ await asyncio.sleep(0.1)
597
+
598
+ participant.commit_time = datetime.now(UTC)
599
+ participant.last_contact = datetime.now(UTC)
600
+
601
+ logger.info(
602
+ f"Participant {participant.participant_id} committed successfully"
603
+ )
604
+
605
+ except Exception as e:
606
+ logger.error(f"Failed to send commit to {participant.participant_id}: {e}")
607
+ raise
608
+
609
+ async def _abort_all_participants(self):
610
+ """Send abort requests to all participants."""
611
+ logger.info(f"Aborting all participants for transaction {self.transaction_id}")
612
+
613
+ abort_tasks = []
614
+ for participant in self.participants.values():
615
+ task = asyncio.create_task(self._send_abort_request(participant))
616
+ abort_tasks.append(task)
617
+
618
+ # Don't wait indefinitely for abort responses
619
+ try:
620
+ await asyncio.wait_for(
621
+ asyncio.gather(*abort_tasks, return_exceptions=True), timeout=30
622
+ )
623
+ except asyncio.TimeoutError:
624
+ logger.warning("Some abort requests timed out")
625
+
626
+ async def _send_abort_request(self, participant: TwoPhaseCommitParticipant):
627
+ """Send abort request to a participant."""
628
+ try:
629
+ logger.info(f"Sending ABORT to {participant.participant_id}")
630
+
631
+ # Simulate abort processing
632
+ await asyncio.sleep(0.05)
633
+
634
+ participant.last_contact = datetime.now(UTC)
635
+
636
+ except Exception as e:
637
+ logger.warning(f"Failed to send abort to {participant.participant_id}: {e}")
638
+
639
+ async def _abort_transaction(self) -> Dict[str, Any]:
640
+ """Abort the transaction."""
641
+ if self.state in [TransactionState.COMMITTED, TransactionState.ABORTED]:
642
+ return {
643
+ "status": "already_finished",
644
+ "transaction_id": self.transaction_id,
645
+ "state": self.state.value,
646
+ }
647
+
648
+ logger.info(f"Aborting transaction {self.transaction_id}")
649
+
650
+ self.state = TransactionState.ABORTING
651
+ await self._persist_state()
652
+
653
+ # Send abort to all participants
654
+ await self._abort_all_participants()
655
+
656
+ self.state = TransactionState.ABORTED
657
+ self.completed_at = datetime.now(UTC)
658
+ await self._persist_state()
659
+
660
+ return {
661
+ "status": "success",
662
+ "transaction_id": self.transaction_id,
663
+ "state": self.state.value,
664
+ "aborted_at": self.completed_at.isoformat() if self.completed_at else None,
665
+ }
666
+
667
+ async def _get_status(self) -> Dict[str, Any]:
668
+ """Get current transaction status."""
669
+ participant_info = []
670
+ for participant in self.participants.values():
671
+ participant_info.append(
672
+ {
673
+ "participant_id": participant.participant_id,
674
+ "vote": participant.vote.value if participant.vote else None,
675
+ "prepare_time": (
676
+ participant.prepare_time.isoformat()
677
+ if participant.prepare_time
678
+ else None
679
+ ),
680
+ "commit_time": (
681
+ participant.commit_time.isoformat()
682
+ if participant.commit_time
683
+ else None
684
+ ),
685
+ "last_contact": (
686
+ participant.last_contact.isoformat()
687
+ if participant.last_contact
688
+ else None
689
+ ),
690
+ }
691
+ )
692
+
693
+ result = {
694
+ "status": "success",
695
+ "transaction_id": self.transaction_id,
696
+ "transaction_name": self.transaction_name,
697
+ "state": self.state.value,
698
+ "participants": participant_info,
699
+ "context": self.context,
700
+ "started_at": self.started_at.isoformat() if self.started_at else None,
701
+ "prepared_at": self.prepared_at.isoformat() if self.prepared_at else None,
702
+ "completed_at": (
703
+ self.completed_at.isoformat() if self.completed_at else None
704
+ ),
705
+ }
706
+
707
+ # Add state-specific timestamps
708
+ if self.state == TransactionState.ABORTED and self.completed_at:
709
+ result["aborted_at"] = self.completed_at.isoformat()
710
+
711
+ if self.error_message:
712
+ result["error"] = self.error_message
713
+
714
+ return result
715
+
716
+ async def _recover_transaction(self, **kwargs) -> Dict[str, Any]:
717
+ """Recover transaction from persistent state."""
718
+ transaction_id = kwargs.get("transaction_id", self.transaction_id)
719
+
720
+ # Initialize storage if not already done
721
+ if not self._storage:
722
+ self._storage = await self._get_storage()
723
+
724
+ if not self._storage:
725
+ raise NodeExecutionError("State storage not configured for recovery")
726
+
727
+ # Load transaction state
728
+ state_data = await self._storage.load_state(transaction_id)
729
+ if not state_data:
730
+ raise NodeExecutionError(f"Transaction {transaction_id} not found")
731
+
732
+ # Restore state
733
+ self._restore_from_state(state_data)
734
+
735
+ logger.info(
736
+ f"Recovered transaction {transaction_id} in state {self.state.value}"
737
+ )
738
+
739
+ # Check if we need to continue processing
740
+ if self.state == TransactionState.PREPARED:
741
+ # We were about to commit - continue with commit
742
+ logger.info(
743
+ f"Continuing commit phase for recovered transaction {transaction_id}"
744
+ )
745
+ commit_result = await self._execute_commit_phase()
746
+ if commit_result:
747
+ self.state = TransactionState.COMMITTED
748
+ self.completed_at = datetime.now(UTC)
749
+ await self._persist_state()
750
+ return await self._get_status()
751
+ else:
752
+ self.state = TransactionState.FAILED
753
+ await self._persist_state()
754
+ return await self._get_status()
755
+ elif self.state == TransactionState.COMMITTING:
756
+ # We were committing - check participant status and retry if needed
757
+ logger.info(
758
+ f"Retrying commit phase for recovered transaction {transaction_id}"
759
+ )
760
+ commit_result = await self._execute_commit_phase()
761
+ if commit_result:
762
+ self.state = TransactionState.COMMITTED
763
+ self.completed_at = datetime.now(UTC)
764
+ await self._persist_state()
765
+ return await self._get_status()
766
+ else:
767
+ self.state = TransactionState.FAILED
768
+ await self._persist_state()
769
+ return await self._get_status()
770
+
771
+ return await self._get_status()
772
+
773
+ async def _persist_state(self):
774
+ """Persist transaction state."""
775
+ if not self._storage:
776
+ # Initialize storage if needed
777
+ self._storage = await self._get_storage()
778
+
779
+ if self._storage:
780
+ state_data = self._get_state_data()
781
+ await self._storage.save_state(self.transaction_id, state_data)
782
+
783
+ async def _get_storage(self):
784
+ """Get storage instance for state persistence."""
785
+ if self.state_storage == "memory":
786
+ from .saga_state_storage import InMemoryStateStorage
787
+
788
+ return InMemoryStateStorage()
789
+ elif self.state_storage == "redis":
790
+ from .saga_state_storage import RedisStateStorage
791
+
792
+ redis_client = self.storage_config.get("redis_client")
793
+ if not redis_client:
794
+ logger.warning("Redis client not provided, using memory storage")
795
+ from .saga_state_storage import InMemoryStateStorage
796
+
797
+ return InMemoryStateStorage()
798
+ return RedisStateStorage(
799
+ redis_client, self.storage_config.get("key_prefix", "2pc:state:")
800
+ )
801
+ elif self.state_storage == "database":
802
+ db_pool = self.storage_config.get("db_pool")
803
+ if not db_pool:
804
+ logger.warning("Database pool not provided, using memory storage")
805
+ from .saga_state_storage import InMemoryStateStorage
806
+
807
+ return InMemoryStateStorage()
808
+ return TwoPhaseDatabaseStorage(
809
+ db_pool,
810
+ self.storage_config.get("table_name", "two_phase_commit_states"),
811
+ )
812
+ else:
813
+ logger.warning(f"Unknown storage type: {self.state_storage}, using memory")
814
+ from .saga_state_storage import InMemoryStateStorage
815
+
816
+ return InMemoryStateStorage()
817
+
818
+ def _get_state_data(self) -> Dict[str, Any]:
819
+ """Get current state as dictionary for persistence."""
820
+ return {
821
+ "transaction_id": self.transaction_id,
822
+ "transaction_name": self.transaction_name,
823
+ "state": self.state.value,
824
+ "context": self.context,
825
+ "participants": {
826
+ p_id: participant.to_dict()
827
+ for p_id, participant in self.participants.items()
828
+ },
829
+ "timeout": self.timeout,
830
+ "prepare_timeout": self.prepare_timeout,
831
+ "commit_timeout": self.commit_timeout,
832
+ "max_retries": self.max_retries,
833
+ "started_at": self.started_at.isoformat() if self.started_at else None,
834
+ "prepared_at": self.prepared_at.isoformat() if self.prepared_at else None,
835
+ "completed_at": (
836
+ self.completed_at.isoformat() if self.completed_at else None
837
+ ),
838
+ "error_message": self.error_message,
839
+ }
840
+
841
+ def _restore_from_state(self, state_data: Dict[str, Any]):
842
+ """Restore transaction state from persistence data."""
843
+ self.transaction_id = state_data["transaction_id"]
844
+ self.transaction_name = state_data["transaction_name"]
845
+ self.state = TransactionState(state_data["state"])
846
+ self.context = state_data.get("context", {})
847
+ self.timeout = state_data.get("timeout", self.timeout)
848
+ self.prepare_timeout = state_data.get("prepare_timeout", self.prepare_timeout)
849
+ self.commit_timeout = state_data.get("commit_timeout", self.commit_timeout)
850
+ self.max_retries = state_data.get("max_retries", self.max_retries)
851
+ self.error_message = state_data.get("error_message")
852
+
853
+ # Restore timestamps
854
+ if state_data.get("started_at"):
855
+ self.started_at = datetime.fromisoformat(state_data["started_at"])
856
+ if state_data.get("prepared_at"):
857
+ self.prepared_at = datetime.fromisoformat(state_data["prepared_at"])
858
+ if state_data.get("completed_at"):
859
+ self.completed_at = datetime.fromisoformat(state_data["completed_at"])
860
+
861
+ # Restore participants
862
+ self.participants = {}
863
+ for p_id, p_data in state_data.get("participants", {}).items():
864
+ self.participants[p_id] = TwoPhaseCommitParticipant.from_dict(p_data)
865
+
866
+
867
+ class TwoPhaseDatabaseStorage:
868
+ """Database storage for Two-Phase Commit states with correct column mapping."""
869
+
870
+ def __init__(self, db_pool: Any, table_name: str = "two_phase_commit_states"):
871
+ self.db_pool = db_pool
872
+ self.table_name = table_name
873
+
874
+ async def save_state(self, transaction_id: str, state_data: Dict[str, Any]) -> bool:
875
+ """Save 2PC state to database with correct column mapping."""
876
+ try:
877
+ async with self.db_pool.acquire() as conn:
878
+ query = f"""
879
+ INSERT INTO {self.table_name}
880
+ (transaction_id, transaction_name, state, state_data, updated_at)
881
+ VALUES ($1, $2, $3, $4, $5)
882
+ ON CONFLICT (transaction_id)
883
+ DO UPDATE SET
884
+ transaction_name = EXCLUDED.transaction_name,
885
+ state = EXCLUDED.state,
886
+ state_data = EXCLUDED.state_data,
887
+ updated_at = EXCLUDED.updated_at
888
+ """
889
+
890
+ await conn.execute(
891
+ query,
892
+ transaction_id,
893
+ state_data.get("transaction_name", ""),
894
+ state_data.get("state", ""),
895
+ json.dumps(state_data),
896
+ datetime.now(UTC),
897
+ )
898
+
899
+ return True
900
+
901
+ except Exception as e:
902
+ logger.error(
903
+ f"Failed to save 2PC state to database for transaction {transaction_id}: {e}"
904
+ )
905
+ return False
906
+
907
+ async def load_state(self, transaction_id: str) -> Optional[Dict[str, Any]]:
908
+ """Load 2PC state from database."""
909
+ try:
910
+ async with self.db_pool.acquire() as conn:
911
+ query = f"""
912
+ SELECT state_data
913
+ FROM {self.table_name}
914
+ WHERE transaction_id = $1
915
+ """
916
+
917
+ row = await conn.fetchrow(query, transaction_id)
918
+
919
+ if row:
920
+ return json.loads(row["state_data"])
921
+ return None
922
+
923
+ except Exception as e:
924
+ logger.error(
925
+ f"Failed to load 2PC state from database for transaction {transaction_id}: {e}"
926
+ )
927
+ return None
928
+
929
+ async def delete_state(self, transaction_id: str) -> bool:
930
+ """Delete 2PC state from database."""
931
+ try:
932
+ async with self.db_pool.acquire() as conn:
933
+ query = f"DELETE FROM {self.table_name} WHERE transaction_id = $1"
934
+ result = await conn.execute(query, transaction_id)
935
+
936
+ # Check if any rows were deleted
937
+ return result.split()[-1] != "0"
938
+
939
+ except Exception as e:
940
+ logger.error(
941
+ f"Failed to delete 2PC state from database for transaction {transaction_id}: {e}"
942
+ )
943
+ return False
944
+
945
+ async def list_sagas(
946
+ self, filter_criteria: Optional[Dict[str, Any]] = None
947
+ ) -> List[str]:
948
+ """List transaction IDs from database (for compatibility)."""
949
+ try:
950
+ async with self.db_pool.acquire() as conn:
951
+ if not filter_criteria:
952
+ query = f"SELECT transaction_id FROM {self.table_name}"
953
+ rows = await conn.fetch(query)
954
+ else:
955
+ # Build WHERE clause
956
+ conditions = []
957
+ params = []
958
+ param_count = 0
959
+
960
+ for key, value in filter_criteria.items():
961
+ param_count += 1
962
+ if key in ["state", "transaction_name"]:
963
+ conditions.append(f"{key} = ${param_count}")
964
+ params.append(value)
965
+ else:
966
+ # For other fields, use JSONB query
967
+ conditions.append(f"state_data->'{key}' = ${param_count}")
968
+ params.append(json.dumps(value))
969
+
970
+ where_clause = " AND ".join(conditions)
971
+ query = f"SELECT transaction_id FROM {self.table_name} WHERE {where_clause}"
972
+ rows = await conn.fetch(query, *params)
973
+
974
+ return [row["transaction_id"] for row in rows]
975
+
976
+ except Exception as e:
977
+ logger.error(f"Failed to list 2PC transactions from database: {e}")
978
+ return []