rappel 0.10.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rappel might be problematic. Click here for more details.

rappel/workflow.py ADDED
@@ -0,0 +1,338 @@
1
+ """
2
+ Workflow base class and registration decorator.
3
+
4
+ This module provides the foundation for defining workflows that can be
5
+ compiled to IR and executed by the Rappel runtime.
6
+ """
7
+
8
+ import hashlib
9
+ import inspect
10
+ import os
11
+ from dataclasses import dataclass
12
+ from datetime import timedelta
13
+ from functools import wraps
14
+ from threading import RLock
15
+ from types import UnionType
16
+ from typing import (
17
+ Any,
18
+ Awaitable,
19
+ ClassVar,
20
+ Optional,
21
+ TypeVar,
22
+ Union,
23
+ get_args,
24
+ get_origin,
25
+ get_type_hints,
26
+ )
27
+
28
+ from proto import ast_pb2 as ir
29
+ from proto import messages_pb2 as pb2
30
+
31
+ from . import bridge
32
+ from .actions import deserialize_result_payload
33
+ from .ir_builder import build_workflow_ir
34
+ from .logger import configure as configure_logger
35
+ from .serialization import build_arguments_from_kwargs
36
+ from .workflow_runtime import WorkflowNodeResult, _coerce_value
37
+
38
+ logger = configure_logger("rappel.workflow")
39
+
40
+ TWorkflow = TypeVar("TWorkflow", bound="Workflow")
41
+ TResult = TypeVar("TResult")
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class RetryPolicy:
46
+ """Retry policy for action execution.
47
+
48
+ Maps to IR RetryPolicy: [ExceptionType -> retry: N, backoff: Xs]
49
+
50
+ Args:
51
+ attempts: Maximum number of retry attempts.
52
+ exception_types: List of exception type names to retry on. Empty = catch all.
53
+ backoff_seconds: Constant backoff duration between retries in seconds.
54
+ """
55
+
56
+ attempts: Optional[int] = None
57
+ exception_types: Optional[list[str]] = None
58
+ backoff_seconds: Optional[float] = None
59
+
60
+
61
+ class Workflow:
62
+ """Base class for workflow definitions."""
63
+
64
+ name: ClassVar[Optional[str]] = None
65
+ """Human-friendly identifier. Override to pin the registry key; defaults to lowercase class name."""
66
+
67
+ concurrent: ClassVar[bool] = False
68
+ """When True, downstream engines may respect DAG-parallel execution; False preserves sequential semantics."""
69
+
70
+ _workflow_ir: ClassVar[Optional[ir.Program]] = None
71
+ _ir_lock: ClassVar[RLock] = RLock()
72
+ _workflow_version_id: ClassVar[Optional[str]] = None
73
+
74
+ async def run(
75
+ self, *args: Any, _blocking: bool = True, _priority: Optional[int] = None, **kwargs: Any
76
+ ) -> Any:
77
+ raise NotImplementedError
78
+
79
+ @classmethod
80
+ def _normalize_run_inputs(cls, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]:
81
+ try:
82
+ run_impl = cls.__workflow_run_impl__ # type: ignore[attr-defined]
83
+ except AttributeError:
84
+ run_impl = cls.run
85
+ sig = inspect.signature(run_impl)
86
+ params = list(sig.parameters.keys())[1:] # Skip 'self'
87
+
88
+ normalized = dict(kwargs)
89
+ for i, arg in enumerate(args):
90
+ if i < len(params):
91
+ normalized[params[i]] = arg
92
+
93
+ bound = sig.bind_partial(None, **normalized)
94
+ bound.apply_defaults()
95
+ return {key: value for key, value in bound.arguments.items() if key != "self"}
96
+
97
+ @classmethod
98
+ def _build_initial_context(
99
+ cls, args: tuple[Any, ...], kwargs: dict[str, Any]
100
+ ) -> pb2.WorkflowArguments:
101
+ initial_kwargs = cls._normalize_run_inputs(args, kwargs)
102
+ return build_arguments_from_kwargs(initial_kwargs)
103
+
104
+ async def run_action(
105
+ self,
106
+ awaitable: Awaitable[TResult],
107
+ *,
108
+ retry: Optional[RetryPolicy] = None,
109
+ timeout: Optional[float | int | timedelta] = None,
110
+ ) -> TResult:
111
+ """Helper that simply awaits the provided action coroutine.
112
+
113
+ The retry and timeout arguments are consumed by the workflow compiler
114
+ (IR builder) rather than the runtime execution path.
115
+
116
+ Args:
117
+ awaitable: The action coroutine to execute.
118
+ retry: Retry policy including max attempts, exception types, and backoff.
119
+ timeout: Timeout duration in seconds (or timedelta).
120
+ """
121
+ # Parameters are intentionally unused at runtime; the workflow compiler
122
+ # inspects the AST to record them.
123
+ del retry, timeout
124
+ return await awaitable
125
+
126
+ @classmethod
127
+ def short_name(cls) -> str:
128
+ if cls.name:
129
+ return cls.name
130
+ return cls.__name__.lower()
131
+
132
+ @classmethod
133
+ def workflow_ir(cls) -> ir.Program:
134
+ """Build and cache the IR program for this workflow."""
135
+ if cls._workflow_ir is None:
136
+ with cls._ir_lock:
137
+ if cls._workflow_ir is None:
138
+ cls._workflow_ir = build_workflow_ir(cls)
139
+ return cls._workflow_ir
140
+
141
+ @classmethod
142
+ def _build_registration_payload(
143
+ cls,
144
+ initial_context: Optional[pb2.WorkflowArguments] = None,
145
+ priority: Optional[int] = None,
146
+ ) -> pb2.WorkflowRegistration:
147
+ """Build a registration payload with the serialized IR."""
148
+ program = cls.workflow_ir()
149
+
150
+ # Serialize IR to bytes
151
+ ir_bytes = program.SerializeToString()
152
+ ir_hash = hashlib.sha256(ir_bytes).hexdigest()
153
+
154
+ message = pb2.WorkflowRegistration(
155
+ workflow_name=cls.short_name(),
156
+ ir=ir_bytes,
157
+ ir_hash=ir_hash,
158
+ concurrent=cls.concurrent,
159
+ )
160
+
161
+ if initial_context:
162
+ message.initial_context.CopyFrom(initial_context)
163
+
164
+ if priority is not None:
165
+ message.priority = priority
166
+
167
+ return message
168
+
169
+
170
+ class WorkflowRegistry:
171
+ """Registry of workflow definitions keyed by workflow name."""
172
+
173
+ def __init__(self) -> None:
174
+ self._workflows: dict[str, type[Workflow]] = {}
175
+ self._lock = RLock()
176
+
177
+ def register(self, name: str, workflow_cls: type[Workflow]) -> None:
178
+ with self._lock:
179
+ if name in self._workflows:
180
+ raise ValueError(f"workflow '{name}' already registered")
181
+ self._workflows[name] = workflow_cls
182
+
183
+ def get(self, name: str) -> Optional[type[Workflow]]:
184
+ with self._lock:
185
+ return self._workflows.get(name)
186
+
187
+ def names(self) -> list[str]:
188
+ with self._lock:
189
+ return sorted(self._workflows.keys())
190
+
191
+ def reset(self) -> None:
192
+ with self._lock:
193
+ self._workflows.clear()
194
+
195
+
196
+ workflow_registry = WorkflowRegistry()
197
+
198
+
199
+ def workflow(cls: type[TWorkflow]) -> type[TWorkflow]:
200
+ """Decorator that registers workflow classes and caches their IR."""
201
+
202
+ if not issubclass(cls, Workflow):
203
+ raise TypeError("workflow decorator requires Workflow subclasses")
204
+ run_impl = cls.run
205
+ if not inspect.iscoroutinefunction(run_impl):
206
+ raise TypeError("workflow run() must be defined with 'async def'")
207
+
208
+ @wraps(run_impl)
209
+ async def run_public(
210
+ self: Workflow,
211
+ *args: Any,
212
+ _blocking: bool = True,
213
+ _priority: Optional[int] = None,
214
+ **kwargs: Any,
215
+ ) -> Any:
216
+ if _running_under_pytest():
217
+ logger.debug(
218
+ "pytest run: workflow=%s in_memory=%s",
219
+ cls.short_name(),
220
+ os.environ.get("RAPPEL_BRIDGE_IN_MEMORY"),
221
+ )
222
+ _enable_in_memory_broker()
223
+ initial_context = cls._build_initial_context(args, kwargs)
224
+ payload = cls._build_registration_payload(initial_context, priority=_priority)
225
+ result_bytes = await bridge.execute_workflow(payload.SerializeToString())
226
+ return _deserialize_workflow_result(cls, result_bytes)
227
+
228
+ initial_context = cls._build_initial_context(args, kwargs)
229
+
230
+ payload = cls._build_registration_payload(initial_context, priority=_priority)
231
+ run_result = await bridge.run_instance(payload.SerializeToString())
232
+ cls._workflow_version_id = run_result.workflow_version_id
233
+
234
+ if not _blocking:
235
+ return run_result.workflow_instance_id
236
+
237
+ if _skip_wait_for_instance():
238
+ logger.info(
239
+ "Skipping wait_for_instance for workflow %s due to RAPPEL_SKIP_WAIT_FOR_INSTANCE",
240
+ cls.short_name(),
241
+ )
242
+ return None
243
+ result_bytes = await bridge.wait_for_instance(
244
+ instance_id=run_result.workflow_instance_id,
245
+ poll_interval_secs=1.0,
246
+ )
247
+ if result_bytes is None:
248
+ raise TimeoutError(
249
+ f"workflow instance {run_result.workflow_instance_id} did not complete"
250
+ )
251
+ return _deserialize_workflow_result(cls, result_bytes)
252
+
253
+ cls.__workflow_run_impl__ = run_impl
254
+ cls.run = run_public # type: ignore[assignment]
255
+ workflow_registry.register(cls.short_name(), cls)
256
+ return cls
257
+
258
+
259
+ def _running_under_pytest() -> bool:
260
+ return bool(os.environ.get("PYTEST_CURRENT_TEST"))
261
+
262
+
263
+ def _skip_wait_for_instance() -> bool:
264
+ value = os.environ.get("RAPPEL_SKIP_WAIT_FOR_INSTANCE")
265
+ if not value:
266
+ return False
267
+ return value.strip().lower() not in {"0", "false", "no"}
268
+
269
+
270
+ def _enable_in_memory_broker() -> None:
271
+ os.environ.setdefault("RAPPEL_BRIDGE_IN_MEMORY", "1")
272
+
273
+
274
+ def _deserialize_workflow_result(
275
+ workflow_cls: type[Workflow],
276
+ result_bytes: bytes,
277
+ ) -> Any:
278
+ arguments = pb2.WorkflowArguments()
279
+ arguments.ParseFromString(result_bytes)
280
+ result = deserialize_result_payload(arguments)
281
+ if result.error:
282
+ raise RuntimeError(f"workflow failed: {result.error}")
283
+
284
+ # Unwrap WorkflowNodeResult if present (internal worker representation)
285
+ if isinstance(result.result, WorkflowNodeResult):
286
+ # Extract the actual result from the variables dict
287
+ variables = result.result.variables
288
+ program = workflow_cls.workflow_ir()
289
+ # Get the return variable from the IR if available
290
+ if program.functions:
291
+ outputs = list(program.functions[0].io.outputs)
292
+ if outputs:
293
+ return_var = outputs[0]
294
+ if return_var in variables:
295
+ return variables[return_var]
296
+ return None
297
+
298
+ value = result.result
299
+ target_type = _resolve_return_type(workflow_cls)
300
+ if target_type is None:
301
+ return value
302
+ return _coerce_result_value(value, target_type)
303
+
304
+
305
+ def _resolve_return_type(workflow_cls: type[Workflow]) -> Optional[type]:
306
+ try:
307
+ run_impl = workflow_cls.__workflow_run_impl__ # type: ignore[attr-defined]
308
+ except AttributeError:
309
+ run_impl = workflow_cls.run
310
+ try:
311
+ type_hints = get_type_hints(run_impl)
312
+ except Exception:
313
+ return None
314
+ target_type = type_hints.get("return")
315
+ if target_type is None or target_type is Any:
316
+ return None
317
+ return target_type
318
+
319
+
320
+ def _coerce_result_value(value: Any, target_type: type) -> Any:
321
+ origin = get_origin(target_type)
322
+ if origin is UnionType or origin is Union:
323
+ for arg in get_args(target_type):
324
+ if arg is type(None) and value is None:
325
+ return None
326
+ try:
327
+ coerced = _coerce_value(value, arg)
328
+ except Exception:
329
+ continue
330
+ if coerced is not value:
331
+ return coerced
332
+ if isinstance(arg, type) and isinstance(value, arg):
333
+ return value
334
+ return value
335
+ try:
336
+ return _coerce_value(value, target_type)
337
+ except Exception:
338
+ return value
@@ -0,0 +1,292 @@
1
+ """Runtime helpers for executing actions inside the worker.
2
+
3
+ This module provides the execution layer for Python workers that receive
4
+ action dispatch commands from the Rust scheduler.
5
+ """
6
+
7
+ import asyncio
8
+ import dataclasses
9
+ from base64 import b64decode
10
+ from dataclasses import dataclass
11
+ from datetime import date, datetime, time, timedelta
12
+ from decimal import Decimal
13
+ from pathlib import Path, PurePath
14
+ from typing import Any, Dict, get_args, get_origin, get_type_hints
15
+ from uuid import UUID
16
+
17
+ from pydantic import BaseModel
18
+
19
+ from proto import messages_pb2 as pb2
20
+
21
+ from .dependencies import provide_dependencies
22
+ from .registry import registry
23
+ from .serialization import arguments_to_kwargs
24
+
25
+
26
+ class WorkflowNodeResult(BaseModel):
27
+ """Result from a workflow node execution containing variable bindings."""
28
+
29
+ variables: Dict[str, Any]
30
+
31
+
32
+ @dataclass
33
+ class ActionExecutionResult:
34
+ """Result of an action execution."""
35
+
36
+ result: Any
37
+ exception: BaseException | None = None
38
+
39
+
40
+ def _is_pydantic_model(cls: type) -> bool:
41
+ """Check if a class is a Pydantic BaseModel subclass."""
42
+ try:
43
+ return isinstance(cls, type) and issubclass(cls, BaseModel)
44
+ except TypeError:
45
+ return False
46
+
47
+
48
+ def _is_dataclass_type(cls: type) -> bool:
49
+ """Check if a class is a dataclass."""
50
+ return dataclasses.is_dataclass(cls) and isinstance(cls, type)
51
+
52
+
53
+ def _coerce_primitive(value: Any, target_type: type) -> Any:
54
+ """Coerce a value to a primitive type based on target_type.
55
+
56
+ Handles conversion of serialized values (strings, floats) back to their
57
+ native Python types (UUID, datetime, etc.).
58
+ """
59
+ # Handle None
60
+ if value is None:
61
+ return None
62
+
63
+ # UUID from string
64
+ if target_type is UUID:
65
+ if isinstance(value, UUID):
66
+ return value
67
+ if isinstance(value, str):
68
+ return UUID(value)
69
+ return value
70
+
71
+ # datetime from ISO string
72
+ if target_type is datetime:
73
+ if isinstance(value, datetime):
74
+ return value
75
+ if isinstance(value, str):
76
+ return datetime.fromisoformat(value)
77
+ return value
78
+
79
+ # date from ISO string
80
+ if target_type is date:
81
+ if isinstance(value, date):
82
+ return value
83
+ if isinstance(value, str):
84
+ return date.fromisoformat(value)
85
+ return value
86
+
87
+ # time from ISO string
88
+ if target_type is time:
89
+ if isinstance(value, time):
90
+ return value
91
+ if isinstance(value, str):
92
+ return time.fromisoformat(value)
93
+ return value
94
+
95
+ # timedelta from total seconds
96
+ if target_type is timedelta:
97
+ if isinstance(value, timedelta):
98
+ return value
99
+ if isinstance(value, (int, float)):
100
+ return timedelta(seconds=value)
101
+ return value
102
+
103
+ # Decimal from string
104
+ if target_type is Decimal:
105
+ if isinstance(value, Decimal):
106
+ return value
107
+ if isinstance(value, (str, int, float)):
108
+ return Decimal(str(value))
109
+ return value
110
+
111
+ # bytes from base64 string
112
+ if target_type is bytes:
113
+ if isinstance(value, bytes):
114
+ return value
115
+ if isinstance(value, str):
116
+ return b64decode(value)
117
+ return value
118
+
119
+ # Path from string
120
+ if target_type is Path or target_type is PurePath:
121
+ if isinstance(value, PurePath):
122
+ return value
123
+ if isinstance(value, str):
124
+ return Path(value)
125
+ return value
126
+
127
+ return value
128
+
129
+
130
+ # Types that can be coerced from serialized form
131
+ COERCIBLE_TYPES = (UUID, datetime, date, time, timedelta, Decimal, bytes, Path, PurePath)
132
+
133
+
134
+ def _coerce_dict_to_model(value: Any, target_type: type) -> Any:
135
+ """Convert a dict to a Pydantic model or dataclass if needed.
136
+
137
+ If value is a dict and target_type is a Pydantic model or dataclass,
138
+ instantiate the model with the dict values. Otherwise, return value unchanged.
139
+ """
140
+ if not isinstance(value, dict):
141
+ return value
142
+
143
+ if _is_pydantic_model(target_type):
144
+ # Use model_validate for Pydantic v2, fall back to direct instantiation
145
+ model_validate = getattr(target_type, "model_validate", None)
146
+ if model_validate is not None:
147
+ return model_validate(value)
148
+ return target_type(**value)
149
+
150
+ if _is_dataclass_type(target_type):
151
+ return target_type(**value)
152
+
153
+ return value
154
+
155
+
156
+ def _coerce_value(value: Any, target_type: type) -> Any:
157
+ """Coerce a value to the target type.
158
+
159
+ Handles:
160
+ - Primitive types (UUID, datetime, etc.)
161
+ - Pydantic models and dataclasses (from dicts)
162
+ - Generic collections like list[UUID], set[datetime]
163
+ """
164
+ # Handle None
165
+ if value is None:
166
+ return None
167
+
168
+ # Check for coercible primitive types
169
+ if isinstance(target_type, type) and issubclass(target_type, COERCIBLE_TYPES):
170
+ return _coerce_primitive(value, target_type)
171
+
172
+ # Check for Pydantic models or dataclasses
173
+ if isinstance(value, dict):
174
+ coerced = _coerce_dict_to_model(value, target_type)
175
+ if coerced is not value:
176
+ return coerced
177
+
178
+ # Handle generic types like list[UUID], set[datetime]
179
+ origin = get_origin(target_type)
180
+ if origin is not None:
181
+ args = get_args(target_type)
182
+
183
+ # Handle list[T]
184
+ if origin is list and isinstance(value, list) and args:
185
+ item_type = args[0]
186
+ return [_coerce_value(item, item_type) for item in value]
187
+
188
+ # Handle set[T] (serialized as list)
189
+ if origin is set and isinstance(value, list) and args:
190
+ item_type = args[0]
191
+ return {_coerce_value(item, item_type) for item in value}
192
+
193
+ # Handle frozenset[T] (serialized as list)
194
+ if origin is frozenset and isinstance(value, list) and args:
195
+ item_type = args[0]
196
+ return frozenset(_coerce_value(item, item_type) for item in value)
197
+
198
+ # Handle tuple[T, ...] (serialized as list)
199
+ if origin is tuple and isinstance(value, (list, tuple)) and args:
200
+ # Variable length tuple like tuple[int, ...]
201
+ if len(args) == 2 and args[1] is ...:
202
+ item_type = args[0]
203
+ return tuple(_coerce_value(item, item_type) for item in value)
204
+ # Fixed length tuple like tuple[int, str, UUID]
205
+ return tuple(
206
+ _coerce_value(item, item_type) for item, item_type in zip(value, args, strict=False)
207
+ )
208
+
209
+ # Handle dict[K, V]
210
+ if origin is dict and isinstance(value, dict) and len(args) == 2:
211
+ key_type, val_type = args
212
+ return {
213
+ _coerce_value(k, key_type): _coerce_value(v, val_type) for k, v in value.items()
214
+ }
215
+
216
+ return value
217
+
218
+
219
+ def _coerce_kwargs_to_type_hints(handler: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
220
+ """Coerce kwargs to expected types based on handler's type hints.
221
+
222
+ Handles:
223
+ - Pydantic models and dataclasses (from dicts)
224
+ - Primitive types like UUID, datetime, Decimal, etc.
225
+ - Generic collections like list[UUID], dict[str, datetime]
226
+ """
227
+ try:
228
+ type_hints = get_type_hints(handler)
229
+ except Exception:
230
+ # If we can't get type hints (e.g., forward references), return as-is
231
+ return kwargs
232
+
233
+ coerced = {}
234
+ for key, value in kwargs.items():
235
+ if key in type_hints:
236
+ target_type = type_hints[key]
237
+ coerced[key] = _coerce_value(value, target_type)
238
+ else:
239
+ coerced[key] = value
240
+
241
+ return coerced
242
+
243
+
244
+ async def execute_action(dispatch: pb2.ActionDispatch) -> ActionExecutionResult:
245
+ """Execute an action based on the dispatch command.
246
+
247
+ Args:
248
+ dispatch: The action dispatch command from the Rust scheduler.
249
+
250
+ Returns:
251
+ The result of executing the action.
252
+ """
253
+ action_name = dispatch.action_name
254
+ module_name = dispatch.module_name
255
+
256
+ module = None
257
+ if module_name:
258
+ import importlib
259
+
260
+ module = importlib.import_module(module_name)
261
+
262
+ # Get the action handler using both module and name
263
+ handler = registry.get(module_name, action_name)
264
+ if handler is None and module is not None:
265
+ import importlib
266
+
267
+ module = importlib.reload(module)
268
+ handler = registry.get(module_name, action_name)
269
+ if handler is None:
270
+ return ActionExecutionResult(
271
+ result=None,
272
+ exception=KeyError(f"action '{module_name}:{action_name}' not registered"),
273
+ )
274
+
275
+ # Deserialize kwargs
276
+ kwargs = arguments_to_kwargs(dispatch.kwargs)
277
+
278
+ # Coerce dict arguments to Pydantic models or dataclasses based on type hints
279
+ # This is needed because the IR converts model constructor calls to dicts
280
+ kwargs = _coerce_kwargs_to_type_hints(handler, kwargs)
281
+
282
+ try:
283
+ async with provide_dependencies(handler, kwargs) as call_kwargs:
284
+ value = handler(**call_kwargs)
285
+ if asyncio.iscoroutine(value):
286
+ value = await value
287
+ return ActionExecutionResult(result=value)
288
+ except Exception as e:
289
+ return ActionExecutionResult(
290
+ result=None,
291
+ exception=e,
292
+ )