acp-sdk 0.8.3__tar.gz → 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/PKG-INFO +2 -2
  2. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/pyproject.toml +2 -2
  3. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/client/client.py +4 -1
  4. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/models/models.py +17 -1
  5. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/agent.py +11 -3
  6. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/app.py +9 -3
  7. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/bundle.py +4 -0
  8. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/server.py +150 -15
  9. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/types.py +1 -1
  10. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/fixtures/server.py +13 -0
  11. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/test_suites/test_runs.py +24 -6
  12. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/unit/client/test_client.py +22 -0
  13. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/unit/models/test_models.py +41 -1
  14. acp_sdk-0.9.0/tests/unit/server/__init__.py +0 -0
  15. acp_sdk-0.9.0/tests/unit/server/test_server.py +31 -0
  16. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/.gitignore +0 -0
  17. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/.python-version +0 -0
  18. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/README.md +0 -0
  19. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/.gitignore +0 -0
  20. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/Makefile +0 -0
  21. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/conf.py +0 -0
  22. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/index.rst +0 -0
  23. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/docs/make.bat +0 -0
  24. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/pytest.ini +0 -0
  25. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/__init__.py +0 -0
  26. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/client/__init__.py +0 -0
  27. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/client/types.py +0 -0
  28. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/client/utils.py +0 -0
  29. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/instrumentation.py +0 -0
  30. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/models/__init__.py +0 -0
  31. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/models/errors.py +0 -0
  32. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/models/schemas.py +0 -0
  33. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/py.typed +0 -0
  34. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/__init__.py +0 -0
  35. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/context.py +0 -0
  36. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/errors.py +0 -0
  37. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/logging.py +0 -0
  38. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/session.py +0 -0
  39. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/telemetry.py +0 -0
  40. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/server/utils.py +0 -0
  41. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/src/acp_sdk/version.py +0 -0
  42. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/conftest.py +0 -0
  43. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/__init__.py +0 -0
  44. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/config.py +0 -0
  45. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/fixtures/__init__.py +0 -0
  46. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/fixtures/client.py +0 -0
  47. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/test_suites/__init__.py +0 -0
  48. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/e2e/test_suites/test_discovery.py +0 -0
  49. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/unit/client/test_utils.py +0 -0
  50. {acp_sdk-0.8.3 → acp_sdk-0.9.0}/tests/unit/models/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: acp-sdk
3
- Version: 0.8.3
3
+ Version: 0.9.0
4
4
  Summary: Agent Communication Protocol SDK
5
5
  Author: IBM Corp.
6
6
  Maintainer-email: Tomas Pilar <thomas7pilar@gmail.com>
@@ -16,7 +16,7 @@ Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31.1
16
16
  Requires-Dist: opentelemetry-instrumentation-fastapi>=0.52b1
17
17
  Requires-Dist: opentelemetry-instrumentation-httpx>=0.52b1
18
18
  Requires-Dist: opentelemetry-sdk>=1.31.1
19
- Requires-Dist: pydantic>=2.11.1
19
+ Requires-Dist: pydantic>=2.0.0
20
20
  Description-Content-Type: text/markdown
21
21
 
22
22
  # Agent Communication Protocol SDK for Python
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "acp-sdk"
3
- version = "0.8.3"
3
+ version = "0.9.0"
4
4
  description = "Agent Communication Protocol SDK"
5
5
  license = "Apache-2.0"
6
6
  readme = "README.md"
@@ -9,7 +9,7 @@ maintainers = [{ name = "Tomas Pilar", email = "thomas7pilar@gmail.com" }]
9
9
  requires-python = ">=3.11, <4.0"
10
10
  dependencies = [
11
11
  "opentelemetry-api>=1.31.1",
12
- "pydantic>=2.11.1",
12
+ "pydantic>=2.0.0",
13
13
  "httpx>=0.26.0",
14
14
  "httpx-sse>=0.4.0",
15
15
  "opentelemetry-instrumentation-httpx>=0.52b1",
@@ -22,6 +22,7 @@ from acp_sdk.models import (
22
22
  AgentsListResponse,
23
23
  AwaitResume,
24
24
  Error,
25
+ ErrorEvent,
25
26
  Event,
26
27
  PingResponse,
27
28
  Run,
@@ -224,7 +225,9 @@ class Client:
224
225
  await event_source.response.aread()
225
226
  self._raise_error(event_source.response)
226
227
  async for event in event_source.aiter_sse():
227
- event = TypeAdapter(Event).validate_json(event.data)
228
+ event: Event = TypeAdapter(Event).validate_json(event.data)
229
+ if isinstance(event, ErrorEvent):
230
+ raise ACPError(error=event.error)
228
231
  yield event
229
232
 
230
233
  def _raise_error(self, response: httpx.Response) -> None:
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import uuid
2
3
  from datetime import datetime, timezone
3
4
  from enum import Enum
@@ -5,7 +6,7 @@ from typing import Any, Literal, Optional, Union
5
6
 
6
7
  from pydantic import AnyUrl, BaseModel, ConfigDict, Field
7
8
 
8
- from acp_sdk.models.errors import Error
9
+ from acp_sdk.models.errors import ACPError, Error
9
10
 
10
11
 
11
12
  class AnyModel(BaseModel):
@@ -196,6 +197,15 @@ class Run(BaseModel):
196
197
  created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
197
198
  finished_at: datetime | None = None
198
199
 
200
+ def raise_for_status(self) -> "Run":
201
+ match self.status:
202
+ case RunStatus.CANCELLED:
203
+ raise asyncio.CancelledError()
204
+ case RunStatus.FAILED:
205
+ raise ACPError(error=self.error)
206
+ case _:
207
+ return self
208
+
199
209
 
200
210
  class MessageCreatedEvent(BaseModel):
201
211
  type: Literal["message.created"] = "message.created"
@@ -252,7 +262,13 @@ class RunCompletedEvent(BaseModel):
252
262
  run: Run
253
263
 
254
264
 
265
+ class ErrorEvent(BaseModel):
266
+ type: Literal["error"] = "error"
267
+ error: Error
268
+
269
+
255
270
  Event = Union[
271
+ ErrorEvent,
256
272
  RunCreatedEvent,
257
273
  RunInProgressEvent,
258
274
  MessageCreatedEvent,
@@ -58,13 +58,13 @@ class Agent(abc.ABC):
58
58
  run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context)
59
59
 
60
60
  try:
61
- while True:
61
+ while not run.done() or yield_queue.async_q.qsize() > 0:
62
62
  value = yield await yield_queue.async_q.get()
63
+ if isinstance(value, Exception):
64
+ raise value
63
65
  await yield_resume_queue.async_q.put(value)
64
66
  except janus.AsyncQueueShutDown:
65
67
  pass
66
- finally:
67
- await run # Raise exceptions
68
68
 
69
69
  async def _run_async_gen(self, input: list[Message], context: Context) -> None:
70
70
  try:
@@ -74,12 +74,16 @@ class Agent(abc.ABC):
74
74
  value = await context.yield_async(await gen.asend(value))
75
75
  except StopAsyncIteration:
76
76
  pass
77
+ except Exception as e:
78
+ await context.yield_async(e)
77
79
  finally:
78
80
  context.shutdown()
79
81
 
80
82
  async def _run_coro(self, input: list[Message], context: Context) -> None:
81
83
  try:
82
84
  await context.yield_async(await self.run(input, context))
85
+ except Exception as e:
86
+ await context.yield_async(e)
83
87
  finally:
84
88
  context.shutdown()
85
89
 
@@ -91,12 +95,16 @@ class Agent(abc.ABC):
91
95
  value = context.yield_sync(gen.send(value))
92
96
  except StopIteration:
93
97
  pass
98
+ except Exception as e:
99
+ context.yield_sync(e)
94
100
  finally:
95
101
  context.shutdown()
96
102
 
97
103
  def _run_func(self, input: list[Message], context: Context) -> None:
98
104
  try:
99
105
  context.yield_sync(self.run(input, context))
106
+ except Exception as e:
107
+ context.yield_sync(e)
100
108
  finally:
101
109
  context.shutdown()
102
110
 
@@ -6,6 +6,7 @@ from enum import Enum
6
6
 
7
7
  from cachetools import TTLCache
8
8
  from fastapi import Depends, FastAPI, HTTPException, status
9
+ from fastapi.applications import AppType, Lifespan
9
10
  from fastapi.encoders import jsonable_encoder
10
11
  from fastapi.responses import JSONResponse, StreamingResponse
11
12
  from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
@@ -53,19 +54,24 @@ def create_app(
53
54
  *agents: Agent,
54
55
  run_limit: int = 1000,
55
56
  run_ttl: timedelta = timedelta(hours=1),
57
+ lifespan: Lifespan[AppType] | None = None,
56
58
  dependencies: list[Depends] | None = None,
57
59
  ) -> FastAPI:
58
60
  executor: ThreadPoolExecutor
59
61
 
60
62
  @asynccontextmanager
61
- async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
63
+ async def internal_lifespan(app: FastAPI) -> AsyncGenerator[None]:
62
64
  nonlocal executor
63
65
  with ThreadPoolExecutor() as exec:
64
66
  executor = exec
65
- yield
67
+ if not lifespan:
68
+ yield None
69
+ else:
70
+ async with lifespan(app) as state:
71
+ yield state
66
72
 
67
73
  app = FastAPI(
68
- lifespan=lifespan,
74
+ lifespan=internal_lifespan,
69
75
  dependencies=dependencies,
70
76
  )
71
77
 
@@ -142,6 +142,8 @@ class RunBundle:
142
142
  run_logger.info("Run resumed")
143
143
  elif isinstance(next, Error):
144
144
  raise ACPError(error=next)
145
+ elif isinstance(next, ACPError):
146
+ raise next
145
147
  elif next is None:
146
148
  await flush_message()
147
149
  elif isinstance(next, BaseModel):
@@ -176,3 +178,5 @@ class RunBundle:
176
178
  finally:
177
179
  self.await_or_terminate_event.set()
178
180
  await self.stream_queue.put(None)
181
+ if not self.task.done():
182
+ self.task.cancel()
@@ -1,12 +1,14 @@
1
1
  import asyncio
2
2
  import os
3
- from collections.abc import Awaitable
3
+ from collections.abc import AsyncGenerator, Awaitable
4
+ from contextlib import asynccontextmanager
4
5
  from datetime import timedelta
5
6
  from typing import Any, Callable
6
7
 
7
8
  import requests
8
9
  import uvicorn
9
10
  import uvicorn.config
11
+ from fastapi import FastAPI
10
12
 
11
13
  from acp_sdk.models import Metadata
12
14
  from acp_sdk.server.agent import Agent
@@ -20,8 +22,8 @@ from acp_sdk.server.utils import async_request_with_retry
20
22
 
21
23
  class Server:
22
24
  def __init__(self) -> None:
23
- self._agents: list[Agent] = []
24
- self._server: uvicorn.Server | None = None
25
+ self.agents: list[Agent] = []
26
+ self.server: uvicorn.Server | None = None
25
27
 
26
28
  def agent(
27
29
  self,
@@ -40,10 +42,15 @@ class Server:
40
42
  return decorator
41
43
 
42
44
  def register(self, *agents: Agent) -> None:
43
- self._agents.extend(agents)
45
+ self.agents.extend(agents)
44
46
 
45
- def run(
47
+ @asynccontextmanager
48
+ async def lifespan(self, app: FastAPI) -> AsyncGenerator[None]:
49
+ yield
50
+
51
+ async def serve(
46
52
  self,
53
+ *,
47
54
  configure_logger: bool = True,
48
55
  configure_telemetry: bool = False,
49
56
  self_registration: bool = True,
@@ -101,9 +108,14 @@ class Server:
101
108
  factory: bool = False,
102
109
  h11_max_incomplete_event_size: int | None = None,
103
110
  ) -> None:
104
- if self._server:
111
+ if self.server:
105
112
  raise RuntimeError("The server is already running")
106
113
 
114
+ if headers is None:
115
+ headers = [("server", "acp")]
116
+ elif not any(k.lower() == "server" for k, _ in headers):
117
+ headers.append(("server", "acp"))
118
+
107
119
  import uvicorn
108
120
 
109
121
  if configure_logger:
@@ -112,7 +124,7 @@ class Server:
112
124
  configure_telemetry_func()
113
125
 
114
126
  config = uvicorn.Config(
115
- create_app(*self._agents, run_limit=run_limit, run_ttl=run_ttl),
127
+ create_app(*self.agents, lifespan=self.lifespan, run_limit=run_limit, run_ttl=run_ttl),
116
128
  host,
117
129
  port,
118
130
  uds,
@@ -161,23 +173,139 @@ class Server:
161
173
  factory,
162
174
  h11_max_incomplete_event_size,
163
175
  )
164
- self._server = uvicorn.Server(config)
176
+ self.server = uvicorn.Server(config)
177
+ await self._serve(self_registration=self_registration)
165
178
 
166
- asyncio.run(self._serve(self_registration=self_registration))
179
+ def run(
180
+ self,
181
+ *,
182
+ configure_logger: bool = True,
183
+ configure_telemetry: bool = False,
184
+ self_registration: bool = True,
185
+ run_limit: int = 1000,
186
+ run_ttl: timedelta = timedelta(hours=1),
187
+ host: str = "127.0.0.1",
188
+ port: int = 8000,
189
+ uds: str | None = None,
190
+ fd: int | None = None,
191
+ loop: uvicorn.config.LoopSetupType = "auto",
192
+ http: type[asyncio.Protocol] | uvicorn.config.HTTPProtocolType = "auto",
193
+ ws: type[asyncio.Protocol] | uvicorn.config.WSProtocolType = "auto",
194
+ ws_max_size: int = 16 * 1024 * 1024,
195
+ ws_max_queue: int = 32,
196
+ ws_ping_interval: float | None = 20.0,
197
+ ws_ping_timeout: float | None = 20.0,
198
+ ws_per_message_deflate: bool = True,
199
+ lifespan: uvicorn.config.LifespanType = "auto",
200
+ env_file: str | os.PathLike[str] | None = None,
201
+ log_config: dict[str, Any]
202
+ | str
203
+ | uvicorn.config.RawConfigParser
204
+ | uvicorn.config.IO[Any]
205
+ | None = uvicorn.config.LOGGING_CONFIG,
206
+ log_level: str | int | None = None,
207
+ access_log: bool = True,
208
+ use_colors: bool | None = None,
209
+ interface: uvicorn.config.InterfaceType = "auto",
210
+ reload: bool = False,
211
+ reload_dirs: list[str] | str | None = None,
212
+ reload_delay: float = 0.25,
213
+ reload_includes: list[str] | str | None = None,
214
+ reload_excludes: list[str] | str | None = None,
215
+ workers: int | None = None,
216
+ proxy_headers: bool = True,
217
+ server_header: bool = True,
218
+ date_header: bool = True,
219
+ forwarded_allow_ips: list[str] | str | None = None,
220
+ root_path: str = "",
221
+ limit_concurrency: int | None = None,
222
+ limit_max_requests: int | None = None,
223
+ backlog: int = 2048,
224
+ timeout_keep_alive: int = 5,
225
+ timeout_notify: int = 30,
226
+ timeout_graceful_shutdown: int | None = None,
227
+ callback_notify: Callable[..., Awaitable[None]] | None = None,
228
+ ssl_keyfile: str | os.PathLike[str] | None = None,
229
+ ssl_certfile: str | os.PathLike[str] | None = None,
230
+ ssl_keyfile_password: str | None = None,
231
+ ssl_version: int = uvicorn.config.SSL_PROTOCOL_VERSION,
232
+ ssl_cert_reqs: int = uvicorn.config.ssl.CERT_NONE,
233
+ ssl_ca_certs: str | None = None,
234
+ ssl_ciphers: str = "TLSv1",
235
+ headers: list[tuple[str, str]] | None = None,
236
+ factory: bool = False,
237
+ h11_max_incomplete_event_size: int | None = None,
238
+ ) -> None:
239
+ asyncio.run(
240
+ self.serve(
241
+ configure_logger=configure_logger,
242
+ configure_telemetry=configure_telemetry,
243
+ self_registration=self_registration,
244
+ run_limit=run_limit,
245
+ run_ttl=run_ttl,
246
+ host=host,
247
+ port=port,
248
+ uds=uds,
249
+ fd=fd,
250
+ loop=loop,
251
+ http=http,
252
+ ws=ws,
253
+ ws_max_size=ws_max_size,
254
+ ws_max_queue=ws_max_queue,
255
+ ws_ping_interval=ws_ping_interval,
256
+ ws_ping_timeout=ws_ping_timeout,
257
+ ws_per_message_deflate=ws_per_message_deflate,
258
+ lifespan=lifespan,
259
+ env_file=env_file,
260
+ log_config=log_config,
261
+ log_level=log_level,
262
+ access_log=access_log,
263
+ use_colors=use_colors,
264
+ interface=interface,
265
+ reload=reload,
266
+ reload_dirs=reload_dirs,
267
+ reload_delay=reload_delay,
268
+ reload_includes=reload_includes,
269
+ reload_excludes=reload_excludes,
270
+ workers=workers,
271
+ proxy_headers=proxy_headers,
272
+ server_header=server_header,
273
+ date_header=date_header,
274
+ forwarded_allow_ips=forwarded_allow_ips,
275
+ root_path=root_path,
276
+ limit_concurrency=limit_concurrency,
277
+ limit_max_requests=limit_max_requests,
278
+ backlog=backlog,
279
+ timeout_keep_alive=timeout_keep_alive,
280
+ timeout_notify=timeout_notify,
281
+ timeout_graceful_shutdown=timeout_graceful_shutdown,
282
+ callback_notify=callback_notify,
283
+ ssl_keyfile=ssl_keyfile,
284
+ ssl_certfile=ssl_certfile,
285
+ ssl_keyfile_password=ssl_keyfile_password,
286
+ ssl_version=ssl_version,
287
+ ssl_cert_reqs=ssl_cert_reqs,
288
+ ssl_ca_certs=ssl_ca_certs,
289
+ ssl_ciphers=ssl_ciphers,
290
+ headers=headers,
291
+ factory=factory,
292
+ h11_max_incomplete_event_size=h11_max_incomplete_event_size,
293
+ )
294
+ )
167
295
 
168
296
  async def _serve(self, self_registration: bool = True) -> None:
169
297
  registration_task = asyncio.create_task(self._register_agent()) if self_registration else None
170
- await self._server.serve()
298
+ await self.server.serve()
171
299
  if registration_task:
172
300
  registration_task.cancel()
173
301
 
174
302
  @property
175
303
  def should_exit(self) -> bool:
176
- return self._server.should_exit if self._server else False
304
+ return self.server.should_exit if self.server else False
177
305
 
178
306
  @should_exit.setter
179
307
  def should_exit(self, value: bool) -> None:
180
- self._server.should_exit = value
308
+ self.server.should_exit = value
181
309
 
182
310
  async def _register_agent(self) -> None:
183
311
  """If not in PRODUCTION mode, register agent to the beeai platform and provide missing env variables"""
@@ -187,7 +315,7 @@ class Server:
187
315
 
188
316
  url = os.getenv("PLATFORM_URL", "http://127.0.0.1:8333")
189
317
  request_data = {
190
- "location": f"http://{self._server.config.host}:{self._server.config.port}",
318
+ "location": f"http://{self.server.config.host}:{self.server.config.port}",
191
319
  }
192
320
  try:
193
321
  await async_request_with_retry(
@@ -198,7 +326,7 @@ class Server:
198
326
  # check missing env keyes
199
327
  envs_request = await async_request_with_retry(lambda client: client.get(f"{url}/api/v1/variables"))
200
328
  envs = envs_request.get("env")
201
- for agent in self._agents:
329
+ for agent in self.agents:
202
330
  # register all available envs
203
331
  missing_keyes = []
204
332
  for env in agent.metadata.model_dump().get("env", []):
@@ -215,4 +343,11 @@ class Server:
215
343
  except requests.exceptions.ConnectionError as e:
216
344
  logger.warning(f"Can not reach server, check if running on {url} : {e}")
217
345
  except (requests.exceptions.HTTPError, Exception) as e:
218
- logger.warning(f"Agent can not be registered to beeai server: {e}")
346
+ try:
347
+ error_message = e.response.json().get("detail")
348
+ if error_message:
349
+ logger.warning(f"Agent can not be registered to beeai server: {error_message}")
350
+ else:
351
+ logger.warning(f"Agent can not be registered to beeai server: {e}")
352
+ except Exception:
353
+ logger.warning(f"Agent can not be registered to beeai server: {e}")
@@ -5,5 +5,5 @@ from pydantic import BaseModel
5
5
  from acp_sdk.models import AwaitRequest, AwaitResume, Message
6
6
  from acp_sdk.models.models import MessagePart
7
7
 
8
- RunYield = Message | MessagePart | str | AwaitRequest | BaseModel | dict[str | Any] | None
8
+ RunYield = Message | MessagePart | str | AwaitRequest | BaseModel | dict[str | Any] | None | Exception
9
9
  RunYieldResume = AwaitResume | None
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import base64
2
3
  import time
3
4
  from collections.abc import AsyncGenerator, AsyncIterator, Generator
@@ -6,6 +7,7 @@ from threading import Thread
6
7
 
7
8
  import pytest
8
9
  from acp_sdk.models import Artifact, AwaitResume, Error, ErrorCode, Message, MessageAwaitRequest, MessagePart
10
+ from acp_sdk.models.errors import ACPError
9
11
  from acp_sdk.server import Context, Server
10
12
 
11
13
  from e2e.config import Config
@@ -21,6 +23,12 @@ def server(request: pytest.FixtureRequest) -> Generator[None]:
21
23
  for message in input:
22
24
  yield message
23
25
 
26
+ @server.agent()
27
+ async def slow_echo(input: list[Message], context: Context) -> AsyncIterator[Message]:
28
+ for message in input:
29
+ await asyncio.sleep(1)
30
+ yield message
31
+
24
32
  @server.agent()
25
33
  async def awaiter(
26
34
  input: list[Message], context: Context
@@ -31,6 +39,11 @@ def server(request: pytest.FixtureRequest) -> Generator[None]:
31
39
  @server.agent()
32
40
  async def failer(input: list[Message], context: Context) -> AsyncIterator[Message]:
33
41
  yield Error(code=ErrorCode.INVALID_INPUT, message="Wrong question buddy!")
42
+ raise RuntimeError("Unreachable code")
43
+
44
+ @server.agent()
45
+ async def raiser(input: list[Message], context: Context) -> AsyncIterator[Message]:
46
+ raise ACPError(Error(code=ErrorCode.INVALID_INPUT, message="Wrong question buddy!"))
34
47
 
35
48
  @server.agent()
36
49
  async def sessioner(input: list[Message], context: Context) -> AsyncIterator[Message]:
@@ -5,18 +5,20 @@ from datetime import timedelta
5
5
  import pytest
6
6
  from acp_sdk.client import Client
7
7
  from acp_sdk.models import (
8
+ ACPError,
9
+ AgentName,
8
10
  ArtifactEvent,
9
11
  ErrorCode,
10
12
  Message,
11
13
  MessageAwaitResume,
12
14
  MessagePart,
13
15
  MessagePartEvent,
16
+ RunCancelledEvent,
14
17
  RunCompletedEvent,
15
18
  RunCreatedEvent,
16
19
  RunInProgressEvent,
17
20
  RunStatus,
18
21
  )
19
- from acp_sdk.models.errors import ACPError
20
22
  from acp_sdk.server import Server
21
23
 
22
24
  input = [Message(parts=[MessagePart(content="Hello!")])]
@@ -70,19 +72,35 @@ async def test_run_events_are_stream(server: Server, client: Client) -> None:
70
72
 
71
73
 
72
74
  @pytest.mark.asyncio
73
- async def test_failure(server: Server, client: Client) -> None:
74
- run = await client.run_sync(agent="failer", input=input)
75
+ @pytest.mark.parametrize("agent", ["failer", "raiser"])
76
+ async def test_failure(server: Server, client: Client, agent: AgentName) -> None:
77
+ run = await client.run_sync(agent=agent, input=input)
75
78
  assert run.status == RunStatus.FAILED
76
79
  assert run.error is not None
77
80
  assert run.error.code == ErrorCode.INVALID_INPUT
78
81
 
79
82
 
80
83
  @pytest.mark.asyncio
81
- async def test_run_cancel(server: Server, client: Client) -> None:
82
- run = await client.run_sync(agent="awaiter", input=input)
83
- assert run.status == RunStatus.AWAITING
84
+ @pytest.mark.parametrize("agent", ["awaiter", "slow_echo"])
85
+ async def test_run_cancel(server: Server, client: Client, agent: AgentName) -> None:
86
+ run = await client.run_async(agent=agent, input=input)
84
87
  run = await client.run_cancel(run_id=run.run_id)
85
88
  assert run.status == RunStatus.CANCELLING
89
+ await asyncio.sleep(2)
90
+ run = await client.run_status(run_id=run.run_id)
91
+ assert run.status == RunStatus.CANCELLED
92
+
93
+
94
+ @pytest.mark.asyncio
95
+ @pytest.mark.parametrize("agent", ["slow_echo"])
96
+ async def test_run_cancel_stream(server: Server, client: Client, agent: AgentName) -> None:
97
+ last_event = None
98
+ async for event in client.run_stream(agent=agent, input=input):
99
+ last_event = event
100
+ if isinstance(event, RunCreatedEvent):
101
+ run = await client.run_cancel(run_id=event.run.run_id)
102
+ assert run.status == RunStatus.CANCELLING
103
+ assert isinstance(last_event, RunCancelledEvent)
86
104
 
87
105
 
88
106
  @pytest.mark.asyncio
@@ -4,8 +4,12 @@ import uuid
4
4
  import pytest
5
5
  from acp_sdk.client import Client
6
6
  from acp_sdk.models import (
7
+ ACPError,
7
8
  Agent,
8
9
  AgentsListResponse,
10
+ Error,
11
+ ErrorCode,
12
+ ErrorEvent,
9
13
  Message,
10
14
  MessageAwaitResume,
11
15
  MessagePart,
@@ -77,6 +81,24 @@ async def test_run_stream(httpx_mock: HTTPXMock) -> None:
77
81
  assert event == mock_event
78
82
 
79
83
 
84
+ @pytest.mark.asyncio
85
+ async def test_run_stream_error(httpx_mock: HTTPXMock) -> None:
86
+ error = Error(code=ErrorCode.SERVER_ERROR, message="whoops")
87
+ mock_event = ErrorEvent(error=error)
88
+ httpx_mock.add_response(
89
+ url="http://test/runs",
90
+ method="POST",
91
+ headers={"content-type": "text/event-stream"},
92
+ content=f"data: {mock_event.model_dump_json()}\n\n",
93
+ )
94
+
95
+ async with Client(base_url="http://test") as client:
96
+ with pytest.raises(ACPError) as e:
97
+ async for _ in client.run_stream("Howdy!", agent=mock_run.agent_name):
98
+ raise AssertionError()
99
+ assert e.value.error == error
100
+
101
+
80
102
  @pytest.mark.asyncio
81
103
  async def test_run_status(httpx_mock: HTTPXMock) -> None:
82
104
  httpx_mock.add_response(url=f"http://test/runs/{mock_run.run_id}", method="GET", content=mock_run.model_dump_json())
@@ -1,5 +1,8 @@
1
+ import asyncio
2
+
1
3
  import pytest
2
- from acp_sdk.models.models import Message, MessagePart
4
+ from acp_sdk.models.errors import ACPError, Error, ErrorCode
5
+ from acp_sdk.models.models import Message, MessagePart, Run, RunStatus
3
6
 
4
7
  timestamp = "2021-09-09T22:02:47.89Z"
5
8
 
@@ -96,3 +99,40 @@ def test_message_add(first: Message, second: Message, result: Message) -> None:
96
99
  )
97
100
  def test_message_compress(uncompressed: Message, compressed: Message) -> None:
98
101
  assert uncompressed.compress() == compressed
102
+
103
+
104
+ @pytest.mark.parametrize(
105
+ "run,error",
106
+ [
107
+ (
108
+ Run(agent_name="foo", status=RunStatus.CANCELLED),
109
+ asyncio.CancelledError,
110
+ ),
111
+ (
112
+ Run(
113
+ agent_name="foo",
114
+ status=RunStatus.FAILED,
115
+ error=Error(code=ErrorCode.SERVER_ERROR, message="Unspecified"),
116
+ ),
117
+ ACPError,
118
+ ),
119
+ (
120
+ Run(agent_name="foo"),
121
+ None,
122
+ ),
123
+ (
124
+ Run(agent_name="foo", status=RunStatus.IN_PROGRESS),
125
+ None,
126
+ ),
127
+ (
128
+ Run(agent_name="foo", status=RunStatus.COMPLETED),
129
+ None,
130
+ ),
131
+ ],
132
+ )
133
+ def test_run_raise_on_status_raise(run: Run, error: type[Exception] | None) -> None:
134
+ if error:
135
+ with pytest.raises(error):
136
+ run.raise_for_status()
137
+ else:
138
+ run.raise_for_status()
File without changes
@@ -0,0 +1,31 @@
1
+ import asyncio
2
+ from collections.abc import AsyncGenerator
3
+ from contextlib import asynccontextmanager
4
+
5
+ import pytest
6
+ from acp_sdk.server import Server
7
+ from fastapi import FastAPI
8
+
9
+
10
+ @pytest.mark.asyncio
11
+ async def test_lifespan() -> None:
12
+ entry = False
13
+ exit = False
14
+
15
+ class TestServer(Server):
16
+ @asynccontextmanager
17
+ async def lifespan(self, app: FastAPI) -> AsyncGenerator[None]:
18
+ nonlocal entry
19
+ nonlocal exit
20
+ entry = True
21
+ yield
22
+ exit = True
23
+
24
+ server = TestServer()
25
+ task = asyncio.create_task(server.serve())
26
+ await asyncio.sleep(1)
27
+ server.should_exit = True
28
+ await task
29
+
30
+ assert entry
31
+ assert exit
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes