acp-sdk 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acp_sdk/client/client.py +133 -42
- acp_sdk/models/__init__.py +1 -0
- acp_sdk/models/models.py +75 -9
- acp_sdk/models/schemas.py +18 -3
- acp_sdk/models/types.py +9 -0
- acp_sdk/server/__init__.py +5 -1
- acp_sdk/server/agent.py +8 -84
- acp_sdk/server/app.py +87 -29
- acp_sdk/server/context.py +11 -3
- acp_sdk/server/executor.py +139 -14
- acp_sdk/server/server.py +31 -6
- acp_sdk/server/store/memory_store.py +4 -4
- acp_sdk/server/utils.py +1 -1
- acp_sdk/shared/__init__.py +2 -0
- acp_sdk/shared/resources.py +46 -0
- {acp_sdk-0.10.0.dist-info → acp_sdk-0.11.0.dist-info}/METADATA +2 -1
- acp_sdk-0.11.0.dist-info/RECORD +35 -0
- acp_sdk/server/session.py +0 -24
- acp_sdk-0.10.0.dist-info/RECORD +0 -33
- {acp_sdk-0.10.0.dist-info → acp_sdk-0.11.0.dist-info}/WHEEL +0 -0
acp_sdk/client/client.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
1
3
|
import ssl
|
2
4
|
import typing
|
3
|
-
import
|
4
|
-
from collections.abc import AsyncGenerator, AsyncIterator
|
5
|
-
from contextlib import asynccontextmanager
|
5
|
+
from collections.abc import AsyncIterator
|
6
6
|
from types import TracebackType
|
7
7
|
from typing import Self
|
8
8
|
|
@@ -15,12 +15,13 @@ from acp_sdk.client.utils import input_to_messages
|
|
15
15
|
from acp_sdk.instrumentation import get_tracer
|
16
16
|
from acp_sdk.models import (
|
17
17
|
ACPError,
|
18
|
-
|
18
|
+
AgentManifest,
|
19
19
|
AgentName,
|
20
20
|
AgentReadResponse,
|
21
21
|
AgentsListResponse,
|
22
22
|
AwaitResume,
|
23
23
|
Error,
|
24
|
+
ErrorCode,
|
24
25
|
ErrorEvent,
|
25
26
|
Event,
|
26
27
|
PingResponse,
|
@@ -33,16 +34,20 @@ from acp_sdk.models import (
|
|
33
34
|
RunMode,
|
34
35
|
RunResumeRequest,
|
35
36
|
RunResumeResponse,
|
36
|
-
|
37
|
+
Session,
|
38
|
+
SessionReadResponse,
|
37
39
|
)
|
38
40
|
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
39
43
|
|
40
44
|
class Client:
|
41
45
|
def __init__(
|
42
46
|
self,
|
43
47
|
*,
|
44
|
-
|
48
|
+
session: Session | None = None,
|
45
49
|
client: httpx.AsyncClient | None = None,
|
50
|
+
manage_client: bool = True,
|
46
51
|
auth: httpx._types.AuthTypes | None = None,
|
47
52
|
params: httpx._types.QueryParamTypes | None = None,
|
48
53
|
headers: httpx._types.HeaderTypes | None = None,
|
@@ -62,7 +67,10 @@ class Client:
|
|
62
67
|
transport: httpx.AsyncBaseTransport | None = None,
|
63
68
|
trust_env: bool = True,
|
64
69
|
) -> None:
|
65
|
-
self.
|
70
|
+
self._session = session
|
71
|
+
self._session_last_refresh_base_url: httpx.URL | None = None
|
72
|
+
self._session_refresh_lock = asyncio.Lock()
|
73
|
+
|
66
74
|
self._client = client or httpx.AsyncClient(
|
67
75
|
auth=auth,
|
68
76
|
params=params,
|
@@ -83,13 +91,24 @@ class Client:
|
|
83
91
|
transport=transport,
|
84
92
|
trust_env=trust_env,
|
85
93
|
)
|
94
|
+
self._manage_client = manage_client
|
86
95
|
|
87
96
|
@property
|
88
97
|
def client(self) -> httpx.AsyncClient:
|
89
98
|
return self._client
|
90
99
|
|
91
100
|
async def __aenter__(self) -> Self:
|
92
|
-
|
101
|
+
if self._manage_client:
|
102
|
+
await self._client.__aenter__()
|
103
|
+
self._session_span_manager = (
|
104
|
+
(
|
105
|
+
get_tracer()
|
106
|
+
.start_as_current_span("session", attributes={"acp.session": str(self._session.id)})
|
107
|
+
.__enter__()
|
108
|
+
)
|
109
|
+
if self._session
|
110
|
+
else None
|
111
|
+
)
|
93
112
|
return self
|
94
113
|
|
95
114
|
async def __aexit__(
|
@@ -98,121 +117,152 @@ class Client:
|
|
98
117
|
exc_value: BaseException | None = None,
|
99
118
|
traceback: TracebackType | None = None,
|
100
119
|
) -> None:
|
101
|
-
|
120
|
+
if self._session_span_manager:
|
121
|
+
self._session_span_manager.__exit__(exc_type, exc_value, traceback)
|
122
|
+
if self._manage_client:
|
123
|
+
await self._client.__aexit__(exc_type, exc_value, traceback)
|
102
124
|
|
103
|
-
|
104
|
-
|
105
|
-
session_id = session_id or uuid.uuid4()
|
106
|
-
with get_tracer().start_as_current_span("session", attributes={"acp.session": str(session_id)}):
|
107
|
-
yield Client(client=self._client, session_id=session_id)
|
125
|
+
def session(self, session: Session | None = None) -> Self:
|
126
|
+
return Client(client=self._client, manage_client=False, session=session or Session())
|
108
127
|
|
109
|
-
async def agents(self) -> AsyncIterator[
|
110
|
-
response = await self._client.get("/agents")
|
128
|
+
async def agents(self, *, base_url: httpx.URL | str | None = None) -> AsyncIterator[AgentManifest]:
|
129
|
+
response = await self._client.get(self._create_url("/agents", base_url=base_url))
|
111
130
|
self._raise_error(response)
|
112
131
|
for agent in AgentsListResponse.model_validate(response.json()).agents:
|
113
132
|
yield agent
|
114
133
|
|
115
|
-
async def agent(self, *, name: AgentName) ->
|
116
|
-
response = await self._client.get(f"/agents/{name}")
|
134
|
+
async def agent(self, *, name: AgentName, base_url: httpx.URL | str | None = None) -> AgentManifest:
|
135
|
+
response = await self._client.get(self._create_url(f"/agents/{name}", base_url=base_url))
|
117
136
|
self._raise_error(response)
|
118
137
|
response = AgentReadResponse.model_validate(response.json())
|
119
|
-
return
|
138
|
+
return AgentManifest(**response.model_dump())
|
120
139
|
|
121
|
-
async def ping(self) -> bool:
|
122
|
-
response = await self._client.get("/ping")
|
140
|
+
async def ping(self, *, base_url: httpx.URL | str | None = None) -> bool:
|
141
|
+
response = await self._client.get(self._create_url("/ping", base_url=base_url))
|
123
142
|
self._raise_error(response)
|
124
143
|
PingResponse.model_validate(response.json())
|
125
144
|
return
|
126
145
|
|
127
|
-
async def run_sync(self, input: Input, *, agent: AgentName) -> Run:
|
146
|
+
async def run_sync(self, input: Input, *, agent: AgentName, base_url: httpx.URL | str | None = None) -> Run:
|
128
147
|
response = await self._client.post(
|
129
|
-
"/runs",
|
148
|
+
self._create_url("/runs", base_url=base_url),
|
130
149
|
content=RunCreateRequest(
|
131
150
|
agent_name=agent,
|
132
151
|
input=input_to_messages(input),
|
133
152
|
mode=RunMode.SYNC,
|
134
|
-
|
153
|
+
**(await self._prepare_session_for_run(base_url=base_url)),
|
135
154
|
).model_dump_json(),
|
136
155
|
)
|
137
156
|
self._raise_error(response)
|
138
157
|
response = RunCreateResponse.model_validate(response.json())
|
139
158
|
return Run(**response.model_dump())
|
140
159
|
|
141
|
-
async def run_async(self, input: Input, *, agent: AgentName) -> Run:
|
160
|
+
async def run_async(self, input: Input, *, agent: AgentName, base_url: httpx.URL | str | None = None) -> Run:
|
142
161
|
response = await self._client.post(
|
143
|
-
"/runs",
|
162
|
+
self._create_url("/runs", base_url=base_url),
|
144
163
|
content=RunCreateRequest(
|
145
164
|
agent_name=agent,
|
146
165
|
input=input_to_messages(input),
|
147
166
|
mode=RunMode.ASYNC,
|
148
|
-
|
167
|
+
**(await self._prepare_session_for_run(base_url=base_url)),
|
149
168
|
).model_dump_json(),
|
150
169
|
)
|
151
170
|
self._raise_error(response)
|
152
171
|
response = RunCreateResponse.model_validate(response.json())
|
153
172
|
return Run(**response.model_dump())
|
154
173
|
|
155
|
-
async def run_stream(
|
174
|
+
async def run_stream(
|
175
|
+
self, input: Input, *, agent: AgentName, base_url: httpx.URL | str | None = None
|
176
|
+
) -> AsyncIterator[Event]:
|
156
177
|
async with aconnect_sse(
|
157
178
|
self._client,
|
158
179
|
"POST",
|
159
|
-
"/runs",
|
180
|
+
self._create_url("/runs", base_url=base_url),
|
160
181
|
content=RunCreateRequest(
|
161
182
|
agent_name=agent,
|
162
183
|
input=input_to_messages(input),
|
163
184
|
mode=RunMode.STREAM,
|
164
|
-
|
185
|
+
session=await self._prepare_session_for_run(base_url=base_url),
|
165
186
|
).model_dump_json(),
|
166
187
|
) as event_source:
|
167
188
|
async for event in self._validate_stream(event_source):
|
168
189
|
yield event
|
169
190
|
|
170
|
-
async def run_status(self, *, run_id: RunId) -> Run:
|
171
|
-
response = await self._client.get(f"/runs/{run_id}")
|
191
|
+
async def run_status(self, *, run_id: RunId, base_url: httpx.URL | str | None = None) -> Run:
|
192
|
+
response = await self._client.get(self._create_url(f"/runs/{run_id}", base_url=base_url))
|
172
193
|
self._raise_error(response)
|
173
194
|
return Run.model_validate(response.json())
|
174
195
|
|
175
|
-
async def run_events(self, *, run_id: RunId) -> AsyncIterator[Event]:
|
176
|
-
response = await self._client.get(f"/runs/{run_id}/events")
|
196
|
+
async def run_events(self, *, run_id: RunId, base_url: httpx.URL | str | None = None) -> AsyncIterator[Event]:
|
197
|
+
response = await self._client.get(self._create_url(f"/runs/{run_id}/events", base_url=base_url))
|
177
198
|
self._raise_error(response)
|
178
199
|
response = RunEventsListResponse.model_validate(response.json())
|
179
200
|
for event in response.events:
|
180
201
|
yield event
|
181
202
|
|
182
|
-
async def run_cancel(self, *, run_id: RunId) -> Run:
|
183
|
-
response = await self._client.post(f"/runs/{run_id}/cancel")
|
203
|
+
async def run_cancel(self, *, run_id: RunId, base_url: httpx.URL | str | None = None) -> Run:
|
204
|
+
response = await self._client.post(self._create_url(f"/runs/{run_id}/cancel", base_url=base_url))
|
184
205
|
self._raise_error(response)
|
185
206
|
response = RunCancelResponse.model_validate(response.json())
|
186
207
|
return Run(**response.model_dump())
|
187
208
|
|
188
|
-
async def run_resume_sync(
|
209
|
+
async def run_resume_sync(
|
210
|
+
self, await_resume: AwaitResume, *, run_id: RunId, base_url: httpx.URL | str | None = None
|
211
|
+
) -> Run:
|
189
212
|
response = await self._client.post(
|
190
|
-
f"/runs/{run_id}",
|
213
|
+
self._create_url(f"/runs/{run_id}", base_url=base_url),
|
191
214
|
content=RunResumeRequest(await_resume=await_resume, mode=RunMode.SYNC).model_dump_json(),
|
192
215
|
)
|
193
216
|
self._raise_error(response)
|
194
217
|
response = RunResumeResponse.model_validate(response.json())
|
195
218
|
return Run(**response.model_dump())
|
196
219
|
|
197
|
-
async def run_resume_async(
|
220
|
+
async def run_resume_async(
|
221
|
+
self, await_resume: AwaitResume, *, run_id: RunId, base_url: httpx.URL | str | None = None
|
222
|
+
) -> Run:
|
198
223
|
response = await self._client.post(
|
199
|
-
f"/runs/{run_id}",
|
224
|
+
self._create_url(f"/runs/{run_id}", base_url=base_url),
|
200
225
|
content=RunResumeRequest(await_resume=await_resume, mode=RunMode.ASYNC).model_dump_json(),
|
201
226
|
)
|
202
227
|
self._raise_error(response)
|
203
228
|
response = RunResumeResponse.model_validate(response.json())
|
204
229
|
return Run(**response.model_dump())
|
205
230
|
|
206
|
-
async def run_resume_stream(
|
231
|
+
async def run_resume_stream(
|
232
|
+
self, await_resume: AwaitResume, *, run_id: RunId, base_url: httpx.URL | str | None = None
|
233
|
+
) -> AsyncIterator[Event]:
|
207
234
|
async with aconnect_sse(
|
208
235
|
self._client,
|
209
236
|
"POST",
|
210
|
-
f"/runs/{run_id}",
|
237
|
+
self._create_url(f"/runs/{run_id}", base_url=base_url),
|
211
238
|
content=RunResumeRequest(await_resume=await_resume, mode=RunMode.STREAM).model_dump_json(),
|
212
239
|
) as event_source:
|
213
240
|
async for event in self._validate_stream(event_source):
|
214
241
|
yield event
|
215
242
|
|
243
|
+
async def refresh_session(
|
244
|
+
self, *, base_url: httpx.URL | str | None = None, timeout: httpx._types.TimeoutTypes = 5000
|
245
|
+
) -> Session:
|
246
|
+
if not self._session:
|
247
|
+
raise RuntimeError("Client is not in a session")
|
248
|
+
|
249
|
+
async with self._session_refresh_lock:
|
250
|
+
url = self._create_url(
|
251
|
+
f"/sessions/{self._session.id}",
|
252
|
+
base_url=base_url or self._session_last_refresh_base_url,
|
253
|
+
)
|
254
|
+
|
255
|
+
try:
|
256
|
+
response = await self._client.get(url, timeout=timeout)
|
257
|
+
response = SessionReadResponse.model_validate(response.json())
|
258
|
+
self._session = Session(**response.model_dump())
|
259
|
+
except ACPError as e:
|
260
|
+
if e.error.code == ErrorCode.NOT_FOUND:
|
261
|
+
pass
|
262
|
+
raise e
|
263
|
+
|
264
|
+
return self._session
|
265
|
+
|
216
266
|
async def _validate_stream(
|
217
267
|
self,
|
218
268
|
event_source: EventSource,
|
@@ -231,3 +281,44 @@ class Client:
|
|
231
281
|
response.raise_for_status()
|
232
282
|
except httpx.HTTPError:
|
233
283
|
raise ACPError(Error.model_validate(response.json()))
|
284
|
+
|
285
|
+
def _create_base_url(self, base_url: httpx.URL | str | None) -> httpx.URL:
|
286
|
+
base_url = httpx.URL(base_url or self._client.base_url)
|
287
|
+
if not base_url.raw_path.endswith(b"/"):
|
288
|
+
base_url = base_url.copy_with(raw_path=base_url.raw_path + b"/")
|
289
|
+
return base_url
|
290
|
+
|
291
|
+
def _create_url(self, endpoint: str, base_url: httpx.URL | str | None) -> httpx.URL:
|
292
|
+
merge_url = httpx.URL(endpoint)
|
293
|
+
|
294
|
+
if not merge_url.is_relative_url:
|
295
|
+
raise ValueError("Endpoint must be a relative URL")
|
296
|
+
|
297
|
+
base_url = self._create_base_url(base_url)
|
298
|
+
merge_raw_path = base_url.raw_path + merge_url.raw_path.lstrip(b"/")
|
299
|
+
return base_url.copy_with(raw_path=merge_raw_path)
|
300
|
+
|
301
|
+
async def _prepare_session_for_run(self, *, base_url: httpx.URL | str | None) -> dict:
|
302
|
+
if not self._session:
|
303
|
+
return {}
|
304
|
+
|
305
|
+
target_base_url = self._create_base_url(base_url=base_url)
|
306
|
+
try:
|
307
|
+
if not self._session_last_refresh_base_url:
|
308
|
+
return {"session": self._session}
|
309
|
+
if self._session_last_refresh_base_url == target_base_url:
|
310
|
+
# Same server, no need to forward session
|
311
|
+
return {"session_id": self._session.id}
|
312
|
+
|
313
|
+
session = await self.refresh_session()
|
314
|
+
return {"session": session}
|
315
|
+
except ACPError as e:
|
316
|
+
if e.error.code == ErrorCode.NOT_FOUND:
|
317
|
+
return {"session": self._session}
|
318
|
+
raise e
|
319
|
+
finally:
|
320
|
+
await self._update_session_refresh_url(target_base_url)
|
321
|
+
|
322
|
+
async def _update_session_refresh_url(self, url: httpx.URL) -> None:
|
323
|
+
async with self._session_refresh_lock:
|
324
|
+
self._session_last_refresh_base_url = url
|
acp_sdk/models/__init__.py
CHANGED
acp_sdk/models/models.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import uuid
|
3
|
+
from collections.abc import AsyncIterator
|
3
4
|
from datetime import datetime, timezone
|
4
5
|
from enum import Enum
|
5
6
|
from typing import Any, Literal, Optional, Union
|
@@ -7,6 +8,8 @@ from typing import Any, Literal, Optional, Union
|
|
7
8
|
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
8
9
|
|
9
10
|
from acp_sdk.models.errors import ACPError, Error
|
11
|
+
from acp_sdk.models.types import AgentName, ResourceId, ResourceUrl, RunId, SessionId
|
12
|
+
from acp_sdk.shared import ResourceLoader, ResourceStore
|
10
13
|
|
11
14
|
|
12
15
|
class AnyModel(BaseModel):
|
@@ -74,6 +77,40 @@ class Metadata(BaseModel):
|
|
74
77
|
model_config = ConfigDict(extra="allow")
|
75
78
|
|
76
79
|
|
80
|
+
class CitationMetadata(BaseModel):
|
81
|
+
"""
|
82
|
+
Represents an inline citation, providing info about information source. This
|
83
|
+
is supposed to be rendered as an inline icon, optionally marking a text
|
84
|
+
range it belongs to.
|
85
|
+
|
86
|
+
If CitationMetadata is included together with content in the message part,
|
87
|
+
the citation belongs to that content and renders at the MessagePart position.
|
88
|
+
This way may be used for non-text content, like images and files.
|
89
|
+
|
90
|
+
Alternatively, `start_index` and `end_index` may define a text range,
|
91
|
+
counting characters in the current Message across all MessageParts with
|
92
|
+
content type `text/*`, where the citation will be rendered. If one of
|
93
|
+
`start_index` and `end_index` is missing or their values are equal, the
|
94
|
+
citation renders only as an inline icon at that position.
|
95
|
+
|
96
|
+
If both `start_index` and `end_index` are not present and MessagePart has no
|
97
|
+
content, the citation renders as inline icon only at the MessagePart position.
|
98
|
+
|
99
|
+
Properties:
|
100
|
+
- url: URL of the source document.
|
101
|
+
- title: Title of the source document.
|
102
|
+
- description: Accompanying text, which may be a general description of the
|
103
|
+
source document, or a specific snippet.
|
104
|
+
"""
|
105
|
+
|
106
|
+
kind: Literal["citation"] = "citation"
|
107
|
+
start_index: Optional[int]
|
108
|
+
end_index: Optional[int]
|
109
|
+
url: Optional[str]
|
110
|
+
title: Optional[str]
|
111
|
+
description: Optional[str]
|
112
|
+
|
113
|
+
|
77
114
|
class MessagePart(BaseModel):
|
78
115
|
name: Optional[str] = None
|
79
116
|
content_type: Optional[str] = "text/plain"
|
@@ -83,9 +120,9 @@ class MessagePart(BaseModel):
|
|
83
120
|
|
84
121
|
model_config = ConfigDict(extra="allow")
|
85
122
|
|
123
|
+
metadata: Optional[CitationMetadata] = Field(discriminator="kind", default=None)
|
124
|
+
|
86
125
|
def model_post_init(self, __context: Any) -> None:
|
87
|
-
if self.content is None and self.content_url is None:
|
88
|
-
raise ValueError("Either content or content_url must be provided")
|
89
126
|
if self.content is not None and self.content_url is not None:
|
90
127
|
raise ValueError("Only one of content or content_url can be provided")
|
91
128
|
|
@@ -95,6 +132,7 @@ class Artifact(MessagePart):
|
|
95
132
|
|
96
133
|
|
97
134
|
class Message(BaseModel):
|
135
|
+
role: Literal["user"] | Literal["agent"] | str = Field("user", pattern=r"^(user|agent(\/[a-zA-Z0-9_\-]+)?)$")
|
98
136
|
parts: list[MessagePart]
|
99
137
|
created_at: datetime | None = Field(default_factory=lambda: datetime.now(timezone.utc))
|
100
138
|
completed_at: datetime | None = Field(default_factory=lambda: datetime.now(timezone.utc))
|
@@ -102,7 +140,10 @@ class Message(BaseModel):
|
|
102
140
|
def __add__(self, other: "Message") -> "Message":
|
103
141
|
if not isinstance(other, Message):
|
104
142
|
raise TypeError(f"Cannot concatenate Message with {type(other).__name__}")
|
143
|
+
if self.role != other.role:
|
144
|
+
raise ValueError("Cannot concatenate messages with different roles")
|
105
145
|
return Message(
|
146
|
+
role=self.role,
|
106
147
|
parts=self.parts + other.parts,
|
107
148
|
created_at=min(self.created_at, other.created_at) if self.created_at and other.created_at else None,
|
108
149
|
completed_at=max(self.completed_at, other.completed_at)
|
@@ -146,11 +187,6 @@ class Message(BaseModel):
|
|
146
187
|
return Message(parts=parts, created_at=self.created_at, completed_at=self.completed_at)
|
147
188
|
|
148
189
|
|
149
|
-
AgentName = str
|
150
|
-
SessionId = uuid.UUID
|
151
|
-
RunId = uuid.UUID
|
152
|
-
|
153
|
-
|
154
190
|
class RunMode(str, Enum):
|
155
191
|
SYNC = "sync"
|
156
192
|
ASYNC = "async"
|
@@ -280,11 +316,41 @@ Event = Union[
|
|
280
316
|
RunCancelledEvent,
|
281
317
|
RunFailedEvent,
|
282
318
|
RunCompletedEvent,
|
283
|
-
MessagePartEvent,
|
284
319
|
]
|
285
320
|
|
286
321
|
|
287
|
-
class
|
322
|
+
class AgentManifest(BaseModel):
|
288
323
|
name: str
|
289
324
|
description: str | None = None
|
290
325
|
metadata: Metadata = Metadata()
|
326
|
+
|
327
|
+
|
328
|
+
class Session(BaseModel):
|
329
|
+
id: SessionId = Field(default_factory=uuid.uuid4)
|
330
|
+
history: list[ResourceUrl] = Field(default_factory=list)
|
331
|
+
state: ResourceUrl | None = None
|
332
|
+
|
333
|
+
loader: ResourceLoader | None = Field(None, exclude=True)
|
334
|
+
store: ResourceStore | None = Field(None, exclude=True)
|
335
|
+
|
336
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
337
|
+
|
338
|
+
async def load_history(self, *, loader: ResourceLoader | None = None) -> AsyncIterator[Message]:
|
339
|
+
loader = loader or self.loader or ResourceLoader()
|
340
|
+
for url in self.history:
|
341
|
+
data = await loader.load(url)
|
342
|
+
yield Message.model_validate_json(data)
|
343
|
+
|
344
|
+
async def load_state(self, *, loader: ResourceLoader | None = None) -> bytes:
|
345
|
+
loader = loader or self.loader or ResourceLoader()
|
346
|
+
data = await loader.load(self.state)
|
347
|
+
return data
|
348
|
+
|
349
|
+
async def store_state(self, data: bytes, *, store: ResourceStore | None = None) -> ResourceUrl:
|
350
|
+
store = store or self.store
|
351
|
+
if not store:
|
352
|
+
raise ValueError("Store must be specified")
|
353
|
+
|
354
|
+
id = ResourceId()
|
355
|
+
await store.store(id, data)
|
356
|
+
return await store.url(id)
|
acp_sdk/models/schemas.py
CHANGED
@@ -1,6 +1,16 @@
|
|
1
1
|
from pydantic import BaseModel
|
2
2
|
|
3
|
-
from acp_sdk.models.models import
|
3
|
+
from acp_sdk.models.models import (
|
4
|
+
AgentManifest,
|
5
|
+
AgentName,
|
6
|
+
AwaitResume,
|
7
|
+
Event,
|
8
|
+
Message,
|
9
|
+
Run,
|
10
|
+
RunMode,
|
11
|
+
Session,
|
12
|
+
SessionId,
|
13
|
+
)
|
4
14
|
|
5
15
|
|
6
16
|
class PingResponse(BaseModel):
|
@@ -8,16 +18,17 @@ class PingResponse(BaseModel):
|
|
8
18
|
|
9
19
|
|
10
20
|
class AgentsListResponse(BaseModel):
|
11
|
-
agents: list[
|
21
|
+
agents: list[AgentManifest]
|
12
22
|
|
13
23
|
|
14
|
-
class AgentReadResponse(
|
24
|
+
class AgentReadResponse(AgentManifest):
|
15
25
|
pass
|
16
26
|
|
17
27
|
|
18
28
|
class RunCreateRequest(BaseModel):
|
19
29
|
agent_name: AgentName
|
20
30
|
session_id: SessionId | None = None
|
31
|
+
session: Session | None = None
|
21
32
|
input: list[Message]
|
22
33
|
mode: RunMode = RunMode.SYNC
|
23
34
|
|
@@ -45,3 +56,7 @@ class RunCancelResponse(Run):
|
|
45
56
|
|
46
57
|
class RunEventsListResponse(BaseModel):
|
47
58
|
events: list[Event]
|
59
|
+
|
60
|
+
|
61
|
+
class SessionReadResponse(Session):
|
62
|
+
pass
|
acp_sdk/models/types.py
ADDED
acp_sdk/server/__init__.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1
|
-
from acp_sdk.server.agent import
|
1
|
+
from acp_sdk.server.agent import AgentManifest as AgentManifest
|
2
2
|
from acp_sdk.server.agent import agent as agent
|
3
3
|
from acp_sdk.server.app import create_app as create_app
|
4
4
|
from acp_sdk.server.context import Context as Context
|
5
5
|
from acp_sdk.server.server import Server as Server
|
6
|
+
from acp_sdk.server.store import MemoryStore as MemoryStore
|
7
|
+
from acp_sdk.server.store import PostgreSQLStore as PostgreSQLStore
|
8
|
+
from acp_sdk.server.store import RedisStore as RedisStore
|
9
|
+
from acp_sdk.server.store import Store as Store
|
6
10
|
from acp_sdk.server.types import RunYield as RunYield
|
7
11
|
from acp_sdk.server.types import RunYieldResume as RunYieldResume
|
acp_sdk/server/agent.py
CHANGED
@@ -1,23 +1,14 @@
|
|
1
1
|
import abc
|
2
|
-
import asyncio
|
3
2
|
import inspect
|
4
3
|
from collections.abc import AsyncGenerator, Coroutine, Generator
|
5
|
-
from concurrent.futures import ThreadPoolExecutor
|
6
4
|
from typing import Callable
|
7
5
|
|
8
|
-
import
|
9
|
-
|
10
|
-
from acp_sdk.models import (
|
11
|
-
AgentName,
|
12
|
-
Message,
|
13
|
-
SessionId,
|
14
|
-
)
|
15
|
-
from acp_sdk.models.models import Metadata
|
6
|
+
from acp_sdk.models import AgentName, Message, Metadata
|
16
7
|
from acp_sdk.server.context import Context
|
17
8
|
from acp_sdk.server.types import RunYield, RunYieldResume
|
18
9
|
|
19
10
|
|
20
|
-
class
|
11
|
+
class AgentManifest(abc.ABC):
|
21
12
|
@property
|
22
13
|
def name(self) -> AgentName:
|
23
14
|
return self.__class__.__name__
|
@@ -38,75 +29,8 @@ class Agent(abc.ABC):
|
|
38
29
|
):
|
39
30
|
pass
|
40
31
|
|
41
|
-
|
42
|
-
|
43
|
-
) -> AsyncGenerator[RunYield, RunYieldResume]:
|
44
|
-
yield_queue: janus.Queue[RunYield] = janus.Queue()
|
45
|
-
yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue()
|
46
|
-
|
47
|
-
context = Context(
|
48
|
-
session_id=session_id, executor=executor, yield_queue=yield_queue, yield_resume_queue=yield_resume_queue
|
49
|
-
)
|
50
|
-
|
51
|
-
if inspect.isasyncgenfunction(self.run):
|
52
|
-
run = asyncio.create_task(self._run_async_gen(input, context))
|
53
|
-
elif inspect.iscoroutinefunction(self.run):
|
54
|
-
run = asyncio.create_task(self._run_coro(input, context))
|
55
|
-
elif inspect.isgeneratorfunction(self.run):
|
56
|
-
run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, input, context)
|
57
|
-
else:
|
58
|
-
run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context)
|
59
|
-
|
60
|
-
try:
|
61
|
-
while not run.done() or yield_queue.async_q.qsize() > 0:
|
62
|
-
value = yield await yield_queue.async_q.get()
|
63
|
-
if isinstance(value, Exception):
|
64
|
-
raise value
|
65
|
-
await yield_resume_queue.async_q.put(value)
|
66
|
-
except janus.AsyncQueueShutDown:
|
67
|
-
pass
|
68
|
-
|
69
|
-
async def _run_async_gen(self, input: list[Message], context: Context) -> None:
|
70
|
-
try:
|
71
|
-
gen: AsyncGenerator[RunYield, RunYieldResume] = self.run(input, context)
|
72
|
-
value = None
|
73
|
-
while True:
|
74
|
-
value = await context.yield_async(await gen.asend(value))
|
75
|
-
except StopAsyncIteration:
|
76
|
-
pass
|
77
|
-
except Exception as e:
|
78
|
-
await context.yield_async(e)
|
79
|
-
finally:
|
80
|
-
context.shutdown()
|
81
|
-
|
82
|
-
async def _run_coro(self, input: list[Message], context: Context) -> None:
|
83
|
-
try:
|
84
|
-
await context.yield_async(await self.run(input, context))
|
85
|
-
except Exception as e:
|
86
|
-
await context.yield_async(e)
|
87
|
-
finally:
|
88
|
-
context.shutdown()
|
89
|
-
|
90
|
-
def _run_gen(self, input: list[Message], context: Context) -> None:
|
91
|
-
try:
|
92
|
-
gen: Generator[RunYield, RunYieldResume] = self.run(input, context)
|
93
|
-
value = None
|
94
|
-
while True:
|
95
|
-
value = context.yield_sync(gen.send(value))
|
96
|
-
except StopIteration:
|
97
|
-
pass
|
98
|
-
except Exception as e:
|
99
|
-
context.yield_sync(e)
|
100
|
-
finally:
|
101
|
-
context.shutdown()
|
102
|
-
|
103
|
-
def _run_func(self, input: list[Message], context: Context) -> None:
|
104
|
-
try:
|
105
|
-
context.yield_sync(self.run(input, context))
|
106
|
-
except Exception as e:
|
107
|
-
context.yield_sync(e)
|
108
|
-
finally:
|
109
|
-
context.shutdown()
|
32
|
+
|
33
|
+
Agent = AgentManifest
|
110
34
|
|
111
35
|
|
112
36
|
def agent(
|
@@ -114,10 +38,10 @@ def agent(
|
|
114
38
|
description: str | None = None,
|
115
39
|
*,
|
116
40
|
metadata: Metadata | None = None,
|
117
|
-
) -> Callable[[Callable],
|
41
|
+
) -> Callable[[Callable], AgentManifest]:
|
118
42
|
"""Decorator to create an agent."""
|
119
43
|
|
120
|
-
def decorator(fn: Callable) ->
|
44
|
+
def decorator(fn: Callable) -> AgentManifest:
|
121
45
|
signature = inspect.signature(fn)
|
122
46
|
parameters = list(signature.parameters.values())
|
123
47
|
|
@@ -130,7 +54,7 @@ def agent(
|
|
130
54
|
|
131
55
|
has_context_param = len(parameters) == 2
|
132
56
|
|
133
|
-
class DecoratorAgentBase(
|
57
|
+
class DecoratorAgentBase(AgentManifest):
|
134
58
|
@property
|
135
59
|
def name(self) -> str:
|
136
60
|
return name or fn.__name__
|
@@ -143,7 +67,7 @@ def agent(
|
|
143
67
|
def metadata(self) -> Metadata:
|
144
68
|
return metadata or Metadata()
|
145
69
|
|
146
|
-
agent:
|
70
|
+
agent: AgentManifest
|
147
71
|
if inspect.isasyncgenfunction(fn):
|
148
72
|
|
149
73
|
class AsyncGenDecoratorAgent(DecoratorAgentBase):
|