rappel 0.5.7__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proto/ast_pb2.py +115 -0
- proto/ast_pb2.pyi +1522 -0
- proto/ast_pb2_grpc.py +24 -0
- proto/ast_pb2_grpc.pyi +22 -0
- proto/messages_pb2.py +106 -0
- proto/messages_pb2.pyi +1217 -0
- proto/messages_pb2_grpc.py +406 -0
- proto/messages_pb2_grpc.pyi +380 -0
- rappel/__init__.py +61 -0
- rappel/actions.py +108 -0
- rappel/bin/boot-rappel-singleton.exe +0 -0
- rappel/bin/rappel-bridge.exe +0 -0
- rappel/bin/start-workers.exe +0 -0
- rappel/bridge.py +228 -0
- rappel/dependencies.py +149 -0
- rappel/exceptions.py +18 -0
- rappel/formatter.py +110 -0
- rappel/ir_builder.py +3191 -0
- rappel/logger.py +39 -0
- rappel/registry.py +106 -0
- rappel/schedule.py +363 -0
- rappel/serialization.py +253 -0
- rappel/worker.py +191 -0
- rappel/workflow.py +240 -0
- rappel/workflow_runtime.py +287 -0
- rappel-0.5.7.data/scripts/boot-rappel-singleton.exe +0 -0
- rappel-0.5.7.data/scripts/rappel-bridge.exe +0 -0
- rappel-0.5.7.data/scripts/start-workers.exe +0 -0
- rappel-0.5.7.dist-info/METADATA +299 -0
- rappel-0.5.7.dist-info/RECORD +32 -0
- rappel-0.5.7.dist-info/WHEEL +4 -0
- rappel-0.5.7.dist-info/entry_points.txt +2 -0
rappel/serialization.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import importlib
|
|
3
|
+
import traceback
|
|
4
|
+
from base64 import b64encode
|
|
5
|
+
from datetime import date, datetime, time, timedelta
|
|
6
|
+
from decimal import Decimal
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from pathlib import PurePath
|
|
9
|
+
from typing import Any
|
|
10
|
+
from uuid import UUID
|
|
11
|
+
|
|
12
|
+
from google.protobuf import json_format, struct_pb2
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from proto import messages_pb2 as pb2
|
|
16
|
+
|
|
17
|
+
NULL_VALUE = struct_pb2.NULL_VALUE # type: ignore[attr-defined]
|
|
18
|
+
|
|
19
|
+
PRIMITIVE_TYPES = (str, int, float, bool, type(None))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def dumps(value: Any) -> pb2.WorkflowArgumentValue:
|
|
23
|
+
"""Serialize a Python value into a WorkflowArgumentValue message."""
|
|
24
|
+
|
|
25
|
+
return _to_argument_value(value)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def loads(data: Any) -> Any:
|
|
29
|
+
"""Deserialize a workflow argument payload into a Python object."""
|
|
30
|
+
|
|
31
|
+
if isinstance(data, pb2.WorkflowArgumentValue):
|
|
32
|
+
argument = data
|
|
33
|
+
elif isinstance(data, dict):
|
|
34
|
+
argument = pb2.WorkflowArgumentValue()
|
|
35
|
+
json_format.ParseDict(data, argument)
|
|
36
|
+
else:
|
|
37
|
+
raise TypeError("argument value payload must be a dict or ArgumentValue message")
|
|
38
|
+
return _from_argument_value(argument)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def build_arguments_from_kwargs(kwargs: dict[str, Any]) -> pb2.WorkflowArguments:
|
|
42
|
+
arguments = pb2.WorkflowArguments()
|
|
43
|
+
for key, value in kwargs.items():
|
|
44
|
+
entry = arguments.arguments.add()
|
|
45
|
+
entry.key = key
|
|
46
|
+
entry.value.CopyFrom(dumps(value))
|
|
47
|
+
return arguments
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def arguments_to_kwargs(arguments: pb2.WorkflowArguments | None) -> dict[str, Any]:
|
|
51
|
+
if arguments is None:
|
|
52
|
+
return {}
|
|
53
|
+
result: dict[str, Any] = {}
|
|
54
|
+
for entry in arguments.arguments:
|
|
55
|
+
result[entry.key] = loads(entry.value)
|
|
56
|
+
return result
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _to_argument_value(value: Any) -> pb2.WorkflowArgumentValue:
|
|
60
|
+
argument = pb2.WorkflowArgumentValue()
|
|
61
|
+
if isinstance(value, PRIMITIVE_TYPES):
|
|
62
|
+
argument.primitive.CopyFrom(_serialize_primitive(value))
|
|
63
|
+
return argument
|
|
64
|
+
if isinstance(value, UUID):
|
|
65
|
+
# Serialize UUID as string primitive
|
|
66
|
+
argument.primitive.CopyFrom(_serialize_primitive(str(value)))
|
|
67
|
+
return argument
|
|
68
|
+
if isinstance(value, datetime):
|
|
69
|
+
# Serialize datetime as ISO format string
|
|
70
|
+
argument.primitive.CopyFrom(_serialize_primitive(value.isoformat()))
|
|
71
|
+
return argument
|
|
72
|
+
if isinstance(value, date):
|
|
73
|
+
# Serialize date as ISO format string (must come after datetime check)
|
|
74
|
+
argument.primitive.CopyFrom(_serialize_primitive(value.isoformat()))
|
|
75
|
+
return argument
|
|
76
|
+
if isinstance(value, time):
|
|
77
|
+
# Serialize time as ISO format string
|
|
78
|
+
argument.primitive.CopyFrom(_serialize_primitive(value.isoformat()))
|
|
79
|
+
return argument
|
|
80
|
+
if isinstance(value, timedelta):
|
|
81
|
+
# Serialize timedelta as total seconds
|
|
82
|
+
argument.primitive.CopyFrom(_serialize_primitive(value.total_seconds()))
|
|
83
|
+
return argument
|
|
84
|
+
if isinstance(value, Decimal):
|
|
85
|
+
# Serialize Decimal as string to preserve precision
|
|
86
|
+
argument.primitive.CopyFrom(_serialize_primitive(str(value)))
|
|
87
|
+
return argument
|
|
88
|
+
if isinstance(value, Enum):
|
|
89
|
+
# Serialize Enum as its value
|
|
90
|
+
return _to_argument_value(value.value)
|
|
91
|
+
if isinstance(value, bytes):
|
|
92
|
+
# Serialize bytes as base64 string
|
|
93
|
+
argument.primitive.CopyFrom(_serialize_primitive(b64encode(value).decode("ascii")))
|
|
94
|
+
return argument
|
|
95
|
+
if isinstance(value, PurePath):
|
|
96
|
+
# Serialize Path as string
|
|
97
|
+
argument.primitive.CopyFrom(_serialize_primitive(str(value)))
|
|
98
|
+
return argument
|
|
99
|
+
if isinstance(value, (set, frozenset)):
|
|
100
|
+
# Serialize sets as lists
|
|
101
|
+
argument.list_value.SetInParent()
|
|
102
|
+
for item in value:
|
|
103
|
+
item_value = argument.list_value.items.add()
|
|
104
|
+
item_value.CopyFrom(_to_argument_value(item))
|
|
105
|
+
return argument
|
|
106
|
+
if isinstance(value, BaseException):
|
|
107
|
+
argument.exception.type = value.__class__.__name__
|
|
108
|
+
argument.exception.module = value.__class__.__module__
|
|
109
|
+
argument.exception.message = str(value)
|
|
110
|
+
tb_text = "".join(traceback.format_exception(type(value), value, value.__traceback__))
|
|
111
|
+
argument.exception.traceback = tb_text
|
|
112
|
+
return argument
|
|
113
|
+
if _is_base_model(value):
|
|
114
|
+
model_class = value.__class__
|
|
115
|
+
model_data = _serialize_model_data(value)
|
|
116
|
+
argument.basemodel.module = model_class.__module__
|
|
117
|
+
argument.basemodel.name = model_class.__qualname__
|
|
118
|
+
# Serialize as dict to preserve types (Struct converts all numbers to float)
|
|
119
|
+
for key, item in model_data.items():
|
|
120
|
+
entry = argument.basemodel.data.entries.add()
|
|
121
|
+
entry.key = key
|
|
122
|
+
entry.value.CopyFrom(_to_argument_value(item))
|
|
123
|
+
return argument
|
|
124
|
+
if _is_dataclass_instance(value):
|
|
125
|
+
# Dataclasses use the same basemodel serialization path as Pydantic models
|
|
126
|
+
dc_class = value.__class__
|
|
127
|
+
dc_data = dataclasses.asdict(value)
|
|
128
|
+
argument.basemodel.module = dc_class.__module__
|
|
129
|
+
argument.basemodel.name = dc_class.__qualname__
|
|
130
|
+
for key, item in dc_data.items():
|
|
131
|
+
entry = argument.basemodel.data.entries.add()
|
|
132
|
+
entry.key = key
|
|
133
|
+
entry.value.CopyFrom(_to_argument_value(item))
|
|
134
|
+
return argument
|
|
135
|
+
if isinstance(value, dict):
|
|
136
|
+
argument.dict_value.SetInParent()
|
|
137
|
+
for key, item in value.items():
|
|
138
|
+
if not isinstance(key, str):
|
|
139
|
+
raise TypeError("workflow dict keys must be strings")
|
|
140
|
+
entry = argument.dict_value.entries.add()
|
|
141
|
+
entry.key = key
|
|
142
|
+
entry.value.CopyFrom(_to_argument_value(item))
|
|
143
|
+
return argument
|
|
144
|
+
if isinstance(value, list):
|
|
145
|
+
argument.list_value.SetInParent()
|
|
146
|
+
for item in value:
|
|
147
|
+
item_value = argument.list_value.items.add()
|
|
148
|
+
item_value.CopyFrom(_to_argument_value(item))
|
|
149
|
+
return argument
|
|
150
|
+
if isinstance(value, tuple):
|
|
151
|
+
argument.tuple_value.SetInParent()
|
|
152
|
+
for item in value:
|
|
153
|
+
item_value = argument.tuple_value.items.add()
|
|
154
|
+
item_value.CopyFrom(_to_argument_value(item))
|
|
155
|
+
return argument
|
|
156
|
+
raise TypeError(f"unsupported value type {type(value)!r}")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _from_argument_value(argument: pb2.WorkflowArgumentValue) -> Any:
|
|
160
|
+
kind = argument.WhichOneof("kind") # type: ignore[attr-defined]
|
|
161
|
+
if kind == "primitive":
|
|
162
|
+
return _primitive_to_python(argument.primitive)
|
|
163
|
+
if kind == "basemodel":
|
|
164
|
+
module = argument.basemodel.module
|
|
165
|
+
name = argument.basemodel.name
|
|
166
|
+
# Deserialize dict entries to preserve types
|
|
167
|
+
data: dict[str, Any] = {}
|
|
168
|
+
for entry in argument.basemodel.data.entries:
|
|
169
|
+
data[entry.key] = _from_argument_value(entry.value)
|
|
170
|
+
return _instantiate_serialized_model(module, name, data)
|
|
171
|
+
if kind == "exception":
|
|
172
|
+
return {
|
|
173
|
+
"type": argument.exception.type,
|
|
174
|
+
"module": argument.exception.module,
|
|
175
|
+
"message": argument.exception.message,
|
|
176
|
+
"traceback": argument.exception.traceback,
|
|
177
|
+
}
|
|
178
|
+
if kind == "list_value":
|
|
179
|
+
return [_from_argument_value(item) for item in argument.list_value.items]
|
|
180
|
+
if kind == "tuple_value":
|
|
181
|
+
return tuple(_from_argument_value(item) for item in argument.tuple_value.items)
|
|
182
|
+
if kind == "dict_value":
|
|
183
|
+
result: dict[str, Any] = {}
|
|
184
|
+
for entry in argument.dict_value.entries:
|
|
185
|
+
result[entry.key] = _from_argument_value(entry.value)
|
|
186
|
+
return result
|
|
187
|
+
raise ValueError("argument value missing kind discriminator")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _serialize_model_data(model: BaseModel) -> dict[str, Any]:
|
|
191
|
+
if hasattr(model, "model_dump"):
|
|
192
|
+
return model.model_dump(mode="json") # type: ignore[attr-defined]
|
|
193
|
+
if hasattr(model, "dict"):
|
|
194
|
+
return model.dict() # type: ignore[attr-defined]
|
|
195
|
+
return model.__dict__
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _serialize_primitive(value: Any) -> pb2.PrimitiveWorkflowArgument:
|
|
199
|
+
primitive = pb2.PrimitiveWorkflowArgument()
|
|
200
|
+
if value is None:
|
|
201
|
+
primitive.null_value = NULL_VALUE
|
|
202
|
+
elif isinstance(value, bool):
|
|
203
|
+
primitive.bool_value = value
|
|
204
|
+
elif isinstance(value, int) and not isinstance(value, bool):
|
|
205
|
+
primitive.int_value = value
|
|
206
|
+
elif isinstance(value, float):
|
|
207
|
+
primitive.double_value = value
|
|
208
|
+
elif isinstance(value, str):
|
|
209
|
+
primitive.string_value = value
|
|
210
|
+
else: # pragma: no cover - unreachable given PRIMITIVE_TYPES
|
|
211
|
+
raise TypeError(f"unsupported primitive type {type(value)!r}")
|
|
212
|
+
return primitive
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _primitive_to_python(primitive: pb2.PrimitiveWorkflowArgument) -> Any:
|
|
216
|
+
kind = primitive.WhichOneof("kind") # type: ignore[attr-defined]
|
|
217
|
+
if kind == "string_value":
|
|
218
|
+
return primitive.string_value
|
|
219
|
+
if kind == "double_value":
|
|
220
|
+
return primitive.double_value
|
|
221
|
+
if kind == "int_value":
|
|
222
|
+
return primitive.int_value
|
|
223
|
+
if kind == "bool_value":
|
|
224
|
+
return primitive.bool_value
|
|
225
|
+
if kind == "null_value":
|
|
226
|
+
return None
|
|
227
|
+
raise ValueError("primitive argument missing kind discriminator")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _instantiate_serialized_model(module: str, name: str, model_data: dict[str, Any]) -> Any:
|
|
231
|
+
cls = _import_symbol(module, name)
|
|
232
|
+
if hasattr(cls, "model_validate"):
|
|
233
|
+
return cls.model_validate(model_data) # type: ignore[attr-defined]
|
|
234
|
+
return cls(**model_data)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _is_base_model(value: Any) -> bool:
|
|
238
|
+
return isinstance(value, BaseModel)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _is_dataclass_instance(value: Any) -> bool:
|
|
242
|
+
"""Check if value is a dataclass instance (not a class)."""
|
|
243
|
+
return dataclasses.is_dataclass(value) and not isinstance(value, type)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _import_symbol(module: str, qualname: str) -> Any:
|
|
247
|
+
module_obj = importlib.import_module(module)
|
|
248
|
+
attr: Any = module_obj
|
|
249
|
+
for part in qualname.split("."):
|
|
250
|
+
attr = getattr(attr, part)
|
|
251
|
+
if not isinstance(attr, type):
|
|
252
|
+
raise ValueError(f"{qualname} from {module} is not a class")
|
|
253
|
+
return attr
|
rappel/worker.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""gRPC worker client that executes rappel actions."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import asyncio
|
|
5
|
+
import importlib
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
from typing import Any, AsyncIterator, cast
|
|
10
|
+
|
|
11
|
+
import grpc
|
|
12
|
+
|
|
13
|
+
from proto import messages_pb2 as pb2
|
|
14
|
+
from proto import messages_pb2_grpc as pb2_grpc
|
|
15
|
+
from rappel.actions import serialize_error_payload, serialize_result_payload
|
|
16
|
+
|
|
17
|
+
from . import workflow_runtime
|
|
18
|
+
from .logger import configure as configure_logger
|
|
19
|
+
|
|
20
|
+
LOGGER = configure_logger("rappel.worker")
|
|
21
|
+
aio = cast(Any, grpc).aio
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _parse_args(argv: list[str] | None) -> argparse.Namespace:
|
|
25
|
+
parser = argparse.ArgumentParser(description="Rappel workflow worker")
|
|
26
|
+
parser.add_argument("--bridge", required=True, help="gRPC address of the Rust bridge")
|
|
27
|
+
parser.add_argument("--worker-id", required=True, type=int, help="Logical worker identifier")
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--user-module",
|
|
30
|
+
action="append",
|
|
31
|
+
default=[],
|
|
32
|
+
help="Optional user module(s) to import eagerly",
|
|
33
|
+
)
|
|
34
|
+
return parser.parse_args(argv)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
async def _outgoing_stream(
|
|
38
|
+
queue: "asyncio.Queue[pb2.Envelope]", worker_id: int
|
|
39
|
+
) -> AsyncIterator[pb2.Envelope]:
|
|
40
|
+
hello = pb2.WorkerHello(worker_id=worker_id)
|
|
41
|
+
envelope = pb2.Envelope(
|
|
42
|
+
delivery_id=0,
|
|
43
|
+
partition_id=0,
|
|
44
|
+
kind=pb2.MessageKind.MESSAGE_KIND_WORKER_HELLO,
|
|
45
|
+
payload=hello.SerializeToString(),
|
|
46
|
+
)
|
|
47
|
+
yield envelope
|
|
48
|
+
try:
|
|
49
|
+
while True:
|
|
50
|
+
message = await queue.get()
|
|
51
|
+
yield message
|
|
52
|
+
except asyncio.CancelledError: # pragma: no cover - best effort shutdown
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
async def _send_ack(outgoing: "asyncio.Queue[pb2.Envelope]", envelope: pb2.Envelope) -> None:
|
|
57
|
+
ack = pb2.Ack(acked_delivery_id=envelope.delivery_id)
|
|
58
|
+
ack_envelope = pb2.Envelope(
|
|
59
|
+
delivery_id=envelope.delivery_id,
|
|
60
|
+
partition_id=envelope.partition_id,
|
|
61
|
+
kind=pb2.MessageKind.MESSAGE_KIND_ACK,
|
|
62
|
+
payload=ack.SerializeToString(),
|
|
63
|
+
)
|
|
64
|
+
await outgoing.put(ack_envelope)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def _handle_dispatch(
|
|
68
|
+
envelope: pb2.Envelope,
|
|
69
|
+
outgoing: "asyncio.Queue[pb2.Envelope]",
|
|
70
|
+
) -> None:
|
|
71
|
+
await _send_ack(outgoing, envelope)
|
|
72
|
+
dispatch = pb2.ActionDispatch()
|
|
73
|
+
dispatch.ParseFromString(envelope.payload)
|
|
74
|
+
timeout_seconds = dispatch.timeout_seconds if dispatch.HasField("timeout_seconds") else 0
|
|
75
|
+
|
|
76
|
+
worker_start = time.perf_counter_ns()
|
|
77
|
+
success = True
|
|
78
|
+
action_name = dispatch.action_name
|
|
79
|
+
execution: workflow_runtime.ActionExecutionResult | None = None
|
|
80
|
+
try:
|
|
81
|
+
if timeout_seconds > 0:
|
|
82
|
+
execution = await asyncio.wait_for(
|
|
83
|
+
workflow_runtime.execute_action(dispatch), timeout=timeout_seconds
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
execution = await workflow_runtime.execute_action(dispatch)
|
|
87
|
+
|
|
88
|
+
if execution.exception:
|
|
89
|
+
success = False
|
|
90
|
+
response_payload = serialize_error_payload(action_name, execution.exception)
|
|
91
|
+
else:
|
|
92
|
+
response_payload = serialize_result_payload(execution.result)
|
|
93
|
+
except asyncio.TimeoutError:
|
|
94
|
+
success = False
|
|
95
|
+
error = TimeoutError(f"action {action_name} timed out after {timeout_seconds} seconds")
|
|
96
|
+
response_payload = serialize_error_payload(action_name, error)
|
|
97
|
+
LOGGER.warning(
|
|
98
|
+
"Action %s timed out after %ss for action_id=%s sequence=%s",
|
|
99
|
+
action_name,
|
|
100
|
+
timeout_seconds,
|
|
101
|
+
dispatch.action_id,
|
|
102
|
+
dispatch.sequence,
|
|
103
|
+
)
|
|
104
|
+
except Exception as exc: # noqa: BLE001 - propagate structured errors
|
|
105
|
+
success = False
|
|
106
|
+
response_payload = serialize_error_payload(action_name, exc)
|
|
107
|
+
LOGGER.exception(
|
|
108
|
+
"Action %s failed for action_id=%s sequence=%s",
|
|
109
|
+
action_name,
|
|
110
|
+
dispatch.action_id,
|
|
111
|
+
dispatch.sequence,
|
|
112
|
+
)
|
|
113
|
+
worker_end = time.perf_counter_ns()
|
|
114
|
+
response = pb2.ActionResult(
|
|
115
|
+
action_id=dispatch.action_id,
|
|
116
|
+
success=success,
|
|
117
|
+
worker_start_ns=worker_start,
|
|
118
|
+
worker_end_ns=worker_end,
|
|
119
|
+
)
|
|
120
|
+
response.payload.CopyFrom(response_payload)
|
|
121
|
+
if dispatch.dispatch_token:
|
|
122
|
+
response.dispatch_token = dispatch.dispatch_token
|
|
123
|
+
response_envelope = pb2.Envelope(
|
|
124
|
+
delivery_id=envelope.delivery_id,
|
|
125
|
+
partition_id=envelope.partition_id,
|
|
126
|
+
kind=pb2.MessageKind.MESSAGE_KIND_ACTION_RESULT,
|
|
127
|
+
payload=response.SerializeToString(),
|
|
128
|
+
)
|
|
129
|
+
await outgoing.put(response_envelope)
|
|
130
|
+
LOGGER.debug("Handled action=%s seq=%s success=%s", action_name, dispatch.sequence, success)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
async def _handle_incoming_stream(
|
|
134
|
+
stub: pb2_grpc.WorkerBridgeStub,
|
|
135
|
+
worker_id: int,
|
|
136
|
+
outgoing: "asyncio.Queue[pb2.Envelope]",
|
|
137
|
+
) -> None:
|
|
138
|
+
"""Process incoming messages, running action dispatches concurrently."""
|
|
139
|
+
pending_tasks: set[asyncio.Task[None]] = set()
|
|
140
|
+
|
|
141
|
+
async for envelope in stub.Attach(_outgoing_stream(outgoing, worker_id)):
|
|
142
|
+
kind = envelope.kind
|
|
143
|
+
if kind == pb2.MessageKind.MESSAGE_KIND_ACTION_DISPATCH:
|
|
144
|
+
# Spawn task to handle dispatch concurrently
|
|
145
|
+
task = asyncio.create_task(_handle_dispatch(envelope, outgoing))
|
|
146
|
+
pending_tasks.add(task)
|
|
147
|
+
task.add_done_callback(pending_tasks.discard)
|
|
148
|
+
elif kind == pb2.MessageKind.MESSAGE_KIND_HEARTBEAT:
|
|
149
|
+
LOGGER.debug("Received heartbeat delivery=%s", envelope.delivery_id)
|
|
150
|
+
await _send_ack(outgoing, envelope)
|
|
151
|
+
else:
|
|
152
|
+
LOGGER.warning("Unhandled message kind: %s", kind)
|
|
153
|
+
await _send_ack(outgoing, envelope)
|
|
154
|
+
|
|
155
|
+
# Wait for any remaining tasks on stream close
|
|
156
|
+
if pending_tasks:
|
|
157
|
+
await asyncio.gather(*pending_tasks, return_exceptions=True)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
async def _run_worker(args: argparse.Namespace) -> None:
|
|
161
|
+
outgoing: "asyncio.Queue[pb2.Envelope]" = asyncio.Queue()
|
|
162
|
+
for module_name in args.user_module:
|
|
163
|
+
if not module_name:
|
|
164
|
+
continue
|
|
165
|
+
LOGGER.info("Preloading user module %s", module_name)
|
|
166
|
+
importlib.import_module(module_name)
|
|
167
|
+
|
|
168
|
+
async with aio.insecure_channel(args.bridge) as channel:
|
|
169
|
+
stub = pb2_grpc.WorkerBridgeStub(channel)
|
|
170
|
+
LOGGER.info("Worker %s connected to %s", args.worker_id, args.bridge)
|
|
171
|
+
try:
|
|
172
|
+
await _handle_incoming_stream(stub, args.worker_id, outgoing)
|
|
173
|
+
except aio.AioRpcError as exc: # pragma: no cover
|
|
174
|
+
status = exc.code()
|
|
175
|
+
LOGGER.error("Worker stream closed: %s", status)
|
|
176
|
+
raise
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def main(argv: list[str] | None = None) -> None:
|
|
180
|
+
args = _parse_args(argv)
|
|
181
|
+
logging.basicConfig(level=logging.INFO, format="[worker] %(message)s", stream=sys.stderr)
|
|
182
|
+
try:
|
|
183
|
+
asyncio.run(_run_worker(args))
|
|
184
|
+
except KeyboardInterrupt: # pragma: no cover - exit quietly on Ctrl+C
|
|
185
|
+
return
|
|
186
|
+
except grpc.RpcError:
|
|
187
|
+
sys.exit(1)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
if __name__ == "__main__":
|
|
191
|
+
main()
|
rappel/workflow.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Workflow base class and registration decorator.
|
|
3
|
+
|
|
4
|
+
This module provides the foundation for defining workflows that can be
|
|
5
|
+
compiled to IR and executed by the Rappel runtime.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import hashlib
|
|
9
|
+
import inspect
|
|
10
|
+
import os
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from datetime import timedelta
|
|
13
|
+
from functools import wraps
|
|
14
|
+
from threading import RLock
|
|
15
|
+
from typing import Any, Awaitable, ClassVar, Optional, TypeVar
|
|
16
|
+
|
|
17
|
+
from proto import ast_pb2 as ir
|
|
18
|
+
from proto import messages_pb2 as pb2
|
|
19
|
+
|
|
20
|
+
from . import bridge
|
|
21
|
+
from .actions import deserialize_result_payload
|
|
22
|
+
from .ir_builder import build_workflow_ir
|
|
23
|
+
from .logger import configure as configure_logger
|
|
24
|
+
from .serialization import build_arguments_from_kwargs
|
|
25
|
+
from .workflow_runtime import WorkflowNodeResult
|
|
26
|
+
|
|
27
|
+
logger = configure_logger("rappel.workflow")
|
|
28
|
+
|
|
29
|
+
TWorkflow = TypeVar("TWorkflow", bound="Workflow")
|
|
30
|
+
TResult = TypeVar("TResult")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class RetryPolicy:
|
|
35
|
+
"""Retry policy for action execution.
|
|
36
|
+
|
|
37
|
+
Maps to IR RetryPolicy: [ExceptionType -> retry: N, backoff: Xs]
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
attempts: Maximum number of retry attempts.
|
|
41
|
+
exception_types: List of exception type names to retry on. Empty = catch all.
|
|
42
|
+
backoff_seconds: Constant backoff duration between retries in seconds.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
attempts: Optional[int] = None
|
|
46
|
+
exception_types: Optional[list[str]] = None
|
|
47
|
+
backoff_seconds: Optional[float] = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Workflow:
|
|
51
|
+
"""Base class for workflow definitions."""
|
|
52
|
+
|
|
53
|
+
name: ClassVar[Optional[str]] = None
|
|
54
|
+
"""Human-friendly identifier. Override to pin the registry key; defaults to lowercase class name."""
|
|
55
|
+
|
|
56
|
+
concurrent: ClassVar[bool] = False
|
|
57
|
+
"""When True, downstream engines may respect DAG-parallel execution; False preserves sequential semantics."""
|
|
58
|
+
|
|
59
|
+
_workflow_ir: ClassVar[Optional[ir.Program]] = None
|
|
60
|
+
_ir_lock: ClassVar[RLock] = RLock()
|
|
61
|
+
_workflow_version_id: ClassVar[Optional[str]] = None
|
|
62
|
+
|
|
63
|
+
async def run(self, *args: Any, **kwargs: Any) -> Any:
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
async def run_action(
|
|
67
|
+
self,
|
|
68
|
+
awaitable: Awaitable[TResult],
|
|
69
|
+
*,
|
|
70
|
+
retry: Optional[RetryPolicy] = None,
|
|
71
|
+
timeout: Optional[float | int | timedelta] = None,
|
|
72
|
+
) -> TResult:
|
|
73
|
+
"""Helper that simply awaits the provided action coroutine.
|
|
74
|
+
|
|
75
|
+
The retry and timeout arguments are consumed by the workflow compiler
|
|
76
|
+
(IR builder) rather than the runtime execution path.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
awaitable: The action coroutine to execute.
|
|
80
|
+
retry: Retry policy including max attempts, exception types, and backoff.
|
|
81
|
+
timeout: Timeout duration in seconds (or timedelta).
|
|
82
|
+
"""
|
|
83
|
+
# Parameters are intentionally unused at runtime; the workflow compiler
|
|
84
|
+
# inspects the AST to record them.
|
|
85
|
+
del retry, timeout
|
|
86
|
+
return await awaitable
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def short_name(cls) -> str:
|
|
90
|
+
if cls.name:
|
|
91
|
+
return cls.name
|
|
92
|
+
return cls.__name__.lower()
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def workflow_ir(cls) -> ir.Program:
|
|
96
|
+
"""Build and cache the IR program for this workflow."""
|
|
97
|
+
if cls._workflow_ir is None:
|
|
98
|
+
with cls._ir_lock:
|
|
99
|
+
if cls._workflow_ir is None:
|
|
100
|
+
cls._workflow_ir = build_workflow_ir(cls)
|
|
101
|
+
return cls._workflow_ir
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def _build_registration_payload(
|
|
105
|
+
cls, initial_context: Optional[pb2.WorkflowArguments] = None
|
|
106
|
+
) -> pb2.WorkflowRegistration:
|
|
107
|
+
"""Build a registration payload with the serialized IR."""
|
|
108
|
+
program = cls.workflow_ir()
|
|
109
|
+
|
|
110
|
+
# Serialize IR to bytes
|
|
111
|
+
ir_bytes = program.SerializeToString()
|
|
112
|
+
ir_hash = hashlib.sha256(ir_bytes).hexdigest()
|
|
113
|
+
|
|
114
|
+
message = pb2.WorkflowRegistration(
|
|
115
|
+
workflow_name=cls.short_name(),
|
|
116
|
+
ir=ir_bytes,
|
|
117
|
+
ir_hash=ir_hash,
|
|
118
|
+
concurrent=cls.concurrent,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if initial_context:
|
|
122
|
+
message.initial_context.CopyFrom(initial_context)
|
|
123
|
+
|
|
124
|
+
return message
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class WorkflowRegistry:
|
|
128
|
+
"""Registry of workflow definitions keyed by workflow name."""
|
|
129
|
+
|
|
130
|
+
def __init__(self) -> None:
|
|
131
|
+
self._workflows: dict[str, type[Workflow]] = {}
|
|
132
|
+
self._lock = RLock()
|
|
133
|
+
|
|
134
|
+
def register(self, name: str, workflow_cls: type[Workflow]) -> None:
|
|
135
|
+
with self._lock:
|
|
136
|
+
if name in self._workflows:
|
|
137
|
+
raise ValueError(f"workflow '{name}' already registered")
|
|
138
|
+
self._workflows[name] = workflow_cls
|
|
139
|
+
|
|
140
|
+
def get(self, name: str) -> Optional[type[Workflow]]:
|
|
141
|
+
with self._lock:
|
|
142
|
+
return self._workflows.get(name)
|
|
143
|
+
|
|
144
|
+
def names(self) -> list[str]:
|
|
145
|
+
with self._lock:
|
|
146
|
+
return sorted(self._workflows.keys())
|
|
147
|
+
|
|
148
|
+
def reset(self) -> None:
|
|
149
|
+
with self._lock:
|
|
150
|
+
self._workflows.clear()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
workflow_registry = WorkflowRegistry()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def workflow(cls: type[TWorkflow]) -> type[TWorkflow]:
|
|
157
|
+
"""Decorator that registers workflow classes and caches their IR."""
|
|
158
|
+
|
|
159
|
+
if not issubclass(cls, Workflow):
|
|
160
|
+
raise TypeError("workflow decorator requires Workflow subclasses")
|
|
161
|
+
run_impl = cls.run
|
|
162
|
+
if not inspect.iscoroutinefunction(run_impl):
|
|
163
|
+
raise TypeError("workflow run() must be defined with 'async def'")
|
|
164
|
+
|
|
165
|
+
@wraps(run_impl)
|
|
166
|
+
async def run_public(self: Workflow, *args: Any, **kwargs: Any) -> Any:
|
|
167
|
+
if _running_under_pytest():
|
|
168
|
+
cls.workflow_ir()
|
|
169
|
+
return await run_impl(self, *args, **kwargs)
|
|
170
|
+
|
|
171
|
+
# Get the signature of run() to map positional args to parameter names
|
|
172
|
+
sig = inspect.signature(run_impl)
|
|
173
|
+
params = list(sig.parameters.keys())[1:] # Skip 'self'
|
|
174
|
+
|
|
175
|
+
# Convert positional args to kwargs
|
|
176
|
+
for i, arg in enumerate(args):
|
|
177
|
+
if i < len(params):
|
|
178
|
+
kwargs[params[i]] = arg
|
|
179
|
+
|
|
180
|
+
bound = sig.bind_partial(self, **kwargs)
|
|
181
|
+
bound.apply_defaults()
|
|
182
|
+
initial_kwargs = {key: value for key, value in bound.arguments.items() if key != "self"}
|
|
183
|
+
|
|
184
|
+
# Serialize kwargs using common logic
|
|
185
|
+
initial_context = build_arguments_from_kwargs(initial_kwargs)
|
|
186
|
+
|
|
187
|
+
payload = cls._build_registration_payload(initial_context)
|
|
188
|
+
run_result = await bridge.run_instance(payload.SerializeToString())
|
|
189
|
+
cls._workflow_version_id = run_result.workflow_version_id
|
|
190
|
+
if _skip_wait_for_instance():
|
|
191
|
+
logger.info(
|
|
192
|
+
"Skipping wait_for_instance for workflow %s due to RAPPEL_SKIP_WAIT_FOR_INSTANCE",
|
|
193
|
+
cls.short_name(),
|
|
194
|
+
)
|
|
195
|
+
return None
|
|
196
|
+
result_bytes = await bridge.wait_for_instance(
|
|
197
|
+
instance_id=run_result.workflow_instance_id,
|
|
198
|
+
poll_interval_secs=1.0,
|
|
199
|
+
)
|
|
200
|
+
if result_bytes is None:
|
|
201
|
+
raise TimeoutError(
|
|
202
|
+
f"workflow instance {run_result.workflow_instance_id} did not complete"
|
|
203
|
+
)
|
|
204
|
+
arguments = pb2.WorkflowArguments()
|
|
205
|
+
arguments.ParseFromString(result_bytes)
|
|
206
|
+
result = deserialize_result_payload(arguments)
|
|
207
|
+
if result.error:
|
|
208
|
+
raise RuntimeError(f"workflow failed: {result.error}")
|
|
209
|
+
|
|
210
|
+
# Unwrap WorkflowNodeResult if present (internal worker representation)
|
|
211
|
+
if isinstance(result.result, WorkflowNodeResult):
|
|
212
|
+
# Extract the actual result from the variables dict
|
|
213
|
+
variables = result.result.variables
|
|
214
|
+
program = cls.workflow_ir()
|
|
215
|
+
# Get the return variable from the IR if available
|
|
216
|
+
if program.functions:
|
|
217
|
+
outputs = list(program.functions[0].io.outputs)
|
|
218
|
+
if outputs:
|
|
219
|
+
return_var = outputs[0]
|
|
220
|
+
if return_var in variables:
|
|
221
|
+
return variables[return_var]
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
return result.result
|
|
225
|
+
|
|
226
|
+
cls.__workflow_run_impl__ = run_impl
|
|
227
|
+
cls.run = run_public # type: ignore[assignment]
|
|
228
|
+
workflow_registry.register(cls.short_name(), cls)
|
|
229
|
+
return cls
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _running_under_pytest() -> bool:
|
|
233
|
+
return bool(os.environ.get("PYTEST_CURRENT_TEST"))
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _skip_wait_for_instance() -> bool:
|
|
237
|
+
value = os.environ.get("RAPPEL_SKIP_WAIT_FOR_INSTANCE")
|
|
238
|
+
if not value:
|
|
239
|
+
return False
|
|
240
|
+
return value.strip().lower() not in {"0", "false", "no"}
|