acp-sdk 0.0.6__py3-none-any.whl → 1.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acp_sdk/client/__init__.py +1 -0
- acp_sdk/client/client.py +135 -0
- acp_sdk/models.py +219 -0
- acp_sdk/server/__init__.py +2 -0
- acp_sdk/server/agent.py +32 -0
- acp_sdk/server/bundle.py +133 -0
- acp_sdk/server/context.py +6 -0
- acp_sdk/server/server.py +137 -0
- acp_sdk/server/telemetry.py +45 -0
- acp_sdk/server/utils.py +12 -0
- acp_sdk-1.0.0rc1.dist-info/METADATA +53 -0
- acp_sdk-1.0.0rc1.dist-info/RECORD +15 -0
- acp/__init__.py +0 -138
- acp/cli/__init__.py +0 -6
- acp/cli/claude.py +0 -139
- acp/cli/cli.py +0 -471
- acp/client/__main__.py +0 -79
- acp/client/session.py +0 -372
- acp/client/sse.py +0 -145
- acp/client/stdio.py +0 -153
- acp/server/__init__.py +0 -3
- acp/server/__main__.py +0 -50
- acp/server/highlevel/__init__.py +0 -9
- acp/server/highlevel/agents/__init__.py +0 -5
- acp/server/highlevel/agents/agent_manager.py +0 -110
- acp/server/highlevel/agents/base.py +0 -20
- acp/server/highlevel/agents/templates.py +0 -21
- acp/server/highlevel/context.py +0 -185
- acp/server/highlevel/exceptions.py +0 -25
- acp/server/highlevel/prompts/__init__.py +0 -4
- acp/server/highlevel/prompts/base.py +0 -167
- acp/server/highlevel/prompts/manager.py +0 -50
- acp/server/highlevel/prompts/prompt_manager.py +0 -33
- acp/server/highlevel/resources/__init__.py +0 -23
- acp/server/highlevel/resources/base.py +0 -48
- acp/server/highlevel/resources/resource_manager.py +0 -94
- acp/server/highlevel/resources/templates.py +0 -80
- acp/server/highlevel/resources/types.py +0 -185
- acp/server/highlevel/server.py +0 -705
- acp/server/highlevel/tools/__init__.py +0 -4
- acp/server/highlevel/tools/base.py +0 -83
- acp/server/highlevel/tools/tool_manager.py +0 -53
- acp/server/highlevel/utilities/__init__.py +0 -1
- acp/server/highlevel/utilities/func_metadata.py +0 -210
- acp/server/highlevel/utilities/logging.py +0 -43
- acp/server/highlevel/utilities/types.py +0 -54
- acp/server/lowlevel/__init__.py +0 -3
- acp/server/lowlevel/helper_types.py +0 -9
- acp/server/lowlevel/server.py +0 -643
- acp/server/models.py +0 -17
- acp/server/session.py +0 -315
- acp/server/sse.py +0 -175
- acp/server/stdio.py +0 -83
- acp/server/websocket.py +0 -61
- acp/shared/__init__.py +0 -0
- acp/shared/context.py +0 -14
- acp/shared/exceptions.py +0 -14
- acp/shared/memory.py +0 -87
- acp/shared/progress.py +0 -40
- acp/shared/session.py +0 -413
- acp/shared/version.py +0 -3
- acp/types.py +0 -1258
- acp_sdk-0.0.6.dist-info/METADATA +0 -46
- acp_sdk-0.0.6.dist-info/RECORD +0 -57
- acp_sdk-0.0.6.dist-info/entry_points.txt +0 -2
- acp_sdk-0.0.6.dist-info/licenses/LICENSE +0 -22
- {acp/client → acp_sdk}/__init__.py +0 -0
- {acp → acp_sdk}/py.typed +0 -0
- {acp_sdk-0.0.6.dist-info → acp_sdk-1.0.0rc1.dist-info}/WHEEL +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
from acp_sdk.client.client import Client
|
acp_sdk/client/client.py
ADDED
@@ -0,0 +1,135 @@
|
|
1
|
+
from types import TracebackType
|
2
|
+
from typing import AsyncIterator
|
3
|
+
|
4
|
+
import httpx
|
5
|
+
from httpx_sse import EventSource, aconnect_sse
|
6
|
+
|
7
|
+
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
8
|
+
|
9
|
+
from acp_sdk.models import (
|
10
|
+
Agent,
|
11
|
+
AgentName,
|
12
|
+
AgentReadResponse,
|
13
|
+
AgentsListResponse,
|
14
|
+
AwaitResume,
|
15
|
+
Message,
|
16
|
+
RunCancelResponse,
|
17
|
+
RunCreateRequest,
|
18
|
+
Run,
|
19
|
+
RunCreateResponse,
|
20
|
+
RunEvent,
|
21
|
+
RunId,
|
22
|
+
RunMode,
|
23
|
+
RunResumeRequest,
|
24
|
+
RunResumeResponse,
|
25
|
+
)
|
26
|
+
from pydantic import TypeAdapter
|
27
|
+
|
28
|
+
|
29
|
+
class Client:
|
30
|
+
def __init__(
|
31
|
+
self, *, base_url: httpx.URL | str = "", client: httpx.AsyncClient | None = None
|
32
|
+
):
|
33
|
+
self.base_url = base_url
|
34
|
+
|
35
|
+
self._client = self._init_client(client)
|
36
|
+
|
37
|
+
def _init_client(self, client: httpx.AsyncClient | None = None):
|
38
|
+
client = client or httpx.AsyncClient(base_url=self.base_url)
|
39
|
+
HTTPXClientInstrumentor.instrument_client(client)
|
40
|
+
return client
|
41
|
+
|
42
|
+
async def __aenter__(self):
|
43
|
+
await self._client.__aenter__()
|
44
|
+
return self
|
45
|
+
|
46
|
+
async def __aexit__(
|
47
|
+
self,
|
48
|
+
exc_type: type[BaseException] | None = None,
|
49
|
+
exc_value: BaseException | None = None,
|
50
|
+
traceback: TracebackType | None = None,
|
51
|
+
):
|
52
|
+
await self._client.__aexit__(exc_type, exc_value, traceback)
|
53
|
+
|
54
|
+
async def agents(self) -> AsyncIterator[Agent]:
|
55
|
+
response = await self._client.get("/agents")
|
56
|
+
for agent in AgentsListResponse.model_validate(response.json()).agents:
|
57
|
+
yield agent
|
58
|
+
|
59
|
+
async def agent(self, *, name: AgentName) -> Agent:
|
60
|
+
response = await self._client.get(f"/agents/{name}")
|
61
|
+
return AgentReadResponse.model_validate(response.json())
|
62
|
+
|
63
|
+
async def run_sync(self, *, agent: AgentName, input: Message) -> Run:
|
64
|
+
response = await self._client.post(
|
65
|
+
"/runs",
|
66
|
+
json=RunCreateRequest(
|
67
|
+
agent_name=agent, input=input, mode=RunMode.SYNC
|
68
|
+
).model_dump(),
|
69
|
+
)
|
70
|
+
return RunCreateResponse.model_validate(response.json())
|
71
|
+
|
72
|
+
async def run_async(self, *, agent: AgentName, input: Message) -> Run:
|
73
|
+
response = await self._client.post(
|
74
|
+
"/runs",
|
75
|
+
json=RunCreateRequest(
|
76
|
+
agent_name=agent, input=input, mode=RunMode.ASYNC
|
77
|
+
).model_dump(),
|
78
|
+
)
|
79
|
+
return RunCreateResponse.model_validate(response.json())
|
80
|
+
|
81
|
+
async def run_stream(
|
82
|
+
self, *, agent: AgentName, input: Message
|
83
|
+
) -> AsyncIterator[RunEvent]:
|
84
|
+
async with aconnect_sse(
|
85
|
+
self._client,
|
86
|
+
"POST",
|
87
|
+
"/runs",
|
88
|
+
json=RunCreateRequest(
|
89
|
+
agent_name=agent, input=input, mode=RunMode.STREAM
|
90
|
+
).model_dump(),
|
91
|
+
) as event_source:
|
92
|
+
async for event in self._validate_stream(event_source):
|
93
|
+
yield event
|
94
|
+
|
95
|
+
async def run_status(self, *, run_id: RunId) -> Run:
|
96
|
+
response = await self._client.get(f"/runs/{run_id}")
|
97
|
+
return Run.model_validate(response.json())
|
98
|
+
|
99
|
+
async def run_cancel(self, *, run_id: RunId) -> Run:
|
100
|
+
response = await self._client.post(f"/runs/{run_id}/cancel")
|
101
|
+
return RunCancelResponse.model_validate(response.json())
|
102
|
+
|
103
|
+
async def run_resume_sync(self, *, run_id: RunId, await_: AwaitResume) -> Run:
|
104
|
+
response = await self._client.post(
|
105
|
+
f"/runs/{run_id}",
|
106
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.SYNC).model_dump(),
|
107
|
+
)
|
108
|
+
return RunResumeResponse.model_validate(response.json())
|
109
|
+
|
110
|
+
async def run_resume_async(self, *, run_id: RunId, await_: AwaitResume) -> Run:
|
111
|
+
response = await self._client.post(
|
112
|
+
f"/runs/{run_id}",
|
113
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.ASYNC).model_dump(),
|
114
|
+
)
|
115
|
+
return RunResumeResponse.model_validate(response.json())
|
116
|
+
|
117
|
+
async def run_resume_stream(
|
118
|
+
self, *, run_id: RunId, await_: AwaitResume
|
119
|
+
) -> AsyncIterator[RunEvent]:
|
120
|
+
async with aconnect_sse(
|
121
|
+
self._client,
|
122
|
+
"POST",
|
123
|
+
f"/runs/{run_id}",
|
124
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.STREAM).model_dump(),
|
125
|
+
) as event_source:
|
126
|
+
async for event in self._validate_stream(event_source):
|
127
|
+
yield event
|
128
|
+
|
129
|
+
async def _validate_stream(
|
130
|
+
self,
|
131
|
+
event_source: EventSource,
|
132
|
+
) -> AsyncIterator[RunEvent]:
|
133
|
+
async for event in event_source.aiter_sse():
|
134
|
+
event = TypeAdapter(RunEvent).validate_json(event.data)
|
135
|
+
yield event
|
acp_sdk/models.py
ADDED
@@ -0,0 +1,219 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Annotated, Literal, Union
|
3
|
+
import uuid
|
4
|
+
|
5
|
+
from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel
|
6
|
+
|
7
|
+
|
8
|
+
class ACPError(BaseModel):
|
9
|
+
code: str
|
10
|
+
message: str
|
11
|
+
|
12
|
+
|
13
|
+
class AnyModel(BaseModel):
|
14
|
+
model_config = ConfigDict(extra="allow")
|
15
|
+
|
16
|
+
|
17
|
+
class MessagePartBase(BaseModel):
|
18
|
+
type: Literal["text", "image", "artifact"]
|
19
|
+
|
20
|
+
|
21
|
+
class TextMessagePart(MessagePartBase):
|
22
|
+
type: Literal["text"] = "text"
|
23
|
+
content: str
|
24
|
+
|
25
|
+
|
26
|
+
class ImageMessagePart(MessagePartBase):
|
27
|
+
type: Literal["image"] = "image"
|
28
|
+
content_url: AnyUrl
|
29
|
+
|
30
|
+
|
31
|
+
class ArtifactMessagePart(MessagePartBase):
|
32
|
+
type: Literal["artifact"] = "artifact"
|
33
|
+
name: str
|
34
|
+
content_url: AnyUrl
|
35
|
+
|
36
|
+
|
37
|
+
MessagePart = Union[TextMessagePart, ImageMessagePart, ArtifactMessagePart]
|
38
|
+
|
39
|
+
|
40
|
+
class Message(RootModel):
|
41
|
+
root: list[MessagePart]
|
42
|
+
|
43
|
+
def __init__(self, *items: MessagePart):
|
44
|
+
super().__init__(root=list(items))
|
45
|
+
|
46
|
+
def __iter__(self):
|
47
|
+
return iter(self.root)
|
48
|
+
|
49
|
+
def __getitem__(self, item):
|
50
|
+
return self.root[item]
|
51
|
+
|
52
|
+
def __add__(self, other: "Message") -> "Message":
|
53
|
+
if not isinstance(other, Message):
|
54
|
+
raise TypeError(f"Cannot concatenate Message with {type(other).__name__}")
|
55
|
+
return Message(*(self.root + other.root))
|
56
|
+
|
57
|
+
def __str__(self):
|
58
|
+
return "".join(
|
59
|
+
str(part) for part in self.root if isinstance(part, TextMessagePart)
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
AgentName = str
|
64
|
+
SessionId = str
|
65
|
+
RunId = str
|
66
|
+
|
67
|
+
|
68
|
+
class RunMode(str, Enum):
|
69
|
+
SYNC = "sync"
|
70
|
+
ASYNC = "async"
|
71
|
+
STREAM = "stream"
|
72
|
+
|
73
|
+
|
74
|
+
class RunStatus(str, Enum):
|
75
|
+
CREATED = "created"
|
76
|
+
IN_PROGRESS = "in-progress"
|
77
|
+
AWAITING = "awaiting"
|
78
|
+
CANCELLING = "cancelling"
|
79
|
+
CANCELLED = "cancelled"
|
80
|
+
COMPLETED = "completed"
|
81
|
+
FAILED = "failed"
|
82
|
+
|
83
|
+
@property
|
84
|
+
def is_terminal(self) -> bool:
|
85
|
+
terminal_states = {RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.CANCELLED}
|
86
|
+
return self in terminal_states
|
87
|
+
|
88
|
+
|
89
|
+
class Await(BaseModel):
|
90
|
+
type: Literal["placeholder"] = "placeholder"
|
91
|
+
|
92
|
+
|
93
|
+
class AwaitResume(BaseModel):
|
94
|
+
pass
|
95
|
+
|
96
|
+
|
97
|
+
class Run(BaseModel):
|
98
|
+
run_id: RunId = str(uuid.uuid4())
|
99
|
+
agent_name: AgentName
|
100
|
+
session_id: SessionId | None = None
|
101
|
+
status: RunStatus = RunStatus.CREATED
|
102
|
+
await_: Await | None = Field(None, alias="await")
|
103
|
+
output: Message | None = None
|
104
|
+
error: ACPError | None = None
|
105
|
+
|
106
|
+
model_config = ConfigDict(populate_by_name=True)
|
107
|
+
|
108
|
+
def model_dump_json(
|
109
|
+
self,
|
110
|
+
**kwargs,
|
111
|
+
):
|
112
|
+
return super().model_dump_json(
|
113
|
+
by_alias=True,
|
114
|
+
**kwargs,
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
class MessageEvent(BaseModel):
|
119
|
+
type: Literal["message"] = "message"
|
120
|
+
message: Message
|
121
|
+
|
122
|
+
|
123
|
+
class AwaitEvent(BaseModel):
|
124
|
+
type: Literal["await"] = "await"
|
125
|
+
await_: Await | None = Field(alias="await")
|
126
|
+
|
127
|
+
model_config = ConfigDict(populate_by_name=True)
|
128
|
+
|
129
|
+
def model_dump_json(
|
130
|
+
self,
|
131
|
+
**kwargs,
|
132
|
+
):
|
133
|
+
return super().model_dump_json(
|
134
|
+
by_alias=True,
|
135
|
+
**kwargs,
|
136
|
+
)
|
137
|
+
|
138
|
+
|
139
|
+
class GenericEvent(BaseModel):
|
140
|
+
type: Literal["generic"] = "generic"
|
141
|
+
generic: AnyModel
|
142
|
+
|
143
|
+
|
144
|
+
class CreatedEvent(BaseModel):
|
145
|
+
type: Literal["created"] = "created"
|
146
|
+
run: Run
|
147
|
+
|
148
|
+
|
149
|
+
class InProgressEvent(BaseModel):
|
150
|
+
type: Literal["in-progress"] = "in-progress"
|
151
|
+
run: Run
|
152
|
+
|
153
|
+
|
154
|
+
class FailedEvent(BaseModel):
|
155
|
+
type: Literal["failed"] = "failed"
|
156
|
+
run: Run
|
157
|
+
|
158
|
+
|
159
|
+
class CancelledEvent(BaseModel):
|
160
|
+
type: Literal["cancelled"] = "cancelled"
|
161
|
+
run: Run
|
162
|
+
|
163
|
+
|
164
|
+
class CompletedEvent(BaseModel):
|
165
|
+
type: Literal["completed"] = "completed"
|
166
|
+
run: Run
|
167
|
+
|
168
|
+
|
169
|
+
RunEvent = Union[
|
170
|
+
CreatedEvent,
|
171
|
+
InProgressEvent,
|
172
|
+
MessageEvent,
|
173
|
+
AwaitEvent,
|
174
|
+
GenericEvent,
|
175
|
+
CancelledEvent,
|
176
|
+
FailedEvent,
|
177
|
+
CompletedEvent,
|
178
|
+
]
|
179
|
+
|
180
|
+
|
181
|
+
class RunCreateRequest(BaseModel):
|
182
|
+
agent_name: AgentName
|
183
|
+
session_id: SessionId | None = None
|
184
|
+
input: Message
|
185
|
+
mode: RunMode = RunMode.SYNC
|
186
|
+
|
187
|
+
|
188
|
+
class RunCreateResponse(Run):
|
189
|
+
pass
|
190
|
+
|
191
|
+
|
192
|
+
class RunResumeRequest(BaseModel):
|
193
|
+
await_: AwaitResume = Field(alias="await")
|
194
|
+
mode: RunMode
|
195
|
+
|
196
|
+
|
197
|
+
class RunResumeResponse(Run):
|
198
|
+
pass
|
199
|
+
|
200
|
+
|
201
|
+
class RunReadResponse(Run):
|
202
|
+
pass
|
203
|
+
|
204
|
+
|
205
|
+
class RunCancelResponse(Run):
|
206
|
+
pass
|
207
|
+
|
208
|
+
|
209
|
+
class Agent(BaseModel):
|
210
|
+
name: str
|
211
|
+
description: str | None = None
|
212
|
+
|
213
|
+
|
214
|
+
class AgentsListResponse(BaseModel):
|
215
|
+
agents: list[Agent]
|
216
|
+
|
217
|
+
|
218
|
+
class AgentReadResponse(Agent):
|
219
|
+
pass
|
acp_sdk/server/agent.py
ADDED
@@ -0,0 +1,32 @@
|
|
1
|
+
import abc
|
2
|
+
from typing import AsyncGenerator
|
3
|
+
|
4
|
+
from acp_sdk.models import (
|
5
|
+
AgentName,
|
6
|
+
Message,
|
7
|
+
Await,
|
8
|
+
AwaitResume,
|
9
|
+
SessionId,
|
10
|
+
)
|
11
|
+
from acp_sdk.server.context import Context
|
12
|
+
|
13
|
+
|
14
|
+
class Agent(abc.ABC):
|
15
|
+
@property
|
16
|
+
def name(self) -> AgentName:
|
17
|
+
return self.__class__.__name__
|
18
|
+
|
19
|
+
@property
|
20
|
+
def description(self) -> str:
|
21
|
+
return ""
|
22
|
+
|
23
|
+
@abc.abstractmethod
|
24
|
+
def run(
|
25
|
+
self, input: Message, *, context: Context
|
26
|
+
) -> AsyncGenerator[Message | Await, AwaitResume]:
|
27
|
+
pass
|
28
|
+
|
29
|
+
async def session(self, session_id: SessionId | None) -> SessionId | None:
|
30
|
+
if session_id:
|
31
|
+
raise NotImplementedError()
|
32
|
+
return None
|
acp_sdk/server/bundle.py
ADDED
@@ -0,0 +1,133 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
|
4
|
+
from opentelemetry import trace
|
5
|
+
from pydantic import ValidationError
|
6
|
+
|
7
|
+
from acp_sdk.server.agent import Agent
|
8
|
+
from acp_sdk.models import (
|
9
|
+
ACPError,
|
10
|
+
AnyModel,
|
11
|
+
Await,
|
12
|
+
AwaitEvent,
|
13
|
+
CancelledEvent,
|
14
|
+
CompletedEvent,
|
15
|
+
CreatedEvent,
|
16
|
+
FailedEvent,
|
17
|
+
GenericEvent,
|
18
|
+
InProgressEvent,
|
19
|
+
Message,
|
20
|
+
MessageEvent,
|
21
|
+
Run,
|
22
|
+
AwaitResume,
|
23
|
+
RunEvent,
|
24
|
+
RunStatus,
|
25
|
+
)
|
26
|
+
from acp_sdk.server.context import Context
|
27
|
+
|
28
|
+
logger = logging.getLogger("uvicorn.error")
|
29
|
+
|
30
|
+
|
31
|
+
class RunBundle:
|
32
|
+
def __init__(self, *, agent: Agent, run: Run, task: asyncio.Task | None = None):
|
33
|
+
self.agent = agent
|
34
|
+
self.run = run
|
35
|
+
self.task = task
|
36
|
+
|
37
|
+
self.stream_queue: asyncio.Queue[RunEvent] = asyncio.Queue()
|
38
|
+
self.composed_message = Message()
|
39
|
+
|
40
|
+
self.await_queue: asyncio.Queue[AwaitResume] = asyncio.Queue(maxsize=1)
|
41
|
+
self.await_or_terminate_event = asyncio.Event()
|
42
|
+
|
43
|
+
async def stream(self):
|
44
|
+
try:
|
45
|
+
while True:
|
46
|
+
event = await self.stream_queue.get()
|
47
|
+
yield event
|
48
|
+
self.stream_queue.task_done()
|
49
|
+
except asyncio.QueueShutDown:
|
50
|
+
pass
|
51
|
+
|
52
|
+
async def emit(self, event: RunEvent):
|
53
|
+
await self.stream_queue.put(event)
|
54
|
+
|
55
|
+
async def await_(self) -> AwaitResume:
|
56
|
+
self.stream_queue.shutdown()
|
57
|
+
self.await_queue.empty()
|
58
|
+
self.await_or_terminate_event.set()
|
59
|
+
self.await_or_terminate_event.clear()
|
60
|
+
resume = await self.await_queue.get()
|
61
|
+
self.await_queue.task_done()
|
62
|
+
return resume
|
63
|
+
|
64
|
+
async def resume(self, resume: AwaitResume):
|
65
|
+
self.stream_queue = asyncio.Queue()
|
66
|
+
await self.await_queue.put(resume)
|
67
|
+
|
68
|
+
async def join(self):
|
69
|
+
await self.await_or_terminate_event.wait()
|
70
|
+
|
71
|
+
async def execute(self, input: Message):
|
72
|
+
with trace.get_tracer(__name__).start_as_current_span("execute"):
|
73
|
+
run_logger = logging.LoggerAdapter(logger, {"run_id": self.run.run_id})
|
74
|
+
|
75
|
+
await self.emit(CreatedEvent(run=self.run))
|
76
|
+
try:
|
77
|
+
self.run.session_id = await self.agent.session(self.run.session_id)
|
78
|
+
run_logger.info("Session loaded")
|
79
|
+
|
80
|
+
generator = self.agent.run(
|
81
|
+
input=input, context=Context(session_id=self.run.session_id)
|
82
|
+
)
|
83
|
+
run_logger.info("Run started")
|
84
|
+
|
85
|
+
self.run.status = RunStatus.IN_PROGRESS
|
86
|
+
await self.emit(InProgressEvent(run=self.run))
|
87
|
+
|
88
|
+
await_resume = None
|
89
|
+
while True:
|
90
|
+
next = await generator.asend(await_resume)
|
91
|
+
if isinstance(next, Message):
|
92
|
+
self.composed_message += next
|
93
|
+
await self.emit(MessageEvent(message=next))
|
94
|
+
elif isinstance(next, Await):
|
95
|
+
self.run.await_ = next
|
96
|
+
self.run.status = RunStatus.AWAITING
|
97
|
+
await self.emit(
|
98
|
+
AwaitEvent.model_validate(
|
99
|
+
{
|
100
|
+
"run_id": self.run.run_id,
|
101
|
+
"type": "await",
|
102
|
+
"await": next,
|
103
|
+
}
|
104
|
+
)
|
105
|
+
)
|
106
|
+
run_logger.info("Run awaited")
|
107
|
+
await_resume = await self.await_()
|
108
|
+
self.run.status = RunStatus.IN_PROGRESS
|
109
|
+
await self.emit(InProgressEvent(run=self.run))
|
110
|
+
run_logger.info("Run resumed")
|
111
|
+
else:
|
112
|
+
try:
|
113
|
+
generic = AnyModel.model_validate(next)
|
114
|
+
await self.emit(GenericEvent(generic=generic))
|
115
|
+
except ValidationError:
|
116
|
+
raise TypeError("Invalid yield")
|
117
|
+
except StopAsyncIteration:
|
118
|
+
self.run.output = self.composed_message
|
119
|
+
self.run.status = RunStatus.COMPLETED
|
120
|
+
await self.emit(CompletedEvent(run=self.run))
|
121
|
+
run_logger.info("Run completed")
|
122
|
+
except asyncio.CancelledError:
|
123
|
+
self.run.status = RunStatus.CANCELLED
|
124
|
+
await self.emit(CancelledEvent(run=self.run))
|
125
|
+
run_logger.info("Run cancelled")
|
126
|
+
except Exception as e:
|
127
|
+
self.run.error = ACPError(code="unspecified", message=str(e))
|
128
|
+
self.run.status = RunStatus.FAILED
|
129
|
+
await self.emit(FailedEvent(run=self.run))
|
130
|
+
run_logger.exception("Run failed")
|
131
|
+
finally:
|
132
|
+
self.await_or_terminate_event.set()
|
133
|
+
self.stream_queue.shutdown()
|
acp_sdk/server/server.py
ADDED
@@ -0,0 +1,137 @@
|
|
1
|
+
import asyncio
|
2
|
+
|
3
|
+
from acp_sdk.server.telemetry import configure_telemetry
|
4
|
+
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
5
|
+
|
6
|
+
from fastapi import FastAPI, HTTPException, status
|
7
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
8
|
+
|
9
|
+
from acp_sdk.server.agent import Agent
|
10
|
+
from acp_sdk.models import (
|
11
|
+
AgentName,
|
12
|
+
Agent as AgentModel,
|
13
|
+
AgentsListResponse,
|
14
|
+
AgentReadResponse,
|
15
|
+
Run,
|
16
|
+
RunCancelResponse,
|
17
|
+
RunCreateRequest,
|
18
|
+
RunCreateResponse,
|
19
|
+
RunId,
|
20
|
+
RunMode,
|
21
|
+
RunReadResponse,
|
22
|
+
RunResumeRequest,
|
23
|
+
RunResumeResponse,
|
24
|
+
RunStatus,
|
25
|
+
)
|
26
|
+
from acp_sdk.server.bundle import RunBundle
|
27
|
+
from acp_sdk.server.utils import stream_sse
|
28
|
+
|
29
|
+
|
30
|
+
def create_app(*agents: Agent) -> FastAPI:
|
31
|
+
app = FastAPI(title="acp-agents")
|
32
|
+
|
33
|
+
configure_telemetry()
|
34
|
+
FastAPIInstrumentor.instrument_app(app)
|
35
|
+
|
36
|
+
agents: dict[AgentName, Agent] = {agent.name: agent for agent in agents}
|
37
|
+
runs: dict[RunId, RunBundle] = dict()
|
38
|
+
|
39
|
+
def find_run_bundle(run_id: RunId):
|
40
|
+
bundle = runs.get(run_id, None)
|
41
|
+
if not bundle:
|
42
|
+
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
43
|
+
return bundle
|
44
|
+
|
45
|
+
def find_agent(agent_name: AgentName):
|
46
|
+
agent = agents.get(agent_name, None)
|
47
|
+
if not agent:
|
48
|
+
raise HTTPException(status_code=404, detail=f"Agent {agent_name} not found")
|
49
|
+
return agent
|
50
|
+
|
51
|
+
@app.get("/agents")
|
52
|
+
async def list() -> AgentsListResponse:
|
53
|
+
return AgentsListResponse(
|
54
|
+
agents=[
|
55
|
+
AgentModel(name=agent.name, description=agent.description)
|
56
|
+
for agent in agents.values()
|
57
|
+
]
|
58
|
+
)
|
59
|
+
|
60
|
+
@app.get("/agents/{name}")
|
61
|
+
async def read(name: AgentName) -> AgentReadResponse:
|
62
|
+
agent = find_agent(name)
|
63
|
+
return AgentModel(name=agent.name, description=agent.description)
|
64
|
+
|
65
|
+
@app.post("/runs")
|
66
|
+
async def create(request: RunCreateRequest) -> RunCreateResponse:
|
67
|
+
agent = find_agent(request.agent_name)
|
68
|
+
bundle = RunBundle(
|
69
|
+
agent=agent,
|
70
|
+
run=Run(
|
71
|
+
agent_name=agent.name,
|
72
|
+
session_id=request.session_id,
|
73
|
+
),
|
74
|
+
)
|
75
|
+
|
76
|
+
bundle.task = asyncio.create_task(bundle.execute(request.input))
|
77
|
+
runs[bundle.run.run_id] = bundle
|
78
|
+
|
79
|
+
match request.mode:
|
80
|
+
case RunMode.STREAM:
|
81
|
+
return StreamingResponse(
|
82
|
+
stream_sse(bundle),
|
83
|
+
media_type="text/event-stream",
|
84
|
+
)
|
85
|
+
case RunMode.SYNC:
|
86
|
+
await bundle.join()
|
87
|
+
return bundle.run
|
88
|
+
case RunMode.ASYNC:
|
89
|
+
return JSONResponse(
|
90
|
+
status_code=status.HTTP_202_ACCEPTED,
|
91
|
+
content=bundle.run.model_dump(),
|
92
|
+
)
|
93
|
+
case _:
|
94
|
+
raise NotImplementedError()
|
95
|
+
|
96
|
+
@app.get("/runs/{run_id}")
|
97
|
+
async def read(run_id: RunId) -> RunReadResponse:
|
98
|
+
bundle = find_run_bundle(run_id)
|
99
|
+
return bundle.run
|
100
|
+
|
101
|
+
@app.post("/runs/{run_id}")
|
102
|
+
async def resume(run_id: RunId, request: RunResumeRequest) -> RunResumeResponse:
|
103
|
+
bundle = find_run_bundle(run_id)
|
104
|
+
bundle.stream_queue = asyncio.Queue() # TODO improve
|
105
|
+
await bundle.await_queue.put(request.await_)
|
106
|
+
match request.mode:
|
107
|
+
case RunMode.STREAM:
|
108
|
+
return StreamingResponse(
|
109
|
+
stream_sse(bundle),
|
110
|
+
media_type="text/event-stream",
|
111
|
+
)
|
112
|
+
case RunMode.SYNC:
|
113
|
+
await bundle.join()
|
114
|
+
return bundle.run
|
115
|
+
case RunMode.ASYNC:
|
116
|
+
return JSONResponse(
|
117
|
+
status_code=status.HTTP_202_ACCEPTED,
|
118
|
+
content=bundle.run.model_dump(),
|
119
|
+
)
|
120
|
+
case _:
|
121
|
+
raise NotImplementedError()
|
122
|
+
|
123
|
+
@app.post("/runs/{run_id}/cancel")
|
124
|
+
async def cancel(run_id: RunId) -> RunCancelResponse:
|
125
|
+
bundle = find_run_bundle(run_id)
|
126
|
+
if bundle.run.status.is_terminal:
|
127
|
+
raise HTTPException(
|
128
|
+
status_code=403,
|
129
|
+
detail=f"Run with terminal status {bundle.run.status} can't be cancelled",
|
130
|
+
)
|
131
|
+
bundle.task.cancel()
|
132
|
+
bundle.run.status = RunStatus.CANCELLING
|
133
|
+
return JSONResponse(
|
134
|
+
status_code=status.HTTP_202_ACCEPTED, content=bundle.run.model_dump()
|
135
|
+
)
|
136
|
+
|
137
|
+
return app
|