acp-sdk 0.7.0__tar.gz → 0.7.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/PKG-INFO +4 -4
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/README.md +3 -3
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/docs/client.md +4 -4
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/docs/server.md +4 -4
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/pyproject.toml +1 -1
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/client/client.py +10 -28
- acp_sdk-0.7.2/src/acp_sdk/client/types.py +3 -0
- acp_sdk-0.7.2/src/acp_sdk/client/utils.py +24 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/models/models.py +8 -2
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/agent.py +14 -14
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/app.py +16 -4
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/bundle.py +5 -5
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/server.py +54 -1
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/session.py +1 -1
- acp_sdk-0.7.2/src/acp_sdk/server/utils.py +49 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/e2e/fixtures/server.py +8 -8
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/e2e/test_suites/test_runs.py +23 -23
- acp_sdk-0.7.2/tests/unit/client/test_client.py +145 -0
- acp_sdk-0.7.2/tests/unit/client/test_utils.py +30 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/unit/models/test_models.py +19 -0
- acp_sdk-0.7.0/src/acp_sdk/server/utils.py +0 -14
- acp_sdk-0.7.0/tests/unit/client/test_client.py +0 -36
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/.gitignore +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/.python-version +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/docs/_sidebar.md +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/docs/index.html +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/docs/models.md +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/pytest.ini +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/__init__.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/client/__init__.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/instrumentation.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/models/__init__.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/models/errors.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/models/schemas.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/py.typed +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/__init__.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/context.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/errors.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/logging.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/telemetry.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/server/types.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/src/acp_sdk/version.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/conftest.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/e2e/__init__.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/e2e/config.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/e2e/fixtures/__init__.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/e2e/fixtures/client.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/e2e/test_suites/__init__.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/e2e/test_suites/test_discovery.py +0 -0
- {acp_sdk-0.7.0 → acp_sdk-0.7.2}/tests/unit/models/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: acp-sdk
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.2
|
4
4
|
Summary: Agent Communication Protocol SDK
|
5
5
|
Author: IBM Corp.
|
6
6
|
Maintainer-email: Tomas Pilar <thomas7pilar@gmail.com>
|
@@ -44,9 +44,9 @@ Register an agent and run the server:
|
|
44
44
|
server = Server()
|
45
45
|
|
46
46
|
@server.agent()
|
47
|
-
async def echo(
|
47
|
+
async def echo(input: list[Message]):
|
48
48
|
"""Echoes everything"""
|
49
|
-
for message in
|
49
|
+
for message in input:
|
50
50
|
yield message
|
51
51
|
|
52
52
|
server.run(port=8000)
|
@@ -56,7 +56,7 @@ From another process, connect to the server and run the agent:
|
|
56
56
|
|
57
57
|
```py
|
58
58
|
async with Client(base_url="http://localhost:8000") as client:
|
59
|
-
run = await client.run_sync(agent="echo",
|
59
|
+
run = await client.run_sync(agent="echo", input=[Message(parts=[MessagePart(content="Howdy!")])])
|
60
60
|
print(run)
|
61
61
|
|
62
62
|
```
|
@@ -23,9 +23,9 @@ Register an agent and run the server:
|
|
23
23
|
server = Server()
|
24
24
|
|
25
25
|
@server.agent()
|
26
|
-
async def echo(
|
26
|
+
async def echo(input: list[Message]):
|
27
27
|
"""Echoes everything"""
|
28
|
-
for message in
|
28
|
+
for message in input:
|
29
29
|
yield message
|
30
30
|
|
31
31
|
server.run(port=8000)
|
@@ -35,7 +35,7 @@ From another process, connect to the server and run the agent:
|
|
35
35
|
|
36
36
|
```py
|
37
37
|
async with Client(base_url="http://localhost:8000") as client:
|
38
|
-
run = await client.run_sync(agent="echo",
|
38
|
+
run = await client.run_sync(agent="echo", input=[Message(parts=[MessagePart(content="Howdy!")])])
|
39
39
|
print(run)
|
40
40
|
|
41
41
|
```
|
@@ -71,15 +71,15 @@ async with Client(base_url="http://localhost:8000") as client:
|
|
71
71
|
message = Message(parts=[MessagePart(content="Hello")])
|
72
72
|
|
73
73
|
# Async
|
74
|
-
run = await client.run_async(
|
74
|
+
run = await client.run_async(agent="agent", input=[message])
|
75
75
|
print(run.status)
|
76
76
|
|
77
77
|
# Sync - waits for completion, failure, cancellation or await
|
78
|
-
run = await client.run_sync(
|
78
|
+
run = await client.run_sync(agent="agent", input=[message])
|
79
79
|
print(run.output)
|
80
80
|
|
81
81
|
# Stream - as sync but also receives events
|
82
|
-
async for event in client.run_stream(
|
82
|
+
async for event in client.run_stream(agent="agent", input=[message])
|
83
83
|
print(event)
|
84
84
|
```
|
85
85
|
|
@@ -95,5 +95,5 @@ async with Client(base_url="http://localhost:8000" as client:
|
|
95
95
|
|
96
96
|
async with client.session() as session:
|
97
97
|
for agent in agents:
|
98
|
-
await session.run_sync(
|
98
|
+
await session.run_sync(agent=agent.name, input=[Message(parts=[MessagePart(content="Hello!")])])
|
99
99
|
```
|
@@ -47,9 +47,9 @@ server = Server()
|
|
47
47
|
|
48
48
|
|
49
49
|
@server.agent()
|
50
|
-
async def echo(
|
50
|
+
async def echo(input: list[Message], context: Context) -> AsyncGenerator[RunYield, RunYieldResume]:
|
51
51
|
"""Echoes everything"""
|
52
|
-
for message in
|
52
|
+
for message in input:
|
53
53
|
await asyncio.sleep(0.5)
|
54
54
|
yield {"thought": "I should echo everything"}
|
55
55
|
await asyncio.sleep(0.5)
|
@@ -81,9 +81,9 @@ from acp_sdk.server import RunYield, RunYieldResume, agent, create_app
|
|
81
81
|
|
82
82
|
|
83
83
|
@agent()
|
84
|
-
async def echo(
|
84
|
+
async def echo(input: list[Message]) -> AsyncGenerator[RunYield, RunYieldResume]:
|
85
85
|
"""Echoes everything"""
|
86
|
-
for message in
|
86
|
+
for message in input:
|
87
87
|
yield message
|
88
88
|
|
89
89
|
|
@@ -11,6 +11,8 @@ from httpx_sse import EventSource, aconnect_sse
|
|
11
11
|
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
12
12
|
from pydantic import TypeAdapter
|
13
13
|
|
14
|
+
from acp_sdk.client.types import Input
|
15
|
+
from acp_sdk.client.utils import input_to_messages
|
14
16
|
from acp_sdk.instrumentation import get_tracer
|
15
17
|
from acp_sdk.models import (
|
16
18
|
ACPError,
|
@@ -21,7 +23,6 @@ from acp_sdk.models import (
|
|
21
23
|
AwaitResume,
|
22
24
|
Error,
|
23
25
|
Event,
|
24
|
-
Message,
|
25
26
|
Run,
|
26
27
|
RunCancelResponse,
|
27
28
|
RunCreatedEvent,
|
@@ -33,9 +34,6 @@ from acp_sdk.models import (
|
|
33
34
|
RunResumeResponse,
|
34
35
|
SessionId,
|
35
36
|
)
|
36
|
-
from acp_sdk.models.models import MessagePart
|
37
|
-
|
38
|
-
Input = list[Message] | Message | list[MessagePart] | MessagePart | list[str] | str
|
39
37
|
|
40
38
|
|
41
39
|
class Client:
|
@@ -122,12 +120,17 @@ class Client:
|
|
122
120
|
response = AgentReadResponse.model_validate(response.json())
|
123
121
|
return Agent(**response.model_dump())
|
124
122
|
|
123
|
+
async def ping(self) -> bool:
|
124
|
+
response = await self._client.get("/healthcheck")
|
125
|
+
self._raise_error(response)
|
126
|
+
return response.json() == "OK"
|
127
|
+
|
125
128
|
async def run_sync(self, input: Input, *, agent: AgentName) -> Run:
|
126
129
|
response = await self._client.post(
|
127
130
|
"/runs",
|
128
131
|
content=RunCreateRequest(
|
129
132
|
agent_name=agent,
|
130
|
-
input=
|
133
|
+
input=input_to_messages(input),
|
131
134
|
mode=RunMode.SYNC,
|
132
135
|
session_id=self._session_id,
|
133
136
|
).model_dump_json(),
|
@@ -142,7 +145,7 @@ class Client:
|
|
142
145
|
"/runs",
|
143
146
|
content=RunCreateRequest(
|
144
147
|
agent_name=agent,
|
145
|
-
input=
|
148
|
+
input=input_to_messages(input),
|
146
149
|
mode=RunMode.ASYNC,
|
147
150
|
session_id=self._session_id,
|
148
151
|
).model_dump_json(),
|
@@ -159,7 +162,7 @@ class Client:
|
|
159
162
|
"/runs",
|
160
163
|
content=RunCreateRequest(
|
161
164
|
agent_name=agent,
|
162
|
-
input=
|
165
|
+
input=input_to_messages(input),
|
163
166
|
mode=RunMode.STREAM,
|
164
167
|
session_id=self._session_id,
|
165
168
|
).model_dump_json(),
|
@@ -227,24 +230,3 @@ class Client:
|
|
227
230
|
|
228
231
|
def _set_session(self, run: Run) -> None:
|
229
232
|
self._session_id = run.session_id
|
230
|
-
|
231
|
-
def _unify_inputs(self, input: Input) -> list[Message]:
|
232
|
-
if isinstance(input, list):
|
233
|
-
if len(input) == 0:
|
234
|
-
return []
|
235
|
-
if all(isinstance(item, Message) for item in input):
|
236
|
-
return input
|
237
|
-
elif all(isinstance(item, MessagePart) for item in input):
|
238
|
-
return [Message(parts=input)]
|
239
|
-
elif all(isinstance(item, str) for item in input):
|
240
|
-
return [Message(parts=[MessagePart(content=content) for content in input])]
|
241
|
-
else:
|
242
|
-
raise RuntimeError("List with mixed types is not supported")
|
243
|
-
else:
|
244
|
-
if isinstance(input, str):
|
245
|
-
input = MessagePart(content=input)
|
246
|
-
if isinstance(input, MessagePart):
|
247
|
-
input = Message(parts=[input])
|
248
|
-
if isinstance(input, Message):
|
249
|
-
input = [input]
|
250
|
-
return input
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from acp_sdk.client.types import Input
|
2
|
+
from acp_sdk.models.models import Message, MessagePart
|
3
|
+
|
4
|
+
|
5
|
+
def input_to_messages(input: Input) -> list[Message]:
|
6
|
+
if isinstance(input, list):
|
7
|
+
if len(input) == 0:
|
8
|
+
return []
|
9
|
+
if all(isinstance(item, Message) for item in input):
|
10
|
+
return input
|
11
|
+
elif all(isinstance(item, MessagePart) for item in input):
|
12
|
+
return [Message(parts=input)]
|
13
|
+
elif all(isinstance(item, str) for item in input):
|
14
|
+
return [Message(parts=[MessagePart(content=content) for content in input])]
|
15
|
+
else:
|
16
|
+
raise TypeError("List with mixed types is not supported")
|
17
|
+
else:
|
18
|
+
if isinstance(input, str):
|
19
|
+
input = MessagePart(content=input)
|
20
|
+
if isinstance(input, MessagePart):
|
21
|
+
input = Message(parts=[input])
|
22
|
+
if isinstance(input, Message):
|
23
|
+
input = [input]
|
24
|
+
return input
|
@@ -47,6 +47,11 @@ class Dependency(BaseModel):
|
|
47
47
|
name: str
|
48
48
|
|
49
49
|
|
50
|
+
class Capability(BaseModel):
|
51
|
+
name: str
|
52
|
+
description: str
|
53
|
+
|
54
|
+
|
50
55
|
class Metadata(BaseModel):
|
51
56
|
annotations: AnyModel | None = None
|
52
57
|
documentation: str | None = None
|
@@ -54,7 +59,8 @@ class Metadata(BaseModel):
|
|
54
59
|
programming_language: str | None = None
|
55
60
|
natural_languages: list[str] | None = None
|
56
61
|
framework: str | None = None
|
57
|
-
|
62
|
+
capabilities: list[Capability] | None = None
|
63
|
+
domains: list[str] | None = None
|
58
64
|
tags: list[str] | None = None
|
59
65
|
created_at: datetime | None = None
|
60
66
|
updated_at: datetime | None = None
|
@@ -93,7 +99,7 @@ class Message(BaseModel):
|
|
93
99
|
def __add__(self, other: "Message") -> "Message":
|
94
100
|
if not isinstance(other, Message):
|
95
101
|
raise TypeError(f"Cannot concatenate Message with {type(other).__name__}")
|
96
|
-
return Message(
|
102
|
+
return Message(parts=self.parts + other.parts)
|
97
103
|
|
98
104
|
def __str__(self) -> str:
|
99
105
|
return "".join(
|
@@ -32,14 +32,14 @@ class Agent(abc.ABC):
|
|
32
32
|
|
33
33
|
@abc.abstractmethod
|
34
34
|
def run(
|
35
|
-
self,
|
35
|
+
self, input: list[Message], context: Context
|
36
36
|
) -> (
|
37
37
|
AsyncGenerator[RunYield, RunYieldResume] | Generator[RunYield, RunYieldResume] | Coroutine[RunYield] | RunYield
|
38
38
|
):
|
39
39
|
pass
|
40
40
|
|
41
41
|
async def execute(
|
42
|
-
self,
|
42
|
+
self, input: list[Message], session_id: SessionId | None, executor: ThreadPoolExecutor
|
43
43
|
) -> AsyncGenerator[RunYield, RunYieldResume]:
|
44
44
|
yield_queue: janus.Queue[RunYield] = janus.Queue()
|
45
45
|
yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue()
|
@@ -49,13 +49,13 @@ class Agent(abc.ABC):
|
|
49
49
|
)
|
50
50
|
|
51
51
|
if inspect.isasyncgenfunction(self.run):
|
52
|
-
run = asyncio.create_task(self._run_async_gen(
|
52
|
+
run = asyncio.create_task(self._run_async_gen(input, context))
|
53
53
|
elif inspect.iscoroutinefunction(self.run):
|
54
|
-
run = asyncio.create_task(self._run_coro(
|
54
|
+
run = asyncio.create_task(self._run_coro(input, context))
|
55
55
|
elif inspect.isgeneratorfunction(self.run):
|
56
|
-
run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen,
|
56
|
+
run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, input, context)
|
57
57
|
else:
|
58
|
-
run = asyncio.get_running_loop().run_in_executor(executor, self._run_func,
|
58
|
+
run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context)
|
59
59
|
|
60
60
|
try:
|
61
61
|
while True:
|
@@ -66,7 +66,7 @@ class Agent(abc.ABC):
|
|
66
66
|
finally:
|
67
67
|
await run # Raise exceptions
|
68
68
|
|
69
|
-
async def _run_async_gen(self, input: Message, context: Context) -> None:
|
69
|
+
async def _run_async_gen(self, input: list[Message], context: Context) -> None:
|
70
70
|
try:
|
71
71
|
gen: AsyncGenerator[RunYield, RunYieldResume] = self.run(input, context)
|
72
72
|
value = None
|
@@ -77,13 +77,13 @@ class Agent(abc.ABC):
|
|
77
77
|
finally:
|
78
78
|
context.shutdown()
|
79
79
|
|
80
|
-
async def _run_coro(self, input: Message, context: Context) -> None:
|
80
|
+
async def _run_coro(self, input: list[Message], context: Context) -> None:
|
81
81
|
try:
|
82
82
|
await context.yield_async(await self.run(input, context))
|
83
83
|
finally:
|
84
84
|
context.shutdown()
|
85
85
|
|
86
|
-
def _run_gen(self, input: Message, context: Context) -> None:
|
86
|
+
def _run_gen(self, input: list[Message], context: Context) -> None:
|
87
87
|
try:
|
88
88
|
gen: Generator[RunYield, RunYieldResume] = self.run(input, context)
|
89
89
|
value = None
|
@@ -94,7 +94,7 @@ class Agent(abc.ABC):
|
|
94
94
|
finally:
|
95
95
|
context.shutdown()
|
96
96
|
|
97
|
-
def _run_func(self, input: Message, context: Context) -> None:
|
97
|
+
def _run_func(self, input: list[Message], context: Context) -> None:
|
98
98
|
try:
|
99
99
|
context.yield_sync(self.run(input, context))
|
100
100
|
finally:
|
@@ -139,7 +139,7 @@ def agent(
|
|
139
139
|
if inspect.isasyncgenfunction(fn):
|
140
140
|
|
141
141
|
class AsyncGenDecoratorAgent(DecoratorAgentBase):
|
142
|
-
async def run(self, input: Message, context: Context) -> AsyncGenerator[RunYield, RunYieldResume]:
|
142
|
+
async def run(self, input: list[Message], context: Context) -> AsyncGenerator[RunYield, RunYieldResume]:
|
143
143
|
try:
|
144
144
|
gen: AsyncGenerator[RunYield, RunYieldResume] = (
|
145
145
|
fn(input, context) if has_context_param else fn(input)
|
@@ -154,21 +154,21 @@ def agent(
|
|
154
154
|
elif inspect.iscoroutinefunction(fn):
|
155
155
|
|
156
156
|
class CoroDecoratorAgent(DecoratorAgentBase):
|
157
|
-
async def run(self, input: Message, context: Context) -> Coroutine[RunYield]:
|
157
|
+
async def run(self, input: list[Message], context: Context) -> Coroutine[RunYield]:
|
158
158
|
return await (fn(input, context) if has_context_param else fn(input))
|
159
159
|
|
160
160
|
agent = CoroDecoratorAgent()
|
161
161
|
elif inspect.isgeneratorfunction(fn):
|
162
162
|
|
163
163
|
class GenDecoratorAgent(DecoratorAgentBase):
|
164
|
-
def run(self, input: Message, context: Context) -> Generator[RunYield, RunYieldResume]:
|
164
|
+
def run(self, input: list[Message], context: Context) -> Generator[RunYield, RunYieldResume]:
|
165
165
|
yield from (fn(input, context) if has_context_param else fn(input))
|
166
166
|
|
167
167
|
agent = GenDecoratorAgent()
|
168
168
|
else:
|
169
169
|
|
170
170
|
class FuncDecoratorAgent(DecoratorAgentBase):
|
171
|
-
def run(self, input: Message, context: Context) -> RunYield:
|
171
|
+
def run(self, input: list[Message], context: Context) -> RunYield:
|
172
172
|
return fn(input, context) if has_context_param else fn(input)
|
173
173
|
|
174
174
|
agent = FuncDecoratorAgent()
|
@@ -5,7 +5,7 @@ from datetime import datetime, timedelta
|
|
5
5
|
from enum import Enum
|
6
6
|
|
7
7
|
from cachetools import TTLCache
|
8
|
-
from fastapi import FastAPI, HTTPException, status
|
8
|
+
from fastapi import Depends, FastAPI, HTTPException, status
|
9
9
|
from fastapi.encoders import jsonable_encoder
|
10
10
|
from fastapi.responses import JSONResponse, StreamingResponse
|
11
11
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
@@ -47,7 +47,12 @@ class Headers(str, Enum):
|
|
47
47
|
RUN_ID = "Run-ID"
|
48
48
|
|
49
49
|
|
50
|
-
def create_app(
|
50
|
+
def create_app(
|
51
|
+
*agents: Agent,
|
52
|
+
run_limit: int = 1000,
|
53
|
+
run_ttl: timedelta = timedelta(hours=1),
|
54
|
+
dependencies: list[Depends] | None = None,
|
55
|
+
) -> FastAPI:
|
51
56
|
executor: ThreadPoolExecutor
|
52
57
|
|
53
58
|
@asynccontextmanager
|
@@ -57,7 +62,10 @@ def create_app(*agents: Agent, run_limit: int = 1000, run_ttl: timedelta = timed
|
|
57
62
|
executor = exec
|
58
63
|
yield
|
59
64
|
|
60
|
-
app = FastAPI(
|
65
|
+
app = FastAPI(
|
66
|
+
lifespan=lifespan,
|
67
|
+
dependencies=dependencies,
|
68
|
+
)
|
61
69
|
|
62
70
|
FastAPIInstrumentor.instrument_app(app)
|
63
71
|
|
@@ -96,6 +104,10 @@ def create_app(*agents: Agent, run_limit: int = 1000, run_ttl: timedelta = timed
|
|
96
104
|
agent = find_agent(name)
|
97
105
|
return AgentModel(name=agent.name, description=agent.description, metadata=agent.metadata)
|
98
106
|
|
107
|
+
@app.get("/healthcheck")
|
108
|
+
async def healthcheck() -> str:
|
109
|
+
return "OK"
|
110
|
+
|
99
111
|
@app.post("/runs")
|
100
112
|
async def create_run(request: RunCreateRequest) -> RunCreateResponse:
|
101
113
|
agent = find_agent(request.agent_name)
|
@@ -105,7 +117,7 @@ def create_app(*agents: Agent, run_limit: int = 1000, run_ttl: timedelta = timed
|
|
105
117
|
bundle = RunBundle(
|
106
118
|
agent=agent,
|
107
119
|
run=Run(agent_name=agent.name, session_id=session.id),
|
108
|
-
|
120
|
+
input=request.input,
|
109
121
|
history=list(session.history()),
|
110
122
|
executor=executor,
|
111
123
|
)
|
@@ -35,11 +35,11 @@ from acp_sdk.server.logging import logger
|
|
35
35
|
|
36
36
|
class RunBundle:
|
37
37
|
def __init__(
|
38
|
-
self, *, agent: Agent, run: Run,
|
38
|
+
self, *, agent: Agent, run: Run, input: list[Message], history: list[Message], executor: ThreadPoolExecutor
|
39
39
|
) -> None:
|
40
40
|
self.agent = agent
|
41
41
|
self.run = run
|
42
|
-
self.
|
42
|
+
self.input = input
|
43
43
|
self.history = history
|
44
44
|
|
45
45
|
self.stream_queue: asyncio.Queue[Event] = asyncio.Queue()
|
@@ -47,7 +47,7 @@ class RunBundle:
|
|
47
47
|
self.await_queue: asyncio.Queue[AwaitResume] = asyncio.Queue(maxsize=1)
|
48
48
|
self.await_or_terminate_event = asyncio.Event()
|
49
49
|
|
50
|
-
self.task = asyncio.create_task(self._execute(
|
50
|
+
self.task = asyncio.create_task(self._execute(input, executor=executor))
|
51
51
|
|
52
52
|
async def stream(self) -> AsyncGenerator[Event]:
|
53
53
|
while True:
|
@@ -83,7 +83,7 @@ class RunBundle:
|
|
83
83
|
async def join(self) -> None:
|
84
84
|
await self.await_or_terminate_event.wait()
|
85
85
|
|
86
|
-
async def _execute(self,
|
86
|
+
async def _execute(self, input: list[Message], *, executor: ThreadPoolExecutor) -> None:
|
87
87
|
with get_tracer().start_as_current_span("run"):
|
88
88
|
run_logger = logging.LoggerAdapter(logger, {"run_id": str(self.run.run_id)})
|
89
89
|
|
@@ -99,7 +99,7 @@ class RunBundle:
|
|
99
99
|
await self.emit(RunCreatedEvent(run=self.run))
|
100
100
|
|
101
101
|
generator = self.agent.execute(
|
102
|
-
|
102
|
+
input=self.history + input, session_id=self.run.session_id, executor=executor
|
103
103
|
)
|
104
104
|
run_logger.info("Run started")
|
105
105
|
|
@@ -4,6 +4,7 @@ from collections.abc import Awaitable
|
|
4
4
|
from datetime import timedelta
|
5
5
|
from typing import Any, Callable
|
6
6
|
|
7
|
+
import requests
|
7
8
|
import uvicorn
|
8
9
|
import uvicorn.config
|
9
10
|
|
@@ -12,7 +13,9 @@ from acp_sdk.server.agent import Agent
|
|
12
13
|
from acp_sdk.server.agent import agent as agent_decorator
|
13
14
|
from acp_sdk.server.app import create_app
|
14
15
|
from acp_sdk.server.logging import configure_logger as configure_logger_func
|
16
|
+
from acp_sdk.server.logging import logger
|
15
17
|
from acp_sdk.server.telemetry import configure_telemetry as configure_telemetry_func
|
18
|
+
from acp_sdk.server.utils import async_request_with_retry
|
16
19
|
|
17
20
|
|
18
21
|
class Server:
|
@@ -43,6 +46,7 @@ class Server:
|
|
43
46
|
self,
|
44
47
|
configure_logger: bool = True,
|
45
48
|
configure_telemetry: bool = False,
|
49
|
+
self_registration: bool = True,
|
46
50
|
run_limit: int = 1000,
|
47
51
|
run_ttl: timedelta = timedelta(hours=1),
|
48
52
|
host: str = "127.0.0.1",
|
@@ -158,7 +162,14 @@ class Server:
|
|
158
162
|
h11_max_incomplete_event_size,
|
159
163
|
)
|
160
164
|
self._server = uvicorn.Server(config)
|
161
|
-
|
165
|
+
|
166
|
+
asyncio.run(self._serve(self_registration=self_registration))
|
167
|
+
|
168
|
+
async def _serve(self, self_registration: bool = True) -> None:
|
169
|
+
registration_task = asyncio.create_task(self._register_agent()) if self_registration else None
|
170
|
+
await self._server.serve()
|
171
|
+
if registration_task:
|
172
|
+
registration_task.cancel()
|
162
173
|
|
163
174
|
@property
|
164
175
|
def should_exit(self) -> bool:
|
@@ -167,3 +178,45 @@ class Server:
|
|
167
178
|
@should_exit.setter
|
168
179
|
def should_exit(self, value: bool) -> None:
|
169
180
|
self._server.should_exit = value
|
181
|
+
|
182
|
+
async def _register_agent(self) -> None:
|
183
|
+
"""If not in PRODUCTION mode, register agent to the beeai platform and provide missing env variables"""
|
184
|
+
if os.getenv("PRODUCTION_MODE", False):
|
185
|
+
logger.debug("Agent is not automatically registered in the production mode.")
|
186
|
+
return
|
187
|
+
|
188
|
+
url = os.getenv("PLATFORM_URL", "http://127.0.0.1:8333")
|
189
|
+
for agent in self._agents:
|
190
|
+
request_data = {
|
191
|
+
"location": f"http://{self._server.config.host}:{self._server.config.port}",
|
192
|
+
"id": agent.name,
|
193
|
+
}
|
194
|
+
try:
|
195
|
+
await async_request_with_retry(
|
196
|
+
lambda client, data=request_data: client.post(
|
197
|
+
f"{url}/api/v1/provider/register/unmanaged", json=data
|
198
|
+
)
|
199
|
+
)
|
200
|
+
logger.info("Agent registered to the beeai server.")
|
201
|
+
|
202
|
+
# check missing env keyes
|
203
|
+
envs_request = await async_request_with_retry(lambda client: client.get(f"{url}/api/v1/env"))
|
204
|
+
envs = envs_request.get("env")
|
205
|
+
|
206
|
+
# register all available envs
|
207
|
+
missing_keyes = []
|
208
|
+
for env in agent.metadata.model_dump().get("env", []):
|
209
|
+
server_env = envs.get(env.get("name"))
|
210
|
+
if server_env:
|
211
|
+
logger.debug(f"Env variable {env['name']} = '{server_env}' added dynamically")
|
212
|
+
os.environ[env["name"]] = server_env
|
213
|
+
elif env.get("required"):
|
214
|
+
missing_keyes.append(env)
|
215
|
+
if len(missing_keyes):
|
216
|
+
logger.error(f"Can not run agent, missing required env variables: {missing_keyes}")
|
217
|
+
raise Exception("Missing env variables")
|
218
|
+
|
219
|
+
except requests.exceptions.ConnectionError as e:
|
220
|
+
logger.warning(f"Can not reach server, check if running on {url} : {e}")
|
221
|
+
except (requests.exceptions.HTTPError, Exception) as e:
|
222
|
+
logger.warning(f"Agent can not be registered to beeai server: {e}")
|
@@ -0,0 +1,49 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncGenerator, Coroutine
|
3
|
+
from typing import Any, Callable
|
4
|
+
|
5
|
+
import httpx
|
6
|
+
import requests
|
7
|
+
from pydantic import BaseModel
|
8
|
+
|
9
|
+
from acp_sdk.server.bundle import RunBundle
|
10
|
+
from acp_sdk.server.logging import logger
|
11
|
+
|
12
|
+
|
13
|
+
def encode_sse(model: BaseModel) -> str:
|
14
|
+
return f"data: {model.model_dump_json()}\n\n"
|
15
|
+
|
16
|
+
|
17
|
+
async def stream_sse(bundle: RunBundle) -> AsyncGenerator[str]:
|
18
|
+
async for event in bundle.stream():
|
19
|
+
yield encode_sse(event)
|
20
|
+
|
21
|
+
|
22
|
+
async def async_request_with_retry(
|
23
|
+
request_func: Callable[[httpx.AsyncClient], Coroutine[Any, Any, httpx.Response]],
|
24
|
+
max_retries: int = 5,
|
25
|
+
backoff_factor: float = 1,
|
26
|
+
) -> dict[str, Any]:
|
27
|
+
async with httpx.AsyncClient() as client:
|
28
|
+
retries = 0
|
29
|
+
while retries < max_retries:
|
30
|
+
try:
|
31
|
+
response = await request_func(client)
|
32
|
+
response.raise_for_status()
|
33
|
+
return response.json()
|
34
|
+
except httpx.HTTPStatusError as e:
|
35
|
+
if e.response.status_code in [429, 500, 502, 503, 504, 509]:
|
36
|
+
retries += 1
|
37
|
+
backoff = backoff_factor * (2 ** (retries - 1))
|
38
|
+
logger.warning(f"Request retry (try {retries}/{max_retries}), waiting {backoff} seconds...")
|
39
|
+
await asyncio.sleep(backoff)
|
40
|
+
else:
|
41
|
+
logger.debug("A non-retryable error was encountered.")
|
42
|
+
raise
|
43
|
+
except httpx.RequestError:
|
44
|
+
retries += 1
|
45
|
+
backoff = backoff_factor * (2 ** (retries - 1))
|
46
|
+
logger.warning(f"Request retry (try {retries}/{max_retries}), waiting {backoff} seconds...")
|
47
|
+
await asyncio.sleep(backoff)
|
48
|
+
|
49
|
+
raise requests.exceptions.ConnectionError(f"Request failed after {max_retries} retries.")
|
@@ -17,36 +17,36 @@ def server(request: pytest.FixtureRequest) -> Generator[None]:
|
|
17
17
|
server = Server()
|
18
18
|
|
19
19
|
@server.agent()
|
20
|
-
async def echo(
|
21
|
-
for message in
|
20
|
+
async def echo(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
21
|
+
for message in input:
|
22
22
|
yield message
|
23
23
|
|
24
24
|
@server.agent()
|
25
25
|
async def awaiter(
|
26
|
-
|
26
|
+
input: list[Message], context: Context
|
27
27
|
) -> AsyncGenerator[Message | MessageAwaitRequest, AwaitResume]:
|
28
28
|
yield MessageAwaitRequest(message=Message(parts=[]))
|
29
29
|
yield MessagePart(content="empty", content_type="text/plain")
|
30
30
|
|
31
31
|
@server.agent()
|
32
|
-
async def failer(
|
32
|
+
async def failer(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
33
33
|
yield Error(code=ErrorCode.INVALID_INPUT, message="Wrong question buddy!")
|
34
34
|
|
35
35
|
@server.agent()
|
36
|
-
async def sessioner(
|
36
|
+
async def sessioner(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
37
37
|
assert context.session_id is not None
|
38
38
|
|
39
39
|
yield MessagePart(content=str(context.session_id), content_type="text/plain")
|
40
40
|
|
41
41
|
@server.agent()
|
42
|
-
async def mime_types(
|
42
|
+
async def mime_types(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
43
43
|
yield MessagePart(content="<h1>HTML Content</h1>", content_type="text/html")
|
44
44
|
yield MessagePart(content='{"key": "value"}', content_type="application/json")
|
45
45
|
yield MessagePart(content="console.log('Hello');", content_type="application/javascript")
|
46
46
|
yield MessagePart(content="body { color: red; }", content_type="text/css")
|
47
47
|
|
48
48
|
@server.agent()
|
49
|
-
async def base64_encoding(
|
49
|
+
async def base64_encoding(input: list[Message], context: Context) -> AsyncIterator[Message]:
|
50
50
|
yield Message(
|
51
51
|
parts=[
|
52
52
|
MessagePart(
|
@@ -61,7 +61,7 @@ def server(request: pytest.FixtureRequest) -> Generator[None]:
|
|
61
61
|
)
|
62
62
|
|
63
63
|
@server.agent()
|
64
|
-
async def artifact_producer(
|
64
|
+
async def artifact_producer(input: list[Message], context: Context) -> AsyncGenerator[Message | Artifact, None]:
|
65
65
|
yield MessagePart(content="Processing with artifacts", content_type="text/plain")
|
66
66
|
yield Artifact(name="text-result.txt", content_type="text/plain", content="This is a text artifact result")
|
67
67
|
yield Artifact(
|
@@ -19,33 +19,33 @@ from acp_sdk.models import (
|
|
19
19
|
from acp_sdk.models.errors import ACPError
|
20
20
|
from acp_sdk.server import Server
|
21
21
|
|
22
|
-
|
22
|
+
input = [Message(parts=[MessagePart(content="Hello!")])]
|
23
23
|
await_resume = MessageAwaitResume(message=Message(parts=[]))
|
24
24
|
|
25
25
|
|
26
26
|
@pytest.mark.asyncio
|
27
27
|
async def test_run_sync(server: Server, client: Client) -> None:
|
28
|
-
run = await client.run_sync(agent="echo", input=
|
28
|
+
run = await client.run_sync(agent="echo", input=input)
|
29
29
|
assert run.status == RunStatus.COMPLETED
|
30
|
-
assert run.output ==
|
30
|
+
assert run.output == input
|
31
31
|
|
32
32
|
|
33
33
|
@pytest.mark.asyncio
|
34
34
|
async def test_run_async(server: Server, client: Client) -> None:
|
35
|
-
run = await client.run_async(agent="echo", input=
|
35
|
+
run = await client.run_async(agent="echo", input=input)
|
36
36
|
assert run.status == RunStatus.CREATED
|
37
37
|
|
38
38
|
|
39
39
|
@pytest.mark.asyncio
|
40
40
|
async def test_run_stream(server: Server, client: Client) -> None:
|
41
|
-
event_stream = [event async for event in client.run_stream(agent="echo", input=
|
41
|
+
event_stream = [event async for event in client.run_stream(agent="echo", input=input)]
|
42
42
|
assert isinstance(event_stream[0], RunCreatedEvent)
|
43
43
|
assert isinstance(event_stream[-1], RunCompletedEvent)
|
44
44
|
|
45
45
|
|
46
46
|
@pytest.mark.asyncio
|
47
47
|
async def test_run_status(server: Server, client: Client) -> None:
|
48
|
-
run = await client.run_async(agent="echo", input=
|
48
|
+
run = await client.run_async(agent="echo", input=input)
|
49
49
|
while run.status in (RunStatus.CREATED, RunStatus.IN_PROGRESS):
|
50
50
|
run = await client.run_status(run_id=run.run_id)
|
51
51
|
assert run.status == RunStatus.COMPLETED
|
@@ -53,7 +53,7 @@ async def test_run_status(server: Server, client: Client) -> None:
|
|
53
53
|
|
54
54
|
@pytest.mark.asyncio
|
55
55
|
async def test_failure(server: Server, client: Client) -> None:
|
56
|
-
run = await client.run_sync(agent="failer", input=
|
56
|
+
run = await client.run_sync(agent="failer", input=input)
|
57
57
|
assert run.status == RunStatus.FAILED
|
58
58
|
assert run.error is not None
|
59
59
|
assert run.error.code == ErrorCode.INVALID_INPUT
|
@@ -61,7 +61,7 @@ async def test_failure(server: Server, client: Client) -> None:
|
|
61
61
|
|
62
62
|
@pytest.mark.asyncio
|
63
63
|
async def test_run_cancel(server: Server, client: Client) -> None:
|
64
|
-
run = await client.run_sync(agent="awaiter", input=
|
64
|
+
run = await client.run_sync(agent="awaiter", input=input)
|
65
65
|
assert run.status == RunStatus.AWAITING
|
66
66
|
run = await client.run_cancel(run_id=run.run_id)
|
67
67
|
assert run.status == RunStatus.CANCELLING
|
@@ -69,7 +69,7 @@ async def test_run_cancel(server: Server, client: Client) -> None:
|
|
69
69
|
|
70
70
|
@pytest.mark.asyncio
|
71
71
|
async def test_run_resume_sync(server: Server, client: Client) -> None:
|
72
|
-
run = await client.run_sync(agent="awaiter", input=
|
72
|
+
run = await client.run_sync(agent="awaiter", input=input)
|
73
73
|
assert run.status == RunStatus.AWAITING
|
74
74
|
assert run.await_request is not None
|
75
75
|
|
@@ -79,7 +79,7 @@ async def test_run_resume_sync(server: Server, client: Client) -> None:
|
|
79
79
|
|
80
80
|
@pytest.mark.asyncio
|
81
81
|
async def test_run_resume_async(server: Server, client: Client) -> None:
|
82
|
-
run = await client.run_sync(agent="awaiter", input=
|
82
|
+
run = await client.run_sync(agent="awaiter", input=input)
|
83
83
|
assert run.status == RunStatus.AWAITING
|
84
84
|
assert run.await_request is not None
|
85
85
|
|
@@ -89,7 +89,7 @@ async def test_run_resume_async(server: Server, client: Client) -> None:
|
|
89
89
|
|
90
90
|
@pytest.mark.asyncio
|
91
91
|
async def test_run_resume_stream(server: Server, client: Client) -> None:
|
92
|
-
run = await client.run_sync(agent="awaiter", input=
|
92
|
+
run = await client.run_sync(agent="awaiter", input=input)
|
93
93
|
assert run.status == RunStatus.AWAITING
|
94
94
|
assert run.await_request is not None
|
95
95
|
|
@@ -101,15 +101,15 @@ async def test_run_resume_stream(server: Server, client: Client) -> None:
|
|
101
101
|
@pytest.mark.asyncio
|
102
102
|
async def test_run_session(server: Server, client: Client) -> None:
|
103
103
|
async with client.session() as session:
|
104
|
-
run = await session.run_sync(agent="echo", input=
|
105
|
-
assert run.output ==
|
106
|
-
run = await session.run_sync(agent="echo", input=
|
107
|
-
assert run.output ==
|
104
|
+
run = await session.run_sync(agent="echo", input=input)
|
105
|
+
assert run.output == input
|
106
|
+
run = await session.run_sync(agent="echo", input=input)
|
107
|
+
assert run.output == input + input + input
|
108
108
|
|
109
109
|
|
110
110
|
@pytest.mark.asyncio
|
111
111
|
async def test_mime_types(server: Server, client: Client) -> None:
|
112
|
-
run = await client.run_sync(agent="mime_types", input=
|
112
|
+
run = await client.run_sync(agent="mime_types", input=input)
|
113
113
|
assert run.status == RunStatus.COMPLETED
|
114
114
|
assert len(run.output) == 1
|
115
115
|
|
@@ -130,7 +130,7 @@ async def test_mime_types(server: Server, client: Client) -> None:
|
|
130
130
|
|
131
131
|
@pytest.mark.asyncio
|
132
132
|
async def test_base64_encoding(server: Server, client: Client) -> None:
|
133
|
-
run = await client.run_sync(agent="base64_encoding", input=
|
133
|
+
run = await client.run_sync(agent="base64_encoding", input=input)
|
134
134
|
assert run.status == RunStatus.COMPLETED
|
135
135
|
assert len(run.output) == 1
|
136
136
|
|
@@ -150,7 +150,7 @@ async def test_base64_encoding(server: Server, client: Client) -> None:
|
|
150
150
|
|
151
151
|
@pytest.mark.asyncio
|
152
152
|
async def test_artifacts(server: Server, client: Client) -> None:
|
153
|
-
run = await client.run_sync(agent="artifact_producer", input=
|
153
|
+
run = await client.run_sync(agent="artifact_producer", input=input)
|
154
154
|
assert run.status == RunStatus.COMPLETED
|
155
155
|
|
156
156
|
assert len(run.output) == 1
|
@@ -179,7 +179,7 @@ async def test_artifacts(server: Server, client: Client) -> None:
|
|
179
179
|
|
180
180
|
@pytest.mark.asyncio
|
181
181
|
async def test_artifact_streaming(server: Server, client: Client) -> None:
|
182
|
-
events = [event async for event in client.run_stream(agent="artifact_producer", input=
|
182
|
+
events = [event async for event in client.run_stream(agent="artifact_producer", input=input)]
|
183
183
|
|
184
184
|
assert isinstance(events[0], RunCreatedEvent)
|
185
185
|
assert isinstance(events[-1], RunCompletedEvent)
|
@@ -201,7 +201,7 @@ async def test_artifact_streaming(server: Server, client: Client) -> None:
|
|
201
201
|
@pytest.mark.asyncio
|
202
202
|
@pytest.mark.parametrize("server", [timedelta(seconds=5)], indirect=True)
|
203
203
|
async def test_run_ttl(server: Server, client: Client) -> None:
|
204
|
-
run = await client.run_async(agent="echo", input=
|
204
|
+
run = await client.run_async(agent="echo", input=input)
|
205
205
|
run = await client.run_status(run_id=run.run_id)
|
206
206
|
await asyncio.sleep(6)
|
207
207
|
try:
|
@@ -218,10 +218,10 @@ async def test_run_ttl(server: Server, client: Client) -> None:
|
|
218
218
|
@pytest.mark.parametrize("server", [timedelta(seconds=5)], indirect=True)
|
219
219
|
async def test_session_ttl(server: Server, client: Client) -> None:
|
220
220
|
async with client.session() as session:
|
221
|
-
run = await session.run_sync(agent="echo", input=
|
221
|
+
run = await session.run_sync(agent="echo", input=input)
|
222
222
|
await asyncio.sleep(3)
|
223
|
-
run = await session.run_sync(agent="echo", input=
|
223
|
+
run = await session.run_sync(agent="echo", input=input)
|
224
224
|
assert len(run.output) == 3
|
225
225
|
await asyncio.sleep(3)
|
226
|
-
run = await session.run_sync(agent="echo", input=
|
226
|
+
run = await session.run_sync(agent="echo", input=input)
|
227
227
|
assert len(run.output) == 7 # First run shall be forgotten
|
@@ -0,0 +1,145 @@
|
|
1
|
+
import json
|
2
|
+
import uuid
|
3
|
+
|
4
|
+
import pytest
|
5
|
+
from acp_sdk.client import Client
|
6
|
+
from acp_sdk.models import Agent, AgentsListResponse, Message, MessagePart, Run, RunCompletedEvent
|
7
|
+
from acp_sdk.models.models import MessageAwaitResume
|
8
|
+
from pytest_httpx import HTTPXMock
|
9
|
+
|
10
|
+
mock_agent = Agent(name="mock")
|
11
|
+
mock_agents = [mock_agent]
|
12
|
+
mock_run = Run(
|
13
|
+
agent_name=mock_agent.name, session_id=uuid.uuid4(), output=[Message(parts=[MessagePart(content="Hello!")])]
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
@pytest.mark.asyncio
|
18
|
+
async def test_agents(httpx_mock: HTTPXMock) -> None:
|
19
|
+
httpx_mock.add_response(
|
20
|
+
url="http://test/agents", method="GET", content=AgentsListResponse(agents=mock_agents).model_dump_json()
|
21
|
+
)
|
22
|
+
|
23
|
+
async with Client(base_url="http://test") as client:
|
24
|
+
agents = [agent async for agent in client.agents()]
|
25
|
+
assert agents == mock_agents
|
26
|
+
|
27
|
+
|
28
|
+
@pytest.mark.asyncio
|
29
|
+
async def test_agent(httpx_mock: HTTPXMock) -> None:
|
30
|
+
httpx_mock.add_response(
|
31
|
+
url=f"http://test/agents/{mock_agent.name}", method="GET", content=mock_agent.model_dump_json()
|
32
|
+
)
|
33
|
+
|
34
|
+
async with Client(base_url="http://test") as client:
|
35
|
+
agent = await client.agent(name=mock_agent.name)
|
36
|
+
assert agent == mock_agent
|
37
|
+
|
38
|
+
|
39
|
+
@pytest.mark.asyncio
|
40
|
+
async def test_run_sync(httpx_mock: HTTPXMock) -> None:
|
41
|
+
httpx_mock.add_response(url="http://test/runs", method="POST", content=mock_run.model_dump_json())
|
42
|
+
|
43
|
+
async with Client(base_url="http://test") as client:
|
44
|
+
run = await client.run_sync("Howdy!", agent=mock_run.agent_name)
|
45
|
+
assert run == mock_run
|
46
|
+
|
47
|
+
|
48
|
+
@pytest.mark.asyncio
|
49
|
+
async def test_run_async(httpx_mock: HTTPXMock) -> None:
|
50
|
+
httpx_mock.add_response(url="http://test/runs", method="POST", content=mock_run.model_dump_json())
|
51
|
+
|
52
|
+
async with Client(base_url="http://test") as client:
|
53
|
+
run = await client.run_async("Howdy!", agent=mock_run.agent_name)
|
54
|
+
assert run == mock_run
|
55
|
+
|
56
|
+
|
57
|
+
@pytest.mark.asyncio
|
58
|
+
async def test_run_stream(httpx_mock: HTTPXMock) -> None:
|
59
|
+
mock_event = RunCompletedEvent(run=mock_run)
|
60
|
+
httpx_mock.add_response(
|
61
|
+
url="http://test/runs",
|
62
|
+
method="POST",
|
63
|
+
headers={"content-type": "text/event-stream"},
|
64
|
+
content=f"data: {mock_event.model_dump_json()}\n\n",
|
65
|
+
)
|
66
|
+
|
67
|
+
async with Client(base_url="http://test") as client:
|
68
|
+
async for event in client.run_stream("Howdy!", agent=mock_run.agent_name):
|
69
|
+
assert event == mock_event
|
70
|
+
|
71
|
+
|
72
|
+
@pytest.mark.asyncio
|
73
|
+
async def test_run_status(httpx_mock: HTTPXMock) -> None:
|
74
|
+
httpx_mock.add_response(url=f"http://test/runs/{mock_run.run_id}", method="GET", content=mock_run.model_dump_json())
|
75
|
+
|
76
|
+
async with Client(base_url="http://test") as client:
|
77
|
+
run = await client.run_status(run_id=mock_run.run_id)
|
78
|
+
assert run == mock_run
|
79
|
+
|
80
|
+
|
81
|
+
@pytest.mark.asyncio
|
82
|
+
async def test_run_cancel(httpx_mock: HTTPXMock) -> None:
|
83
|
+
httpx_mock.add_response(
|
84
|
+
url=f"http://test/runs/{mock_run.run_id}/cancel", method="POST", content=mock_run.model_dump_json()
|
85
|
+
)
|
86
|
+
|
87
|
+
async with Client(base_url="http://test") as client:
|
88
|
+
run = await client.run_cancel(run_id=mock_run.run_id)
|
89
|
+
assert run == mock_run
|
90
|
+
|
91
|
+
|
92
|
+
@pytest.mark.asyncio
|
93
|
+
async def test_run_resume_sync(httpx_mock: HTTPXMock) -> None:
|
94
|
+
httpx_mock.add_response(
|
95
|
+
url=f"http://test/runs/{mock_run.run_id}", method="POST", content=mock_run.model_dump_json()
|
96
|
+
)
|
97
|
+
|
98
|
+
async with Client(base_url="http://test") as client:
|
99
|
+
run = await client.run_resume_sync(MessageAwaitResume(message=Message(parts=[])), run_id=mock_run.run_id)
|
100
|
+
assert run == mock_run
|
101
|
+
|
102
|
+
|
103
|
+
@pytest.mark.asyncio
|
104
|
+
async def test_run_resume_async(httpx_mock: HTTPXMock) -> None:
|
105
|
+
httpx_mock.add_response(
|
106
|
+
url=f"http://test/runs/{mock_run.run_id}", method="POST", content=mock_run.model_dump_json()
|
107
|
+
)
|
108
|
+
|
109
|
+
async with Client(base_url="http://test") as client:
|
110
|
+
run = await client.run_resume_async(MessageAwaitResume(message=Message(parts=[])), run_id=mock_run.run_id)
|
111
|
+
assert run == mock_run
|
112
|
+
|
113
|
+
|
114
|
+
@pytest.mark.asyncio
|
115
|
+
async def test_run_resume_stream(httpx_mock: HTTPXMock) -> None:
|
116
|
+
mock_event = RunCompletedEvent(run=mock_run)
|
117
|
+
httpx_mock.add_response(
|
118
|
+
url=f"http://test/runs/{mock_run.run_id}",
|
119
|
+
method="POST",
|
120
|
+
headers={"content-type": "text/event-stream"},
|
121
|
+
content=f"data: {mock_event.model_dump_json()}\n\n",
|
122
|
+
)
|
123
|
+
|
124
|
+
async with Client(base_url="http://test") as client:
|
125
|
+
async for event in client.run_resume_stream(
|
126
|
+
MessageAwaitResume(message=Message(parts=[])), run_id=mock_run.run_id
|
127
|
+
):
|
128
|
+
assert event == mock_event
|
129
|
+
|
130
|
+
|
131
|
+
@pytest.mark.asyncio
|
132
|
+
async def test_session(httpx_mock: HTTPXMock) -> None:
|
133
|
+
httpx_mock.add_response(url="http://test/runs", method="POST", content=mock_run.model_dump_json(), is_reusable=True)
|
134
|
+
|
135
|
+
async with Client(base_url="http://test") as client, client.session(mock_run.session_id) as session:
|
136
|
+
assert session._session_id == mock_run.session_id
|
137
|
+
await session.run_sync("Howdy!", agent=mock_run.agent_name)
|
138
|
+
await client.run_sync("Howdy!", agent=mock_run.agent_name)
|
139
|
+
|
140
|
+
requests = httpx_mock.get_requests()
|
141
|
+
body = json.loads(requests[0].content)
|
142
|
+
assert body["session_id"] == str(mock_run.session_id)
|
143
|
+
|
144
|
+
body = json.loads(requests[1].content)
|
145
|
+
assert body["session_id"] is None
|
@@ -0,0 +1,30 @@
|
|
1
|
+
import pytest
|
2
|
+
from acp_sdk.client.types import Input
|
3
|
+
from acp_sdk.client.utils import input_to_messages
|
4
|
+
from acp_sdk.models import Message, MessagePart
|
5
|
+
|
6
|
+
|
7
|
+
@pytest.mark.parametrize(
|
8
|
+
"input,messages",
|
9
|
+
[
|
10
|
+
([], []),
|
11
|
+
("Hello", [Message(parts=[MessagePart(content="Hello")])]),
|
12
|
+
(["Hello"], [Message(parts=[MessagePart(content="Hello")])]),
|
13
|
+
(MessagePart(content="Hello"), [Message(parts=[MessagePart(content="Hello")])]),
|
14
|
+
([MessagePart(content="Hello")], [Message(parts=[MessagePart(content="Hello")])]),
|
15
|
+
(Message(parts=[MessagePart(content="Hello")]), [Message(parts=[MessagePart(content="Hello")])]),
|
16
|
+
([Message(parts=[MessagePart(content="Hello")])], [Message(parts=[MessagePart(content="Hello")])]),
|
17
|
+
],
|
18
|
+
)
|
19
|
+
def test_input_to_messages(input: Input, messages: list[Message]) -> None:
|
20
|
+
result = input_to_messages(input)
|
21
|
+
assert result == messages
|
22
|
+
|
23
|
+
|
24
|
+
@pytest.mark.parametrize(
|
25
|
+
"input",
|
26
|
+
[["foo", Message(parts=[])], ["foo", MessagePart(content="foo")], [Message(parts=[]), MessagePart(content="foo")]],
|
27
|
+
)
|
28
|
+
def test_input_to_messages_mixed_input(input: Input) -> None:
|
29
|
+
with pytest.raises(TypeError):
|
30
|
+
input_to_messages(["foo", Message(parts=[])])
|
@@ -2,6 +2,25 @@ import pytest
|
|
2
2
|
from acp_sdk.models.models import Message, MessagePart
|
3
3
|
|
4
4
|
|
5
|
+
@pytest.mark.parametrize(
|
6
|
+
"first,second,result",
|
7
|
+
[
|
8
|
+
(
|
9
|
+
Message(parts=[MessagePart(content_type="text/plain", content="Foo")]),
|
10
|
+
Message(parts=[MessagePart(content_type="text/plain", content="Bar")]),
|
11
|
+
Message(
|
12
|
+
parts=[
|
13
|
+
MessagePart(content_type="text/plain", content="Foo"),
|
14
|
+
MessagePart(content_type="text/plain", content="Bar"),
|
15
|
+
]
|
16
|
+
),
|
17
|
+
)
|
18
|
+
],
|
19
|
+
)
|
20
|
+
def test_message_add(first: Message, second: Message, result: Message) -> None:
|
21
|
+
assert first + second == result
|
22
|
+
|
23
|
+
|
5
24
|
@pytest.mark.parametrize(
|
6
25
|
"uncompressed,compressed",
|
7
26
|
[
|
@@ -1,14 +0,0 @@
|
|
1
|
-
from collections.abc import AsyncGenerator
|
2
|
-
|
3
|
-
from pydantic import BaseModel
|
4
|
-
|
5
|
-
from acp_sdk.server.bundle import RunBundle
|
6
|
-
|
7
|
-
|
8
|
-
def encode_sse(model: BaseModel) -> str:
|
9
|
-
return f"data: {model.model_dump_json()}\n\n"
|
10
|
-
|
11
|
-
|
12
|
-
async def stream_sse(bundle: RunBundle) -> AsyncGenerator[str]:
|
13
|
-
async for event in bundle.stream():
|
14
|
-
yield encode_sse(event)
|
@@ -1,36 +0,0 @@
|
|
1
|
-
import pytest
|
2
|
-
from acp_sdk.client import Client
|
3
|
-
from acp_sdk.models import Message, MessagePart, Run, RunCompletedEvent
|
4
|
-
from pytest_httpx import HTTPXMock
|
5
|
-
|
6
|
-
mock_run = Run(agent_name="mock", output=[Message(parts=[MessagePart(content="Hello!")])])
|
7
|
-
|
8
|
-
|
9
|
-
@pytest.mark.asyncio
|
10
|
-
async def test_run_sync(httpx_mock: HTTPXMock) -> None:
|
11
|
-
httpx_mock.add_response(content=mock_run.model_dump_json())
|
12
|
-
|
13
|
-
async with Client(base_url="http://localhost:8000") as client:
|
14
|
-
run = await client.run_sync("Howdy!", agent="mock")
|
15
|
-
assert run == mock_run
|
16
|
-
|
17
|
-
|
18
|
-
@pytest.mark.asyncio
|
19
|
-
async def test_run_async(httpx_mock: HTTPXMock) -> None:
|
20
|
-
httpx_mock.add_response(content=mock_run.model_dump_json())
|
21
|
-
|
22
|
-
async with Client(base_url="http://localhost:8000") as client:
|
23
|
-
run = await client.run_async("Howdy!", agent="mock")
|
24
|
-
assert run == mock_run
|
25
|
-
|
26
|
-
|
27
|
-
@pytest.mark.asyncio
|
28
|
-
async def test_run_stream(httpx_mock: HTTPXMock) -> None:
|
29
|
-
mock_event = RunCompletedEvent(run=mock_run)
|
30
|
-
httpx_mock.add_response(
|
31
|
-
headers={"content-type": "text/event-stream"}, content=f"data: {mock_event.model_dump_json()}\n\n"
|
32
|
-
)
|
33
|
-
|
34
|
-
async with Client(base_url="http://localhost:8000") as client:
|
35
|
-
async for event in client.run_stream("Howdy!", agent="mock"):
|
36
|
-
assert event == mock_event
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|