acp-sdk 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acp_sdk/__init__.py +2 -0
- acp_sdk/client/__init__.py +1 -0
- acp_sdk/client/client.py +178 -0
- acp_sdk/models/__init__.py +3 -0
- acp_sdk/models/errors.py +23 -0
- acp_sdk/models/models.py +192 -0
- acp_sdk/models/schemas.py +41 -0
- acp_sdk/server/__init__.py +7 -0
- acp_sdk/server/agent.py +178 -0
- acp_sdk/server/app.py +176 -0
- acp_sdk/server/bundle.py +146 -0
- acp_sdk/server/context.py +33 -0
- acp_sdk/server/errors.py +54 -0
- acp_sdk/server/logging.py +16 -0
- acp_sdk/server/server.py +166 -0
- acp_sdk/server/session.py +21 -0
- acp_sdk/server/telemetry.py +57 -0
- acp_sdk/server/types.py +6 -0
- acp_sdk/server/utils.py +14 -0
- acp_sdk/version.py +3 -0
- acp_sdk-0.1.0.dist-info/METADATA +113 -0
- acp_sdk-0.1.0.dist-info/RECORD +24 -0
- acp/__init__.py +0 -138
- acp/cli/__init__.py +0 -6
- acp/cli/claude.py +0 -139
- acp/cli/cli.py +0 -471
- acp/client/__init__.py +0 -0
- acp/client/__main__.py +0 -79
- acp/client/session.py +0 -372
- acp/client/sse.py +0 -145
- acp/client/stdio.py +0 -153
- acp/server/__init__.py +0 -3
- acp/server/__main__.py +0 -50
- acp/server/highlevel/__init__.py +0 -9
- acp/server/highlevel/agents/__init__.py +0 -5
- acp/server/highlevel/agents/agent_manager.py +0 -110
- acp/server/highlevel/agents/base.py +0 -20
- acp/server/highlevel/agents/templates.py +0 -21
- acp/server/highlevel/context.py +0 -185
- acp/server/highlevel/exceptions.py +0 -25
- acp/server/highlevel/prompts/__init__.py +0 -4
- acp/server/highlevel/prompts/base.py +0 -167
- acp/server/highlevel/prompts/manager.py +0 -50
- acp/server/highlevel/prompts/prompt_manager.py +0 -33
- acp/server/highlevel/resources/__init__.py +0 -23
- acp/server/highlevel/resources/base.py +0 -48
- acp/server/highlevel/resources/resource_manager.py +0 -94
- acp/server/highlevel/resources/templates.py +0 -80
- acp/server/highlevel/resources/types.py +0 -185
- acp/server/highlevel/server.py +0 -705
- acp/server/highlevel/tools/__init__.py +0 -4
- acp/server/highlevel/tools/base.py +0 -83
- acp/server/highlevel/tools/tool_manager.py +0 -53
- acp/server/highlevel/utilities/__init__.py +0 -1
- acp/server/highlevel/utilities/func_metadata.py +0 -210
- acp/server/highlevel/utilities/logging.py +0 -43
- acp/server/highlevel/utilities/types.py +0 -54
- acp/server/lowlevel/__init__.py +0 -3
- acp/server/lowlevel/helper_types.py +0 -9
- acp/server/lowlevel/server.py +0 -643
- acp/server/models.py +0 -17
- acp/server/session.py +0 -315
- acp/server/sse.py +0 -175
- acp/server/stdio.py +0 -83
- acp/server/websocket.py +0 -61
- acp/shared/__init__.py +0 -0
- acp/shared/context.py +0 -14
- acp/shared/exceptions.py +0 -14
- acp/shared/memory.py +0 -87
- acp/shared/progress.py +0 -40
- acp/shared/session.py +0 -413
- acp/shared/version.py +0 -3
- acp/types.py +0 -1258
- acp_sdk-0.0.6.dist-info/METADATA +0 -46
- acp_sdk-0.0.6.dist-info/RECORD +0 -57
- acp_sdk-0.0.6.dist-info/entry_points.txt +0 -2
- acp_sdk-0.0.6.dist-info/licenses/LICENSE +0 -22
- {acp → acp_sdk}/py.typed +0 -0
- {acp_sdk-0.0.6.dist-info → acp_sdk-0.1.0.dist-info}/WHEEL +0 -0
acp_sdk/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
from acp_sdk.client.client import Client as Client
|
acp_sdk/client/client.py
ADDED
@@ -0,0 +1,178 @@
|
|
1
|
+
import uuid
|
2
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
3
|
+
from contextlib import asynccontextmanager
|
4
|
+
from types import TracebackType
|
5
|
+
from typing import Self
|
6
|
+
|
7
|
+
import httpx
|
8
|
+
from httpx_sse import EventSource, aconnect_sse
|
9
|
+
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
10
|
+
from pydantic import TypeAdapter
|
11
|
+
|
12
|
+
from acp_sdk.models import (
|
13
|
+
ACPError,
|
14
|
+
Agent,
|
15
|
+
AgentName,
|
16
|
+
AgentReadResponse,
|
17
|
+
AgentsListResponse,
|
18
|
+
AwaitResume,
|
19
|
+
CreatedEvent,
|
20
|
+
Error,
|
21
|
+
Message,
|
22
|
+
Run,
|
23
|
+
RunCancelResponse,
|
24
|
+
RunCreateRequest,
|
25
|
+
RunCreateResponse,
|
26
|
+
RunEvent,
|
27
|
+
RunId,
|
28
|
+
RunMode,
|
29
|
+
RunResumeRequest,
|
30
|
+
RunResumeResponse,
|
31
|
+
SessionId,
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class Client:
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
*,
|
39
|
+
base_url: httpx.URL | str = "",
|
40
|
+
session_id: SessionId | None = None,
|
41
|
+
client: httpx.AsyncClient | None = None,
|
42
|
+
) -> None:
|
43
|
+
self.base_url = base_url
|
44
|
+
self.session_id = session_id
|
45
|
+
|
46
|
+
self._client = self._init_client(client)
|
47
|
+
|
48
|
+
def _init_client(self, client: httpx.AsyncClient | None = None) -> httpx.AsyncClient:
|
49
|
+
client = client or httpx.AsyncClient(base_url=self.base_url)
|
50
|
+
HTTPXClientInstrumentor.instrument_client(client)
|
51
|
+
return client
|
52
|
+
|
53
|
+
async def __aenter__(self) -> Self:
|
54
|
+
await self._client.__aenter__()
|
55
|
+
return self
|
56
|
+
|
57
|
+
async def __aexit__(
|
58
|
+
self,
|
59
|
+
exc_type: type[BaseException] | None = None,
|
60
|
+
exc_value: BaseException | None = None,
|
61
|
+
traceback: TracebackType | None = None,
|
62
|
+
) -> None:
|
63
|
+
await self._client.__aexit__(exc_type, exc_value, traceback)
|
64
|
+
|
65
|
+
@asynccontextmanager
|
66
|
+
async def session(self, session_id: SessionId | None = None) -> AsyncGenerator[Self]:
|
67
|
+
yield Client(client=self._client, session_id=session_id or uuid.uuid4())
|
68
|
+
|
69
|
+
async def agents(self) -> AsyncIterator[Agent]:
|
70
|
+
response = await self._client.get("/agents")
|
71
|
+
self._raise_error(response)
|
72
|
+
for agent in AgentsListResponse.model_validate(response.json()).agents:
|
73
|
+
yield agent
|
74
|
+
|
75
|
+
async def agent(self, *, name: AgentName) -> Agent:
|
76
|
+
response = await self._client.get(f"/agents/{name}")
|
77
|
+
self._raise_error(response)
|
78
|
+
return AgentReadResponse.model_validate(response.json())
|
79
|
+
|
80
|
+
async def run_sync(self, *, agent: AgentName, inputs: list[Message]) -> Run:
|
81
|
+
response = await self._client.post(
|
82
|
+
"/runs",
|
83
|
+
content=RunCreateRequest(
|
84
|
+
agent_name=agent,
|
85
|
+
inputs=inputs,
|
86
|
+
mode=RunMode.SYNC,
|
87
|
+
session_id=self.session_id,
|
88
|
+
).model_dump_json(),
|
89
|
+
)
|
90
|
+
self._raise_error(response)
|
91
|
+
response = RunCreateResponse.model_validate(response.json())
|
92
|
+
self._set_session(response)
|
93
|
+
return response
|
94
|
+
|
95
|
+
async def run_async(self, *, agent: AgentName, inputs: list[Message]) -> Run:
|
96
|
+
response = await self._client.post(
|
97
|
+
"/runs",
|
98
|
+
content=RunCreateRequest(
|
99
|
+
agent_name=agent,
|
100
|
+
inputs=inputs,
|
101
|
+
mode=RunMode.ASYNC,
|
102
|
+
session_id=self.session_id,
|
103
|
+
).model_dump_json(),
|
104
|
+
)
|
105
|
+
self._raise_error(response)
|
106
|
+
response = RunCreateResponse.model_validate(response.json())
|
107
|
+
self._set_session(response)
|
108
|
+
return response
|
109
|
+
|
110
|
+
async def run_stream(self, *, agent: AgentName, inputs: list[Message]) -> AsyncIterator[RunEvent]:
|
111
|
+
async with aconnect_sse(
|
112
|
+
self._client,
|
113
|
+
"POST",
|
114
|
+
"/runs",
|
115
|
+
content=RunCreateRequest(
|
116
|
+
agent_name=agent,
|
117
|
+
inputs=inputs,
|
118
|
+
mode=RunMode.STREAM,
|
119
|
+
session_id=self.session_id,
|
120
|
+
).model_dump_json(),
|
121
|
+
) as event_source:
|
122
|
+
async for event in self._validate_stream(event_source):
|
123
|
+
if isinstance(event, CreatedEvent):
|
124
|
+
self._set_session(event.run)
|
125
|
+
yield event
|
126
|
+
|
127
|
+
async def run_status(self, *, run_id: RunId) -> Run:
|
128
|
+
response = await self._client.get(f"/runs/{run_id}")
|
129
|
+
self._raise_error(response)
|
130
|
+
return Run.model_validate(response.json())
|
131
|
+
|
132
|
+
async def run_cancel(self, *, run_id: RunId) -> Run:
|
133
|
+
response = await self._client.post(f"/runs/{run_id}/cancel")
|
134
|
+
self._raise_error(response)
|
135
|
+
return RunCancelResponse.model_validate(response.json())
|
136
|
+
|
137
|
+
async def run_resume_sync(self, *, run_id: RunId, await_: AwaitResume) -> Run:
|
138
|
+
response = await self._client.post(
|
139
|
+
f"/runs/{run_id}",
|
140
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.SYNC).model_dump(),
|
141
|
+
)
|
142
|
+
self._raise_error(response)
|
143
|
+
return RunResumeResponse.model_validate(response.json())
|
144
|
+
|
145
|
+
async def run_resume_async(self, *, run_id: RunId, await_: AwaitResume) -> Run:
|
146
|
+
response = await self._client.post(
|
147
|
+
f"/runs/{run_id}",
|
148
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.ASYNC).model_dump(),
|
149
|
+
)
|
150
|
+
self._raise_error(response)
|
151
|
+
return RunResumeResponse.model_validate(response.json())
|
152
|
+
|
153
|
+
async def run_resume_stream(self, *, run_id: RunId, await_: AwaitResume) -> AsyncIterator[RunEvent]:
|
154
|
+
async with aconnect_sse(
|
155
|
+
self._client,
|
156
|
+
"POST",
|
157
|
+
f"/runs/{run_id}",
|
158
|
+
json=RunResumeRequest(await_=await_, mode=RunMode.STREAM).model_dump(),
|
159
|
+
) as event_source:
|
160
|
+
async for event in self._validate_stream(event_source):
|
161
|
+
yield event
|
162
|
+
|
163
|
+
async def _validate_stream(
|
164
|
+
self,
|
165
|
+
event_source: EventSource,
|
166
|
+
) -> AsyncIterator[RunEvent]:
|
167
|
+
async for event in event_source.aiter_sse():
|
168
|
+
event = TypeAdapter(RunEvent).validate_json(event.data)
|
169
|
+
yield event
|
170
|
+
|
171
|
+
def _raise_error(self, response: httpx.Response) -> None:
|
172
|
+
try:
|
173
|
+
response.raise_for_status()
|
174
|
+
except httpx.HTTPError:
|
175
|
+
raise ACPError(Error.model_validate(response.json()))
|
176
|
+
|
177
|
+
def _set_session(self, run: Run) -> None:
|
178
|
+
self.session_id = run.session_id
|
acp_sdk/models/errors.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
|
6
|
+
class ErrorCode(str, Enum):
|
7
|
+
SERVER_ERROR = "server_error"
|
8
|
+
INVALID_INPUT = "invalid_input"
|
9
|
+
NOT_FOUND = "not_found"
|
10
|
+
|
11
|
+
|
12
|
+
class Error(BaseModel):
|
13
|
+
code: ErrorCode
|
14
|
+
message: str
|
15
|
+
|
16
|
+
|
17
|
+
class ACPError(Exception):
|
18
|
+
def __init__(self, error: Error) -> None:
|
19
|
+
super().__init__()
|
20
|
+
self.error = error
|
21
|
+
|
22
|
+
def __str__(self) -> str:
|
23
|
+
return str(self.error.message)
|
acp_sdk/models/models.py
ADDED
@@ -0,0 +1,192 @@
|
|
1
|
+
import uuid
|
2
|
+
from enum import Enum
|
3
|
+
from typing import Any, Literal, Optional, Union
|
4
|
+
|
5
|
+
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
6
|
+
|
7
|
+
from acp_sdk.models.errors import Error
|
8
|
+
|
9
|
+
|
10
|
+
class Metadata(BaseModel):
|
11
|
+
model_config = ConfigDict(extra="allow")
|
12
|
+
|
13
|
+
|
14
|
+
class AnyModel(BaseModel):
|
15
|
+
model_config = ConfigDict(extra="allow")
|
16
|
+
|
17
|
+
|
18
|
+
class MessagePart(BaseModel):
|
19
|
+
name: Optional[str] = None
|
20
|
+
content_type: str
|
21
|
+
content: Optional[str] = None
|
22
|
+
content_encoding: Optional[Literal["plain", "base64"]] = "plain"
|
23
|
+
content_url: Optional[AnyUrl] = None
|
24
|
+
|
25
|
+
model_config = ConfigDict(extra="forbid")
|
26
|
+
|
27
|
+
def model_post_init(self, __context: Any) -> None:
|
28
|
+
if self.content is None and self.content_url is None:
|
29
|
+
raise ValueError("Either content or content_url must be provided")
|
30
|
+
if self.content is not None and self.content_url is not None:
|
31
|
+
raise ValueError("Only one of content or content_url can be provided")
|
32
|
+
|
33
|
+
|
34
|
+
class Message(BaseModel):
|
35
|
+
parts: list[MessagePart]
|
36
|
+
|
37
|
+
def __add__(self, other: "Message") -> "Message":
|
38
|
+
if not isinstance(other, Message):
|
39
|
+
raise TypeError(f"Cannot concatenate Message with {type(other).__name__}")
|
40
|
+
return Message(*(self.parts + other.parts))
|
41
|
+
|
42
|
+
def __str__(self) -> str:
|
43
|
+
return "".join(
|
44
|
+
part.content for part in self.parts if part.content is not None and part.content_type == "text/plain"
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
AgentName = str
|
49
|
+
SessionId = uuid.UUID
|
50
|
+
RunId = uuid.UUID
|
51
|
+
|
52
|
+
|
53
|
+
class RunMode(str, Enum):
|
54
|
+
SYNC = "sync"
|
55
|
+
ASYNC = "async"
|
56
|
+
STREAM = "stream"
|
57
|
+
|
58
|
+
|
59
|
+
class RunStatus(str, Enum):
|
60
|
+
CREATED = "created"
|
61
|
+
IN_PROGRESS = "in-progress"
|
62
|
+
AWAITING = "awaiting"
|
63
|
+
CANCELLING = "cancelling"
|
64
|
+
CANCELLED = "cancelled"
|
65
|
+
COMPLETED = "completed"
|
66
|
+
FAILED = "failed"
|
67
|
+
|
68
|
+
@property
|
69
|
+
def is_terminal(self) -> bool:
|
70
|
+
terminal_states = {RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.CANCELLED}
|
71
|
+
return self in terminal_states
|
72
|
+
|
73
|
+
|
74
|
+
class Await(BaseModel):
|
75
|
+
type: Literal["placeholder"] = "placeholder"
|
76
|
+
|
77
|
+
|
78
|
+
class AwaitResume(BaseModel):
|
79
|
+
pass
|
80
|
+
|
81
|
+
|
82
|
+
class Artifact(BaseModel):
|
83
|
+
name: str
|
84
|
+
content_type: str
|
85
|
+
content: Optional[str] = None
|
86
|
+
content_encoding: Optional[Literal["plain", "base64"]] = "plain"
|
87
|
+
content_url: Optional[AnyUrl] = None
|
88
|
+
|
89
|
+
model_config = ConfigDict(extra="forbid")
|
90
|
+
|
91
|
+
def model_post_init(self, __context: Any) -> None:
|
92
|
+
if self.content is None and self.content_url is None:
|
93
|
+
raise ValueError("Either content or content_url must be provided")
|
94
|
+
if self.content is not None and self.content_url is not None:
|
95
|
+
raise ValueError("Only one of content or content_url can be provided")
|
96
|
+
|
97
|
+
|
98
|
+
class Run(BaseModel):
|
99
|
+
run_id: RunId = Field(default_factory=uuid.uuid4)
|
100
|
+
agent_name: AgentName
|
101
|
+
session_id: SessionId | None = None
|
102
|
+
status: RunStatus = RunStatus.CREATED
|
103
|
+
await_: Await | None = Field(None, alias="await")
|
104
|
+
outputs: list[Message] = []
|
105
|
+
artifacts: list[Artifact] = []
|
106
|
+
error: Error | None = None
|
107
|
+
|
108
|
+
model_config = ConfigDict(populate_by_name=True)
|
109
|
+
|
110
|
+
def model_dump_json(
|
111
|
+
self,
|
112
|
+
**kwargs: dict[str, Any],
|
113
|
+
) -> str:
|
114
|
+
return super().model_dump_json(
|
115
|
+
by_alias=True,
|
116
|
+
**kwargs,
|
117
|
+
)
|
118
|
+
|
119
|
+
|
120
|
+
class MessageEvent(BaseModel):
|
121
|
+
type: Literal["message"] = "message"
|
122
|
+
message: Message
|
123
|
+
|
124
|
+
|
125
|
+
class ArtifactEvent(BaseModel):
|
126
|
+
type: Literal["artifact"] = "artifact"
|
127
|
+
artifact: Artifact
|
128
|
+
|
129
|
+
|
130
|
+
class AwaitEvent(BaseModel):
|
131
|
+
type: Literal["await"] = "await"
|
132
|
+
await_: Await | None = Field(alias="await")
|
133
|
+
|
134
|
+
model_config = ConfigDict(populate_by_name=True)
|
135
|
+
|
136
|
+
def model_dump_json(
|
137
|
+
self,
|
138
|
+
**kwargs: dict[str, Any],
|
139
|
+
) -> str:
|
140
|
+
return super().model_dump_json(
|
141
|
+
by_alias=True,
|
142
|
+
**kwargs,
|
143
|
+
)
|
144
|
+
|
145
|
+
|
146
|
+
class GenericEvent(BaseModel):
|
147
|
+
type: Literal["generic"] = "generic"
|
148
|
+
generic: AnyModel
|
149
|
+
|
150
|
+
|
151
|
+
class CreatedEvent(BaseModel):
|
152
|
+
type: Literal["created"] = "created"
|
153
|
+
run: Run
|
154
|
+
|
155
|
+
|
156
|
+
class InProgressEvent(BaseModel):
|
157
|
+
type: Literal["in-progress"] = "in-progress"
|
158
|
+
run: Run
|
159
|
+
|
160
|
+
|
161
|
+
class FailedEvent(BaseModel):
|
162
|
+
type: Literal["failed"] = "failed"
|
163
|
+
run: Run
|
164
|
+
|
165
|
+
|
166
|
+
class CancelledEvent(BaseModel):
|
167
|
+
type: Literal["cancelled"] = "cancelled"
|
168
|
+
run: Run
|
169
|
+
|
170
|
+
|
171
|
+
class CompletedEvent(BaseModel):
|
172
|
+
type: Literal["completed"] = "completed"
|
173
|
+
run: Run
|
174
|
+
|
175
|
+
|
176
|
+
RunEvent = Union[
|
177
|
+
CreatedEvent,
|
178
|
+
InProgressEvent,
|
179
|
+
MessageEvent,
|
180
|
+
AwaitEvent,
|
181
|
+
GenericEvent,
|
182
|
+
CancelledEvent,
|
183
|
+
FailedEvent,
|
184
|
+
CompletedEvent,
|
185
|
+
ArtifactEvent,
|
186
|
+
]
|
187
|
+
|
188
|
+
|
189
|
+
class Agent(BaseModel):
|
190
|
+
name: str
|
191
|
+
description: str | None = None
|
192
|
+
metadata: Metadata = Metadata()
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from pydantic import BaseModel, ConfigDict, Field
|
2
|
+
|
3
|
+
from acp_sdk.models.models import Agent, AgentName, AwaitResume, Message, Run, RunMode, SessionId
|
4
|
+
|
5
|
+
|
6
|
+
class AgentsListResponse(BaseModel):
|
7
|
+
agents: list[Agent]
|
8
|
+
|
9
|
+
|
10
|
+
class AgentReadResponse(Agent):
|
11
|
+
pass
|
12
|
+
|
13
|
+
|
14
|
+
class RunCreateRequest(BaseModel):
|
15
|
+
agent_name: AgentName
|
16
|
+
session_id: SessionId | None = None
|
17
|
+
inputs: list[Message]
|
18
|
+
mode: RunMode = RunMode.SYNC
|
19
|
+
|
20
|
+
|
21
|
+
class RunCreateResponse(Run):
|
22
|
+
pass
|
23
|
+
|
24
|
+
|
25
|
+
class RunResumeRequest(BaseModel):
|
26
|
+
await_: AwaitResume = Field(alias="await")
|
27
|
+
mode: RunMode
|
28
|
+
|
29
|
+
model_config = ConfigDict(populate_by_name=True)
|
30
|
+
|
31
|
+
|
32
|
+
class RunResumeResponse(Run):
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class RunReadResponse(Run):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
class RunCancelResponse(Run):
|
41
|
+
pass
|
@@ -0,0 +1,7 @@
|
|
1
|
+
from acp_sdk.server.agent import Agent as Agent
|
2
|
+
from acp_sdk.server.agent import agent as agent
|
3
|
+
from acp_sdk.server.app import create_app as create_app
|
4
|
+
from acp_sdk.server.context import Context as Context
|
5
|
+
from acp_sdk.server.server import Server as Server
|
6
|
+
from acp_sdk.server.types import RunYield as RunYield
|
7
|
+
from acp_sdk.server.types import RunYieldResume as RunYieldResume
|
acp_sdk/server/agent.py
ADDED
@@ -0,0 +1,178 @@
|
|
1
|
+
import abc
|
2
|
+
import asyncio
|
3
|
+
import inspect
|
4
|
+
from collections.abc import AsyncGenerator, Coroutine, Generator
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
6
|
+
from typing import Callable
|
7
|
+
|
8
|
+
import janus
|
9
|
+
|
10
|
+
from acp_sdk.models import (
|
11
|
+
AgentName,
|
12
|
+
Message,
|
13
|
+
SessionId,
|
14
|
+
)
|
15
|
+
from acp_sdk.models.models import Metadata
|
16
|
+
from acp_sdk.server.context import Context
|
17
|
+
from acp_sdk.server.types import RunYield, RunYieldResume
|
18
|
+
|
19
|
+
|
20
|
+
class Agent(abc.ABC):
|
21
|
+
@property
|
22
|
+
def name(self) -> AgentName:
|
23
|
+
return self.__class__.__name__
|
24
|
+
|
25
|
+
@property
|
26
|
+
def description(self) -> str:
|
27
|
+
return ""
|
28
|
+
|
29
|
+
@property
|
30
|
+
def metadata(self) -> Metadata:
|
31
|
+
return Metadata()
|
32
|
+
|
33
|
+
@abc.abstractmethod
|
34
|
+
def run(
|
35
|
+
self, inputs: list[Message], context: Context
|
36
|
+
) -> (
|
37
|
+
AsyncGenerator[RunYield, RunYieldResume] | Generator[RunYield, RunYieldResume] | Coroutine[RunYield] | RunYield
|
38
|
+
):
|
39
|
+
pass
|
40
|
+
|
41
|
+
async def execute(
|
42
|
+
self, inputs: list[Message], session_id: SessionId | None, executor: ThreadPoolExecutor
|
43
|
+
) -> AsyncGenerator[RunYield, RunYieldResume]:
|
44
|
+
yield_queue: janus.Queue[RunYield] = janus.Queue()
|
45
|
+
yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue()
|
46
|
+
|
47
|
+
context = Context(
|
48
|
+
session_id=session_id, executor=executor, yield_queue=yield_queue, yield_resume_queue=yield_resume_queue
|
49
|
+
)
|
50
|
+
|
51
|
+
if inspect.isasyncgenfunction(self.run):
|
52
|
+
run = asyncio.create_task(self._run_async_gen(inputs, context))
|
53
|
+
elif inspect.iscoroutinefunction(self.run):
|
54
|
+
run = asyncio.create_task(self._run_coro(inputs, context))
|
55
|
+
elif inspect.isgeneratorfunction(self.run):
|
56
|
+
run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, inputs, context)
|
57
|
+
else:
|
58
|
+
run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, inputs, context)
|
59
|
+
|
60
|
+
try:
|
61
|
+
while True:
|
62
|
+
value = yield await yield_queue.async_q.get()
|
63
|
+
await yield_resume_queue.async_q.put(value)
|
64
|
+
except janus.AsyncQueueShutDown:
|
65
|
+
pass
|
66
|
+
finally:
|
67
|
+
await run # Raise exceptions
|
68
|
+
|
69
|
+
async def _run_async_gen(self, input: Message, context: Context) -> None:
|
70
|
+
try:
|
71
|
+
gen: AsyncGenerator[RunYield, RunYieldResume] = self.run(input, context)
|
72
|
+
value = None
|
73
|
+
while True:
|
74
|
+
value = await context.yield_async(await gen.asend(value))
|
75
|
+
except StopAsyncIteration:
|
76
|
+
pass
|
77
|
+
finally:
|
78
|
+
context.shutdown()
|
79
|
+
|
80
|
+
async def _run_coro(self, input: Message, context: Context) -> None:
|
81
|
+
try:
|
82
|
+
await context.yield_async(await self.run(input, context))
|
83
|
+
finally:
|
84
|
+
context.shutdown()
|
85
|
+
|
86
|
+
def _run_gen(self, input: Message, context: Context) -> None:
|
87
|
+
try:
|
88
|
+
gen: Generator[RunYield, RunYieldResume] = self.run(input, context)
|
89
|
+
value = None
|
90
|
+
while True:
|
91
|
+
value = context.yield_sync(gen.send(value))
|
92
|
+
except StopIteration:
|
93
|
+
pass
|
94
|
+
finally:
|
95
|
+
context.shutdown()
|
96
|
+
|
97
|
+
def _run_func(self, input: Message, context: Context) -> None:
|
98
|
+
try:
|
99
|
+
context.yield_sync(self.run(input, context))
|
100
|
+
finally:
|
101
|
+
context.shutdown()
|
102
|
+
|
103
|
+
|
104
|
+
def agent(
|
105
|
+
name: str | None = None,
|
106
|
+
description: str | None = None,
|
107
|
+
*,
|
108
|
+
metadata: Metadata | None = None,
|
109
|
+
) -> Callable[[Callable], Agent]:
|
110
|
+
"""Decorator to create an agent."""
|
111
|
+
|
112
|
+
def decorator(fn: Callable) -> Agent:
|
113
|
+
signature = inspect.signature(fn)
|
114
|
+
parameters = list(signature.parameters.values())
|
115
|
+
|
116
|
+
if len(parameters) == 0:
|
117
|
+
raise TypeError("The agent function must have at least 'input' argument")
|
118
|
+
if len(parameters) > 2:
|
119
|
+
raise TypeError("The agent function must have only 'input' and 'context' arguments")
|
120
|
+
if len(parameters) == 2 and parameters[1].name != "context":
|
121
|
+
raise TypeError("The second argument of the agent function must be 'context'")
|
122
|
+
|
123
|
+
has_context_param = len(parameters) == 2
|
124
|
+
|
125
|
+
class DecoratorAgentBase(Agent):
|
126
|
+
@property
|
127
|
+
def name(self) -> str:
|
128
|
+
return name or fn.__name__
|
129
|
+
|
130
|
+
@property
|
131
|
+
def description(self) -> str:
|
132
|
+
return description or fn.__doc__ or ""
|
133
|
+
|
134
|
+
@property
|
135
|
+
def metadata(self) -> Metadata:
|
136
|
+
return metadata or Metadata()
|
137
|
+
|
138
|
+
agent: Agent
|
139
|
+
if inspect.isasyncgenfunction(fn):
|
140
|
+
|
141
|
+
class AsyncGenDecoratorAgent(DecoratorAgentBase):
|
142
|
+
async def run(self, input: Message, context: Context) -> AsyncGenerator[RunYield, RunYieldResume]:
|
143
|
+
try:
|
144
|
+
gen: AsyncGenerator[RunYield, RunYieldResume] = (
|
145
|
+
fn(input, context) if has_context_param else fn(input)
|
146
|
+
)
|
147
|
+
value = None
|
148
|
+
while True:
|
149
|
+
value = yield await gen.asend(value)
|
150
|
+
except StopAsyncIteration:
|
151
|
+
pass
|
152
|
+
|
153
|
+
agent = AsyncGenDecoratorAgent()
|
154
|
+
elif inspect.iscoroutinefunction(fn):
|
155
|
+
|
156
|
+
class CoroDecoratorAgent(DecoratorAgentBase):
|
157
|
+
async def run(self, input: Message, context: Context) -> Coroutine[RunYield]:
|
158
|
+
return await (fn(input, context) if has_context_param else fn(input))
|
159
|
+
|
160
|
+
agent = CoroDecoratorAgent()
|
161
|
+
elif inspect.isgeneratorfunction(fn):
|
162
|
+
|
163
|
+
class GenDecoratorAgent(DecoratorAgentBase):
|
164
|
+
def run(self, input: Message, context: Context) -> Generator[RunYield, RunYieldResume]:
|
165
|
+
yield from (fn(input, context) if has_context_param else fn(input))
|
166
|
+
|
167
|
+
agent = GenDecoratorAgent()
|
168
|
+
else:
|
169
|
+
|
170
|
+
class FuncDecoratorAgent(DecoratorAgentBase):
|
171
|
+
def run(self, input: Message, context: Context) -> RunYield:
|
172
|
+
return fn(input, context) if has_context_param else fn(input)
|
173
|
+
|
174
|
+
agent = FuncDecoratorAgent()
|
175
|
+
|
176
|
+
return agent
|
177
|
+
|
178
|
+
return decorator
|