rappel 0.10.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rappel might be problematic. Click here for more details.

@@ -0,0 +1,279 @@
1
+ import dataclasses
2
+ import importlib
3
+ import traceback
4
+ from base64 import b64encode
5
+ from datetime import date, datetime, time, timedelta
6
+ from decimal import Decimal
7
+ from enum import Enum
8
+ from pathlib import PurePath
9
+ from typing import Any
10
+ from uuid import UUID
11
+
12
+ from google.protobuf import json_format, struct_pb2
13
+ from pydantic import BaseModel
14
+
15
+ from proto import messages_pb2 as pb2
16
+
17
+ NULL_VALUE = struct_pb2.NULL_VALUE # type: ignore[attr-defined]
18
+
19
+ PRIMITIVE_TYPES = (str, int, float, bool, type(None))
20
+
21
+
22
+ def dumps(value: Any) -> pb2.WorkflowArgumentValue:
23
+ """Serialize a Python value into a WorkflowArgumentValue message."""
24
+
25
+ return _to_argument_value(value)
26
+
27
+
28
+ def loads(data: Any) -> Any:
29
+ """Deserialize a workflow argument payload into a Python object."""
30
+
31
+ if isinstance(data, pb2.WorkflowArgumentValue):
32
+ argument = data
33
+ elif isinstance(data, dict):
34
+ argument = pb2.WorkflowArgumentValue()
35
+ json_format.ParseDict(data, argument)
36
+ else:
37
+ raise TypeError("argument value payload must be a dict or ArgumentValue message")
38
+ return _from_argument_value(argument)
39
+
40
+
41
+ def build_arguments_from_kwargs(kwargs: dict[str, Any]) -> pb2.WorkflowArguments:
42
+ arguments = pb2.WorkflowArguments()
43
+ for key, value in kwargs.items():
44
+ entry = arguments.arguments.add()
45
+ entry.key = key
46
+ entry.value.CopyFrom(dumps(value))
47
+ return arguments
48
+
49
+
50
+ def arguments_to_kwargs(arguments: pb2.WorkflowArguments | None) -> dict[str, Any]:
51
+ if arguments is None:
52
+ return {}
53
+ result: dict[str, Any] = {}
54
+ for entry in arguments.arguments:
55
+ result[entry.key] = loads(entry.value)
56
+ return result
57
+
58
+
59
+ def _to_argument_value(value: Any) -> pb2.WorkflowArgumentValue:
60
+ argument = pb2.WorkflowArgumentValue()
61
+ if isinstance(value, PRIMITIVE_TYPES):
62
+ argument.primitive.CopyFrom(_serialize_primitive(value))
63
+ return argument
64
+ if isinstance(value, UUID):
65
+ # Serialize UUID as string primitive
66
+ argument.primitive.CopyFrom(_serialize_primitive(str(value)))
67
+ return argument
68
+ if isinstance(value, datetime):
69
+ # Serialize datetime as ISO format string
70
+ argument.primitive.CopyFrom(_serialize_primitive(value.isoformat()))
71
+ return argument
72
+ if isinstance(value, date):
73
+ # Serialize date as ISO format string (must come after datetime check)
74
+ argument.primitive.CopyFrom(_serialize_primitive(value.isoformat()))
75
+ return argument
76
+ if isinstance(value, time):
77
+ # Serialize time as ISO format string
78
+ argument.primitive.CopyFrom(_serialize_primitive(value.isoformat()))
79
+ return argument
80
+ if isinstance(value, timedelta):
81
+ # Serialize timedelta as total seconds
82
+ argument.primitive.CopyFrom(_serialize_primitive(value.total_seconds()))
83
+ return argument
84
+ if isinstance(value, Decimal):
85
+ # Serialize Decimal as string to preserve precision
86
+ argument.primitive.CopyFrom(_serialize_primitive(str(value)))
87
+ return argument
88
+ if isinstance(value, Enum):
89
+ # Serialize Enum as its value
90
+ return _to_argument_value(value.value)
91
+ if isinstance(value, bytes):
92
+ # Serialize bytes as base64 string
93
+ argument.primitive.CopyFrom(_serialize_primitive(b64encode(value).decode("ascii")))
94
+ return argument
95
+ if isinstance(value, PurePath):
96
+ # Serialize Path as string
97
+ argument.primitive.CopyFrom(_serialize_primitive(str(value)))
98
+ return argument
99
+ if isinstance(value, (set, frozenset)):
100
+ # Serialize sets as lists
101
+ argument.list_value.SetInParent()
102
+ for item in value:
103
+ item_value = argument.list_value.items.add()
104
+ item_value.CopyFrom(_to_argument_value(item))
105
+ return argument
106
+ if isinstance(value, BaseException):
107
+ argument.exception.type = value.__class__.__name__
108
+ argument.exception.module = value.__class__.__module__
109
+ argument.exception.message = str(value)
110
+ tb_text = "".join(traceback.format_exception(type(value), value, value.__traceback__))
111
+ argument.exception.traceback = tb_text
112
+ # Include the exception class hierarchy (MRO) for proper except matching.
113
+ # This allows `except LookupError:` to catch KeyError, etc.
114
+ for cls in value.__class__.__mro__:
115
+ if cls is object:
116
+ continue # Skip 'object' as it's not useful for exception matching
117
+ argument.exception.type_hierarchy.append(cls.__name__)
118
+ values = _serialize_exception_values(value)
119
+ for key, item in values.items():
120
+ entry = argument.exception.values.entries.add()
121
+ entry.key = key
122
+ try:
123
+ entry.value.CopyFrom(_to_argument_value(item))
124
+ except TypeError:
125
+ entry.value.CopyFrom(_to_argument_value(str(item)))
126
+ return argument
127
+ if _is_base_model(value):
128
+ model_class = value.__class__
129
+ model_data = _serialize_model_data(value)
130
+ argument.basemodel.module = model_class.__module__
131
+ argument.basemodel.name = model_class.__qualname__
132
+ # Serialize as dict to preserve types (Struct converts all numbers to float)
133
+ for key, item in model_data.items():
134
+ entry = argument.basemodel.data.entries.add()
135
+ entry.key = key
136
+ entry.value.CopyFrom(_to_argument_value(item))
137
+ return argument
138
+ if _is_dataclass_instance(value):
139
+ # Dataclasses use the same basemodel serialization path as Pydantic models
140
+ dc_class = value.__class__
141
+ dc_data = dataclasses.asdict(value)
142
+ argument.basemodel.module = dc_class.__module__
143
+ argument.basemodel.name = dc_class.__qualname__
144
+ for key, item in dc_data.items():
145
+ entry = argument.basemodel.data.entries.add()
146
+ entry.key = key
147
+ entry.value.CopyFrom(_to_argument_value(item))
148
+ return argument
149
+ if isinstance(value, dict):
150
+ argument.dict_value.SetInParent()
151
+ for key, item in value.items():
152
+ if not isinstance(key, str):
153
+ raise TypeError("workflow dict keys must be strings")
154
+ entry = argument.dict_value.entries.add()
155
+ entry.key = key
156
+ entry.value.CopyFrom(_to_argument_value(item))
157
+ return argument
158
+ if isinstance(value, list):
159
+ argument.list_value.SetInParent()
160
+ for item in value:
161
+ item_value = argument.list_value.items.add()
162
+ item_value.CopyFrom(_to_argument_value(item))
163
+ return argument
164
+ if isinstance(value, tuple):
165
+ argument.tuple_value.SetInParent()
166
+ for item in value:
167
+ item_value = argument.tuple_value.items.add()
168
+ item_value.CopyFrom(_to_argument_value(item))
169
+ return argument
170
+ raise TypeError(f"unsupported value type {type(value)!r}")
171
+
172
+
173
+ def _from_argument_value(argument: pb2.WorkflowArgumentValue) -> Any:
174
+ kind = argument.WhichOneof("kind") # type: ignore[attr-defined]
175
+ if kind == "primitive":
176
+ return _primitive_to_python(argument.primitive)
177
+ if kind == "basemodel":
178
+ module = argument.basemodel.module
179
+ name = argument.basemodel.name
180
+ # Deserialize dict entries to preserve types
181
+ data: dict[str, Any] = {}
182
+ for entry in argument.basemodel.data.entries:
183
+ data[entry.key] = _from_argument_value(entry.value)
184
+ return _instantiate_serialized_model(module, name, data)
185
+ if kind == "exception":
186
+ values: dict[str, Any] = {}
187
+ if argument.exception.HasField("values"):
188
+ for entry in argument.exception.values.entries:
189
+ values[entry.key] = _from_argument_value(entry.value)
190
+ return {
191
+ "type": argument.exception.type,
192
+ "module": argument.exception.module,
193
+ "message": argument.exception.message,
194
+ "traceback": argument.exception.traceback,
195
+ "values": values,
196
+ }
197
+ if kind == "list_value":
198
+ return [_from_argument_value(item) for item in argument.list_value.items]
199
+ if kind == "tuple_value":
200
+ return tuple(_from_argument_value(item) for item in argument.tuple_value.items)
201
+ if kind == "dict_value":
202
+ result: dict[str, Any] = {}
203
+ for entry in argument.dict_value.entries:
204
+ result[entry.key] = _from_argument_value(entry.value)
205
+ return result
206
+ raise ValueError("argument value missing kind discriminator")
207
+
208
+
209
+ def _serialize_model_data(model: BaseModel) -> dict[str, Any]:
210
+ if hasattr(model, "model_dump"):
211
+ return model.model_dump(mode="json") # type: ignore[attr-defined]
212
+ if hasattr(model, "dict"):
213
+ return model.dict() # type: ignore[attr-defined]
214
+ return model.__dict__
215
+
216
+
217
+ def _serialize_exception_values(exc: BaseException) -> dict[str, Any]:
218
+ values = dict(vars(exc))
219
+ if "args" not in values:
220
+ values["args"] = exc.args
221
+ return values
222
+
223
+
224
+ def _serialize_primitive(value: Any) -> pb2.PrimitiveWorkflowArgument:
225
+ primitive = pb2.PrimitiveWorkflowArgument()
226
+ if value is None:
227
+ primitive.null_value = NULL_VALUE
228
+ elif isinstance(value, bool):
229
+ primitive.bool_value = value
230
+ elif isinstance(value, int) and not isinstance(value, bool):
231
+ primitive.int_value = value
232
+ elif isinstance(value, float):
233
+ primitive.double_value = value
234
+ elif isinstance(value, str):
235
+ primitive.string_value = value
236
+ else: # pragma: no cover - unreachable given PRIMITIVE_TYPES
237
+ raise TypeError(f"unsupported primitive type {type(value)!r}")
238
+ return primitive
239
+
240
+
241
+ def _primitive_to_python(primitive: pb2.PrimitiveWorkflowArgument) -> Any:
242
+ kind = primitive.WhichOneof("kind") # type: ignore[attr-defined]
243
+ if kind == "string_value":
244
+ return primitive.string_value
245
+ if kind == "double_value":
246
+ return primitive.double_value
247
+ if kind == "int_value":
248
+ return primitive.int_value
249
+ if kind == "bool_value":
250
+ return primitive.bool_value
251
+ if kind == "null_value":
252
+ return None
253
+ raise ValueError("primitive argument missing kind discriminator")
254
+
255
+
256
+ def _instantiate_serialized_model(module: str, name: str, model_data: dict[str, Any]) -> Any:
257
+ cls = _import_symbol(module, name)
258
+ if hasattr(cls, "model_validate"):
259
+ return cls.model_validate(model_data) # type: ignore[attr-defined]
260
+ return cls(**model_data)
261
+
262
+
263
+ def _is_base_model(value: Any) -> bool:
264
+ return isinstance(value, BaseModel)
265
+
266
+
267
+ def _is_dataclass_instance(value: Any) -> bool:
268
+ """Check if value is a dataclass instance (not a class)."""
269
+ return dataclasses.is_dataclass(value) and not isinstance(value, type)
270
+
271
+
272
+ def _import_symbol(module: str, qualname: str) -> Any:
273
+ module_obj = importlib.import_module(module)
274
+ attr: Any = module_obj
275
+ for part in qualname.split("."):
276
+ attr = getattr(attr, part)
277
+ if not isinstance(attr, type):
278
+ raise ValueError(f"{qualname} from {module} is not a class")
279
+ return attr
rappel/worker.py ADDED
@@ -0,0 +1,191 @@
1
+ """gRPC worker client that executes rappel actions."""
2
+
3
+ import argparse
4
+ import asyncio
5
+ import importlib
6
+ import logging
7
+ import sys
8
+ import time
9
+ from typing import Any, AsyncIterator, cast
10
+
11
+ import grpc
12
+
13
+ from proto import messages_pb2 as pb2
14
+ from proto import messages_pb2_grpc as pb2_grpc
15
+ from rappel.actions import serialize_error_payload, serialize_result_payload
16
+
17
+ from . import workflow_runtime
18
+ from .logger import configure as configure_logger
19
+
20
+ LOGGER = configure_logger("rappel.worker")
21
+ aio = cast(Any, grpc).aio
22
+
23
+
24
+ def _parse_args(argv: list[str] | None) -> argparse.Namespace:
25
+ parser = argparse.ArgumentParser(description="Rappel workflow worker")
26
+ parser.add_argument("--bridge", required=True, help="gRPC address of the Rust bridge")
27
+ parser.add_argument("--worker-id", required=True, type=int, help="Logical worker identifier")
28
+ parser.add_argument(
29
+ "--user-module",
30
+ action="append",
31
+ default=[],
32
+ help="Optional user module(s) to import eagerly",
33
+ )
34
+ return parser.parse_args(argv)
35
+
36
+
37
+ async def _outgoing_stream(
38
+ queue: "asyncio.Queue[pb2.Envelope]", worker_id: int
39
+ ) -> AsyncIterator[pb2.Envelope]:
40
+ hello = pb2.WorkerHello(worker_id=worker_id)
41
+ envelope = pb2.Envelope(
42
+ delivery_id=0,
43
+ partition_id=0,
44
+ kind=pb2.MessageKind.MESSAGE_KIND_WORKER_HELLO,
45
+ payload=hello.SerializeToString(),
46
+ )
47
+ yield envelope
48
+ try:
49
+ while True:
50
+ message = await queue.get()
51
+ yield message
52
+ except asyncio.CancelledError: # pragma: no cover - best effort shutdown
53
+ return
54
+
55
+
56
+ async def _send_ack(outgoing: "asyncio.Queue[pb2.Envelope]", envelope: pb2.Envelope) -> None:
57
+ ack = pb2.Ack(acked_delivery_id=envelope.delivery_id)
58
+ ack_envelope = pb2.Envelope(
59
+ delivery_id=envelope.delivery_id,
60
+ partition_id=envelope.partition_id,
61
+ kind=pb2.MessageKind.MESSAGE_KIND_ACK,
62
+ payload=ack.SerializeToString(),
63
+ )
64
+ await outgoing.put(ack_envelope)
65
+
66
+
67
+ async def _handle_dispatch(
68
+ envelope: pb2.Envelope,
69
+ outgoing: "asyncio.Queue[pb2.Envelope]",
70
+ ) -> None:
71
+ await _send_ack(outgoing, envelope)
72
+ dispatch = pb2.ActionDispatch()
73
+ dispatch.ParseFromString(envelope.payload)
74
+ timeout_seconds = dispatch.timeout_seconds if dispatch.HasField("timeout_seconds") else 0
75
+
76
+ worker_start = time.perf_counter_ns()
77
+ success = True
78
+ action_name = dispatch.action_name
79
+ execution: workflow_runtime.ActionExecutionResult | None = None
80
+ try:
81
+ if timeout_seconds > 0:
82
+ execution = await asyncio.wait_for(
83
+ workflow_runtime.execute_action(dispatch), timeout=timeout_seconds
84
+ )
85
+ else:
86
+ execution = await workflow_runtime.execute_action(dispatch)
87
+
88
+ if execution.exception:
89
+ success = False
90
+ response_payload = serialize_error_payload(action_name, execution.exception)
91
+ else:
92
+ response_payload = serialize_result_payload(execution.result)
93
+ except asyncio.TimeoutError:
94
+ success = False
95
+ error = TimeoutError(f"action {action_name} timed out after {timeout_seconds} seconds")
96
+ response_payload = serialize_error_payload(action_name, error)
97
+ LOGGER.warning(
98
+ "Action %s timed out after %ss for action_id=%s sequence=%s",
99
+ action_name,
100
+ timeout_seconds,
101
+ dispatch.action_id,
102
+ dispatch.sequence,
103
+ )
104
+ except Exception as exc: # noqa: BLE001 - propagate structured errors
105
+ success = False
106
+ response_payload = serialize_error_payload(action_name, exc)
107
+ LOGGER.exception(
108
+ "Action %s failed for action_id=%s sequence=%s",
109
+ action_name,
110
+ dispatch.action_id,
111
+ dispatch.sequence,
112
+ )
113
+ worker_end = time.perf_counter_ns()
114
+ response = pb2.ActionResult(
115
+ action_id=dispatch.action_id,
116
+ success=success,
117
+ worker_start_ns=worker_start,
118
+ worker_end_ns=worker_end,
119
+ )
120
+ response.payload.CopyFrom(response_payload)
121
+ if dispatch.dispatch_token:
122
+ response.dispatch_token = dispatch.dispatch_token
123
+ response_envelope = pb2.Envelope(
124
+ delivery_id=envelope.delivery_id,
125
+ partition_id=envelope.partition_id,
126
+ kind=pb2.MessageKind.MESSAGE_KIND_ACTION_RESULT,
127
+ payload=response.SerializeToString(),
128
+ )
129
+ await outgoing.put(response_envelope)
130
+ LOGGER.debug("Handled action=%s seq=%s success=%s", action_name, dispatch.sequence, success)
131
+
132
+
133
+ async def _handle_incoming_stream(
134
+ stub: pb2_grpc.WorkerBridgeStub,
135
+ worker_id: int,
136
+ outgoing: "asyncio.Queue[pb2.Envelope]",
137
+ ) -> None:
138
+ """Process incoming messages, running action dispatches concurrently."""
139
+ pending_tasks: set[asyncio.Task[None]] = set()
140
+
141
+ async for envelope in stub.Attach(_outgoing_stream(outgoing, worker_id)):
142
+ kind = envelope.kind
143
+ if kind == pb2.MessageKind.MESSAGE_KIND_ACTION_DISPATCH:
144
+ # Spawn task to handle dispatch concurrently
145
+ task = asyncio.create_task(_handle_dispatch(envelope, outgoing))
146
+ pending_tasks.add(task)
147
+ task.add_done_callback(pending_tasks.discard)
148
+ elif kind == pb2.MessageKind.MESSAGE_KIND_HEARTBEAT:
149
+ LOGGER.debug("Received heartbeat delivery=%s", envelope.delivery_id)
150
+ await _send_ack(outgoing, envelope)
151
+ else:
152
+ LOGGER.warning("Unhandled message kind: %s", kind)
153
+ await _send_ack(outgoing, envelope)
154
+
155
+ # Wait for any remaining tasks on stream close
156
+ if pending_tasks:
157
+ await asyncio.gather(*pending_tasks, return_exceptions=True)
158
+
159
+
160
+ async def _run_worker(args: argparse.Namespace) -> None:
161
+ outgoing: "asyncio.Queue[pb2.Envelope]" = asyncio.Queue()
162
+ for module_name in args.user_module:
163
+ if not module_name:
164
+ continue
165
+ LOGGER.info("Preloading user module %s", module_name)
166
+ importlib.import_module(module_name)
167
+
168
+ async with aio.insecure_channel(args.bridge) as channel:
169
+ stub = pb2_grpc.WorkerBridgeStub(channel)
170
+ LOGGER.info("Worker %s connected to %s", args.worker_id, args.bridge)
171
+ try:
172
+ await _handle_incoming_stream(stub, args.worker_id, outgoing)
173
+ except aio.AioRpcError as exc: # pragma: no cover
174
+ status = exc.code()
175
+ LOGGER.error("Worker stream closed: %s", status)
176
+ raise
177
+
178
+
179
+ def main(argv: list[str] | None = None) -> None:
180
+ args = _parse_args(argv)
181
+ logging.basicConfig(level=logging.INFO, format="[worker] %(message)s", stream=sys.stderr)
182
+ try:
183
+ asyncio.run(_run_worker(args))
184
+ except KeyboardInterrupt: # pragma: no cover - exit quietly on Ctrl+C
185
+ return
186
+ except grpc.RpcError:
187
+ sys.exit(1)
188
+
189
+
190
+ if __name__ == "__main__":
191
+ main()