rappel 0.7.2__py3-none-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rappel might be problematic. Click here for more details.

rappel/workflow.py ADDED
@@ -0,0 +1,251 @@
1
+ """
2
+ Workflow base class and registration decorator.
3
+
4
+ This module provides the foundation for defining workflows that can be
5
+ compiled to IR and executed by the Rappel runtime.
6
+ """
7
+
8
+ import hashlib
9
+ import inspect
10
+ import os
11
+ from dataclasses import dataclass
12
+ from datetime import timedelta
13
+ from functools import wraps
14
+ from threading import RLock
15
+ from typing import Any, Awaitable, ClassVar, Optional, TypeVar
16
+
17
+ from proto import ast_pb2 as ir
18
+ from proto import messages_pb2 as pb2
19
+
20
+ from . import bridge
21
+ from .actions import deserialize_result_payload
22
+ from .ir_builder import build_workflow_ir
23
+ from .logger import configure as configure_logger
24
+ from .serialization import build_arguments_from_kwargs
25
+ from .workflow_runtime import WorkflowNodeResult
26
+
27
+ logger = configure_logger("rappel.workflow")
28
+
29
+ TWorkflow = TypeVar("TWorkflow", bound="Workflow")
30
+ TResult = TypeVar("TResult")
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class RetryPolicy:
35
+ """Retry policy for action execution.
36
+
37
+ Maps to IR RetryPolicy: [ExceptionType -> retry: N, backoff: Xs]
38
+
39
+ Args:
40
+ attempts: Maximum number of retry attempts.
41
+ exception_types: List of exception type names to retry on. Empty = catch all.
42
+ backoff_seconds: Constant backoff duration between retries in seconds.
43
+ """
44
+
45
+ attempts: Optional[int] = None
46
+ exception_types: Optional[list[str]] = None
47
+ backoff_seconds: Optional[float] = None
48
+
49
+
50
+ class Workflow:
51
+ """Base class for workflow definitions."""
52
+
53
+ name: ClassVar[Optional[str]] = None
54
+ """Human-friendly identifier. Override to pin the registry key; defaults to lowercase class name."""
55
+
56
+ concurrent: ClassVar[bool] = False
57
+ """When True, downstream engines may respect DAG-parallel execution; False preserves sequential semantics."""
58
+
59
+ _workflow_ir: ClassVar[Optional[ir.Program]] = None
60
+ _ir_lock: ClassVar[RLock] = RLock()
61
+ _workflow_version_id: ClassVar[Optional[str]] = None
62
+
63
+ async def run(self, *args: Any, **kwargs: Any) -> Any:
64
+ raise NotImplementedError
65
+
66
+ @classmethod
67
+ def _normalize_run_inputs(cls, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]:
68
+ try:
69
+ run_impl = cls.__workflow_run_impl__ # type: ignore[attr-defined]
70
+ except AttributeError:
71
+ run_impl = cls.run
72
+ sig = inspect.signature(run_impl)
73
+ params = list(sig.parameters.keys())[1:] # Skip 'self'
74
+
75
+ normalized = dict(kwargs)
76
+ for i, arg in enumerate(args):
77
+ if i < len(params):
78
+ normalized[params[i]] = arg
79
+
80
+ bound = sig.bind_partial(None, **normalized)
81
+ bound.apply_defaults()
82
+ return {key: value for key, value in bound.arguments.items() if key != "self"}
83
+
84
+ @classmethod
85
+ def _build_initial_context(
86
+ cls, args: tuple[Any, ...], kwargs: dict[str, Any]
87
+ ) -> pb2.WorkflowArguments:
88
+ initial_kwargs = cls._normalize_run_inputs(args, kwargs)
89
+ return build_arguments_from_kwargs(initial_kwargs)
90
+
91
+ async def run_action(
92
+ self,
93
+ awaitable: Awaitable[TResult],
94
+ *,
95
+ retry: Optional[RetryPolicy] = None,
96
+ timeout: Optional[float | int | timedelta] = None,
97
+ ) -> TResult:
98
+ """Helper that simply awaits the provided action coroutine.
99
+
100
+ The retry and timeout arguments are consumed by the workflow compiler
101
+ (IR builder) rather than the runtime execution path.
102
+
103
+ Args:
104
+ awaitable: The action coroutine to execute.
105
+ retry: Retry policy including max attempts, exception types, and backoff.
106
+ timeout: Timeout duration in seconds (or timedelta).
107
+ """
108
+ # Parameters are intentionally unused at runtime; the workflow compiler
109
+ # inspects the AST to record them.
110
+ del retry, timeout
111
+ return await awaitable
112
+
113
+ @classmethod
114
+ def short_name(cls) -> str:
115
+ if cls.name:
116
+ return cls.name
117
+ return cls.__name__.lower()
118
+
119
+ @classmethod
120
+ def workflow_ir(cls) -> ir.Program:
121
+ """Build and cache the IR program for this workflow."""
122
+ if cls._workflow_ir is None:
123
+ with cls._ir_lock:
124
+ if cls._workflow_ir is None:
125
+ cls._workflow_ir = build_workflow_ir(cls)
126
+ return cls._workflow_ir
127
+
128
+ @classmethod
129
+ def _build_registration_payload(
130
+ cls, initial_context: Optional[pb2.WorkflowArguments] = None
131
+ ) -> pb2.WorkflowRegistration:
132
+ """Build a registration payload with the serialized IR."""
133
+ program = cls.workflow_ir()
134
+
135
+ # Serialize IR to bytes
136
+ ir_bytes = program.SerializeToString()
137
+ ir_hash = hashlib.sha256(ir_bytes).hexdigest()
138
+
139
+ message = pb2.WorkflowRegistration(
140
+ workflow_name=cls.short_name(),
141
+ ir=ir_bytes,
142
+ ir_hash=ir_hash,
143
+ concurrent=cls.concurrent,
144
+ )
145
+
146
+ if initial_context:
147
+ message.initial_context.CopyFrom(initial_context)
148
+
149
+ return message
150
+
151
+
152
+ class WorkflowRegistry:
153
+ """Registry of workflow definitions keyed by workflow name."""
154
+
155
+ def __init__(self) -> None:
156
+ self._workflows: dict[str, type[Workflow]] = {}
157
+ self._lock = RLock()
158
+
159
+ def register(self, name: str, workflow_cls: type[Workflow]) -> None:
160
+ with self._lock:
161
+ if name in self._workflows:
162
+ raise ValueError(f"workflow '{name}' already registered")
163
+ self._workflows[name] = workflow_cls
164
+
165
+ def get(self, name: str) -> Optional[type[Workflow]]:
166
+ with self._lock:
167
+ return self._workflows.get(name)
168
+
169
+ def names(self) -> list[str]:
170
+ with self._lock:
171
+ return sorted(self._workflows.keys())
172
+
173
+ def reset(self) -> None:
174
+ with self._lock:
175
+ self._workflows.clear()
176
+
177
+
178
+ workflow_registry = WorkflowRegistry()
179
+
180
+
181
+ def workflow(cls: type[TWorkflow]) -> type[TWorkflow]:
182
+ """Decorator that registers workflow classes and caches their IR."""
183
+
184
+ if not issubclass(cls, Workflow):
185
+ raise TypeError("workflow decorator requires Workflow subclasses")
186
+ run_impl = cls.run
187
+ if not inspect.iscoroutinefunction(run_impl):
188
+ raise TypeError("workflow run() must be defined with 'async def'")
189
+
190
+ @wraps(run_impl)
191
+ async def run_public(self: Workflow, *args: Any, **kwargs: Any) -> Any:
192
+ if _running_under_pytest():
193
+ cls.workflow_ir()
194
+ return await run_impl(self, *args, **kwargs)
195
+
196
+ initial_context = cls._build_initial_context(args, kwargs)
197
+
198
+ payload = cls._build_registration_payload(initial_context)
199
+ run_result = await bridge.run_instance(payload.SerializeToString())
200
+ cls._workflow_version_id = run_result.workflow_version_id
201
+ if _skip_wait_for_instance():
202
+ logger.info(
203
+ "Skipping wait_for_instance for workflow %s due to RAPPEL_SKIP_WAIT_FOR_INSTANCE",
204
+ cls.short_name(),
205
+ )
206
+ return None
207
+ result_bytes = await bridge.wait_for_instance(
208
+ instance_id=run_result.workflow_instance_id,
209
+ poll_interval_secs=1.0,
210
+ )
211
+ if result_bytes is None:
212
+ raise TimeoutError(
213
+ f"workflow instance {run_result.workflow_instance_id} did not complete"
214
+ )
215
+ arguments = pb2.WorkflowArguments()
216
+ arguments.ParseFromString(result_bytes)
217
+ result = deserialize_result_payload(arguments)
218
+ if result.error:
219
+ raise RuntimeError(f"workflow failed: {result.error}")
220
+
221
+ # Unwrap WorkflowNodeResult if present (internal worker representation)
222
+ if isinstance(result.result, WorkflowNodeResult):
223
+ # Extract the actual result from the variables dict
224
+ variables = result.result.variables
225
+ program = cls.workflow_ir()
226
+ # Get the return variable from the IR if available
227
+ if program.functions:
228
+ outputs = list(program.functions[0].io.outputs)
229
+ if outputs:
230
+ return_var = outputs[0]
231
+ if return_var in variables:
232
+ return variables[return_var]
233
+ return None
234
+
235
+ return result.result
236
+
237
+ cls.__workflow_run_impl__ = run_impl
238
+ cls.run = run_public # type: ignore[assignment]
239
+ workflow_registry.register(cls.short_name(), cls)
240
+ return cls
241
+
242
+
243
+ def _running_under_pytest() -> bool:
244
+ return bool(os.environ.get("PYTEST_CURRENT_TEST"))
245
+
246
+
247
+ def _skip_wait_for_instance() -> bool:
248
+ value = os.environ.get("RAPPEL_SKIP_WAIT_FOR_INSTANCE")
249
+ if not value:
250
+ return False
251
+ return value.strip().lower() not in {"0", "false", "no"}
@@ -0,0 +1,287 @@
1
+ """Runtime helpers for executing actions inside the worker.
2
+
3
+ This module provides the execution layer for Python workers that receive
4
+ action dispatch commands from the Rust scheduler.
5
+ """
6
+
7
+ import asyncio
8
+ import dataclasses
9
+ from base64 import b64decode
10
+ from dataclasses import dataclass
11
+ from datetime import date, datetime, time, timedelta
12
+ from decimal import Decimal
13
+ from pathlib import Path, PurePath
14
+ from typing import Any, Dict, get_args, get_origin, get_type_hints
15
+ from uuid import UUID
16
+
17
+ from pydantic import BaseModel
18
+
19
+ from proto import messages_pb2 as pb2
20
+
21
+ from .dependencies import provide_dependencies
22
+ from .registry import registry
23
+ from .serialization import arguments_to_kwargs
24
+
25
+
26
+ class WorkflowNodeResult(BaseModel):
27
+ """Result from a workflow node execution containing variable bindings."""
28
+
29
+ variables: Dict[str, Any]
30
+
31
+
32
+ @dataclass
33
+ class ActionExecutionResult:
34
+ """Result of an action execution."""
35
+
36
+ result: Any
37
+ exception: BaseException | None = None
38
+
39
+
40
+ def _is_pydantic_model(cls: type) -> bool:
41
+ """Check if a class is a Pydantic BaseModel subclass."""
42
+ try:
43
+ return isinstance(cls, type) and issubclass(cls, BaseModel)
44
+ except TypeError:
45
+ return False
46
+
47
+
48
+ def _is_dataclass_type(cls: type) -> bool:
49
+ """Check if a class is a dataclass."""
50
+ return dataclasses.is_dataclass(cls) and isinstance(cls, type)
51
+
52
+
53
+ def _coerce_primitive(value: Any, target_type: type) -> Any:
54
+ """Coerce a value to a primitive type based on target_type.
55
+
56
+ Handles conversion of serialized values (strings, floats) back to their
57
+ native Python types (UUID, datetime, etc.).
58
+ """
59
+ # Handle None
60
+ if value is None:
61
+ return None
62
+
63
+ # UUID from string
64
+ if target_type is UUID:
65
+ if isinstance(value, UUID):
66
+ return value
67
+ if isinstance(value, str):
68
+ return UUID(value)
69
+ return value
70
+
71
+ # datetime from ISO string
72
+ if target_type is datetime:
73
+ if isinstance(value, datetime):
74
+ return value
75
+ if isinstance(value, str):
76
+ return datetime.fromisoformat(value)
77
+ return value
78
+
79
+ # date from ISO string
80
+ if target_type is date:
81
+ if isinstance(value, date):
82
+ return value
83
+ if isinstance(value, str):
84
+ return date.fromisoformat(value)
85
+ return value
86
+
87
+ # time from ISO string
88
+ if target_type is time:
89
+ if isinstance(value, time):
90
+ return value
91
+ if isinstance(value, str):
92
+ return time.fromisoformat(value)
93
+ return value
94
+
95
+ # timedelta from total seconds
96
+ if target_type is timedelta:
97
+ if isinstance(value, timedelta):
98
+ return value
99
+ if isinstance(value, (int, float)):
100
+ return timedelta(seconds=value)
101
+ return value
102
+
103
+ # Decimal from string
104
+ if target_type is Decimal:
105
+ if isinstance(value, Decimal):
106
+ return value
107
+ if isinstance(value, (str, int, float)):
108
+ return Decimal(str(value))
109
+ return value
110
+
111
+ # bytes from base64 string
112
+ if target_type is bytes:
113
+ if isinstance(value, bytes):
114
+ return value
115
+ if isinstance(value, str):
116
+ return b64decode(value)
117
+ return value
118
+
119
+ # Path from string
120
+ if target_type is Path or target_type is PurePath:
121
+ if isinstance(value, PurePath):
122
+ return value
123
+ if isinstance(value, str):
124
+ return Path(value)
125
+ return value
126
+
127
+ return value
128
+
129
+
130
+ # Types that can be coerced from serialized form
131
+ COERCIBLE_TYPES = (UUID, datetime, date, time, timedelta, Decimal, bytes, Path, PurePath)
132
+
133
+
134
+ def _coerce_dict_to_model(value: Any, target_type: type) -> Any:
135
+ """Convert a dict to a Pydantic model or dataclass if needed.
136
+
137
+ If value is a dict and target_type is a Pydantic model or dataclass,
138
+ instantiate the model with the dict values. Otherwise, return value unchanged.
139
+ """
140
+ if not isinstance(value, dict):
141
+ return value
142
+
143
+ if _is_pydantic_model(target_type):
144
+ # Use model_validate for Pydantic v2, fall back to direct instantiation
145
+ model_validate = getattr(target_type, "model_validate", None)
146
+ if model_validate is not None:
147
+ return model_validate(value)
148
+ return target_type(**value)
149
+
150
+ if _is_dataclass_type(target_type):
151
+ return target_type(**value)
152
+
153
+ return value
154
+
155
+
156
+ def _coerce_value(value: Any, target_type: type) -> Any:
157
+ """Coerce a value to the target type.
158
+
159
+ Handles:
160
+ - Primitive types (UUID, datetime, etc.)
161
+ - Pydantic models and dataclasses (from dicts)
162
+ - Generic collections like list[UUID], set[datetime]
163
+ """
164
+ # Handle None
165
+ if value is None:
166
+ return None
167
+
168
+ # Check for coercible primitive types
169
+ if isinstance(target_type, type) and issubclass(target_type, COERCIBLE_TYPES):
170
+ return _coerce_primitive(value, target_type)
171
+
172
+ # Check for Pydantic models or dataclasses
173
+ if isinstance(value, dict):
174
+ coerced = _coerce_dict_to_model(value, target_type)
175
+ if coerced is not value:
176
+ return coerced
177
+
178
+ # Handle generic types like list[UUID], set[datetime]
179
+ origin = get_origin(target_type)
180
+ if origin is not None:
181
+ args = get_args(target_type)
182
+
183
+ # Handle list[T]
184
+ if origin is list and isinstance(value, list) and args:
185
+ item_type = args[0]
186
+ return [_coerce_value(item, item_type) for item in value]
187
+
188
+ # Handle set[T] (serialized as list)
189
+ if origin is set and isinstance(value, list) and args:
190
+ item_type = args[0]
191
+ return {_coerce_value(item, item_type) for item in value}
192
+
193
+ # Handle frozenset[T] (serialized as list)
194
+ if origin is frozenset and isinstance(value, list) and args:
195
+ item_type = args[0]
196
+ return frozenset(_coerce_value(item, item_type) for item in value)
197
+
198
+ # Handle tuple[T, ...] (serialized as list)
199
+ if origin is tuple and isinstance(value, (list, tuple)) and args:
200
+ # Variable length tuple like tuple[int, ...]
201
+ if len(args) == 2 and args[1] is ...:
202
+ item_type = args[0]
203
+ return tuple(_coerce_value(item, item_type) for item in value)
204
+ # Fixed length tuple like tuple[int, str, UUID]
205
+ return tuple(
206
+ _coerce_value(item, item_type) for item, item_type in zip(value, args, strict=False)
207
+ )
208
+
209
+ # Handle dict[K, V]
210
+ if origin is dict and isinstance(value, dict) and len(args) == 2:
211
+ key_type, val_type = args
212
+ return {
213
+ _coerce_value(k, key_type): _coerce_value(v, val_type) for k, v in value.items()
214
+ }
215
+
216
+ return value
217
+
218
+
219
+ def _coerce_kwargs_to_type_hints(handler: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
220
+ """Coerce kwargs to expected types based on handler's type hints.
221
+
222
+ Handles:
223
+ - Pydantic models and dataclasses (from dicts)
224
+ - Primitive types like UUID, datetime, Decimal, etc.
225
+ - Generic collections like list[UUID], dict[str, datetime]
226
+ """
227
+ try:
228
+ type_hints = get_type_hints(handler)
229
+ except Exception:
230
+ # If we can't get type hints (e.g., forward references), return as-is
231
+ return kwargs
232
+
233
+ coerced = {}
234
+ for key, value in kwargs.items():
235
+ if key in type_hints:
236
+ target_type = type_hints[key]
237
+ coerced[key] = _coerce_value(value, target_type)
238
+ else:
239
+ coerced[key] = value
240
+
241
+ return coerced
242
+
243
+
244
+ async def execute_action(dispatch: pb2.ActionDispatch) -> ActionExecutionResult:
245
+ """Execute an action based on the dispatch command.
246
+
247
+ Args:
248
+ dispatch: The action dispatch command from the Rust scheduler.
249
+
250
+ Returns:
251
+ The result of executing the action.
252
+ """
253
+ action_name = dispatch.action_name
254
+ module_name = dispatch.module_name
255
+
256
+ # Import the module if specified (this registers actions via @action decorator)
257
+ if module_name:
258
+ import importlib
259
+
260
+ importlib.import_module(module_name)
261
+
262
+ # Get the action handler using both module and name
263
+ handler = registry.get(module_name, action_name)
264
+ if handler is None:
265
+ return ActionExecutionResult(
266
+ result=None,
267
+ exception=KeyError(f"action '{module_name}:{action_name}' not registered"),
268
+ )
269
+
270
+ # Deserialize kwargs
271
+ kwargs = arguments_to_kwargs(dispatch.kwargs)
272
+
273
+ # Coerce dict arguments to Pydantic models or dataclasses based on type hints
274
+ # This is needed because the IR converts model constructor calls to dicts
275
+ kwargs = _coerce_kwargs_to_type_hints(handler, kwargs)
276
+
277
+ try:
278
+ async with provide_dependencies(handler, kwargs) as call_kwargs:
279
+ value = handler(**call_kwargs)
280
+ if asyncio.iscoroutine(value):
281
+ value = await value
282
+ return ActionExecutionResult(result=value)
283
+ except Exception as e:
284
+ return ActionExecutionResult(
285
+ result=None,
286
+ exception=e,
287
+ )
Binary file
Binary file