edda-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edda/context.py ADDED
@@ -0,0 +1,489 @@
1
+ """
2
+ Workflow context module for Edda framework.
3
+
4
+ This module provides the WorkflowContext class for workflow execution,
5
+ managing state, history, and replay during workflow execution.
6
+ """
7
+
8
+ from collections.abc import AsyncIterator
9
+ from contextlib import asynccontextmanager
10
+ from typing import TYPE_CHECKING, Any, cast
11
+
12
+ from edda.events import ReceivedEvent
13
+ from edda.storage.protocol import StorageProtocol
14
+
15
+ if TYPE_CHECKING:
16
+ from sqlalchemy.ext.asyncio import AsyncSession
17
+
18
+
19
+ class WorkflowContext:
20
+ """
21
+ Context for workflow execution.
22
+
23
+ Provides access to workflow instance metadata, storage, history management,
24
+ and utilities for deterministic replay.
25
+
26
+ This context is passed to activities and contains all the information needed
27
+ for execution and replay.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ instance_id: str,
33
+ workflow_name: str,
34
+ storage: StorageProtocol,
35
+ worker_id: str,
36
+ is_replaying: bool = False,
37
+ hooks: Any = None,
38
+ ):
39
+ """
40
+ Initialize workflow context.
41
+
42
+ Args:
43
+ instance_id: Workflow instance ID
44
+ workflow_name: Name of the workflow
45
+ storage: Storage backend
46
+ worker_id: Worker ID holding the lock
47
+ is_replaying: Whether this is a replay execution
48
+ hooks: Optional WorkflowHooks implementation for observability
49
+ """
50
+ self.instance_id = instance_id
51
+ self.workflow_name = workflow_name
52
+ self._storage = storage # Private: use properties/methods instead
53
+ self.worker_id = worker_id
54
+ self.is_replaying = is_replaying
55
+ self.hooks = hooks
56
+
57
+ # Activity ID tracking for deterministic replay
58
+ self.executed_activity_ids: set[str] = set()
59
+
60
+ # History cache for replay (activity_id -> result)
61
+ self._history_cache: dict[str, Any] = {}
62
+
63
+ # Flag to track if we've loaded history
64
+ self._history_loaded = False
65
+
66
+ # Auto-generation counter for activity IDs (func_name -> call_count)
67
+ self._activity_call_counters: dict[str, int] = {}
68
+
69
+ # Default retry policy from EddaApp (set by ReplayEngine)
70
+ self._app_retry_policy: Any = None
71
+
72
+ @property
73
+ def storage(self) -> StorageProtocol:
74
+ """
75
+ Get storage backend (internal use only).
76
+
77
+ Warning:
78
+ This property is for framework internal use only.
79
+ Direct storage access may break deterministic replay guarantees.
80
+ Use WorkflowContext methods instead (transaction(), in_transaction()).
81
+ """
82
+ return self._storage
83
+
84
+ @property
85
+ def session(self) -> "AsyncSession":
86
+ """
87
+ Get Edda-managed database session for custom database operations.
88
+
89
+ This property provides access to the current transaction's SQLAlchemy session,
90
+ allowing you to execute custom database operations (ORM queries, raw SQL, etc.)
91
+ within the same transaction as Edda's workflow operations.
92
+
93
+ The session is automatically managed by Edda:
94
+ - Commit/rollback happens automatically at the end of @activity
95
+ - All operations are atomic (workflow history + your DB operations)
96
+ - Transaction safety is guaranteed
97
+
98
+ Returns:
99
+ AsyncSession managed by Edda's transaction context
100
+
101
+ Raises:
102
+ RuntimeError: If not inside a transaction (must use @activity or ctx.transaction())
103
+
104
+ Example:
105
+ @activity
106
+ async def create_order(ctx: WorkflowContext, order_id: str, amount: float):
107
+ # Get Edda-managed session
108
+ session = ctx.session
109
+
110
+ # Your business logic (same DB as Edda)
111
+ order = Order(order_id=order_id, amount=amount)
112
+ session.add(order)
113
+
114
+ # Event publishing (same transaction)
115
+ await send_event_transactional(
116
+ ctx, "order.created", "order-service",
117
+ {"order_id": order_id, "amount": amount}
118
+ )
119
+
120
+ # Edda commits automatically (or rolls back on error)
121
+ return {"order_id": order_id, "status": "created"}
122
+
123
+ Note:
124
+ - Requires @activity (default) or async with ctx.transaction()
125
+ - All operations commit/rollback together atomically
126
+ - Your tables must be in the same database as Edda
127
+ - Do NOT call session.commit() or session.rollback() manually
128
+ """
129
+ if not self.storage.in_transaction():
130
+ raise RuntimeError(
131
+ "ctx.session must be accessed inside a transaction. "
132
+ "Use @activity (default) or async with ctx.transaction()"
133
+ )
134
+
135
+ return cast("AsyncSession", self.storage._get_session_for_operation()) # type: ignore[attr-defined]
136
+
137
+ async def _load_history(self) -> None:
138
+ """
139
+ Load execution history from storage (internal use only).
140
+
141
+ This is called at the beginning of a replay to populate the history cache.
142
+ """
143
+ if self._history_loaded:
144
+ return
145
+
146
+ history = await self.storage.get_history(self.instance_id)
147
+
148
+ for event in history:
149
+ activity_id = event["activity_id"]
150
+ event_type = event["event_type"]
151
+ event_data = event["event_data"]
152
+
153
+ # Track executed activity IDs
154
+ self.executed_activity_ids.add(activity_id)
155
+
156
+ if event_type == "ActivityCompleted":
157
+ # Cache the activity result
158
+ self._history_cache[activity_id] = event_data.get("result")
159
+ elif event_type == "ActivityFailed":
160
+ # Cache the error for replay
161
+ self._history_cache[activity_id] = {
162
+ "_error": True,
163
+ "error_type": event_data.get("error_type"),
164
+ "error_message": event_data.get("error_message"),
165
+ }
166
+ elif event_type == "EventReceived":
167
+ # Cache the event data and metadata for wait_event replay
168
+ # Reconstruct ReceivedEvent from stored data
169
+ payload = event_data.get("payload", {})
170
+ metadata = event_data.get("metadata", {})
171
+ extensions = event_data.get("extensions", {})
172
+
173
+ # For backward compatibility: check if old format (event_data directly)
174
+ if "payload" not in event_data and "metadata" not in event_data:
175
+ # Old format: {"event_data": {...}}
176
+ payload = event_data.get("event_data", {})
177
+ metadata = {
178
+ "type": "unknown",
179
+ "source": "unknown",
180
+ "id": "unknown",
181
+ }
182
+
183
+ received_event = ReceivedEvent(
184
+ data=payload,
185
+ type=metadata.get("type", "unknown"),
186
+ source=metadata.get("source", "unknown"),
187
+ id=metadata.get("id", "unknown"),
188
+ time=metadata.get("time"),
189
+ datacontenttype=metadata.get("datacontenttype"),
190
+ subject=metadata.get("subject"),
191
+ extensions=extensions,
192
+ )
193
+ self._history_cache[activity_id] = received_event
194
+ elif event_type == "TimerExpired":
195
+ # Cache the timer result for wait_timer replay
196
+ # Timer returns None, so we cache the result field
197
+ self._history_cache[activity_id] = event_data.get("result")
198
+
199
+ self._history_loaded = True
200
+
201
+ def _get_cached_result(self, activity_id: str) -> tuple[bool, Any]:
202
+ """
203
+ Get cached result for an activity during replay (internal use only).
204
+
205
+ Args:
206
+ activity_id: Activity ID
207
+
208
+ Returns:
209
+ Tuple of (found, result) where found is True if result was cached
210
+ """
211
+ if activity_id in self._history_cache:
212
+ return True, self._history_cache[activity_id]
213
+ return False, None
214
+
215
+ def _generate_activity_id(self, function_name: str) -> str:
216
+ """
217
+ Generate a unique activity ID for auto-generation (internal use only).
218
+
219
+ Uses the format: {function_name}:{counter}
220
+
221
+ Args:
222
+ function_name: Name of the activity function
223
+
224
+ Returns:
225
+ Generated activity ID (e.g., "reserve_inventory:1")
226
+ """
227
+ # Increment counter for this function
228
+ count = self._activity_call_counters.get(function_name, 0) + 1
229
+ self._activity_call_counters[function_name] = count
230
+
231
+ activity_id = f"{function_name}:{count}"
232
+
233
+ return activity_id
234
+
235
+ def _record_activity_id(self, activity_id: str) -> None:
236
+ """
237
+ Record that an activity ID has been executed (internal use only).
238
+
239
+ Args:
240
+ activity_id: The activity ID to record
241
+ """
242
+ self.executed_activity_ids.add(activity_id)
243
+
244
+ async def _record_activity_completed(
245
+ self,
246
+ activity_id: str,
247
+ activity_name: str,
248
+ result: Any,
249
+ input_data: dict[str, Any] | None = None,
250
+ retry_metadata: Any = None,
251
+ ) -> None:
252
+ """
253
+ Record that an activity completed successfully (internal use only).
254
+
255
+ Args:
256
+ activity_id: Activity ID
257
+ activity_name: Name of the activity
258
+ result: Activity result (must be JSON-serializable)
259
+ input_data: Activity input parameters (args and kwargs)
260
+ retry_metadata: Optional retry metadata (RetryMetadata instance)
261
+ """
262
+ event_data: dict[str, Any] = {
263
+ "activity_name": activity_name,
264
+ "result": result,
265
+ "input": input_data or {},
266
+ }
267
+
268
+ # Include retry metadata if provided
269
+ if retry_metadata is not None:
270
+ event_data["retry_metadata"] = retry_metadata.to_dict()
271
+
272
+ await self.storage.append_history(
273
+ self.instance_id,
274
+ activity_id=activity_id,
275
+ event_type="ActivityCompleted",
276
+ event_data=event_data,
277
+ )
278
+
279
+ # Update current activity ID
280
+ await self.storage.update_instance_activity(self.instance_id, activity_id)
281
+
282
+ async def _record_activity_failed(
283
+ self,
284
+ activity_id: str,
285
+ activity_name: str,
286
+ error: Exception,
287
+ input_data: dict[str, Any] | None = None,
288
+ retry_metadata: Any = None,
289
+ ) -> None:
290
+ """
291
+ Record that an activity failed (internal use only).
292
+
293
+ Args:
294
+ activity_id: Activity ID
295
+ activity_name: Name of the activity
296
+ error: The exception that was raised
297
+ input_data: Activity input parameters (args and kwargs)
298
+ retry_metadata: Optional retry metadata (RetryMetadata instance)
299
+ """
300
+ import traceback
301
+
302
+ # Capture full stack trace
303
+ stack_trace = "".join(traceback.format_exception(type(error), error, error.__traceback__))
304
+
305
+ event_data: dict[str, Any] = {
306
+ "activity_name": activity_name,
307
+ "error_type": type(error).__name__,
308
+ "error_message": str(error),
309
+ "stack_trace": stack_trace,
310
+ "input": input_data or {},
311
+ }
312
+
313
+ # Include retry metadata if provided
314
+ if retry_metadata is not None:
315
+ event_data["retry_metadata"] = retry_metadata.to_dict()
316
+
317
+ await self.storage.append_history(
318
+ self.instance_id,
319
+ activity_id=activity_id,
320
+ event_type="ActivityFailed",
321
+ event_data=event_data,
322
+ )
323
+
324
+ async def _get_instance(self) -> dict[str, Any] | None:
325
+ """
326
+ Get the workflow instance metadata (internal use only).
327
+
328
+ Returns:
329
+ Instance metadata dictionary or None if not found
330
+ """
331
+ return await self.storage.get_instance(self.instance_id)
332
+
333
+ async def _update_status(self, status: str, output_data: dict[str, Any] | None = None) -> None:
334
+ """
335
+ Update the workflow instance status (internal use only).
336
+
337
+ Args:
338
+ status: New status (e.g., "completed", "failed", "waiting_for_event")
339
+ output_data: Optional output data for completed workflows
340
+ """
341
+ await self.storage.update_instance_status(self.instance_id, status, output_data)
342
+
343
+ async def _register_event_subscription(
344
+ self,
345
+ event_type: str,
346
+ timeout_seconds: int | None = None,
347
+ activity_id: str | None = None,
348
+ ) -> None:
349
+ """
350
+ Register an event subscription for wait_event (internal use only).
351
+
352
+ This is called when a workflow calls wait_event() and needs to pause
353
+ until a matching event arrives.
354
+
355
+ Args:
356
+ event_type: CloudEvent type to wait for
357
+ timeout_seconds: Optional timeout in seconds
358
+ activity_id: The activity ID where wait_event was called
359
+ """
360
+ from datetime import UTC, datetime, timedelta
361
+
362
+ timeout_at = None
363
+ if timeout_seconds is not None:
364
+ timeout_at = datetime.now(UTC) + timedelta(seconds=timeout_seconds)
365
+
366
+ await self.storage.add_event_subscription(
367
+ instance_id=self.instance_id,
368
+ event_type=event_type,
369
+ timeout_at=timeout_at,
370
+ )
371
+
372
+ # Update current activity ID
373
+ if activity_id is not None:
374
+ await self.storage.update_instance_activity(self.instance_id, activity_id)
375
+
376
+ async def _record_event_received(self, activity_id: str, event_data: dict[str, Any]) -> None:
377
+ """
378
+ Record that an event was received during wait_event (internal use only).
379
+
380
+ This is called when resuming a workflow after an event arrives.
381
+
382
+ Args:
383
+ activity_id: The activity ID where wait_event was called
384
+ event_data: The received event data
385
+ """
386
+ await self.storage.append_history(
387
+ instance_id=self.instance_id,
388
+ activity_id=activity_id,
389
+ event_type="EventReceived",
390
+ event_data={"event_data": event_data},
391
+ )
392
+
393
+ async def _push_compensation(self, compensation_action: Any, activity_id: str) -> None:
394
+ """
395
+ Register a compensation action for this workflow (internal use only).
396
+
397
+ Compensation actions are stored in LIFO order and executed on failure.
398
+
399
+ Args:
400
+ compensation_action: The CompensationAction to register
401
+ activity_id: The activity ID where compensation was registered
402
+ """
403
+ # Serialize compensation action with full args and kwargs
404
+ await self.storage.push_compensation(
405
+ instance_id=self.instance_id,
406
+ activity_id=activity_id,
407
+ activity_name=compensation_action.name,
408
+ args={
409
+ "name": compensation_action.name,
410
+ "args": list(compensation_action.args), # Convert tuple to list for JSON
411
+ "kwargs": compensation_action.kwargs,
412
+ },
413
+ )
414
+
415
+ async def _get_compensations(self) -> list[dict[str, Any]]:
416
+ """
417
+ Get all registered compensation actions (internal use only).
418
+
419
+ Returns:
420
+ List of compensation data dictionaries
421
+ """
422
+ return await self.storage.get_compensations(self.instance_id)
423
+
424
+ async def _clear_compensations(self) -> None:
425
+ """
426
+ Clear all registered compensations (internal use only).
427
+
428
+ This is called when a workflow completes successfully.
429
+ """
430
+ await self.storage.clear_compensations(self.instance_id)
431
+
432
+ @asynccontextmanager
433
+ async def transaction(self) -> AsyncIterator[None]:
434
+ """
435
+ Create a transactional context for atomic operations.
436
+
437
+ This context manager allows you to execute multiple storage operations
438
+ within a single database transaction. All operations will be committed
439
+ together, or rolled back together if an exception occurs.
440
+
441
+ Example:
442
+ async with ctx.transaction():
443
+ # All operations here are in the same transaction
444
+ await ctx.storage.append_history(...)
445
+ await send_event_transactional(ctx, ...)
446
+ # If any operation fails, all changes are rolled back
447
+
448
+ Yields:
449
+ None
450
+
451
+ Raises:
452
+ Exception: If any operation within the transaction fails,
453
+ the transaction is rolled back and the exception is re-raised
454
+ """
455
+ await self.storage.begin_transaction()
456
+ try:
457
+ yield
458
+ await self.storage.commit_transaction()
459
+ except Exception:
460
+ await self.storage.rollback_transaction()
461
+ raise
462
+
463
+ def in_transaction(self) -> bool:
464
+ """
465
+ Check if currently in a transaction.
466
+
467
+ This method is useful for ensuring that transactional operations
468
+ (like send_event_transactional) are called within a transaction context.
469
+
470
+ Returns:
471
+ True if inside a transaction context, False otherwise
472
+
473
+ Example:
474
+ if ctx.in_transaction():
475
+ await send_event_transactional(ctx, "order.created", ...)
476
+ else:
477
+ logger.warning("Not in transaction, using outbox pattern")
478
+ await send_event_transactional(ctx, "order.created", ...)
479
+ """
480
+ return self.storage.in_transaction()
481
+
482
+ def __repr__(self) -> str:
483
+ """String representation of the context."""
484
+ return (
485
+ f"WorkflowContext(instance_id={self.instance_id!r}, "
486
+ f"workflow_name={self.workflow_name!r}, "
487
+ f"executed_activities={len(self.executed_activity_ids)}, "
488
+ f"is_replaying={self.is_replaying})"
489
+ )