cap-sdk-python 2.0.16__tar.gz → 2.0.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cap_sdk_python-2.0.16/cap_sdk_python.egg-info → cap_sdk_python-2.0.18}/PKG-INFO +32 -2
- cap_sdk_python-2.0.16/PKG-INFO → cap_sdk_python-2.0.18/README.md +28 -12
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/__init__.py +11 -1
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/client.py +2 -2
- cap_sdk_python-2.0.18/cap/runtime.py +394 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/worker.py +4 -3
- cap_sdk_python-2.0.16/README.md → cap_sdk_python-2.0.18/cap_sdk_python.egg-info/PKG-INFO +42 -1
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap_sdk_python.egg-info/SOURCES.txt +3 -0
- cap_sdk_python-2.0.18/cap_sdk_python.egg-info/requires.txt +6 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/pyproject.toml +4 -1
- cap_sdk_python-2.0.18/tests/test_conformance.py +109 -0
- cap_sdk_python-2.0.18/tests/test_runtime.py +130 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/tests/test_sdk.py +1 -1
- cap_sdk_python-2.0.16/cap_sdk_python.egg-info/requires.txt +0 -3
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/bus.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/__init__.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/__init__.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/__init__.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/__init__.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/alert_pb2.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/alert_pb2_grpc.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/buspacket_pb2.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/buspacket_pb2_grpc.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/heartbeat_pb2.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/heartbeat_pb2_grpc.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/job_pb2.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/job_pb2_grpc.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/safety_pb2.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/safety_pb2_grpc.py +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap_sdk_python.egg-info/dependency_links.txt +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap_sdk_python.egg-info/top_level.txt +0 -0
- {cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cap-sdk-python
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.18
|
|
4
4
|
Summary: CAP (Cordum Agent Protocol) Python SDK
|
|
5
5
|
Project-URL: Homepage, https://github.com/cordum-io/cap
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -8,6 +8,9 @@ Description-Content-Type: text/markdown
|
|
|
8
8
|
Requires-Dist: protobuf>=4.25.0
|
|
9
9
|
Requires-Dist: grpcio>=1.59.0
|
|
10
10
|
Requires-Dist: nats-py>=2.6.0
|
|
11
|
+
Requires-Dist: cryptography>=41.0.0
|
|
12
|
+
Requires-Dist: pydantic>=2.6.0
|
|
13
|
+
Requires-Dist: redis>=5.0.0
|
|
11
14
|
|
|
12
15
|
# CAP Python SDK
|
|
13
16
|
|
|
@@ -77,7 +80,7 @@ Asyncio-first SDK with NATS helpers for CAP workers and clients.
|
|
|
77
80
|
## Defaults
|
|
78
81
|
- Subjects: `sys.job.submit`, `sys.job.result`, `sys.heartbeat`.
|
|
79
82
|
- Protocol version: `1`.
|
|
80
|
-
- Signing: `submit_job` and `run_worker` sign envelopes when given an `ec.EllipticCurvePrivateKey`. Generate a keypair with `cryptography`:
|
|
83
|
+
- Signing: `submit_job` and `run_worker` sign envelopes when given an `ec.EllipticCurvePrivateKey`. Signatures use deterministic protobuf serialization (map entries ordered by key) for cross-SDK verification. Generate a keypair with `cryptography`:
|
|
81
84
|
```python
|
|
82
85
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
83
86
|
priv = ec.generate_private_key(ec.SECP256R1())
|
|
@@ -88,3 +91,30 @@ Asyncio-first SDK with NATS helpers for CAP workers and clients.
|
|
|
88
91
|
- Pass `private_key=None` to `submit_job` if you want to send unsigned envelopes.
|
|
89
92
|
|
|
90
93
|
Swap out `cap.bus` if you need a different transport.
|
|
94
|
+
|
|
95
|
+
## Runtime (High-Level SDK)
|
|
96
|
+
The runtime hides NATS/Redis plumbing and gives you typed handlers.
|
|
97
|
+
|
|
98
|
+
```python
|
|
99
|
+
import asyncio
|
|
100
|
+
from pydantic import BaseModel
|
|
101
|
+
from cap.runtime import Agent, Context
|
|
102
|
+
|
|
103
|
+
class Input(BaseModel):
|
|
104
|
+
prompt: str
|
|
105
|
+
|
|
106
|
+
class Output(BaseModel):
|
|
107
|
+
summary: str
|
|
108
|
+
|
|
109
|
+
agent = Agent(retries=2)
|
|
110
|
+
|
|
111
|
+
@agent.job("job.summarize", input_model=Input, output_model=Output)
|
|
112
|
+
async def summarize(ctx: Context, data: Input) -> Output:
|
|
113
|
+
return Output(summary=data.prompt[:140])
|
|
114
|
+
|
|
115
|
+
asyncio.run(agent.run())
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
### Environment
|
|
119
|
+
- `NATS_URL` (default `nats://127.0.0.1:4222`)
|
|
120
|
+
- `REDIS_URL` (default `redis://127.0.0.1:6379/0`)
|
|
@@ -1,14 +1,3 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: cap-sdk-python
|
|
3
|
-
Version: 2.0.16
|
|
4
|
-
Summary: CAP (Cordum Agent Protocol) Python SDK
|
|
5
|
-
Project-URL: Homepage, https://github.com/cordum-io/cap
|
|
6
|
-
Requires-Python: >=3.9
|
|
7
|
-
Description-Content-Type: text/markdown
|
|
8
|
-
Requires-Dist: protobuf>=4.25.0
|
|
9
|
-
Requires-Dist: grpcio>=1.59.0
|
|
10
|
-
Requires-Dist: nats-py>=2.6.0
|
|
11
|
-
|
|
12
1
|
# CAP Python SDK
|
|
13
2
|
|
|
14
3
|
Asyncio-first SDK with NATS helpers for CAP workers and clients.
|
|
@@ -77,7 +66,7 @@ Asyncio-first SDK with NATS helpers for CAP workers and clients.
|
|
|
77
66
|
## Defaults
|
|
78
67
|
- Subjects: `sys.job.submit`, `sys.job.result`, `sys.heartbeat`.
|
|
79
68
|
- Protocol version: `1`.
|
|
80
|
-
- Signing: `submit_job` and `run_worker` sign envelopes when given an `ec.EllipticCurvePrivateKey`. Generate a keypair with `cryptography`:
|
|
69
|
+
- Signing: `submit_job` and `run_worker` sign envelopes when given an `ec.EllipticCurvePrivateKey`. Signatures use deterministic protobuf serialization (map entries ordered by key) for cross-SDK verification. Generate a keypair with `cryptography`:
|
|
81
70
|
```python
|
|
82
71
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
83
72
|
priv = ec.generate_private_key(ec.SECP256R1())
|
|
@@ -88,3 +77,30 @@ Asyncio-first SDK with NATS helpers for CAP workers and clients.
|
|
|
88
77
|
- Pass `private_key=None` to `submit_job` if you want to send unsigned envelopes.
|
|
89
78
|
|
|
90
79
|
Swap out `cap.bus` if you need a different transport.
|
|
80
|
+
|
|
81
|
+
## Runtime (High-Level SDK)
|
|
82
|
+
The runtime hides NATS/Redis plumbing and gives you typed handlers.
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
import asyncio
|
|
86
|
+
from pydantic import BaseModel
|
|
87
|
+
from cap.runtime import Agent, Context
|
|
88
|
+
|
|
89
|
+
class Input(BaseModel):
|
|
90
|
+
prompt: str
|
|
91
|
+
|
|
92
|
+
class Output(BaseModel):
|
|
93
|
+
summary: str
|
|
94
|
+
|
|
95
|
+
agent = Agent(retries=2)
|
|
96
|
+
|
|
97
|
+
@agent.job("job.summarize", input_model=Input, output_model=Output)
|
|
98
|
+
async def summarize(ctx: Context, data: Input) -> Output:
|
|
99
|
+
return Output(summary=data.prompt[:140])
|
|
100
|
+
|
|
101
|
+
asyncio.run(agent.run())
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### Environment
|
|
105
|
+
- `NATS_URL` (default `nats://127.0.0.1:4222`)
|
|
106
|
+
- `REDIS_URL` (default `redis://127.0.0.1:6379/0`)
|
|
@@ -26,5 +26,15 @@ except Exception:
|
|
|
26
26
|
from .client import submit_job
|
|
27
27
|
from .worker import run_worker
|
|
28
28
|
from .bus import connect_nats
|
|
29
|
+
from .runtime import Agent, Context, BlobStore, RedisBlobStore, InMemoryBlobStore
|
|
29
30
|
|
|
30
|
-
__all__ = [
|
|
31
|
+
__all__ = [
|
|
32
|
+
"submit_job",
|
|
33
|
+
"run_worker",
|
|
34
|
+
"connect_nats",
|
|
35
|
+
"Agent",
|
|
36
|
+
"Context",
|
|
37
|
+
"BlobStore",
|
|
38
|
+
"RedisBlobStore",
|
|
39
|
+
"InMemoryBlobStore",
|
|
40
|
+
]
|
|
@@ -26,8 +26,8 @@ async def submit_job(
|
|
|
26
26
|
packet.job_request.CopyFrom(job_request)
|
|
27
27
|
|
|
28
28
|
if private_key:
|
|
29
|
-
unsigned_data = packet.SerializeToString()
|
|
29
|
+
unsigned_data = packet.SerializeToString(deterministic=True)
|
|
30
30
|
signature = private_key.sign(unsigned_data, ec.ECDSA(hashes.SHA256()))
|
|
31
31
|
packet.signature = signature
|
|
32
32
|
|
|
33
|
-
await nc.publish(SUBJECT_SUBMIT, packet.SerializeToString())
|
|
33
|
+
await nc.publish(SUBJECT_SUBMIT, packet.SerializeToString(deterministic=True))
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, Awaitable, Callable, Dict, Optional, Protocol, Type, TypeVar, Union
|
|
8
|
+
|
|
9
|
+
from google.protobuf import timestamp_pb2
|
|
10
|
+
from cap.pb.cordum.agent.v1 import buspacket_pb2, job_pb2
|
|
11
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
12
|
+
from cryptography.hazmat.primitives import hashes
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import redis.asyncio as redis_async # type: ignore
|
|
16
|
+
except Exception: # pragma: no cover - optional until runtime used
|
|
17
|
+
redis_async = None
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from pydantic import BaseModel, ValidationError
|
|
21
|
+
except Exception: # pragma: no cover - optional until runtime used
|
|
22
|
+
BaseModel = None
|
|
23
|
+
ValidationError = Exception
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
DEFAULT_PROTOCOL_VERSION = 1
|
|
27
|
+
SUBJECT_RESULT = "sys.job.result"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BlobStore(Protocol):
|
|
31
|
+
async def get(self, key: str) -> Optional[bytes]:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
async def set(self, key: str, data: bytes) -> None:
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
async def close(self) -> None:
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RedisBlobStore:
|
|
42
|
+
def __init__(self, redis_url: str) -> None:
|
|
43
|
+
if redis_async is None:
|
|
44
|
+
raise RuntimeError("redis is required for RedisBlobStore")
|
|
45
|
+
self._client = redis_async.from_url(redis_url)
|
|
46
|
+
|
|
47
|
+
async def get(self, key: str) -> Optional[bytes]:
|
|
48
|
+
value = await self._client.get(key)
|
|
49
|
+
return value
|
|
50
|
+
|
|
51
|
+
async def set(self, key: str, data: bytes) -> None:
|
|
52
|
+
await self._client.set(key, data)
|
|
53
|
+
|
|
54
|
+
async def close(self) -> None:
|
|
55
|
+
await self._client.close()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class InMemoryBlobStore:
|
|
59
|
+
def __init__(self) -> None:
|
|
60
|
+
self._data: Dict[str, bytes] = {}
|
|
61
|
+
|
|
62
|
+
async def get(self, key: str) -> Optional[bytes]:
|
|
63
|
+
return self._data.get(key)
|
|
64
|
+
|
|
65
|
+
async def set(self, key: str, data: bytes) -> None:
|
|
66
|
+
self._data[key] = data
|
|
67
|
+
|
|
68
|
+
async def close(self) -> None:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def pointer_for_key(key: str) -> str:
|
|
73
|
+
return "redis://" + key
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def key_from_pointer(ptr: str) -> str:
|
|
77
|
+
if not ptr:
|
|
78
|
+
raise ValueError("empty pointer")
|
|
79
|
+
if not ptr.startswith("redis://"):
|
|
80
|
+
raise ValueError("unsupported pointer scheme")
|
|
81
|
+
key = ptr[len("redis://") :]
|
|
82
|
+
if not key:
|
|
83
|
+
raise ValueError("missing pointer key")
|
|
84
|
+
return key
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _default_logger() -> logging.Logger:
|
|
88
|
+
logger = logging.getLogger("cap.runtime")
|
|
89
|
+
if not logger.handlers:
|
|
90
|
+
handler = logging.StreamHandler()
|
|
91
|
+
formatter = logging.Formatter("%(levelname)s %(message)s")
|
|
92
|
+
handler.setFormatter(formatter)
|
|
93
|
+
logger.addHandler(handler)
|
|
94
|
+
logger.setLevel(logging.INFO)
|
|
95
|
+
return logger
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class Context:
|
|
100
|
+
job: job_pb2.JobRequest
|
|
101
|
+
packet: buspacket_pb2.BusPacket
|
|
102
|
+
logger: logging.LoggerAdapter
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def job_id(self) -> str:
|
|
106
|
+
return self.job.job_id
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def trace_id(self) -> str:
|
|
110
|
+
return self.packet.trace_id
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
TIn = TypeVar("TIn")
|
|
114
|
+
TOut = TypeVar("TOut")
|
|
115
|
+
TAny = TypeVar("TAny")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclass
|
|
119
|
+
class HandlerSpec:
|
|
120
|
+
topic: str
|
|
121
|
+
func: Callable[[Context, Any], Union[Awaitable[Any], Any]]
|
|
122
|
+
input_model: Optional[Type[Any]]
|
|
123
|
+
output_model: Optional[Type[Any]]
|
|
124
|
+
retries: int
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Agent:
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
*,
|
|
131
|
+
nats_url: Optional[str] = None,
|
|
132
|
+
redis_url: Optional[str] = None,
|
|
133
|
+
store: Optional[BlobStore] = None,
|
|
134
|
+
public_keys: Optional[Dict[str, ec.EllipticCurvePublicKey]] = None,
|
|
135
|
+
private_key: Optional[ec.EllipticCurvePrivateKey] = None,
|
|
136
|
+
sender_id: str = "cap-runtime",
|
|
137
|
+
retries: int = 0,
|
|
138
|
+
io_timeout: Optional[float] = 5.0,
|
|
139
|
+
max_context_bytes: Optional[int] = 2 * 1024 * 1024,
|
|
140
|
+
max_result_bytes: Optional[int] = 2 * 1024 * 1024,
|
|
141
|
+
connect_fn: Optional[Callable[..., Awaitable[Any]]] = None,
|
|
142
|
+
logger: Optional[logging.Logger] = None,
|
|
143
|
+
) -> None:
|
|
144
|
+
self._nats_url = nats_url or os.getenv("NATS_URL", "nats://127.0.0.1:4222")
|
|
145
|
+
self._redis_url = redis_url or os.getenv("REDIS_URL", "redis://127.0.0.1:6379/0")
|
|
146
|
+
self._store = store
|
|
147
|
+
self._public_keys = public_keys
|
|
148
|
+
self._private_key = private_key
|
|
149
|
+
self._sender_id = sender_id
|
|
150
|
+
self._default_retries = max(0, retries)
|
|
151
|
+
self._io_timeout = io_timeout if io_timeout and io_timeout > 0 else None
|
|
152
|
+
self._max_context_bytes = max_context_bytes if max_context_bytes and max_context_bytes > 0 else None
|
|
153
|
+
self._max_result_bytes = max_result_bytes if max_result_bytes and max_result_bytes > 0 else None
|
|
154
|
+
self._connect_fn = connect_fn
|
|
155
|
+
self._logger = logger or _default_logger()
|
|
156
|
+
self._handlers: Dict[str, HandlerSpec] = {}
|
|
157
|
+
self._nc = None
|
|
158
|
+
|
|
159
|
+
def job(
|
|
160
|
+
self,
|
|
161
|
+
topic: str,
|
|
162
|
+
*,
|
|
163
|
+
input_model: Optional[Type[Any]] = None,
|
|
164
|
+
output_model: Optional[Type[Any]] = None,
|
|
165
|
+
retries: Optional[int] = None,
|
|
166
|
+
) -> Callable[
|
|
167
|
+
[Callable[[Context, Any], Union[Awaitable[Any], Any]]],
|
|
168
|
+
Callable[[Context, Any], Union[Awaitable[Any], Any]],
|
|
169
|
+
]:
|
|
170
|
+
def decorator(func: Callable[[Context, Any], Union[Awaitable[Any], Any]]):
|
|
171
|
+
spec = HandlerSpec(
|
|
172
|
+
topic=topic,
|
|
173
|
+
func=func,
|
|
174
|
+
input_model=input_model,
|
|
175
|
+
output_model=output_model,
|
|
176
|
+
retries=self._default_retries if retries is None else max(0, retries),
|
|
177
|
+
)
|
|
178
|
+
self._handlers[topic] = spec
|
|
179
|
+
return func
|
|
180
|
+
|
|
181
|
+
return decorator
|
|
182
|
+
|
|
183
|
+
async def start(self) -> None:
|
|
184
|
+
if not self._handlers:
|
|
185
|
+
raise RuntimeError("no handlers registered")
|
|
186
|
+
if self._connect_fn is None:
|
|
187
|
+
try:
|
|
188
|
+
import nats # type: ignore
|
|
189
|
+
except ImportError as exc:
|
|
190
|
+
raise RuntimeError("nats-py is required to connect to NATS") from exc
|
|
191
|
+
self._connect_fn = nats.connect
|
|
192
|
+
|
|
193
|
+
self._nc = await self._with_timeout(
|
|
194
|
+
self._connect_fn(servers=self._nats_url, name=self._sender_id),
|
|
195
|
+
"nats connect",
|
|
196
|
+
)
|
|
197
|
+
if self._store is None:
|
|
198
|
+
self._store = RedisBlobStore(self._redis_url)
|
|
199
|
+
|
|
200
|
+
for topic, spec in self._handlers.items():
|
|
201
|
+
await self._nc.subscribe(topic, queue=topic, cb=lambda msg, s=spec: asyncio.create_task(self._on_msg(msg, s)))
|
|
202
|
+
|
|
203
|
+
async def close(self) -> None:
|
|
204
|
+
if self._nc is not None:
|
|
205
|
+
await self._nc.drain()
|
|
206
|
+
if self._store is not None:
|
|
207
|
+
await self._store.close()
|
|
208
|
+
|
|
209
|
+
async def run(self) -> None:
|
|
210
|
+
await self.start()
|
|
211
|
+
try:
|
|
212
|
+
while True:
|
|
213
|
+
await asyncio.sleep(1)
|
|
214
|
+
finally:
|
|
215
|
+
await self.close()
|
|
216
|
+
|
|
217
|
+
async def _on_msg(self, msg: Any, spec: HandlerSpec) -> None:
|
|
218
|
+
packet = buspacket_pb2.BusPacket()
|
|
219
|
+
try:
|
|
220
|
+
packet.ParseFromString(msg.data)
|
|
221
|
+
except Exception as exc:
|
|
222
|
+
self._logger.error("runtime: decode failed: %s", exc)
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
if self._public_keys is not None:
|
|
226
|
+
sender_key = self._public_keys.get(packet.sender_id)
|
|
227
|
+
if not sender_key:
|
|
228
|
+
self._logger.warning("runtime: no public key for sender %s", packet.sender_id)
|
|
229
|
+
return
|
|
230
|
+
if not packet.signature:
|
|
231
|
+
self._logger.warning("runtime: missing signature for sender %s", packet.sender_id)
|
|
232
|
+
return
|
|
233
|
+
signature = packet.signature
|
|
234
|
+
packet.ClearField("signature")
|
|
235
|
+
unsigned = packet.SerializeToString(deterministic=True)
|
|
236
|
+
packet.signature = signature
|
|
237
|
+
try:
|
|
238
|
+
sender_key.verify(signature, unsigned, ec.ECDSA(hashes.SHA256()))
|
|
239
|
+
except Exception:
|
|
240
|
+
self._logger.warning("runtime: invalid signature from sender %s", packet.sender_id)
|
|
241
|
+
return
|
|
242
|
+
|
|
243
|
+
req = packet.job_request
|
|
244
|
+
if not req.job_id:
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
ctx_logger = logging.LoggerAdapter(
|
|
248
|
+
self._logger,
|
|
249
|
+
{
|
|
250
|
+
"job_id": req.job_id,
|
|
251
|
+
"trace_id": packet.trace_id,
|
|
252
|
+
"topic": req.topic,
|
|
253
|
+
},
|
|
254
|
+
)
|
|
255
|
+
ctx = Context(job=req, packet=packet, logger=ctx_logger)
|
|
256
|
+
|
|
257
|
+
store = self._store
|
|
258
|
+
if store is None:
|
|
259
|
+
ctx_logger.error("runtime: blob store not initialized")
|
|
260
|
+
return
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
key = key_from_pointer(req.context_ptr)
|
|
264
|
+
payload = await self._with_timeout(store.get(key), "context fetch")
|
|
265
|
+
if payload is None:
|
|
266
|
+
raise ValueError("context not found")
|
|
267
|
+
if self._max_context_bytes is not None and len(payload) > self._max_context_bytes:
|
|
268
|
+
raise ValueError("context exceeds max size")
|
|
269
|
+
except Exception as exc:
|
|
270
|
+
await self._publish_failure(ctx, req, str(exc), execution_ms=0)
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
raw = json.loads(payload.decode("utf-8"))
|
|
275
|
+
except Exception as exc:
|
|
276
|
+
await self._publish_failure(ctx, req, f"context decode failed: {exc}", execution_ms=0)
|
|
277
|
+
return
|
|
278
|
+
|
|
279
|
+
try:
|
|
280
|
+
input_data = self._validate_input(spec, raw)
|
|
281
|
+
except Exception as exc:
|
|
282
|
+
await self._publish_failure(ctx, req, f"input validation failed: {exc}", execution_ms=0)
|
|
283
|
+
return
|
|
284
|
+
|
|
285
|
+
start_time = time.monotonic()
|
|
286
|
+
error: Optional[str] = None
|
|
287
|
+
output: Any = None
|
|
288
|
+
for attempt in range(spec.retries + 1):
|
|
289
|
+
try:
|
|
290
|
+
output = spec.func(ctx, input_data)
|
|
291
|
+
if asyncio.iscoroutine(output):
|
|
292
|
+
output = await output
|
|
293
|
+
output = self._validate_output(spec, output)
|
|
294
|
+
error = None
|
|
295
|
+
break
|
|
296
|
+
except Exception as exc:
|
|
297
|
+
error = str(exc)
|
|
298
|
+
ctx_logger.warning("runtime: handler failed (attempt %d/%d): %s", attempt + 1, spec.retries + 1, exc)
|
|
299
|
+
if attempt >= spec.retries:
|
|
300
|
+
break
|
|
301
|
+
|
|
302
|
+
elapsed_ms = int((time.monotonic() - start_time) * 1000)
|
|
303
|
+
if error is not None:
|
|
304
|
+
await self._publish_failure(ctx, req, error, execution_ms=elapsed_ms)
|
|
305
|
+
return
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
result_payload = self._serialize_output(output)
|
|
309
|
+
if self._max_result_bytes is not None and len(result_payload) > self._max_result_bytes:
|
|
310
|
+
raise ValueError("result exceeds max size")
|
|
311
|
+
result_key = f"res:{req.job_id}"
|
|
312
|
+
await self._with_timeout(store.set(result_key, result_payload), "result write")
|
|
313
|
+
result_ptr = pointer_for_key(result_key)
|
|
314
|
+
except Exception as exc:
|
|
315
|
+
await self._publish_failure(ctx, req, f"result write failed: {exc}", execution_ms=elapsed_ms)
|
|
316
|
+
return
|
|
317
|
+
|
|
318
|
+
result = job_pb2.JobResult(
|
|
319
|
+
job_id=req.job_id,
|
|
320
|
+
status=job_pb2.JOB_STATUS_SUCCEEDED,
|
|
321
|
+
result_ptr=result_ptr,
|
|
322
|
+
worker_id=self._sender_id,
|
|
323
|
+
execution_ms=elapsed_ms,
|
|
324
|
+
)
|
|
325
|
+
await self._publish_result(ctx, result)
|
|
326
|
+
|
|
327
|
+
def _validate_input(self, spec: HandlerSpec, data: Any) -> Any:
|
|
328
|
+
if spec.input_model is None:
|
|
329
|
+
return data
|
|
330
|
+
if BaseModel is not None and isinstance(spec.input_model, type) and issubclass(spec.input_model, BaseModel):
|
|
331
|
+
return spec.input_model.model_validate(data)
|
|
332
|
+
return spec.input_model(**data)
|
|
333
|
+
|
|
334
|
+
def _validate_output(self, spec: HandlerSpec, data: Any) -> Any:
|
|
335
|
+
if spec.output_model is None:
|
|
336
|
+
return data
|
|
337
|
+
if BaseModel is not None and isinstance(spec.output_model, type) and issubclass(spec.output_model, BaseModel):
|
|
338
|
+
return spec.output_model.model_validate(data)
|
|
339
|
+
return spec.output_model(**data)
|
|
340
|
+
|
|
341
|
+
def _serialize_output(self, data: Any) -> bytes:
|
|
342
|
+
if BaseModel is not None and isinstance(data, BaseModel):
|
|
343
|
+
return json.dumps(data.model_dump(mode="json")).encode("utf-8")
|
|
344
|
+
if isinstance(data, (dict, list, str, int, float, bool)) or data is None:
|
|
345
|
+
return json.dumps(data).encode("utf-8")
|
|
346
|
+
if hasattr(data, "__dict__"):
|
|
347
|
+
return json.dumps(data.__dict__).encode("utf-8")
|
|
348
|
+
raise ValueError("output is not JSON serializable")
|
|
349
|
+
|
|
350
|
+
async def _publish_failure(
|
|
351
|
+
self,
|
|
352
|
+
ctx: Context,
|
|
353
|
+
req: job_pb2.JobRequest,
|
|
354
|
+
error: str,
|
|
355
|
+
execution_ms: int,
|
|
356
|
+
) -> None:
|
|
357
|
+
result = job_pb2.JobResult(
|
|
358
|
+
job_id=req.job_id,
|
|
359
|
+
status=job_pb2.JOB_STATUS_FAILED,
|
|
360
|
+
error_message=error,
|
|
361
|
+
worker_id=self._sender_id,
|
|
362
|
+
execution_ms=execution_ms,
|
|
363
|
+
)
|
|
364
|
+
await self._publish_result(ctx, result)
|
|
365
|
+
|
|
366
|
+
async def _publish_result(self, ctx: Context, result: job_pb2.JobResult) -> None:
|
|
367
|
+
if self._nc is None:
|
|
368
|
+
ctx.logger.error("runtime: NATS not initialized")
|
|
369
|
+
return
|
|
370
|
+
packet = buspacket_pb2.BusPacket()
|
|
371
|
+
packet.trace_id = ctx.packet.trace_id
|
|
372
|
+
packet.sender_id = self._sender_id
|
|
373
|
+
packet.protocol_version = DEFAULT_PROTOCOL_VERSION
|
|
374
|
+
ts = timestamp_pb2.Timestamp()
|
|
375
|
+
ts.GetCurrentTime()
|
|
376
|
+
packet.created_at.CopyFrom(ts)
|
|
377
|
+
packet.job_result.CopyFrom(result)
|
|
378
|
+
|
|
379
|
+
if self._private_key is not None:
|
|
380
|
+
unsigned = packet.SerializeToString(deterministic=True)
|
|
381
|
+
packet.signature = self._private_key.sign(unsigned, ec.ECDSA(hashes.SHA256()))
|
|
382
|
+
|
|
383
|
+
await self._with_timeout(
|
|
384
|
+
self._nc.publish(SUBJECT_RESULT, packet.SerializeToString(deterministic=True)),
|
|
385
|
+
"result publish",
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
async def _with_timeout(self, coro: Awaitable[TAny], label: str) -> TAny:
|
|
389
|
+
if self._io_timeout is None:
|
|
390
|
+
return await coro
|
|
391
|
+
try:
|
|
392
|
+
return await asyncio.wait_for(coro, timeout=self._io_timeout)
|
|
393
|
+
except asyncio.TimeoutError as exc:
|
|
394
|
+
raise TimeoutError(f"{label} timed out") from exc
|
|
@@ -38,7 +38,8 @@ async def run_worker(nats_url: str, subject: str, handler: Callable[[job_pb2.Job
|
|
|
38
38
|
|
|
39
39
|
signature = packet.signature
|
|
40
40
|
packet.ClearField("signature")
|
|
41
|
-
unsigned_data = packet.SerializeToString()
|
|
41
|
+
unsigned_data = packet.SerializeToString(deterministic=True)
|
|
42
|
+
packet.signature = signature
|
|
42
43
|
try:
|
|
43
44
|
public_key.verify(signature, unsigned_data, ec.ECDSA(hashes.SHA256()))
|
|
44
45
|
except Exception:
|
|
@@ -76,11 +77,11 @@ async def run_worker(nats_url: str, subject: str, handler: Callable[[job_pb2.Job
|
|
|
76
77
|
out.job_result.CopyFrom(res)
|
|
77
78
|
|
|
78
79
|
if private_key:
|
|
79
|
-
unsigned_data = out.SerializeToString()
|
|
80
|
+
unsigned_data = out.SerializeToString(deterministic=True)
|
|
80
81
|
signature = private_key.sign(unsigned_data, ec.ECDSA(hashes.SHA256()))
|
|
81
82
|
out.signature = signature
|
|
82
83
|
|
|
83
|
-
await nc.publish(SUBJECT_RESULT, out.SerializeToString())
|
|
84
|
+
await nc.publish(SUBJECT_RESULT, out.SerializeToString(deterministic=True))
|
|
84
85
|
|
|
85
86
|
await nc.subscribe(subject, queue=subject, cb=on_msg)
|
|
86
87
|
try:
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cap-sdk-python
|
|
3
|
+
Version: 2.0.18
|
|
4
|
+
Summary: CAP (Cordum Agent Protocol) Python SDK
|
|
5
|
+
Project-URL: Homepage, https://github.com/cordum-io/cap
|
|
6
|
+
Requires-Python: >=3.9
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: protobuf>=4.25.0
|
|
9
|
+
Requires-Dist: grpcio>=1.59.0
|
|
10
|
+
Requires-Dist: nats-py>=2.6.0
|
|
11
|
+
Requires-Dist: cryptography>=41.0.0
|
|
12
|
+
Requires-Dist: pydantic>=2.6.0
|
|
13
|
+
Requires-Dist: redis>=5.0.0
|
|
14
|
+
|
|
1
15
|
# CAP Python SDK
|
|
2
16
|
|
|
3
17
|
Asyncio-first SDK with NATS helpers for CAP workers and clients.
|
|
@@ -66,7 +80,7 @@ Asyncio-first SDK with NATS helpers for CAP workers and clients.
|
|
|
66
80
|
## Defaults
|
|
67
81
|
- Subjects: `sys.job.submit`, `sys.job.result`, `sys.heartbeat`.
|
|
68
82
|
- Protocol version: `1`.
|
|
69
|
-
- Signing: `submit_job` and `run_worker` sign envelopes when given an `ec.EllipticCurvePrivateKey`. Generate a keypair with `cryptography`:
|
|
83
|
+
- Signing: `submit_job` and `run_worker` sign envelopes when given an `ec.EllipticCurvePrivateKey`. Signatures use deterministic protobuf serialization (map entries ordered by key) for cross-SDK verification. Generate a keypair with `cryptography`:
|
|
70
84
|
```python
|
|
71
85
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
72
86
|
priv = ec.generate_private_key(ec.SECP256R1())
|
|
@@ -77,3 +91,30 @@ Asyncio-first SDK with NATS helpers for CAP workers and clients.
|
|
|
77
91
|
- Pass `private_key=None` to `submit_job` if you want to send unsigned envelopes.
|
|
78
92
|
|
|
79
93
|
Swap out `cap.bus` if you need a different transport.
|
|
94
|
+
|
|
95
|
+
## Runtime (High-Level SDK)
|
|
96
|
+
The runtime hides NATS/Redis plumbing and gives you typed handlers.
|
|
97
|
+
|
|
98
|
+
```python
|
|
99
|
+
import asyncio
|
|
100
|
+
from pydantic import BaseModel
|
|
101
|
+
from cap.runtime import Agent, Context
|
|
102
|
+
|
|
103
|
+
class Input(BaseModel):
|
|
104
|
+
prompt: str
|
|
105
|
+
|
|
106
|
+
class Output(BaseModel):
|
|
107
|
+
summary: str
|
|
108
|
+
|
|
109
|
+
agent = Agent(retries=2)
|
|
110
|
+
|
|
111
|
+
@agent.job("job.summarize", input_model=Input, output_model=Output)
|
|
112
|
+
async def summarize(ctx: Context, data: Input) -> Output:
|
|
113
|
+
return Output(summary=data.prompt[:140])
|
|
114
|
+
|
|
115
|
+
asyncio.run(agent.run())
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
### Environment
|
|
119
|
+
- `NATS_URL` (default `nats://127.0.0.1:4222`)
|
|
120
|
+
- `REDIS_URL` (default `redis://127.0.0.1:6379/0`)
|
|
@@ -3,6 +3,7 @@ pyproject.toml
|
|
|
3
3
|
cap/__init__.py
|
|
4
4
|
cap/bus.py
|
|
5
5
|
cap/client.py
|
|
6
|
+
cap/runtime.py
|
|
6
7
|
cap/worker.py
|
|
7
8
|
cap/pb/__init__.py
|
|
8
9
|
cap/pb/cordum/__init__.py
|
|
@@ -23,4 +24,6 @@ cap_sdk_python.egg-info/SOURCES.txt
|
|
|
23
24
|
cap_sdk_python.egg-info/dependency_links.txt
|
|
24
25
|
cap_sdk_python.egg-info/requires.txt
|
|
25
26
|
cap_sdk_python.egg-info/top_level.txt
|
|
27
|
+
tests/test_conformance.py
|
|
28
|
+
tests/test_runtime.py
|
|
26
29
|
tests/test_sdk.py
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "cap-sdk-python"
|
|
7
|
-
version = "2.0.
|
|
7
|
+
version = "2.0.18"
|
|
8
8
|
description = "CAP (Cordum Agent Protocol) Python SDK"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.9"
|
|
@@ -12,6 +12,9 @@ dependencies = [
|
|
|
12
12
|
"protobuf>=4.25.0",
|
|
13
13
|
"grpcio>=1.59.0",
|
|
14
14
|
"nats-py>=2.6.0",
|
|
15
|
+
"cryptography>=41.0.0",
|
|
16
|
+
"pydantic>=2.6.0",
|
|
17
|
+
"redis>=5.0.0",
|
|
15
18
|
]
|
|
16
19
|
|
|
17
20
|
[project.urls]
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import unittest
|
|
4
|
+
|
|
5
|
+
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
6
|
+
_sdk_root = os.path.join(_repo_root, "sdk", "python")
|
|
7
|
+
|
|
8
|
+
# Avoid loading duplicate generated stubs from both /python and sdk/python/cap/pb.
|
|
9
|
+
sys.path = [p for p in sys.path if not p.rstrip("/").endswith("python")]
|
|
10
|
+
# Ensure the SDK package and generated modules are discoverable from repo root.
|
|
11
|
+
sys.path.insert(0, _sdk_root)
|
|
12
|
+
sys.path.append(os.path.join(_sdk_root, "cap", "pb"))
|
|
13
|
+
|
|
14
|
+
from cap.pb.cordum.agent.v1 import buspacket_pb2
|
|
15
|
+
from cryptography.hazmat.primitives import hashes, serialization
|
|
16
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
FIXTURE_DIR = os.path.join(_repo_root, "spec", "conformance", "fixtures")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_packet(name: str) -> buspacket_pb2.BusPacket:
|
|
23
|
+
pkt = buspacket_pb2.BusPacket()
|
|
24
|
+
with open(os.path.join(FIXTURE_DIR, name), "rb") as handle:
|
|
25
|
+
pkt.ParseFromString(handle.read())
|
|
26
|
+
return pkt
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestConformanceFixtures(unittest.TestCase):
|
|
30
|
+
@classmethod
|
|
31
|
+
def setUpClass(cls):
|
|
32
|
+
with open(os.path.join(FIXTURE_DIR, "public_key.pem"), "rb") as handle:
|
|
33
|
+
cls.public_key = serialization.load_pem_public_key(handle.read())
|
|
34
|
+
|
|
35
|
+
def _verify_signature(self, pkt: buspacket_pb2.BusPacket):
|
|
36
|
+
signature = pkt.signature
|
|
37
|
+
pkt.ClearField("signature")
|
|
38
|
+
unsigned = pkt.SerializeToString(deterministic=True)
|
|
39
|
+
pkt.signature = signature
|
|
40
|
+
self.public_key.verify(signature, unsigned, ec.ECDSA(hashes.SHA256()))
|
|
41
|
+
|
|
42
|
+
def _assert_common(self, pkt: buspacket_pb2.BusPacket, trace_id: str, sender_id: str):
|
|
43
|
+
self._verify_signature(pkt)
|
|
44
|
+
self.assertEqual(pkt.trace_id, trace_id)
|
|
45
|
+
self.assertEqual(pkt.sender_id, sender_id)
|
|
46
|
+
self.assertEqual(pkt.protocol_version, 1)
|
|
47
|
+
self.assertEqual(pkt.created_at.seconds, 1704164645)
|
|
48
|
+
|
|
49
|
+
def test_job_request_fixture(self):
|
|
50
|
+
pkt = load_packet("buspacket_job_request.bin")
|
|
51
|
+
self._assert_common(pkt, "trace-job-request", "client-1")
|
|
52
|
+
req = pkt.job_request
|
|
53
|
+
self.assertEqual(req.job_id, "job-req-1")
|
|
54
|
+
self.assertEqual(req.topic, "job.tools")
|
|
55
|
+
self.assertEqual(req.priority, 1)
|
|
56
|
+
self.assertEqual(req.context_ptr, "redis://ctx:job-req-1")
|
|
57
|
+
self.assertEqual(req.env["region"], "us-east-1")
|
|
58
|
+
self.assertEqual(req.env["sandbox"], "true")
|
|
59
|
+
self.assertEqual(req.labels["env"], "prod")
|
|
60
|
+
self.assertEqual(req.labels["team"], "platform")
|
|
61
|
+
self.assertEqual(req.meta.idempotency_key, "idem-123")
|
|
62
|
+
self.assertEqual(req.meta.labels["source"], "conformance")
|
|
63
|
+
self.assertEqual(req.compensation.topic, "job.rollback")
|
|
64
|
+
self.assertEqual(req.compensation.labels["rollback"], "true")
|
|
65
|
+
|
|
66
|
+
def test_job_result_fixture(self):
|
|
67
|
+
pkt = load_packet("buspacket_job_result.bin")
|
|
68
|
+
self._assert_common(pkt, "trace-job-result", "worker-1")
|
|
69
|
+
res = pkt.job_result
|
|
70
|
+
self.assertEqual(res.job_id, "job-res-1")
|
|
71
|
+
self.assertEqual(res.worker_id, "worker-1")
|
|
72
|
+
self.assertEqual(res.status, 10)
|
|
73
|
+
self.assertEqual(res.error_code, "E_TEMP")
|
|
74
|
+
self.assertEqual(len(res.artifact_ptrs), 2)
|
|
75
|
+
|
|
76
|
+
def test_heartbeat_fixture(self):
|
|
77
|
+
pkt = load_packet("buspacket_heartbeat.bin")
|
|
78
|
+
self._assert_common(pkt, "trace-heartbeat", "worker-1")
|
|
79
|
+
hb = pkt.heartbeat
|
|
80
|
+
self.assertEqual(hb.worker_id, "worker-1")
|
|
81
|
+
self.assertEqual(hb.pool, "job.tools")
|
|
82
|
+
self.assertEqual(hb.labels["zone"], "us-east-1a")
|
|
83
|
+
self.assertEqual(hb.progress_pct, 60)
|
|
84
|
+
|
|
85
|
+
def test_job_progress_fixture(self):
|
|
86
|
+
pkt = load_packet("buspacket_job_progress.bin")
|
|
87
|
+
self._assert_common(pkt, "trace-progress", "worker-1")
|
|
88
|
+
progress = pkt.job_progress
|
|
89
|
+
self.assertEqual(progress.job_id, "job-prog-1")
|
|
90
|
+
self.assertEqual(progress.percent, 50)
|
|
91
|
+
self.assertEqual(progress.status, 4)
|
|
92
|
+
|
|
93
|
+
def test_job_cancel_fixture(self):
|
|
94
|
+
pkt = load_packet("buspacket_job_cancel.bin")
|
|
95
|
+
self._assert_common(pkt, "trace-cancel", "scheduler-1")
|
|
96
|
+
cancel = pkt.job_cancel
|
|
97
|
+
self.assertEqual(cancel.job_id, "job-cancel-1")
|
|
98
|
+
self.assertEqual(cancel.requested_by, "user-7")
|
|
99
|
+
|
|
100
|
+
def test_alert_fixture(self):
|
|
101
|
+
pkt = load_packet("buspacket_alert.bin")
|
|
102
|
+
self._assert_common(pkt, "trace-alert", "scheduler-1")
|
|
103
|
+
alert = pkt.alert
|
|
104
|
+
self.assertEqual(alert.level, "WARN")
|
|
105
|
+
self.assertEqual(alert.component, "scheduler")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == "__main__":
|
|
109
|
+
unittest.main()
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import unittest
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from cap.runtime import Agent, InMemoryBlobStore
|
|
8
|
+
from cap.pb.cordum.agent.v1 import buspacket_pb2, job_pb2
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MockNATS:
|
|
12
|
+
def __init__(self):
|
|
13
|
+
self.subscriptions = {}
|
|
14
|
+
self.published = asyncio.Queue()
|
|
15
|
+
|
|
16
|
+
async def publish(self, subject, data):
|
|
17
|
+
await self.published.put((subject, data))
|
|
18
|
+
|
|
19
|
+
async def subscribe(self, subject, queue, cb):
|
|
20
|
+
self.subscriptions[subject] = cb
|
|
21
|
+
|
|
22
|
+
async def connect(self, servers, name):
|
|
23
|
+
return self
|
|
24
|
+
|
|
25
|
+
async def drain(self):
|
|
26
|
+
return None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class InputModel(BaseModel):
|
|
30
|
+
prompt: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OutputModel(BaseModel):
|
|
34
|
+
summary: str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TestRuntime(unittest.IsolatedAsyncioTestCase):
|
|
38
|
+
async def asyncSetUp(self):
|
|
39
|
+
self.store = InMemoryBlobStore()
|
|
40
|
+
self.mock = MockNATS()
|
|
41
|
+
|
|
42
|
+
async def _send_job(self, job_id: str, topic: str, context_ptr: str):
|
|
43
|
+
req = job_pb2.JobRequest(job_id=job_id, topic=topic, context_ptr=context_ptr)
|
|
44
|
+
packet = buspacket_pb2.BusPacket()
|
|
45
|
+
packet.trace_id = "trace-1"
|
|
46
|
+
packet.sender_id = "client-1"
|
|
47
|
+
packet.protocol_version = 1
|
|
48
|
+
packet.job_request.CopyFrom(req)
|
|
49
|
+
await self.mock.subscriptions[topic](type("obj", (object,), {"data": packet.SerializeToString(deterministic=True)})())
|
|
50
|
+
|
|
51
|
+
async def test_runtime_success(self):
|
|
52
|
+
job_id = "job-1"
|
|
53
|
+
ctx_key = f"ctx:{job_id}"
|
|
54
|
+
await self.store.set(ctx_key, json.dumps({"prompt": "hello"}).encode("utf-8"))
|
|
55
|
+
|
|
56
|
+
agent = Agent(store=self.store, connect_fn=self.mock.connect, sender_id="worker-1")
|
|
57
|
+
|
|
58
|
+
@agent.job("job.test", input_model=InputModel, output_model=OutputModel)
|
|
59
|
+
async def handler(ctx, data: InputModel) -> OutputModel:
|
|
60
|
+
return OutputModel(summary=data.prompt.upper())
|
|
61
|
+
|
|
62
|
+
await agent.start()
|
|
63
|
+
await self._send_job(job_id, "job.test", f"redis://{ctx_key}")
|
|
64
|
+
|
|
65
|
+
subject, payload = await asyncio.wait_for(self.mock.published.get(), timeout=1)
|
|
66
|
+
self.assertEqual(subject, "sys.job.result")
|
|
67
|
+
result_packet = buspacket_pb2.BusPacket()
|
|
68
|
+
result_packet.ParseFromString(payload)
|
|
69
|
+
self.assertEqual(result_packet.job_result.status, job_pb2.JOB_STATUS_SUCCEEDED)
|
|
70
|
+
|
|
71
|
+
result_data = await self.store.get(f"res:{job_id}")
|
|
72
|
+
self.assertIsNotNone(result_data)
|
|
73
|
+
parsed = json.loads(result_data.decode("utf-8"))
|
|
74
|
+
self.assertEqual(parsed["summary"], "HELLO")
|
|
75
|
+
|
|
76
|
+
await agent.close()
|
|
77
|
+
|
|
78
|
+
async def test_runtime_input_validation_failure(self):
|
|
79
|
+
job_id = "job-2"
|
|
80
|
+
ctx_key = f"ctx:{job_id}"
|
|
81
|
+
await self.store.set(ctx_key, json.dumps({"wrong": "field"}).encode("utf-8"))
|
|
82
|
+
|
|
83
|
+
agent = Agent(store=self.store, connect_fn=self.mock.connect, sender_id="worker-2")
|
|
84
|
+
|
|
85
|
+
@agent.job("job.validate", input_model=InputModel, output_model=OutputModel)
|
|
86
|
+
async def handler(ctx, data: InputModel) -> OutputModel:
|
|
87
|
+
return OutputModel(summary=data.prompt)
|
|
88
|
+
|
|
89
|
+
await agent.start()
|
|
90
|
+
await self._send_job(job_id, "job.validate", f"redis://{ctx_key}")
|
|
91
|
+
|
|
92
|
+
subject, payload = await asyncio.wait_for(self.mock.published.get(), timeout=1)
|
|
93
|
+
self.assertEqual(subject, "sys.job.result")
|
|
94
|
+
result_packet = buspacket_pb2.BusPacket()
|
|
95
|
+
result_packet.ParseFromString(payload)
|
|
96
|
+
self.assertEqual(result_packet.job_result.status, job_pb2.JOB_STATUS_FAILED)
|
|
97
|
+
self.assertIn("input validation failed", result_packet.job_result.error_message)
|
|
98
|
+
|
|
99
|
+
await agent.close()
|
|
100
|
+
|
|
101
|
+
async def test_runtime_retries(self):
|
|
102
|
+
job_id = "job-3"
|
|
103
|
+
ctx_key = f"ctx:{job_id}"
|
|
104
|
+
await self.store.set(ctx_key, json.dumps({"prompt": "retry"}).encode("utf-8"))
|
|
105
|
+
|
|
106
|
+
agent = Agent(store=self.store, connect_fn=self.mock.connect, sender_id="worker-3", retries=1)
|
|
107
|
+
attempts = {"count": 0}
|
|
108
|
+
|
|
109
|
+
@agent.job("job.retry", input_model=InputModel, output_model=OutputModel)
|
|
110
|
+
async def handler(ctx, data: InputModel) -> OutputModel:
|
|
111
|
+
attempts["count"] += 1
|
|
112
|
+
if attempts["count"] == 1:
|
|
113
|
+
raise RuntimeError("boom")
|
|
114
|
+
return OutputModel(summary=data.prompt)
|
|
115
|
+
|
|
116
|
+
await agent.start()
|
|
117
|
+
await self._send_job(job_id, "job.retry", f"redis://{ctx_key}")
|
|
118
|
+
|
|
119
|
+
subject, payload = await asyncio.wait_for(self.mock.published.get(), timeout=1)
|
|
120
|
+
self.assertEqual(subject, "sys.job.result")
|
|
121
|
+
result_packet = buspacket_pb2.BusPacket()
|
|
122
|
+
result_packet.ParseFromString(payload)
|
|
123
|
+
self.assertEqual(result_packet.job_result.status, job_pb2.JOB_STATUS_SUCCEEDED)
|
|
124
|
+
self.assertEqual(attempts["count"], 2)
|
|
125
|
+
|
|
126
|
+
await agent.close()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
if __name__ == "__main__":
|
|
130
|
+
unittest.main()
|
|
@@ -93,7 +93,7 @@ class TestSDK(unittest.TestCase):
|
|
|
93
93
|
result_packet.ParseFromString(data)
|
|
94
94
|
signature = result_packet.signature
|
|
95
95
|
result_packet.ClearField("signature")
|
|
96
|
-
unsigned_data = result_packet.SerializeToString()
|
|
96
|
+
unsigned_data = result_packet.SerializeToString(deterministic=True)
|
|
97
97
|
worker_key.public_key().verify(signature, unsigned_data, ec.ECDSA(hashes.SHA256()))
|
|
98
98
|
|
|
99
99
|
worker_task.cancel()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/buspacket_pb2_grpc.py
RENAMED
|
File without changes
|
|
File without changes
|
{cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap/pb/cordum/agent/v1/heartbeat_pb2_grpc.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cap_sdk_python-2.0.16 → cap_sdk_python-2.0.18}/cap_sdk_python.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|