acp-sdk 0.0.6__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acp_sdk/__init__.py +1 -0
- acp_sdk/client/__init__.py +1 -0
- acp_sdk/client/client.py +139 -0
- acp_sdk/models/__init__.py +3 -0
- acp_sdk/models/errors.py +23 -0
- acp_sdk/models/models.py +181 -0
- acp_sdk/models/schemas.py +39 -0
- acp_sdk/server/__init__.py +6 -0
- acp_sdk/server/agent.py +105 -0
- acp_sdk/server/app.py +161 -0
- acp_sdk/server/bundle.py +131 -0
- acp_sdk/server/context.py +33 -0
- acp_sdk/server/errors.py +54 -0
- acp_sdk/server/logging.py +16 -0
- acp_sdk/server/server.py +122 -0
- acp_sdk/server/telemetry.py +52 -0
- acp_sdk/server/types.py +6 -0
- acp_sdk/server/utils.py +14 -0
- acp_sdk-0.1.0rc6.dist-info/METADATA +78 -0
- acp_sdk-0.1.0rc6.dist-info/RECORD +22 -0
- acp/__init__.py +0 -138
- acp/cli/__init__.py +0 -6
- acp/cli/claude.py +0 -139
- acp/cli/cli.py +0 -471
- acp/client/__init__.py +0 -0
- acp/client/__main__.py +0 -79
- acp/client/session.py +0 -372
- acp/client/sse.py +0 -145
- acp/client/stdio.py +0 -153
- acp/server/__init__.py +0 -3
- acp/server/__main__.py +0 -50
- acp/server/highlevel/__init__.py +0 -9
- acp/server/highlevel/agents/__init__.py +0 -5
- acp/server/highlevel/agents/agent_manager.py +0 -110
- acp/server/highlevel/agents/base.py +0 -20
- acp/server/highlevel/agents/templates.py +0 -21
- acp/server/highlevel/context.py +0 -185
- acp/server/highlevel/exceptions.py +0 -25
- acp/server/highlevel/prompts/__init__.py +0 -4
- acp/server/highlevel/prompts/base.py +0 -167
- acp/server/highlevel/prompts/manager.py +0 -50
- acp/server/highlevel/prompts/prompt_manager.py +0 -33
- acp/server/highlevel/resources/__init__.py +0 -23
- acp/server/highlevel/resources/base.py +0 -48
- acp/server/highlevel/resources/resource_manager.py +0 -94
- acp/server/highlevel/resources/templates.py +0 -80
- acp/server/highlevel/resources/types.py +0 -185
- acp/server/highlevel/server.py +0 -705
- acp/server/highlevel/tools/__init__.py +0 -4
- acp/server/highlevel/tools/base.py +0 -83
- acp/server/highlevel/tools/tool_manager.py +0 -53
- acp/server/highlevel/utilities/__init__.py +0 -1
- acp/server/highlevel/utilities/func_metadata.py +0 -210
- acp/server/highlevel/utilities/logging.py +0 -43
- acp/server/highlevel/utilities/types.py +0 -54
- acp/server/lowlevel/__init__.py +0 -3
- acp/server/lowlevel/helper_types.py +0 -9
- acp/server/lowlevel/server.py +0 -643
- acp/server/models.py +0 -17
- acp/server/session.py +0 -315
- acp/server/sse.py +0 -175
- acp/server/stdio.py +0 -83
- acp/server/websocket.py +0 -61
- acp/shared/__init__.py +0 -0
- acp/shared/context.py +0 -14
- acp/shared/exceptions.py +0 -14
- acp/shared/memory.py +0 -87
- acp/shared/progress.py +0 -40
- acp/shared/session.py +0 -413
- acp/shared/version.py +0 -3
- acp/types.py +0 -1258
- acp_sdk-0.0.6.dist-info/METADATA +0 -46
- acp_sdk-0.0.6.dist-info/RECORD +0 -57
- acp_sdk-0.0.6.dist-info/entry_points.txt +0 -2
- acp_sdk-0.0.6.dist-info/licenses/LICENSE +0 -22
- {acp → acp_sdk}/py.typed +0 -0
- {acp_sdk-0.0.6.dist-info → acp_sdk-0.1.0rc6.dist-info}/WHEEL +0 -0
acp_sdk/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
from acp_sdk.models import * # noqa: F403
|
@@ -0,0 +1 @@
|
|
1
|
+
from acp_sdk.client.client import Client as Client
|
acp_sdk/client/client.py
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
from collections.abc import AsyncIterator
|
2
|
+
from types import TracebackType
|
3
|
+
from typing import Self
|
4
|
+
|
5
|
+
import httpx
|
6
|
+
from httpx_sse import EventSource, aconnect_sse
|
7
|
+
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
8
|
+
from pydantic import TypeAdapter
|
9
|
+
|
10
|
+
from acp_sdk.models import (
|
11
|
+
ACPError,
|
12
|
+
Agent,
|
13
|
+
AgentName,
|
14
|
+
AgentReadResponse,
|
15
|
+
AgentsListResponse,
|
16
|
+
AwaitResume,
|
17
|
+
Error,
|
18
|
+
Message,
|
19
|
+
Run,
|
20
|
+
RunCancelResponse,
|
21
|
+
RunCreateRequest,
|
22
|
+
RunCreateResponse,
|
23
|
+
RunEvent,
|
24
|
+
RunId,
|
25
|
+
RunMode,
|
26
|
+
RunResumeRequest,
|
27
|
+
RunResumeResponse,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
class Client:
|
32
|
+
def __init__(self, *, base_url: httpx.URL | str = "", client: httpx.AsyncClient | None = None) -> None:
|
33
|
+
self.base_url = base_url
|
34
|
+
|
35
|
+
self._client = self._init_client(client)
|
36
|
+
|
37
|
+
def _init_client(self, client: httpx.AsyncClient | None = None) -> httpx.AsyncClient:
|
38
|
+
client = client or httpx.AsyncClient(base_url=self.base_url)
|
39
|
+
HTTPXClientInstrumentor.instrument_client(client)
|
40
|
+
return client
|
41
|
+
|
42
|
+
async def __aenter__(self) -> Self:
|
43
|
+
await self._client.__aenter__()
|
44
|
+
return self
|
45
|
+
|
46
|
+
async def __aexit__(
|
47
|
+
self,
|
48
|
+
exc_type: type[BaseException] | None = None,
|
49
|
+
exc_value: BaseException | None = None,
|
50
|
+
traceback: TracebackType | None = None,
|
51
|
+
) -> None:
|
52
|
+
await self._client.__aexit__(exc_type, exc_value, traceback)
|
53
|
+
|
54
|
+
async def agents(self) -> AsyncIterator[Agent]:
|
55
|
+
response = await self._client.get("/agents")
|
56
|
+
self._raise_error(response)
|
57
|
+
for agent in AgentsListResponse.model_validate(response.json()).agents:
|
58
|
+
yield agent
|
59
|
+
|
60
|
+
async def agent(self, *, name: AgentName) -> Agent:
|
61
|
+
response = await self._client.get(f"/agents/{name}")
|
62
|
+
self._raise_error(response)
|
63
|
+
return AgentReadResponse.model_validate(response.json())
|
64
|
+
|
65
|
+
async def run_sync(self, *, agent: AgentName, input: Message) -> Run:
|
66
|
+
response = await self._client.post(
|
67
|
+
"/runs",
|
68
|
+
json=RunCreateRequest(agent_name=agent, input=input, mode=RunMode.SYNC).model_dump(),
|
69
|
+
)
|
70
|
+
self._raise_error(response)
|
71
|
+
return RunCreateResponse.model_validate(response.json())
|
72
|
+
|
73
|
+
async def run_async(self, *, agent: AgentName, input: Message) -> Run:
|
74
|
+
response = await self._client.post(
|
75
|
+
"/runs",
|
76
|
+
json=RunCreateRequest(agent_name=agent, input=input, mode=RunMode.ASYNC).model_dump(),
|
77
|
+
)
|
78
|
+
self._raise_error(response)
|
79
|
+
return RunCreateResponse.model_validate(response.json())
|
80
|
+
|
81
|
+
async def run_stream(self, *, agent: AgentName, input: Message) -> AsyncIterator[RunEvent]:
|
82
|
+
async with aconnect_sse(
|
83
|
+
self._client,
|
84
|
+
"POST",
|
85
|
+
"/runs",
|
86
|
+
json=RunCreateRequest(agent_name=agent, input=input, mode=RunMode.STREAM).model_dump(),
|
87
|
+
) as event_source:
|
88
|
+
async for event in self._validate_stream(event_source):
|
89
|
+
yield event
|
90
|
+
|
91
|
+
async def run_status(self, *, run_id: RunId) -> Run:
|
92
|
+
response = await self._client.get(f"/runs/{run_id}")
|
93
|
+
self._raise_error(response)
|
94
|
+
return Run.model_validate(response.json())
|
95
|
+
|
96
|
+
async def run_cancel(self, *, run_id: RunId) -> Run:
|
97
|
+
response = await self._client.post(f"/runs/{run_id}/cancel")
|
98
|
+
self._raise_error(response)
|
99
|
+
return RunCancelResponse.model_validate(response.json())
|
100
|
+
|
101
|
+
async def run_resume_sync(self, *, run_id: RunId, await_: AwaitResume) -> Run:
|
102
|
+
response = await self._client.post(
|
103
|
+
f"/runs/{run_id}",
|
104
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.SYNC).model_dump(),
|
105
|
+
)
|
106
|
+
self._raise_error(response)
|
107
|
+
return RunResumeResponse.model_validate(response.json())
|
108
|
+
|
109
|
+
async def run_resume_async(self, *, run_id: RunId, await_: AwaitResume) -> Run:
|
110
|
+
response = await self._client.post(
|
111
|
+
f"/runs/{run_id}",
|
112
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.ASYNC).model_dump(),
|
113
|
+
)
|
114
|
+
self._raise_error(response)
|
115
|
+
return RunResumeResponse.model_validate(response.json())
|
116
|
+
|
117
|
+
async def run_resume_stream(self, *, run_id: RunId, await_: AwaitResume) -> AsyncIterator[RunEvent]:
|
118
|
+
async with aconnect_sse(
|
119
|
+
self._client,
|
120
|
+
"POST",
|
121
|
+
f"/runs/{run_id}",
|
122
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.STREAM).model_dump(),
|
123
|
+
) as event_source:
|
124
|
+
async for event in self._validate_stream(event_source):
|
125
|
+
yield event
|
126
|
+
|
127
|
+
async def _validate_stream(
|
128
|
+
self,
|
129
|
+
event_source: EventSource,
|
130
|
+
) -> AsyncIterator[RunEvent]:
|
131
|
+
async for event in event_source.aiter_sse():
|
132
|
+
event = TypeAdapter(RunEvent).validate_json(event.data)
|
133
|
+
yield event
|
134
|
+
|
135
|
+
def _raise_error(self, response: httpx.Response) -> None:
|
136
|
+
try:
|
137
|
+
response.raise_for_status()
|
138
|
+
except httpx.HTTPError:
|
139
|
+
raise ACPError(Error.model_validate(response.json()))
|
acp_sdk/models/errors.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
|
6
|
+
class ErrorCode(str, Enum):
|
7
|
+
SERVER_ERROR = "server_error"
|
8
|
+
INVALID_INPUT = "invalid_input"
|
9
|
+
NOT_FOUND = "not_found"
|
10
|
+
|
11
|
+
|
12
|
+
class Error(BaseModel):
|
13
|
+
code: ErrorCode
|
14
|
+
message: str
|
15
|
+
|
16
|
+
|
17
|
+
class ACPError(Exception):
|
18
|
+
def __init__(self, error: Error) -> None:
|
19
|
+
super().__init__()
|
20
|
+
self.error = error
|
21
|
+
|
22
|
+
def __str__(self) -> str:
|
23
|
+
return str(self.error.message)
|
acp_sdk/models/models.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
import uuid
|
2
|
+
from collections.abc import Iterator
|
3
|
+
from enum import Enum
|
4
|
+
from typing import Any, Literal, Union
|
5
|
+
|
6
|
+
from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel
|
7
|
+
|
8
|
+
from acp_sdk.models.errors import Error
|
9
|
+
|
10
|
+
|
11
|
+
class Metadata(BaseModel):
|
12
|
+
model_config = ConfigDict(extra="allow")
|
13
|
+
|
14
|
+
|
15
|
+
class AnyModel(BaseModel):
|
16
|
+
model_config = ConfigDict(extra="allow")
|
17
|
+
|
18
|
+
|
19
|
+
class MessagePartBase(BaseModel):
|
20
|
+
type: Literal["text", "image", "artifact"]
|
21
|
+
|
22
|
+
|
23
|
+
class TextMessagePart(MessagePartBase):
|
24
|
+
type: Literal["text"] = "text"
|
25
|
+
content: str
|
26
|
+
|
27
|
+
|
28
|
+
class ImageMessagePart(MessagePartBase):
|
29
|
+
type: Literal["image"] = "image"
|
30
|
+
content_url: AnyUrl
|
31
|
+
|
32
|
+
|
33
|
+
class ArtifactMessagePart(MessagePartBase):
|
34
|
+
type: Literal["artifact"] = "artifact"
|
35
|
+
name: str
|
36
|
+
content_url: AnyUrl
|
37
|
+
|
38
|
+
|
39
|
+
MessagePart = Union[TextMessagePart, ImageMessagePart, ArtifactMessagePart]
|
40
|
+
|
41
|
+
|
42
|
+
class Message(RootModel):
|
43
|
+
root: list[MessagePart]
|
44
|
+
|
45
|
+
def __init__(self, *items: MessagePart) -> None:
|
46
|
+
super().__init__(root=list(items))
|
47
|
+
|
48
|
+
def __iter__(self) -> Iterator[MessagePart]:
|
49
|
+
return iter(self.root)
|
50
|
+
|
51
|
+
def __add__(self, other: "Message") -> "Message":
|
52
|
+
if not isinstance(other, Message):
|
53
|
+
raise TypeError(f"Cannot concatenate Message with {type(other).__name__}")
|
54
|
+
return Message(*(self.root + other.root))
|
55
|
+
|
56
|
+
def __str__(self) -> str:
|
57
|
+
return "".join(str(part) for part in self.root if isinstance(part, TextMessagePart))
|
58
|
+
|
59
|
+
|
60
|
+
AgentName = str
|
61
|
+
SessionId = str
|
62
|
+
RunId = str
|
63
|
+
|
64
|
+
|
65
|
+
class RunMode(str, Enum):
|
66
|
+
SYNC = "sync"
|
67
|
+
ASYNC = "async"
|
68
|
+
STREAM = "stream"
|
69
|
+
|
70
|
+
|
71
|
+
class RunStatus(str, Enum):
|
72
|
+
CREATED = "created"
|
73
|
+
IN_PROGRESS = "in-progress"
|
74
|
+
AWAITING = "awaiting"
|
75
|
+
CANCELLING = "cancelling"
|
76
|
+
CANCELLED = "cancelled"
|
77
|
+
COMPLETED = "completed"
|
78
|
+
FAILED = "failed"
|
79
|
+
|
80
|
+
@property
|
81
|
+
def is_terminal(self) -> bool:
|
82
|
+
terminal_states = {RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.CANCELLED}
|
83
|
+
return self in terminal_states
|
84
|
+
|
85
|
+
|
86
|
+
class Await(BaseModel):
|
87
|
+
type: Literal["placeholder"] = "placeholder"
|
88
|
+
|
89
|
+
|
90
|
+
class AwaitResume(BaseModel):
|
91
|
+
pass
|
92
|
+
|
93
|
+
|
94
|
+
class Run(BaseModel):
|
95
|
+
run_id: RunId = str(uuid.uuid4())
|
96
|
+
agent_name: AgentName
|
97
|
+
session_id: SessionId | None = None
|
98
|
+
status: RunStatus = RunStatus.CREATED
|
99
|
+
await_: Await | None = Field(None, alias="await")
|
100
|
+
output: Message | None = None
|
101
|
+
error: Error | None = None
|
102
|
+
|
103
|
+
model_config = ConfigDict(populate_by_name=True)
|
104
|
+
|
105
|
+
def model_dump_json(
|
106
|
+
self,
|
107
|
+
**kwargs: dict[str, Any],
|
108
|
+
) -> str:
|
109
|
+
return super().model_dump_json(
|
110
|
+
by_alias=True,
|
111
|
+
**kwargs,
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
class MessageEvent(BaseModel):
|
116
|
+
type: Literal["message"] = "message"
|
117
|
+
message: Message
|
118
|
+
|
119
|
+
|
120
|
+
class AwaitEvent(BaseModel):
|
121
|
+
type: Literal["await"] = "await"
|
122
|
+
await_: Await | None = Field(alias="await")
|
123
|
+
|
124
|
+
model_config = ConfigDict(populate_by_name=True)
|
125
|
+
|
126
|
+
def model_dump_json(
|
127
|
+
self,
|
128
|
+
**kwargs: dict[str, Any],
|
129
|
+
) -> str:
|
130
|
+
return super().model_dump_json(
|
131
|
+
by_alias=True,
|
132
|
+
**kwargs,
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
class GenericEvent(BaseModel):
|
137
|
+
type: Literal["generic"] = "generic"
|
138
|
+
generic: AnyModel
|
139
|
+
|
140
|
+
|
141
|
+
class CreatedEvent(BaseModel):
|
142
|
+
type: Literal["created"] = "created"
|
143
|
+
run: Run
|
144
|
+
|
145
|
+
|
146
|
+
class InProgressEvent(BaseModel):
|
147
|
+
type: Literal["in-progress"] = "in-progress"
|
148
|
+
run: Run
|
149
|
+
|
150
|
+
|
151
|
+
class FailedEvent(BaseModel):
|
152
|
+
type: Literal["failed"] = "failed"
|
153
|
+
run: Run
|
154
|
+
|
155
|
+
|
156
|
+
class CancelledEvent(BaseModel):
|
157
|
+
type: Literal["cancelled"] = "cancelled"
|
158
|
+
run: Run
|
159
|
+
|
160
|
+
|
161
|
+
class CompletedEvent(BaseModel):
|
162
|
+
type: Literal["completed"] = "completed"
|
163
|
+
run: Run
|
164
|
+
|
165
|
+
|
166
|
+
RunEvent = Union[
|
167
|
+
CreatedEvent,
|
168
|
+
InProgressEvent,
|
169
|
+
MessageEvent,
|
170
|
+
AwaitEvent,
|
171
|
+
GenericEvent,
|
172
|
+
CancelledEvent,
|
173
|
+
FailedEvent,
|
174
|
+
CompletedEvent,
|
175
|
+
]
|
176
|
+
|
177
|
+
|
178
|
+
class Agent(BaseModel):
|
179
|
+
name: str
|
180
|
+
description: str | None = None
|
181
|
+
metadata: Metadata = Metadata()
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from pydantic import BaseModel, Field
|
2
|
+
|
3
|
+
from acp_sdk.models.models import Agent, AgentName, AwaitResume, Message, Run, RunMode, SessionId
|
4
|
+
|
5
|
+
|
6
|
+
class AgentsListResponse(BaseModel):
|
7
|
+
agents: list[Agent]
|
8
|
+
|
9
|
+
|
10
|
+
class AgentReadResponse(Agent):
|
11
|
+
pass
|
12
|
+
|
13
|
+
|
14
|
+
class RunCreateRequest(BaseModel):
|
15
|
+
agent_name: AgentName
|
16
|
+
session_id: SessionId | None = None
|
17
|
+
input: Message
|
18
|
+
mode: RunMode = RunMode.SYNC
|
19
|
+
|
20
|
+
|
21
|
+
class RunCreateResponse(Run):
|
22
|
+
pass
|
23
|
+
|
24
|
+
|
25
|
+
class RunResumeRequest(BaseModel):
|
26
|
+
await_: AwaitResume = Field(alias="await")
|
27
|
+
mode: RunMode
|
28
|
+
|
29
|
+
|
30
|
+
class RunResumeResponse(Run):
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class RunReadResponse(Run):
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
class RunCancelResponse(Run):
|
39
|
+
pass
|
@@ -0,0 +1,6 @@
|
|
1
|
+
from acp_sdk.server.agent import Agent as Agent
|
2
|
+
from acp_sdk.server.app import create_app as create_app
|
3
|
+
from acp_sdk.server.context import Context as Context
|
4
|
+
from acp_sdk.server.server import Server as Server
|
5
|
+
from acp_sdk.server.types import RunYield as RunYield
|
6
|
+
from acp_sdk.server.types import RunYieldResume as RunYieldResume
|
acp_sdk/server/agent.py
ADDED
@@ -0,0 +1,105 @@
|
|
1
|
+
import abc
|
2
|
+
import asyncio
|
3
|
+
import inspect
|
4
|
+
from collections.abc import AsyncGenerator, Coroutine, Generator
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
6
|
+
|
7
|
+
import janus
|
8
|
+
|
9
|
+
from acp_sdk.models import (
|
10
|
+
AgentName,
|
11
|
+
Message,
|
12
|
+
SessionId,
|
13
|
+
)
|
14
|
+
from acp_sdk.models.models import Metadata
|
15
|
+
from acp_sdk.server.context import Context
|
16
|
+
from acp_sdk.server.types import RunYield, RunYieldResume
|
17
|
+
|
18
|
+
|
19
|
+
class Agent(abc.ABC):
|
20
|
+
@property
|
21
|
+
def name(self) -> AgentName:
|
22
|
+
return self.__class__.__name__
|
23
|
+
|
24
|
+
@property
|
25
|
+
def description(self) -> str:
|
26
|
+
return ""
|
27
|
+
|
28
|
+
@property
|
29
|
+
def metadata(self) -> Metadata:
|
30
|
+
return Metadata()
|
31
|
+
|
32
|
+
@abc.abstractmethod
|
33
|
+
def run(
|
34
|
+
self, input: Message, context: Context
|
35
|
+
) -> (
|
36
|
+
AsyncGenerator[RunYield, RunYieldResume] | Generator[RunYield, RunYieldResume] | Coroutine[RunYield] | RunYield
|
37
|
+
):
|
38
|
+
pass
|
39
|
+
|
40
|
+
async def session(self, session_id: SessionId | None) -> SessionId | None:
|
41
|
+
if session_id:
|
42
|
+
raise NotImplementedError()
|
43
|
+
return None
|
44
|
+
|
45
|
+
async def execute(
|
46
|
+
self, input: Message, session_id: SessionId | None, executor: ThreadPoolExecutor
|
47
|
+
) -> AsyncGenerator[RunYield, RunYieldResume]:
|
48
|
+
yield_queue: janus.Queue[RunYield] = janus.Queue()
|
49
|
+
yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue()
|
50
|
+
|
51
|
+
context = Context(
|
52
|
+
session_id=session_id, executor=executor, yield_queue=yield_queue, yield_resume_queue=yield_resume_queue
|
53
|
+
)
|
54
|
+
|
55
|
+
if inspect.isasyncgenfunction(self.run):
|
56
|
+
run = asyncio.create_task(self._run_async_gen(input, context))
|
57
|
+
elif inspect.iscoroutinefunction(self.run):
|
58
|
+
run = asyncio.create_task(self._run_coro(input, context))
|
59
|
+
elif inspect.isgeneratorfunction(self.run):
|
60
|
+
run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, input, context)
|
61
|
+
else:
|
62
|
+
run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context)
|
63
|
+
|
64
|
+
try:
|
65
|
+
while True:
|
66
|
+
value = yield await yield_queue.async_q.get()
|
67
|
+
await yield_resume_queue.async_q.put(value)
|
68
|
+
except janus.AsyncQueueShutDown:
|
69
|
+
pass
|
70
|
+
finally:
|
71
|
+
await run # Raise exceptions
|
72
|
+
|
73
|
+
async def _run_async_gen(self, input: Message, context: Context) -> None:
|
74
|
+
try:
|
75
|
+
gen: AsyncGenerator[RunYield, RunYieldResume] = self.run(input, context)
|
76
|
+
value = None
|
77
|
+
while True:
|
78
|
+
value = await context.yield_async(await gen.asend(value))
|
79
|
+
except StopAsyncIteration:
|
80
|
+
pass
|
81
|
+
finally:
|
82
|
+
context.shutdown()
|
83
|
+
|
84
|
+
async def _run_coro(self, input: Message, context: Context) -> None:
|
85
|
+
try:
|
86
|
+
await context.yield_async(await self.run(input, context))
|
87
|
+
finally:
|
88
|
+
context.shutdown()
|
89
|
+
|
90
|
+
def _run_gen(self, input: Message, context: Context) -> None:
|
91
|
+
try:
|
92
|
+
gen: Generator[RunYield, RunYieldResume] = self.run(input, context)
|
93
|
+
value = None
|
94
|
+
while True:
|
95
|
+
value = context.yield_sync(gen.send(value))
|
96
|
+
except StopIteration:
|
97
|
+
pass
|
98
|
+
finally:
|
99
|
+
context.shutdown()
|
100
|
+
|
101
|
+
def _run_func(self, input: Message, context: Context) -> None:
|
102
|
+
try:
|
103
|
+
context.yield_sync(self.run(input, context))
|
104
|
+
finally:
|
105
|
+
context.shutdown()
|
acp_sdk/server/app.py
ADDED
@@ -0,0 +1,161 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncGenerator
|
3
|
+
from concurrent.futures import ThreadPoolExecutor
|
4
|
+
from contextlib import asynccontextmanager
|
5
|
+
|
6
|
+
from fastapi import FastAPI, HTTPException, status
|
7
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
8
|
+
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
9
|
+
|
10
|
+
from acp_sdk.models import (
|
11
|
+
Agent as AgentModel,
|
12
|
+
)
|
13
|
+
from acp_sdk.models import (
|
14
|
+
AgentName,
|
15
|
+
AgentReadResponse,
|
16
|
+
AgentsListResponse,
|
17
|
+
Run,
|
18
|
+
RunCancelResponse,
|
19
|
+
RunCreateRequest,
|
20
|
+
RunCreateResponse,
|
21
|
+
RunId,
|
22
|
+
RunMode,
|
23
|
+
RunReadResponse,
|
24
|
+
RunResumeRequest,
|
25
|
+
RunResumeResponse,
|
26
|
+
RunStatus,
|
27
|
+
)
|
28
|
+
from acp_sdk.models.errors import ACPError
|
29
|
+
from acp_sdk.server.agent import Agent
|
30
|
+
from acp_sdk.server.bundle import RunBundle
|
31
|
+
from acp_sdk.server.errors import (
|
32
|
+
RequestValidationError,
|
33
|
+
StarletteHTTPException,
|
34
|
+
acp_error_handler,
|
35
|
+
catch_all_exception_handler,
|
36
|
+
http_exception_handler,
|
37
|
+
validation_exception_handler,
|
38
|
+
)
|
39
|
+
from acp_sdk.server.utils import stream_sse
|
40
|
+
|
41
|
+
|
42
|
+
def create_app(*agents: Agent) -> FastAPI:
|
43
|
+
executor: ThreadPoolExecutor
|
44
|
+
|
45
|
+
@asynccontextmanager
|
46
|
+
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
|
47
|
+
nonlocal executor
|
48
|
+
with ThreadPoolExecutor(max_workers=5) as exec:
|
49
|
+
executor = exec
|
50
|
+
yield
|
51
|
+
|
52
|
+
app = FastAPI(lifespan=lifespan)
|
53
|
+
|
54
|
+
FastAPIInstrumentor.instrument_app(app)
|
55
|
+
|
56
|
+
agents: dict[AgentName, Agent] = {agent.name: agent for agent in agents}
|
57
|
+
runs: dict[RunId, RunBundle] = {}
|
58
|
+
|
59
|
+
app.exception_handler(ACPError)(acp_error_handler)
|
60
|
+
app.exception_handler(StarletteHTTPException)(http_exception_handler)
|
61
|
+
app.exception_handler(RequestValidationError)(validation_exception_handler)
|
62
|
+
app.exception_handler(Exception)(catch_all_exception_handler)
|
63
|
+
|
64
|
+
def find_run_bundle(run_id: RunId) -> RunBundle:
|
65
|
+
bundle = runs.get(run_id)
|
66
|
+
if not bundle:
|
67
|
+
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
68
|
+
return bundle
|
69
|
+
|
70
|
+
def find_agent(agent_name: AgentName) -> Agent:
|
71
|
+
agent = agents.get(agent_name, None)
|
72
|
+
if not agent:
|
73
|
+
raise HTTPException(status_code=404, detail=f"Agent {agent_name} not found")
|
74
|
+
return agent
|
75
|
+
|
76
|
+
@app.get("/agents")
|
77
|
+
async def list_agents() -> AgentsListResponse:
|
78
|
+
return AgentsListResponse(
|
79
|
+
agents=[
|
80
|
+
AgentModel(name=agent.name, description=agent.description, metadata=agent.metadata)
|
81
|
+
for agent in agents.values()
|
82
|
+
]
|
83
|
+
)
|
84
|
+
|
85
|
+
@app.get("/agents/{name}")
|
86
|
+
async def read_agent(name: AgentName) -> AgentReadResponse:
|
87
|
+
agent = find_agent(name)
|
88
|
+
return AgentModel(name=agent.name, description=agent.description, metadata=agent.metadata)
|
89
|
+
|
90
|
+
@app.post("/runs")
|
91
|
+
async def create_run(request: RunCreateRequest) -> RunCreateResponse:
|
92
|
+
agent = find_agent(request.agent_name)
|
93
|
+
bundle = RunBundle(
|
94
|
+
agent=agent,
|
95
|
+
run=Run(
|
96
|
+
agent_name=agent.name,
|
97
|
+
session_id=request.session_id,
|
98
|
+
),
|
99
|
+
)
|
100
|
+
|
101
|
+
nonlocal executor
|
102
|
+
bundle.task = asyncio.create_task(bundle.execute(request.input, executor=executor))
|
103
|
+
runs[bundle.run.run_id] = bundle
|
104
|
+
|
105
|
+
match request.mode:
|
106
|
+
case RunMode.STREAM:
|
107
|
+
return StreamingResponse(
|
108
|
+
stream_sse(bundle),
|
109
|
+
media_type="text/event-stream",
|
110
|
+
)
|
111
|
+
case RunMode.SYNC:
|
112
|
+
await bundle.join()
|
113
|
+
return bundle.run
|
114
|
+
case RunMode.ASYNC:
|
115
|
+
return JSONResponse(
|
116
|
+
status_code=status.HTTP_202_ACCEPTED,
|
117
|
+
content=bundle.run.model_dump(),
|
118
|
+
)
|
119
|
+
case _:
|
120
|
+
raise NotImplementedError()
|
121
|
+
|
122
|
+
@app.get("/runs/{run_id}")
|
123
|
+
async def read_run(run_id: RunId) -> RunReadResponse:
|
124
|
+
bundle = find_run_bundle(run_id)
|
125
|
+
return bundle.run
|
126
|
+
|
127
|
+
@app.post("/runs/{run_id}")
|
128
|
+
async def resume_run(run_id: RunId, request: RunResumeRequest) -> RunResumeResponse:
|
129
|
+
bundle = find_run_bundle(run_id)
|
130
|
+
bundle.stream_queue = asyncio.Queue() # TODO improve
|
131
|
+
await bundle.await_queue.put(request.await_)
|
132
|
+
match request.mode:
|
133
|
+
case RunMode.STREAM:
|
134
|
+
return StreamingResponse(
|
135
|
+
stream_sse(bundle),
|
136
|
+
media_type="text/event-stream",
|
137
|
+
)
|
138
|
+
case RunMode.SYNC:
|
139
|
+
await bundle.join()
|
140
|
+
return bundle.run
|
141
|
+
case RunMode.ASYNC:
|
142
|
+
return JSONResponse(
|
143
|
+
status_code=status.HTTP_202_ACCEPTED,
|
144
|
+
content=bundle.run.model_dump(),
|
145
|
+
)
|
146
|
+
case _:
|
147
|
+
raise NotImplementedError()
|
148
|
+
|
149
|
+
@app.post("/runs/{run_id}/cancel")
|
150
|
+
async def cancel_run(run_id: RunId) -> RunCancelResponse:
|
151
|
+
bundle = find_run_bundle(run_id)
|
152
|
+
if bundle.run.status.is_terminal:
|
153
|
+
raise HTTPException(
|
154
|
+
status_code=403,
|
155
|
+
detail=f"Run with terminal status {bundle.run.status} can't be cancelled",
|
156
|
+
)
|
157
|
+
bundle.task.cancel()
|
158
|
+
bundle.run.status = RunStatus.CANCELLING
|
159
|
+
return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content=bundle.run.model_dump())
|
160
|
+
|
161
|
+
return app
|