acp-sdk 0.8.3__tar.gz → 0.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/PKG-INFO +2 -2
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/pyproject.toml +2 -2
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/client/client.py +4 -1
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/models/models.py +17 -1
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/agent.py +11 -3
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/app.py +9 -3
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/bundle.py +4 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/server.py +150 -15
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/types.py +1 -1
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/fixtures/server.py +13 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/test_suites/test_runs.py +24 -6
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/unit/client/test_client.py +22 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/unit/models/test_models.py +41 -1
- acp_sdk-0.9.0/tests/unit/server/__init__.py +0 -0
- acp_sdk-0.9.0/tests/unit/server/test_server.py +31 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/.gitignore +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/.python-version +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/README.md +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/.gitignore +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/Makefile +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/conf.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/index.rst +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/make.bat +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/pytest.ini +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/__init__.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/client/__init__.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/client/types.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/client/utils.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/instrumentation.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/models/__init__.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/models/errors.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/models/schemas.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/py.typed +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/__init__.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/context.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/errors.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/logging.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/session.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/telemetry.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/utils.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/version.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/conftest.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/__init__.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/config.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/fixtures/__init__.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/fixtures/client.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/test_suites/__init__.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/test_suites/test_discovery.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/unit/client/test_utils.py +0 -0
- {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/unit/models/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: acp-sdk
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.9.0
|
4
4
|
Summary: Agent Communication Protocol SDK
|
5
5
|
Author: IBM Corp.
|
6
6
|
Maintainer-email: Tomas Pilar <thomas7pilar@gmail.com>
|
@@ -16,7 +16,7 @@ Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31.1
|
|
16
16
|
Requires-Dist: opentelemetry-instrumentation-fastapi>=0.52b1
|
17
17
|
Requires-Dist: opentelemetry-instrumentation-httpx>=0.52b1
|
18
18
|
Requires-Dist: opentelemetry-sdk>=1.31.1
|
19
|
-
Requires-Dist: pydantic>=2.
|
19
|
+
Requires-Dist: pydantic>=2.0.0
|
20
20
|
Description-Content-Type: text/markdown
|
21
21
|
|
22
22
|
# Agent Communication Protocol SDK for Python
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "acp-sdk"
|
3
|
-
version = "0.
|
3
|
+
version = "0.9.0"
|
4
4
|
description = "Agent Communication Protocol SDK"
|
5
5
|
license = "Apache-2.0"
|
6
6
|
readme = "README.md"
|
@@ -9,7 +9,7 @@ maintainers = [{ name = "Tomas Pilar", email = "thomas7pilar@gmail.com" }]
|
|
9
9
|
requires-python = ">=3.11, <4.0"
|
10
10
|
dependencies = [
|
11
11
|
"opentelemetry-api>=1.31.1",
|
12
|
-
"pydantic>=2.
|
12
|
+
"pydantic>=2.0.0",
|
13
13
|
"httpx>=0.26.0",
|
14
14
|
"httpx-sse>=0.4.0",
|
15
15
|
"opentelemetry-instrumentation-httpx>=0.52b1",
|
@@ -22,6 +22,7 @@ from acp_sdk.models import (
|
|
22
22
|
AgentsListResponse,
|
23
23
|
AwaitResume,
|
24
24
|
Error,
|
25
|
+
ErrorEvent,
|
25
26
|
Event,
|
26
27
|
PingResponse,
|
27
28
|
Run,
|
@@ -224,7 +225,9 @@ class Client:
|
|
224
225
|
await event_source.response.aread()
|
225
226
|
self._raise_error(event_source.response)
|
226
227
|
async for event in event_source.aiter_sse():
|
227
|
-
event = TypeAdapter(Event).validate_json(event.data)
|
228
|
+
event: Event = TypeAdapter(Event).validate_json(event.data)
|
229
|
+
if isinstance(event, ErrorEvent):
|
230
|
+
raise ACPError(error=event.error)
|
228
231
|
yield event
|
229
232
|
|
230
233
|
def _raise_error(self, response: httpx.Response) -> None:
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import asyncio
|
1
2
|
import uuid
|
2
3
|
from datetime import datetime, timezone
|
3
4
|
from enum import Enum
|
@@ -5,7 +6,7 @@ from typing import Any, Literal, Optional, Union
|
|
5
6
|
|
6
7
|
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
7
8
|
|
8
|
-
from acp_sdk.models.errors import Error
|
9
|
+
from acp_sdk.models.errors import ACPError, Error
|
9
10
|
|
10
11
|
|
11
12
|
class AnyModel(BaseModel):
|
@@ -196,6 +197,15 @@ class Run(BaseModel):
|
|
196
197
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
197
198
|
finished_at: datetime | None = None
|
198
199
|
|
200
|
+
def raise_for_status(self) -> "Run":
|
201
|
+
match self.status:
|
202
|
+
case RunStatus.CANCELLED:
|
203
|
+
raise asyncio.CancelledError()
|
204
|
+
case RunStatus.FAILED:
|
205
|
+
raise ACPError(error=self.error)
|
206
|
+
case _:
|
207
|
+
return self
|
208
|
+
|
199
209
|
|
200
210
|
class MessageCreatedEvent(BaseModel):
|
201
211
|
type: Literal["message.created"] = "message.created"
|
@@ -252,7 +262,13 @@ class RunCompletedEvent(BaseModel):
|
|
252
262
|
run: Run
|
253
263
|
|
254
264
|
|
265
|
+
class ErrorEvent(BaseModel):
|
266
|
+
type: Literal["error"] = "error"
|
267
|
+
error: Error
|
268
|
+
|
269
|
+
|
255
270
|
Event = Union[
|
271
|
+
ErrorEvent,
|
256
272
|
RunCreatedEvent,
|
257
273
|
RunInProgressEvent,
|
258
274
|
MessageCreatedEvent,
|
@@ -58,13 +58,13 @@ class Agent(abc.ABC):
|
|
58
58
|
run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context)
|
59
59
|
|
60
60
|
try:
|
61
|
-
while
|
61
|
+
while not run.done() or yield_queue.async_q.qsize() > 0:
|
62
62
|
value = yield await yield_queue.async_q.get()
|
63
|
+
if isinstance(value, Exception):
|
64
|
+
raise value
|
63
65
|
await yield_resume_queue.async_q.put(value)
|
64
66
|
except janus.AsyncQueueShutDown:
|
65
67
|
pass
|
66
|
-
finally:
|
67
|
-
await run # Raise exceptions
|
68
68
|
|
69
69
|
async def _run_async_gen(self, input: list[Message], context: Context) -> None:
|
70
70
|
try:
|
@@ -74,12 +74,16 @@ class Agent(abc.ABC):
|
|
74
74
|
value = await context.yield_async(await gen.asend(value))
|
75
75
|
except StopAsyncIteration:
|
76
76
|
pass
|
77
|
+
except Exception as e:
|
78
|
+
await context.yield_async(e)
|
77
79
|
finally:
|
78
80
|
context.shutdown()
|
79
81
|
|
80
82
|
async def _run_coro(self, input: list[Message], context: Context) -> None:
|
81
83
|
try:
|
82
84
|
await context.yield_async(await self.run(input, context))
|
85
|
+
except Exception as e:
|
86
|
+
await context.yield_async(e)
|
83
87
|
finally:
|
84
88
|
context.shutdown()
|
85
89
|
|
@@ -91,12 +95,16 @@ class Agent(abc.ABC):
|
|
91
95
|
value = context.yield_sync(gen.send(value))
|
92
96
|
except StopIteration:
|
93
97
|
pass
|
98
|
+
except Exception as e:
|
99
|
+
context.yield_sync(e)
|
94
100
|
finally:
|
95
101
|
context.shutdown()
|
96
102
|
|
97
103
|
def _run_func(self, input: list[Message], context: Context) -> None:
|
98
104
|
try:
|
99
105
|
context.yield_sync(self.run(input, context))
|
106
|
+
except Exception as e:
|
107
|
+
context.yield_sync(e)
|
100
108
|
finally:
|
101
109
|
context.shutdown()
|
102
110
|
|
@@ -6,6 +6,7 @@ from enum import Enum
|
|
6
6
|
|
7
7
|
from cachetools import TTLCache
|
8
8
|
from fastapi import Depends, FastAPI, HTTPException, status
|
9
|
+
from fastapi.applications import AppType, Lifespan
|
9
10
|
from fastapi.encoders import jsonable_encoder
|
10
11
|
from fastapi.responses import JSONResponse, StreamingResponse
|
11
12
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
@@ -53,19 +54,24 @@ def create_app(
|
|
53
54
|
*agents: Agent,
|
54
55
|
run_limit: int = 1000,
|
55
56
|
run_ttl: timedelta = timedelta(hours=1),
|
57
|
+
lifespan: Lifespan[AppType] | None = None,
|
56
58
|
dependencies: list[Depends] | None = None,
|
57
59
|
) -> FastAPI:
|
58
60
|
executor: ThreadPoolExecutor
|
59
61
|
|
60
62
|
@asynccontextmanager
|
61
|
-
async def
|
63
|
+
async def internal_lifespan(app: FastAPI) -> AsyncGenerator[None]:
|
62
64
|
nonlocal executor
|
63
65
|
with ThreadPoolExecutor() as exec:
|
64
66
|
executor = exec
|
65
|
-
|
67
|
+
if not lifespan:
|
68
|
+
yield None
|
69
|
+
else:
|
70
|
+
async with lifespan(app) as state:
|
71
|
+
yield state
|
66
72
|
|
67
73
|
app = FastAPI(
|
68
|
-
lifespan=
|
74
|
+
lifespan=internal_lifespan,
|
69
75
|
dependencies=dependencies,
|
70
76
|
)
|
71
77
|
|
@@ -142,6 +142,8 @@ class RunBundle:
|
|
142
142
|
run_logger.info("Run resumed")
|
143
143
|
elif isinstance(next, Error):
|
144
144
|
raise ACPError(error=next)
|
145
|
+
elif isinstance(next, ACPError):
|
146
|
+
raise next
|
145
147
|
elif next is None:
|
146
148
|
await flush_message()
|
147
149
|
elif isinstance(next, BaseModel):
|
@@ -176,3 +178,5 @@ class RunBundle:
|
|
176
178
|
finally:
|
177
179
|
self.await_or_terminate_event.set()
|
178
180
|
await self.stream_queue.put(None)
|
181
|
+
if not self.task.done():
|
182
|
+
self.task.cancel()
|
@@ -1,12 +1,14 @@
|
|
1
1
|
import asyncio
|
2
2
|
import os
|
3
|
-
from collections.abc import Awaitable
|
3
|
+
from collections.abc import AsyncGenerator, Awaitable
|
4
|
+
from contextlib import asynccontextmanager
|
4
5
|
from datetime import timedelta
|
5
6
|
from typing import Any, Callable
|
6
7
|
|
7
8
|
import requests
|
8
9
|
import uvicorn
|
9
10
|
import uvicorn.config
|
11
|
+
from fastapi import FastAPI
|
10
12
|
|
11
13
|
from acp_sdk.models import Metadata
|
12
14
|
from acp_sdk.server.agent import Agent
|
@@ -20,8 +22,8 @@ from acp_sdk.server.utils import async_request_with_retry
|
|
20
22
|
|
21
23
|
class Server:
|
22
24
|
def __init__(self) -> None:
|
23
|
-
self.
|
24
|
-
self.
|
25
|
+
self.agents: list[Agent] = []
|
26
|
+
self.server: uvicorn.Server | None = None
|
25
27
|
|
26
28
|
def agent(
|
27
29
|
self,
|
@@ -40,10 +42,15 @@ class Server:
|
|
40
42
|
return decorator
|
41
43
|
|
42
44
|
def register(self, *agents: Agent) -> None:
|
43
|
-
self.
|
45
|
+
self.agents.extend(agents)
|
44
46
|
|
45
|
-
|
47
|
+
@asynccontextmanager
|
48
|
+
async def lifespan(self, app: FastAPI) -> AsyncGenerator[None]:
|
49
|
+
yield
|
50
|
+
|
51
|
+
async def serve(
|
46
52
|
self,
|
53
|
+
*,
|
47
54
|
configure_logger: bool = True,
|
48
55
|
configure_telemetry: bool = False,
|
49
56
|
self_registration: bool = True,
|
@@ -101,9 +108,14 @@ class Server:
|
|
101
108
|
factory: bool = False,
|
102
109
|
h11_max_incomplete_event_size: int | None = None,
|
103
110
|
) -> None:
|
104
|
-
if self.
|
111
|
+
if self.server:
|
105
112
|
raise RuntimeError("The server is already running")
|
106
113
|
|
114
|
+
if headers is None:
|
115
|
+
headers = [("server", "acp")]
|
116
|
+
elif not any(k.lower() == "server" for k, _ in headers):
|
117
|
+
headers.append(("server", "acp"))
|
118
|
+
|
107
119
|
import uvicorn
|
108
120
|
|
109
121
|
if configure_logger:
|
@@ -112,7 +124,7 @@ class Server:
|
|
112
124
|
configure_telemetry_func()
|
113
125
|
|
114
126
|
config = uvicorn.Config(
|
115
|
-
create_app(*self.
|
127
|
+
create_app(*self.agents, lifespan=self.lifespan, run_limit=run_limit, run_ttl=run_ttl),
|
116
128
|
host,
|
117
129
|
port,
|
118
130
|
uds,
|
@@ -161,23 +173,139 @@ class Server:
|
|
161
173
|
factory,
|
162
174
|
h11_max_incomplete_event_size,
|
163
175
|
)
|
164
|
-
self.
|
176
|
+
self.server = uvicorn.Server(config)
|
177
|
+
await self._serve(self_registration=self_registration)
|
165
178
|
|
166
|
-
|
179
|
+
def run(
|
180
|
+
self,
|
181
|
+
*,
|
182
|
+
configure_logger: bool = True,
|
183
|
+
configure_telemetry: bool = False,
|
184
|
+
self_registration: bool = True,
|
185
|
+
run_limit: int = 1000,
|
186
|
+
run_ttl: timedelta = timedelta(hours=1),
|
187
|
+
host: str = "127.0.0.1",
|
188
|
+
port: int = 8000,
|
189
|
+
uds: str | None = None,
|
190
|
+
fd: int | None = None,
|
191
|
+
loop: uvicorn.config.LoopSetupType = "auto",
|
192
|
+
http: type[asyncio.Protocol] | uvicorn.config.HTTPProtocolType = "auto",
|
193
|
+
ws: type[asyncio.Protocol] | uvicorn.config.WSProtocolType = "auto",
|
194
|
+
ws_max_size: int = 16 * 1024 * 1024,
|
195
|
+
ws_max_queue: int = 32,
|
196
|
+
ws_ping_interval: float | None = 20.0,
|
197
|
+
ws_ping_timeout: float | None = 20.0,
|
198
|
+
ws_per_message_deflate: bool = True,
|
199
|
+
lifespan: uvicorn.config.LifespanType = "auto",
|
200
|
+
env_file: str | os.PathLike[str] | None = None,
|
201
|
+
log_config: dict[str, Any]
|
202
|
+
| str
|
203
|
+
| uvicorn.config.RawConfigParser
|
204
|
+
| uvicorn.config.IO[Any]
|
205
|
+
| None = uvicorn.config.LOGGING_CONFIG,
|
206
|
+
log_level: str | int | None = None,
|
207
|
+
access_log: bool = True,
|
208
|
+
use_colors: bool | None = None,
|
209
|
+
interface: uvicorn.config.InterfaceType = "auto",
|
210
|
+
reload: bool = False,
|
211
|
+
reload_dirs: list[str] | str | None = None,
|
212
|
+
reload_delay: float = 0.25,
|
213
|
+
reload_includes: list[str] | str | None = None,
|
214
|
+
reload_excludes: list[str] | str | None = None,
|
215
|
+
workers: int | None = None,
|
216
|
+
proxy_headers: bool = True,
|
217
|
+
server_header: bool = True,
|
218
|
+
date_header: bool = True,
|
219
|
+
forwarded_allow_ips: list[str] | str | None = None,
|
220
|
+
root_path: str = "",
|
221
|
+
limit_concurrency: int | None = None,
|
222
|
+
limit_max_requests: int | None = None,
|
223
|
+
backlog: int = 2048,
|
224
|
+
timeout_keep_alive: int = 5,
|
225
|
+
timeout_notify: int = 30,
|
226
|
+
timeout_graceful_shutdown: int | None = None,
|
227
|
+
callback_notify: Callable[..., Awaitable[None]] | None = None,
|
228
|
+
ssl_keyfile: str | os.PathLike[str] | None = None,
|
229
|
+
ssl_certfile: str | os.PathLike[str] | None = None,
|
230
|
+
ssl_keyfile_password: str | None = None,
|
231
|
+
ssl_version: int = uvicorn.config.SSL_PROTOCOL_VERSION,
|
232
|
+
ssl_cert_reqs: int = uvicorn.config.ssl.CERT_NONE,
|
233
|
+
ssl_ca_certs: str | None = None,
|
234
|
+
ssl_ciphers: str = "TLSv1",
|
235
|
+
headers: list[tuple[str, str]] | None = None,
|
236
|
+
factory: bool = False,
|
237
|
+
h11_max_incomplete_event_size: int | None = None,
|
238
|
+
) -> None:
|
239
|
+
asyncio.run(
|
240
|
+
self.serve(
|
241
|
+
configure_logger=configure_logger,
|
242
|
+
configure_telemetry=configure_telemetry,
|
243
|
+
self_registration=self_registration,
|
244
|
+
run_limit=run_limit,
|
245
|
+
run_ttl=run_ttl,
|
246
|
+
host=host,
|
247
|
+
port=port,
|
248
|
+
uds=uds,
|
249
|
+
fd=fd,
|
250
|
+
loop=loop,
|
251
|
+
http=http,
|
252
|
+
ws=ws,
|
253
|
+
ws_max_size=ws_max_size,
|
254
|
+
ws_max_queue=ws_max_queue,
|
255
|
+
ws_ping_interval=ws_ping_interval,
|
256
|
+
ws_ping_timeout=ws_ping_timeout,
|
257
|
+
ws_per_message_deflate=ws_per_message_deflate,
|
258
|
+
lifespan=lifespan,
|
259
|
+
env_file=env_file,
|
260
|
+
log_config=log_config,
|
261
|
+
log_level=log_level,
|
262
|
+
access_log=access_log,
|
263
|
+
use_colors=use_colors,
|
264
|
+
interface=interface,
|
265
|
+
reload=reload,
|
266
|
+
reload_dirs=reload_dirs,
|
267
|
+
reload_delay=reload_delay,
|
268
|
+
reload_includes=reload_includes,
|
269
|
+
reload_excludes=reload_excludes,
|
270
|
+
workers=workers,
|
271
|
+
proxy_headers=proxy_headers,
|
272
|
+
server_header=server_header,
|
273
|
+
date_header=date_header,
|
274
|
+
forwarded_allow_ips=forwarded_allow_ips,
|
275
|
+
root_path=root_path,
|
276
|
+
limit_concurrency=limit_concurrency,
|
277
|
+
limit_max_requests=limit_max_requests,
|
278
|
+
backlog=backlog,
|
279
|
+
timeout_keep_alive=timeout_keep_alive,
|
280
|
+
timeout_notify=timeout_notify,
|
281
|
+
timeout_graceful_shutdown=timeout_graceful_shutdown,
|
282
|
+
callback_notify=callback_notify,
|
283
|
+
ssl_keyfile=ssl_keyfile,
|
284
|
+
ssl_certfile=ssl_certfile,
|
285
|
+
ssl_keyfile_password=ssl_keyfile_password,
|
286
|
+
ssl_version=ssl_version,
|
287
|
+
ssl_cert_reqs=ssl_cert_reqs,
|
288
|
+
ssl_ca_certs=ssl_ca_certs,
|
289
|
+
ssl_ciphers=ssl_ciphers,
|
290
|
+
headers=headers,
|
291
|
+
factory=factory,
|
292
|
+
h11_max_incomplete_event_size=h11_max_incomplete_event_size,
|
293
|
+
)
|
294
|
+
)
|
167
295
|
|
168
296
|
async def _serve(self, self_registration: bool = True) -> None:
|
169
297
|
registration_task = asyncio.create_task(self._register_agent()) if self_registration else None
|
170
|
-
await self.
|
298
|
+
await self.server.serve()
|
171
299
|
if registration_task:
|
172
300
|
registration_task.cancel()
|
173
301
|
|
174
302
|
@property
|
175
303
|
def should_exit(self) -> bool:
|
176
|
-
return self.
|
304
|
+
return self.server.should_exit if self.server else False
|
177
305
|
|
178
306
|
@should_exit.setter
|
179
307
|
def should_exit(self, value: bool) -> None:
|
180
|
-
self.
|
308
|
+
self.server.should_exit = value
|
181
309
|
|
182
310
|
async def _register_agent(self) -> None:
|
183
311
|
"""If not in PRODUCTION mode, register agent to the beeai platform and provide missing env variables"""
|
@@ -187,7 +315,7 @@ class Server:
|
|
187
315
|
|
188
316
|
url = os.getenv("PLATFORM_URL", "http://127.0.0.1:8333")
|
189
317
|
request_data = {
|
190
|
-
"location": f"http://{self.
|
318
|
+
"location": f"http://{self.server.config.host}:{self.server.config.port}",
|
191
319
|
}
|
192
320
|
try:
|
193
321
|
await async_request_with_retry(
|
@@ -198,7 +326,7 @@ class Server:
|
|
198
326
|
# check missing env keyes
|
199
327
|
envs_request = await async_request_with_retry(lambda client: client.get(f"{url}/api/v1/variables"))
|
200
328
|
envs = envs_request.get("env")
|
201
|
-
for agent in self.
|
329
|
+
for agent in self.agents:
|
202
330
|
# register all available envs
|
203
331
|
missing_keyes = []
|
204
332
|
for env in agent.metadata.model_dump().get("env", []):
|
@@ -215,4 +343,11 @@ class Server:
|
|
215
343
|
except requests.exceptions.ConnectionError as e:
|
216
344
|
logger.warning(f"Can not reach server, check if running on {url} : {e}")
|
217
345
|
except (requests.exceptions.HTTPError, Exception) as e:
|
218
|
-
|
346
|
+
try:
|
347
|
+
error_message = e.response.json().get("detail")
|
348
|
+
if error_message:
|
349
|
+
logger.warning(f"Agent can not be registered to beeai server: {error_message}")
|
350
|
+
else:
|
351
|
+
logger.warning(f"Agent can not be registered to beeai server: {e}")
|
352
|
+
except Exception:
|
353
|
+
logger.warning(f"Agent can not be registered to beeai server: {e}")
|
@@ -5,5 +5,5 @@ from pydantic import BaseModel
|
|
5
5
|
from acp_sdk.models import AwaitRequest, AwaitResume, Message
|
6
6
|
from acp_sdk.models.models import MessagePart
|
7
7
|
|
8
|
-
RunYield = Message | MessagePart | str | AwaitRequest | BaseModel | dict[str | Any] | None
|
8
|
+
RunYield = Message | MessagePart | str | AwaitRequest | BaseModel | dict[str | Any] | None | Exception
|
9
9
|
RunYieldResume = AwaitResume | None
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import asyncio
|
1
2
|
import base64
|
2
3
|
import time
|
3
4
|
from collections.abc import AsyncGenerator, AsyncIterator, Generator
|
@@ -6,6 +7,7 @@ from threading import Thread
|
|
6
7
|
|
7
8
|
import pytest
|
8
9
|
from acp_sdk.models import Artifact, AwaitResume, Error, ErrorCode, Message, MessageAwaitRequest, MessagePart
|
10
|
+
from acp_sdk.models.errors import ACPError
|
9
11
|
from acp_sdk.server import Context, Server
|
10
12
|
|
11
13
|
from e2e.config import Config
|
@@ -21,6 +23,12 @@ def server(request: pytest.FixtureRequest) -> Generator[None]:
|
|
21
23
|
for message in input:
|
22
24
|
yield message
|
23
25
|
|
26
|
+
@server.agent()
|
27
|
+
async def slow_echo(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
28
|
+
for message in input:
|
29
|
+
await asyncio.sleep(1)
|
30
|
+
yield message
|
31
|
+
|
24
32
|
@server.agent()
|
25
33
|
async def awaiter(
|
26
34
|
input: list[Message], context: Context
|
@@ -31,6 +39,11 @@ def server(request: pytest.FixtureRequest) -> Generator[None]:
|
|
31
39
|
@server.agent()
|
32
40
|
async def failer(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
33
41
|
yield Error(code=ErrorCode.INVALID_INPUT, message="Wrong question buddy!")
|
42
|
+
raise RuntimeError("Unreachable code")
|
43
|
+
|
44
|
+
@server.agent()
|
45
|
+
async def raiser(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
46
|
+
raise ACPError(Error(code=ErrorCode.INVALID_INPUT, message="Wrong question buddy!"))
|
34
47
|
|
35
48
|
@server.agent()
|
36
49
|
async def sessioner(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
@@ -5,18 +5,20 @@ from datetime import timedelta
|
|
5
5
|
import pytest
|
6
6
|
from acp_sdk.client import Client
|
7
7
|
from acp_sdk.models import (
|
8
|
+
ACPError,
|
9
|
+
AgentName,
|
8
10
|
ArtifactEvent,
|
9
11
|
ErrorCode,
|
10
12
|
Message,
|
11
13
|
MessageAwaitResume,
|
12
14
|
MessagePart,
|
13
15
|
MessagePartEvent,
|
16
|
+
RunCancelledEvent,
|
14
17
|
RunCompletedEvent,
|
15
18
|
RunCreatedEvent,
|
16
19
|
RunInProgressEvent,
|
17
20
|
RunStatus,
|
18
21
|
)
|
19
|
-
from acp_sdk.models.errors import ACPError
|
20
22
|
from acp_sdk.server import Server
|
21
23
|
|
22
24
|
input = [Message(parts=[MessagePart(content="Hello!")])]
|
@@ -70,19 +72,35 @@ async def test_run_events_are_stream(server: Server, client: Client) -> None:
|
|
70
72
|
|
71
73
|
|
72
74
|
@pytest.mark.asyncio
|
73
|
-
|
74
|
-
|
75
|
+
@pytest.mark.parametrize("agent", ["failer", "raiser"])
|
76
|
+
async def test_failure(server: Server, client: Client, agent: AgentName) -> None:
|
77
|
+
run = await client.run_sync(agent=agent, input=input)
|
75
78
|
assert run.status == RunStatus.FAILED
|
76
79
|
assert run.error is not None
|
77
80
|
assert run.error.code == ErrorCode.INVALID_INPUT
|
78
81
|
|
79
82
|
|
80
83
|
@pytest.mark.asyncio
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
+
@pytest.mark.parametrize("agent", ["awaiter", "slow_echo"])
|
85
|
+
async def test_run_cancel(server: Server, client: Client, agent: AgentName) -> None:
|
86
|
+
run = await client.run_async(agent=agent, input=input)
|
84
87
|
run = await client.run_cancel(run_id=run.run_id)
|
85
88
|
assert run.status == RunStatus.CANCELLING
|
89
|
+
await asyncio.sleep(2)
|
90
|
+
run = await client.run_status(run_id=run.run_id)
|
91
|
+
assert run.status == RunStatus.CANCELLED
|
92
|
+
|
93
|
+
|
94
|
+
@pytest.mark.asyncio
|
95
|
+
@pytest.mark.parametrize("agent", ["slow_echo"])
|
96
|
+
async def test_run_cancel_stream(server: Server, client: Client, agent: AgentName) -> None:
|
97
|
+
last_event = None
|
98
|
+
async for event in client.run_stream(agent=agent, input=input):
|
99
|
+
last_event = event
|
100
|
+
if isinstance(event, RunCreatedEvent):
|
101
|
+
run = await client.run_cancel(run_id=event.run.run_id)
|
102
|
+
assert run.status == RunStatus.CANCELLING
|
103
|
+
assert isinstance(last_event, RunCancelledEvent)
|
86
104
|
|
87
105
|
|
88
106
|
@pytest.mark.asyncio
|
@@ -4,8 +4,12 @@ import uuid
|
|
4
4
|
import pytest
|
5
5
|
from acp_sdk.client import Client
|
6
6
|
from acp_sdk.models import (
|
7
|
+
ACPError,
|
7
8
|
Agent,
|
8
9
|
AgentsListResponse,
|
10
|
+
Error,
|
11
|
+
ErrorCode,
|
12
|
+
ErrorEvent,
|
9
13
|
Message,
|
10
14
|
MessageAwaitResume,
|
11
15
|
MessagePart,
|
@@ -77,6 +81,24 @@ async def test_run_stream(httpx_mock: HTTPXMock) -> None:
|
|
77
81
|
assert event == mock_event
|
78
82
|
|
79
83
|
|
84
|
+
@pytest.mark.asyncio
|
85
|
+
async def test_run_stream_error(httpx_mock: HTTPXMock) -> None:
|
86
|
+
error = Error(code=ErrorCode.SERVER_ERROR, message="whoops")
|
87
|
+
mock_event = ErrorEvent(error=error)
|
88
|
+
httpx_mock.add_response(
|
89
|
+
url="http://test/runs",
|
90
|
+
method="POST",
|
91
|
+
headers={"content-type": "text/event-stream"},
|
92
|
+
content=f"data: {mock_event.model_dump_json()}\n\n",
|
93
|
+
)
|
94
|
+
|
95
|
+
async with Client(base_url="http://test") as client:
|
96
|
+
with pytest.raises(ACPError) as e:
|
97
|
+
async for _ in client.run_stream("Howdy!", agent=mock_run.agent_name):
|
98
|
+
raise AssertionError()
|
99
|
+
assert e.value.error == error
|
100
|
+
|
101
|
+
|
80
102
|
@pytest.mark.asyncio
|
81
103
|
async def test_run_status(httpx_mock: HTTPXMock) -> None:
|
82
104
|
httpx_mock.add_response(url=f"http://test/runs/{mock_run.run_id}", method="GET", content=mock_run.model_dump_json())
|
@@ -1,5 +1,8 @@
|
|
1
|
+
import asyncio
|
2
|
+
|
1
3
|
import pytest
|
2
|
-
from acp_sdk.models.
|
4
|
+
from acp_sdk.models.errors import ACPError, Error, ErrorCode
|
5
|
+
from acp_sdk.models.models import Message, MessagePart, Run, RunStatus
|
3
6
|
|
4
7
|
timestamp = "2021-09-09T22:02:47.89Z"
|
5
8
|
|
@@ -96,3 +99,40 @@ def test_message_add(first: Message, second: Message, result: Message) -> None:
|
|
96
99
|
)
|
97
100
|
def test_message_compress(uncompressed: Message, compressed: Message) -> None:
|
98
101
|
assert uncompressed.compress() == compressed
|
102
|
+
|
103
|
+
|
104
|
+
@pytest.mark.parametrize(
|
105
|
+
"run,error",
|
106
|
+
[
|
107
|
+
(
|
108
|
+
Run(agent_name="foo", status=RunStatus.CANCELLED),
|
109
|
+
asyncio.CancelledError,
|
110
|
+
),
|
111
|
+
(
|
112
|
+
Run(
|
113
|
+
agent_name="foo",
|
114
|
+
status=RunStatus.FAILED,
|
115
|
+
error=Error(code=ErrorCode.SERVER_ERROR, message="Unspecified"),
|
116
|
+
),
|
117
|
+
ACPError,
|
118
|
+
),
|
119
|
+
(
|
120
|
+
Run(agent_name="foo"),
|
121
|
+
None,
|
122
|
+
),
|
123
|
+
(
|
124
|
+
Run(agent_name="foo", status=RunStatus.IN_PROGRESS),
|
125
|
+
None,
|
126
|
+
),
|
127
|
+
(
|
128
|
+
Run(agent_name="foo", status=RunStatus.COMPLETED),
|
129
|
+
None,
|
130
|
+
),
|
131
|
+
],
|
132
|
+
)
|
133
|
+
def test_run_raise_on_status_raise(run: Run, error: type[Exception] | None) -> None:
|
134
|
+
if error:
|
135
|
+
with pytest.raises(error):
|
136
|
+
run.raise_for_status()
|
137
|
+
else:
|
138
|
+
run.raise_for_status()
|
File without changes
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncGenerator
|
3
|
+
from contextlib import asynccontextmanager
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
from acp_sdk.server import Server
|
7
|
+
from fastapi import FastAPI
|
8
|
+
|
9
|
+
|
10
|
+
@pytest.mark.asyncio
|
11
|
+
async def test_lifespan() -> None:
|
12
|
+
entry = False
|
13
|
+
exit = False
|
14
|
+
|
15
|
+
class TestServer(Server):
|
16
|
+
@asynccontextmanager
|
17
|
+
async def lifespan(self, app: FastAPI) -> AsyncGenerator[None]:
|
18
|
+
nonlocal entry
|
19
|
+
nonlocal exit
|
20
|
+
entry = True
|
21
|
+
yield
|
22
|
+
exit = True
|
23
|
+
|
24
|
+
server = TestServer()
|
25
|
+
task = asyncio.create_task(server.serve())
|
26
|
+
await asyncio.sleep(1)
|
27
|
+
server.should_exit = True
|
28
|
+
await task
|
29
|
+
|
30
|
+
assert entry
|
31
|
+
assert exit
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|