cap-sdk-python 2.5.3__tar.gz → 2.5.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/PKG-INFO +1 -1
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/__init__.py +22 -0
- cap_sdk_python-2.5.4/cap/heartbeat.py +153 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/alert_pb2_grpc.py +1 -1
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/buspacket_pb2_grpc.py +1 -1
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/handshake_pb2_grpc.py +1 -1
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/heartbeat_pb2_grpc.py +1 -1
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/job_pb2_grpc.py +1 -1
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/safety_pb2_grpc.py +1 -1
- cap_sdk_python-2.5.4/cap/progress.py +100 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/runtime.py +134 -85
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap_sdk_python.egg-info/PKG-INFO +1 -1
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap_sdk_python.egg-info/SOURCES.txt +4 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/pyproject.toml +1 -1
- cap_sdk_python-2.5.4/tests/test_heartbeat.py +311 -0
- cap_sdk_python-2.5.4/tests/test_progress.py +260 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/README.md +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/bus.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/client.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/errors.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/metrics.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/middleware.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/__init__.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/__init__.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/__init__.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/__init__.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/alert_pb2.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/buspacket_pb2.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/handshake_pb2.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/heartbeat_pb2.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/job_pb2.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/pb/cordum/agent/v1/safety_pb2.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/subjects.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/testing.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/validate.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap/worker.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap_sdk_python.egg-info/dependency_links.txt +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap_sdk_python.egg-info/requires.txt +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/cap_sdk_python.egg-info/top_level.txt +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/setup.cfg +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/tests/test_conformance.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/tests/test_errors.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/tests/test_metrics.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/tests/test_middleware.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/tests/test_runtime.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/tests/test_sdk.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/tests/test_testing.py +0 -0
- {cap_sdk_python-2.5.3 → cap_sdk_python-2.5.4}/tests/test_validate.py +0 -0
|
@@ -35,6 +35,19 @@ from .bus import connect_nats
|
|
|
35
35
|
from .runtime import Agent, Context, BlobStore, RedisBlobStore, InMemoryBlobStore
|
|
36
36
|
from .middleware import Middleware, NextFn, logging_middleware
|
|
37
37
|
from .metrics import MetricsHook, NoopMetrics
|
|
38
|
+
from .heartbeat import (
|
|
39
|
+
heartbeat_payload,
|
|
40
|
+
heartbeat_payload_with_memory,
|
|
41
|
+
heartbeat_payload_with_progress,
|
|
42
|
+
emit_heartbeat,
|
|
43
|
+
heartbeat_loop,
|
|
44
|
+
)
|
|
45
|
+
from .progress import (
|
|
46
|
+
progress_payload,
|
|
47
|
+
cancel_payload,
|
|
48
|
+
emit_progress,
|
|
49
|
+
emit_cancel,
|
|
50
|
+
)
|
|
38
51
|
from .validate import (
|
|
39
52
|
ValidationError,
|
|
40
53
|
validate_job_request,
|
|
@@ -88,6 +101,15 @@ __all__ = [
|
|
|
88
101
|
"logging_middleware",
|
|
89
102
|
"MetricsHook",
|
|
90
103
|
"NoopMetrics",
|
|
104
|
+
"heartbeat_payload",
|
|
105
|
+
"heartbeat_payload_with_memory",
|
|
106
|
+
"heartbeat_payload_with_progress",
|
|
107
|
+
"emit_heartbeat",
|
|
108
|
+
"heartbeat_loop",
|
|
109
|
+
"progress_payload",
|
|
110
|
+
"cancel_payload",
|
|
111
|
+
"emit_progress",
|
|
112
|
+
"emit_cancel",
|
|
91
113
|
"ValidationError",
|
|
92
114
|
"validate_job_request",
|
|
93
115
|
"validate_job_result",
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Heartbeat helpers for CAP Python SDK.
|
|
2
|
+
|
|
3
|
+
These helpers build and publish heartbeat BusPacket envelopes.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Callable, Optional
|
|
9
|
+
|
|
10
|
+
from cryptography.hazmat.primitives import hashes
|
|
11
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
12
|
+
from google.protobuf import timestamp_pb2
|
|
13
|
+
|
|
14
|
+
from cap.client import DEFAULT_PROTOCOL_VERSION
|
|
15
|
+
from cap.metrics import MetricsHook
|
|
16
|
+
from cap.pb.cordum.agent.v1 import buspacket_pb2, heartbeat_pb2
|
|
17
|
+
from cap.subjects import SUBJECT_HEARTBEAT
|
|
18
|
+
|
|
19
|
+
_logger = logging.getLogger("cap.heartbeat")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def heartbeat_payload(
|
|
23
|
+
worker_id: str,
|
|
24
|
+
pool: str,
|
|
25
|
+
active_jobs: int,
|
|
26
|
+
max_parallel: int,
|
|
27
|
+
cpu_load: float,
|
|
28
|
+
) -> bytes:
|
|
29
|
+
"""Build a heartbeat payload with CPU utilization only."""
|
|
30
|
+
return heartbeat_payload_with_progress(
|
|
31
|
+
worker_id=worker_id,
|
|
32
|
+
pool=pool,
|
|
33
|
+
active_jobs=active_jobs,
|
|
34
|
+
max_parallel=max_parallel,
|
|
35
|
+
cpu_load=cpu_load,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def heartbeat_payload_with_memory(
|
|
40
|
+
worker_id: str,
|
|
41
|
+
pool: str,
|
|
42
|
+
active_jobs: int,
|
|
43
|
+
max_parallel: int,
|
|
44
|
+
cpu_load: float,
|
|
45
|
+
memory_load: float,
|
|
46
|
+
) -> bytes:
|
|
47
|
+
"""Build a heartbeat payload including memory utilization."""
|
|
48
|
+
return heartbeat_payload_with_progress(
|
|
49
|
+
worker_id=worker_id,
|
|
50
|
+
pool=pool,
|
|
51
|
+
active_jobs=active_jobs,
|
|
52
|
+
max_parallel=max_parallel,
|
|
53
|
+
cpu_load=cpu_load,
|
|
54
|
+
memory_load=memory_load,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def heartbeat_payload_with_progress(
|
|
59
|
+
worker_id: str,
|
|
60
|
+
pool: str,
|
|
61
|
+
active_jobs: int,
|
|
62
|
+
max_parallel: int,
|
|
63
|
+
cpu_load: float,
|
|
64
|
+
memory_load: float = 0.0,
|
|
65
|
+
progress_pct: int = 0,
|
|
66
|
+
last_memo: str = "",
|
|
67
|
+
) -> bytes:
|
|
68
|
+
"""Build a heartbeat payload including optional progress fields."""
|
|
69
|
+
ts = timestamp_pb2.Timestamp()
|
|
70
|
+
ts.GetCurrentTime()
|
|
71
|
+
|
|
72
|
+
packet = buspacket_pb2.BusPacket()
|
|
73
|
+
packet.sender_id = worker_id
|
|
74
|
+
packet.protocol_version = DEFAULT_PROTOCOL_VERSION
|
|
75
|
+
packet.created_at.CopyFrom(ts)
|
|
76
|
+
packet.heartbeat.CopyFrom(
|
|
77
|
+
heartbeat_pb2.Heartbeat(
|
|
78
|
+
worker_id=worker_id,
|
|
79
|
+
pool=pool,
|
|
80
|
+
active_jobs=active_jobs,
|
|
81
|
+
max_parallel_jobs=max_parallel,
|
|
82
|
+
cpu_load=cpu_load,
|
|
83
|
+
memory_load=memory_load,
|
|
84
|
+
progress_pct=progress_pct,
|
|
85
|
+
last_memo=last_memo,
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
return packet.SerializeToString(deterministic=True)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
async def emit_heartbeat(
|
|
92
|
+
nc,
|
|
93
|
+
payload: bytes,
|
|
94
|
+
private_key: Optional[ec.EllipticCurvePrivateKey] = None,
|
|
95
|
+
) -> None:
|
|
96
|
+
"""Publish one heartbeat packet to the heartbeat subject."""
|
|
97
|
+
data = payload
|
|
98
|
+
if private_key is not None:
|
|
99
|
+
packet = buspacket_pb2.BusPacket()
|
|
100
|
+
packet.ParseFromString(payload)
|
|
101
|
+
packet.ClearField("signature")
|
|
102
|
+
unsigned_data = packet.SerializeToString(deterministic=True)
|
|
103
|
+
packet.signature = private_key.sign(unsigned_data, ec.ECDSA(hashes.SHA256()))
|
|
104
|
+
data = packet.SerializeToString(deterministic=True)
|
|
105
|
+
|
|
106
|
+
await nc.publish(SUBJECT_HEARTBEAT, data)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
async def heartbeat_loop(
|
|
110
|
+
nc,
|
|
111
|
+
payload_fn: Callable[[], bytes],
|
|
112
|
+
interval: float = 5.0,
|
|
113
|
+
private_key: Optional[ec.EllipticCurvePrivateKey] = None,
|
|
114
|
+
metrics: MetricsHook | None = None,
|
|
115
|
+
cancel_event: asyncio.Event | None = None,
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Emit heartbeat packets periodically until cancelled."""
|
|
118
|
+
sleep_interval = max(0.0, interval)
|
|
119
|
+
|
|
120
|
+
while True:
|
|
121
|
+
if cancel_event is not None and cancel_event.is_set():
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
if cancel_event is None:
|
|
125
|
+
await asyncio.sleep(sleep_interval)
|
|
126
|
+
else:
|
|
127
|
+
sleep_task = asyncio.create_task(asyncio.sleep(sleep_interval))
|
|
128
|
+
cancel_task = asyncio.create_task(cancel_event.wait())
|
|
129
|
+
done, pending = await asyncio.wait(
|
|
130
|
+
{sleep_task, cancel_task},
|
|
131
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
for task in pending:
|
|
135
|
+
task.cancel()
|
|
136
|
+
if pending:
|
|
137
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
|
138
|
+
|
|
139
|
+
if cancel_task in done and cancel_event.is_set():
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
payload = payload_fn()
|
|
144
|
+
await emit_heartbeat(nc=nc, payload=payload, private_key=private_key)
|
|
145
|
+
if metrics is not None:
|
|
146
|
+
packet = buspacket_pb2.BusPacket()
|
|
147
|
+
packet.ParseFromString(payload)
|
|
148
|
+
worker_id = packet.heartbeat.worker_id or packet.sender_id
|
|
149
|
+
metrics.on_heartbeat_sent(worker_id)
|
|
150
|
+
except asyncio.CancelledError:
|
|
151
|
+
raise
|
|
152
|
+
except Exception:
|
|
153
|
+
_logger.exception("heartbeat emission failed")
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Progress and cancel emission helpers for CAP Python SDK.
|
|
2
|
+
|
|
3
|
+
These helpers build and publish progress/cancel BusPacket envelopes.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from cryptography.hazmat.primitives import hashes
|
|
9
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
10
|
+
from google.protobuf import timestamp_pb2
|
|
11
|
+
|
|
12
|
+
from cap.client import DEFAULT_PROTOCOL_VERSION
|
|
13
|
+
from cap.pb.cordum.agent.v1 import buspacket_pb2, job_pb2
|
|
14
|
+
from cap.subjects import SUBJECT_CANCEL, SUBJECT_PROGRESS
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def progress_payload(
|
|
18
|
+
sender_id: str,
|
|
19
|
+
job_id: str,
|
|
20
|
+
step_id: str,
|
|
21
|
+
percent: int,
|
|
22
|
+
message: str,
|
|
23
|
+
) -> bytes:
|
|
24
|
+
"""Build a progress payload wrapped in a BusPacket envelope."""
|
|
25
|
+
ts = timestamp_pb2.Timestamp()
|
|
26
|
+
ts.GetCurrentTime()
|
|
27
|
+
|
|
28
|
+
packet = buspacket_pb2.BusPacket()
|
|
29
|
+
packet.sender_id = sender_id
|
|
30
|
+
packet.protocol_version = DEFAULT_PROTOCOL_VERSION
|
|
31
|
+
packet.created_at.CopyFrom(ts)
|
|
32
|
+
packet.job_progress.CopyFrom(
|
|
33
|
+
job_pb2.JobProgress(
|
|
34
|
+
job_id=job_id,
|
|
35
|
+
step_id=step_id,
|
|
36
|
+
percent=percent,
|
|
37
|
+
message=message,
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
return packet.SerializeToString(deterministic=True)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def cancel_payload(
|
|
44
|
+
sender_id: str,
|
|
45
|
+
job_id: str,
|
|
46
|
+
reason: str,
|
|
47
|
+
requested_by: str,
|
|
48
|
+
) -> bytes:
|
|
49
|
+
"""Build a cancel payload wrapped in a BusPacket envelope."""
|
|
50
|
+
ts = timestamp_pb2.Timestamp()
|
|
51
|
+
ts.GetCurrentTime()
|
|
52
|
+
|
|
53
|
+
packet = buspacket_pb2.BusPacket()
|
|
54
|
+
packet.sender_id = sender_id
|
|
55
|
+
packet.protocol_version = DEFAULT_PROTOCOL_VERSION
|
|
56
|
+
packet.created_at.CopyFrom(ts)
|
|
57
|
+
packet.job_cancel.CopyFrom(
|
|
58
|
+
job_pb2.JobCancel(
|
|
59
|
+
job_id=job_id,
|
|
60
|
+
reason=reason,
|
|
61
|
+
requested_by=requested_by,
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
return packet.SerializeToString(deterministic=True)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def emit_progress(
|
|
68
|
+
nc,
|
|
69
|
+
payload: bytes,
|
|
70
|
+
private_key: Optional[ec.EllipticCurvePrivateKey] = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Publish one progress packet to the progress subject."""
|
|
73
|
+
data = payload
|
|
74
|
+
if private_key is not None:
|
|
75
|
+
packet = buspacket_pb2.BusPacket()
|
|
76
|
+
packet.ParseFromString(payload)
|
|
77
|
+
packet.ClearField("signature")
|
|
78
|
+
unsigned_data = packet.SerializeToString(deterministic=True)
|
|
79
|
+
packet.signature = private_key.sign(unsigned_data, ec.ECDSA(hashes.SHA256()))
|
|
80
|
+
data = packet.SerializeToString(deterministic=True)
|
|
81
|
+
|
|
82
|
+
await nc.publish(SUBJECT_PROGRESS, data)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def emit_cancel(
|
|
86
|
+
nc,
|
|
87
|
+
payload: bytes,
|
|
88
|
+
private_key: Optional[ec.EllipticCurvePrivateKey] = None,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Publish one cancel packet to the cancel subject."""
|
|
91
|
+
data = payload
|
|
92
|
+
if private_key is not None:
|
|
93
|
+
packet = buspacket_pb2.BusPacket()
|
|
94
|
+
packet.ParseFromString(payload)
|
|
95
|
+
packet.ClearField("signature")
|
|
96
|
+
unsigned_data = packet.SerializeToString(deterministic=True)
|
|
97
|
+
packet.signature = private_key.sign(unsigned_data, ec.ECDSA(hashes.SHA256()))
|
|
98
|
+
data = packet.SerializeToString(deterministic=True)
|
|
99
|
+
|
|
100
|
+
await nc.publish(SUBJECT_CANCEL, data)
|
|
@@ -26,6 +26,7 @@ except Exception: # pragma: no cover - optional until runtime used
|
|
|
26
26
|
from cap.subjects import SUBJECT_RESULT
|
|
27
27
|
from cap.errors import InvalidInputError, MalformedPacketError, SignatureInvalidError, SignatureMissingError
|
|
28
28
|
from cap.metrics import MetricsHook, NoopMetrics
|
|
29
|
+
from cap.heartbeat import heartbeat_loop, heartbeat_payload
|
|
29
30
|
|
|
30
31
|
DEFAULT_PROTOCOL_VERSION = 1
|
|
31
32
|
|
|
@@ -165,6 +166,11 @@ class Agent:
|
|
|
165
166
|
public_keys: Optional[Dict[str, ec.EllipticCurvePublicKey]] = None,
|
|
166
167
|
private_key: Optional[ec.EllipticCurvePrivateKey] = None,
|
|
167
168
|
sender_id: str = "cap-runtime",
|
|
169
|
+
worker_id: Optional[str] = None,
|
|
170
|
+
pool: str = "",
|
|
171
|
+
max_parallel: int = 1,
|
|
172
|
+
heartbeat_interval: float = 5.0,
|
|
173
|
+
heartbeat_payload_fn: Optional[Callable[[], bytes]] = None,
|
|
168
174
|
retries: int = 0,
|
|
169
175
|
io_timeout: Optional[float] = 5.0,
|
|
170
176
|
max_context_bytes: Optional[int] = 2 * 1024 * 1024,
|
|
@@ -178,7 +184,11 @@ class Agent:
|
|
|
178
184
|
self._store = store
|
|
179
185
|
self._public_keys = public_keys
|
|
180
186
|
self._private_key = private_key
|
|
181
|
-
self._sender_id = sender_id
|
|
187
|
+
self._sender_id = worker_id or sender_id
|
|
188
|
+
self._pool = pool
|
|
189
|
+
self._max_parallel = max(1, max_parallel)
|
|
190
|
+
self._heartbeat_interval = heartbeat_interval if heartbeat_interval > 0 else 5.0
|
|
191
|
+
self._heartbeat_payload_fn = heartbeat_payload_fn
|
|
182
192
|
self._default_retries = max(0, retries)
|
|
183
193
|
self._io_timeout = io_timeout if io_timeout and io_timeout > 0 else None
|
|
184
194
|
self._max_context_bytes = max_context_bytes if max_context_bytes and max_context_bytes > 0 else None
|
|
@@ -189,6 +199,9 @@ class Agent:
|
|
|
189
199
|
self._handlers: Dict[str, HandlerSpec] = {}
|
|
190
200
|
self._middlewares: list = []
|
|
191
201
|
self._nc = None
|
|
202
|
+
self._active_jobs: set[str] = set()
|
|
203
|
+
self._heartbeat_cancel_event: Optional[asyncio.Event] = None
|
|
204
|
+
self._heartbeat_task: Optional[asyncio.Task[Any]] = None
|
|
192
205
|
|
|
193
206
|
def use(self, *middlewares) -> None:
|
|
194
207
|
"""Append middleware to the agent. Middleware executes in registration order before the handler."""
|
|
@@ -247,13 +260,45 @@ class Agent:
|
|
|
247
260
|
for topic, spec in self._handlers.items():
|
|
248
261
|
await self._nc.subscribe(topic, queue=topic, cb=lambda msg, s=spec: asyncio.create_task(self._on_msg(msg, s)))
|
|
249
262
|
|
|
263
|
+
if self._heartbeat_task is None or self._heartbeat_task.done():
|
|
264
|
+
self._heartbeat_cancel_event = asyncio.Event()
|
|
265
|
+
payload_fn = self._heartbeat_payload_fn or self._default_heartbeat_payload
|
|
266
|
+
self._heartbeat_task = asyncio.create_task(
|
|
267
|
+
heartbeat_loop(
|
|
268
|
+
nc=self._nc,
|
|
269
|
+
payload_fn=payload_fn,
|
|
270
|
+
interval=self._heartbeat_interval,
|
|
271
|
+
private_key=self._private_key,
|
|
272
|
+
metrics=self._metrics,
|
|
273
|
+
cancel_event=self._heartbeat_cancel_event,
|
|
274
|
+
)
|
|
275
|
+
)
|
|
276
|
+
|
|
250
277
|
async def close(self) -> None:
|
|
251
278
|
"""Drain the NATS connection and close the blob store."""
|
|
279
|
+
if self._heartbeat_cancel_event is not None:
|
|
280
|
+
self._heartbeat_cancel_event.set()
|
|
281
|
+
if self._heartbeat_task is not None:
|
|
282
|
+
try:
|
|
283
|
+
await self._heartbeat_task
|
|
284
|
+
finally:
|
|
285
|
+
self._heartbeat_task = None
|
|
286
|
+
self._heartbeat_cancel_event = None
|
|
287
|
+
|
|
252
288
|
if self._nc is not None:
|
|
253
289
|
await self._nc.drain()
|
|
254
290
|
if self._store is not None:
|
|
255
291
|
await self._store.close()
|
|
256
292
|
|
|
293
|
+
def _default_heartbeat_payload(self) -> bytes:
|
|
294
|
+
return heartbeat_payload(
|
|
295
|
+
worker_id=self._sender_id,
|
|
296
|
+
pool=self._pool,
|
|
297
|
+
active_jobs=len(self._active_jobs),
|
|
298
|
+
max_parallel=self._max_parallel,
|
|
299
|
+
cpu_load=0.0,
|
|
300
|
+
)
|
|
301
|
+
|
|
257
302
|
async def run(self) -> None:
|
|
258
303
|
"""Start the agent and block until interrupted."""
|
|
259
304
|
await self.start()
|
|
@@ -293,99 +338,103 @@ class Agent:
|
|
|
293
338
|
if not req.job_id:
|
|
294
339
|
return
|
|
295
340
|
|
|
296
|
-
self.
|
|
297
|
-
|
|
298
|
-
ctx_logger = logging.LoggerAdapter(
|
|
299
|
-
self._logger,
|
|
300
|
-
{
|
|
301
|
-
"job_id": req.job_id,
|
|
302
|
-
"trace_id": packet.trace_id,
|
|
303
|
-
"topic": req.topic,
|
|
304
|
-
"sender_id": packet.sender_id,
|
|
305
|
-
},
|
|
306
|
-
)
|
|
307
|
-
ctx = Context(job=req, packet=packet, logger=ctx_logger)
|
|
308
|
-
|
|
309
|
-
store = self._store
|
|
310
|
-
if store is None:
|
|
311
|
-
ctx_logger.error("blob store not initialized")
|
|
312
|
-
return
|
|
313
|
-
|
|
341
|
+
self._active_jobs.add(req.job_id)
|
|
314
342
|
try:
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
343
|
+
self._metrics.on_job_received(req.job_id, req.topic)
|
|
344
|
+
|
|
345
|
+
ctx_logger = logging.LoggerAdapter(
|
|
346
|
+
self._logger,
|
|
347
|
+
{
|
|
348
|
+
"job_id": req.job_id,
|
|
349
|
+
"trace_id": packet.trace_id,
|
|
350
|
+
"topic": req.topic,
|
|
351
|
+
"sender_id": packet.sender_id,
|
|
352
|
+
},
|
|
353
|
+
)
|
|
354
|
+
ctx = Context(job=req, packet=packet, logger=ctx_logger)
|
|
324
355
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
return
|
|
356
|
+
store = self._store
|
|
357
|
+
if store is None:
|
|
358
|
+
ctx_logger.error("blob store not initialized")
|
|
359
|
+
return
|
|
330
360
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
361
|
+
try:
|
|
362
|
+
key = key_from_pointer(req.context_ptr)
|
|
363
|
+
payload = await self._with_timeout(store.get(key), "context fetch")
|
|
364
|
+
if payload is None:
|
|
365
|
+
raise ValueError("context not found")
|
|
366
|
+
if self._max_context_bytes is not None and len(payload) > self._max_context_bytes:
|
|
367
|
+
raise ValueError("context exceeds max size")
|
|
368
|
+
except Exception as exc:
|
|
369
|
+
await self._publish_failure(ctx, req, str(exc), execution_ms=0)
|
|
370
|
+
return
|
|
336
371
|
|
|
337
|
-
# Build middleware chain: outermost first, terminal calls handler.
|
|
338
|
-
async def _terminal(c: Context, d: Any) -> Any:
|
|
339
|
-
result = spec.func(c, d)
|
|
340
|
-
if asyncio.iscoroutine(result):
|
|
341
|
-
result = await result
|
|
342
|
-
return result
|
|
343
|
-
|
|
344
|
-
chain = _terminal
|
|
345
|
-
for mw in reversed(self._middlewares):
|
|
346
|
-
_next = chain
|
|
347
|
-
chain = (lambda m, n: (lambda c, d: m(c, d, n)))(mw, _next)
|
|
348
|
-
|
|
349
|
-
start_time = time.monotonic()
|
|
350
|
-
error: Optional[str] = None
|
|
351
|
-
output: Any = None
|
|
352
|
-
for attempt in range(spec.retries + 1):
|
|
353
372
|
try:
|
|
354
|
-
|
|
355
|
-
output = self._validate_output(spec, output)
|
|
356
|
-
error = None
|
|
357
|
-
break
|
|
373
|
+
raw = json.loads(payload.decode("utf-8"))
|
|
358
374
|
except Exception as exc:
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
if attempt >= spec.retries:
|
|
362
|
-
break
|
|
375
|
+
await self._publish_failure(ctx, req, f"context decode failed: {exc}", execution_ms=0)
|
|
376
|
+
return
|
|
363
377
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
378
|
+
try:
|
|
379
|
+
input_data = self._validate_input(spec, raw)
|
|
380
|
+
except Exception as exc:
|
|
381
|
+
await self._publish_failure(ctx, req, f"input validation failed: {exc}", execution_ms=0)
|
|
382
|
+
return
|
|
368
383
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
384
|
+
# Build middleware chain: outermost first, terminal calls handler.
|
|
385
|
+
async def _terminal(c: Context, d: Any) -> Any:
|
|
386
|
+
result = spec.func(c, d)
|
|
387
|
+
if asyncio.iscoroutine(result):
|
|
388
|
+
result = await result
|
|
389
|
+
return result
|
|
390
|
+
|
|
391
|
+
chain = _terminal
|
|
392
|
+
for mw in reversed(self._middlewares):
|
|
393
|
+
_next = chain
|
|
394
|
+
chain = (lambda m, n: (lambda c, d: m(c, d, n)))(mw, _next)
|
|
395
|
+
|
|
396
|
+
start_time = time.monotonic()
|
|
397
|
+
error: Optional[str] = None
|
|
398
|
+
output: Any = None
|
|
399
|
+
for attempt in range(spec.retries + 1):
|
|
400
|
+
try:
|
|
401
|
+
output = await chain(ctx, input_data)
|
|
402
|
+
output = self._validate_output(spec, output)
|
|
403
|
+
error = None
|
|
404
|
+
break
|
|
405
|
+
except Exception as exc:
|
|
406
|
+
error = str(exc)
|
|
407
|
+
ctx_logger.warning("handler failed (attempt %d/%d): %s", attempt + 1, spec.retries + 1, exc)
|
|
408
|
+
if attempt >= spec.retries:
|
|
409
|
+
break
|
|
410
|
+
|
|
411
|
+
elapsed_ms = int((time.monotonic() - start_time) * 1000)
|
|
412
|
+
if error is not None:
|
|
413
|
+
await self._publish_failure(ctx, req, error, execution_ms=elapsed_ms)
|
|
414
|
+
return
|
|
379
415
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
416
|
+
try:
|
|
417
|
+
result_payload = self._serialize_output(output)
|
|
418
|
+
if self._max_result_bytes is not None and len(result_payload) > self._max_result_bytes:
|
|
419
|
+
raise ValueError("result exceeds max size")
|
|
420
|
+
result_key = f"res:{req.job_id}"
|
|
421
|
+
await self._with_timeout(store.set(result_key, result_payload), "result write")
|
|
422
|
+
result_ptr = pointer_for_key(result_key)
|
|
423
|
+
except Exception as exc:
|
|
424
|
+
await self._publish_failure(ctx, req, f"result write failed: {exc}", execution_ms=elapsed_ms)
|
|
425
|
+
return
|
|
426
|
+
|
|
427
|
+
self._metrics.on_job_completed(req.job_id, elapsed_ms, "SUCCEEDED")
|
|
428
|
+
result = job_pb2.JobResult(
|
|
429
|
+
job_id=req.job_id,
|
|
430
|
+
status=job_pb2.JOB_STATUS_SUCCEEDED,
|
|
431
|
+
result_ptr=result_ptr,
|
|
432
|
+
worker_id=self._sender_id,
|
|
433
|
+
execution_ms=elapsed_ms,
|
|
434
|
+
)
|
|
435
|
+
await self._publish_result(ctx, result)
|
|
436
|
+
finally:
|
|
437
|
+
self._active_jobs.discard(req.job_id)
|
|
389
438
|
|
|
390
439
|
def _validate_input(self, spec: HandlerSpec, data: Any) -> Any:
|
|
391
440
|
if spec.input_model is None:
|
|
@@ -4,8 +4,10 @@ cap/__init__.py
|
|
|
4
4
|
cap/bus.py
|
|
5
5
|
cap/client.py
|
|
6
6
|
cap/errors.py
|
|
7
|
+
cap/heartbeat.py
|
|
7
8
|
cap/metrics.py
|
|
8
9
|
cap/middleware.py
|
|
10
|
+
cap/progress.py
|
|
9
11
|
cap/runtime.py
|
|
10
12
|
cap/subjects.py
|
|
11
13
|
cap/testing.py
|
|
@@ -34,8 +36,10 @@ cap_sdk_python.egg-info/requires.txt
|
|
|
34
36
|
cap_sdk_python.egg-info/top_level.txt
|
|
35
37
|
tests/test_conformance.py
|
|
36
38
|
tests/test_errors.py
|
|
39
|
+
tests/test_heartbeat.py
|
|
37
40
|
tests/test_metrics.py
|
|
38
41
|
tests/test_middleware.py
|
|
42
|
+
tests/test_progress.py
|
|
39
43
|
tests/test_runtime.py
|
|
40
44
|
tests/test_sdk.py
|
|
41
45
|
tests/test_testing.py
|