acp-sdk 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acp_sdk/client/client.py +1 -5
- acp_sdk/server/app.py +76 -44
- acp_sdk/server/executor.py +198 -0
- acp_sdk/server/server.py +9 -10
- acp_sdk/server/session.py +18 -15
- acp_sdk/server/store/__init__.py +4 -0
- acp_sdk/server/store/memory_store.py +35 -0
- acp_sdk/server/store/postgresql_store.py +69 -0
- acp_sdk/server/store/redis_store.py +40 -0
- acp_sdk/server/store/store.py +55 -0
- acp_sdk/server/store/utils.py +5 -0
- acp_sdk/server/telemetry.py +6 -2
- acp_sdk/server/utils.py +28 -4
- {acp_sdk-0.9.0.dist-info → acp_sdk-0.10.0.dist-info}/METADATA +13 -11
- {acp_sdk-0.9.0.dist-info → acp_sdk-0.10.0.dist-info}/RECORD +16 -10
- acp_sdk/server/bundle.py +0 -182
- {acp_sdk-0.9.0.dist-info → acp_sdk-0.10.0.dist-info}/WHEEL +0 -0
acp_sdk/client/client.py
CHANGED
@@ -8,7 +8,6 @@ from typing import Self
|
|
8
8
|
|
9
9
|
import httpx
|
10
10
|
from httpx_sse import EventSource, aconnect_sse
|
11
|
-
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
12
11
|
from pydantic import TypeAdapter
|
13
12
|
|
14
13
|
from acp_sdk.client.types import Input
|
@@ -44,7 +43,6 @@ class Client:
|
|
44
43
|
*,
|
45
44
|
session_id: SessionId | None = None,
|
46
45
|
client: httpx.AsyncClient | None = None,
|
47
|
-
instrument: bool = True,
|
48
46
|
auth: httpx._types.AuthTypes | None = None,
|
49
47
|
params: httpx._types.QueryParamTypes | None = None,
|
50
48
|
headers: httpx._types.HeaderTypes | None = None,
|
@@ -85,8 +83,6 @@ class Client:
|
|
85
83
|
transport=transport,
|
86
84
|
trust_env=trust_env,
|
87
85
|
)
|
88
|
-
if instrument:
|
89
|
-
HTTPXClientInstrumentor.instrument_client(self._client)
|
90
86
|
|
91
87
|
@property
|
92
88
|
def client(self) -> httpx.AsyncClient:
|
@@ -108,7 +104,7 @@ class Client:
|
|
108
104
|
async def session(self, session_id: SessionId | None = None) -> AsyncGenerator[Self]:
|
109
105
|
session_id = session_id or uuid.uuid4()
|
110
106
|
with get_tracer().start_as_current_span("session", attributes={"acp.session": str(session_id)}):
|
111
|
-
yield Client(client=self._client, session_id=session_id
|
107
|
+
yield Client(client=self._client, session_id=session_id)
|
112
108
|
|
113
109
|
async def agents(self) -> AsyncIterator[Agent]:
|
114
110
|
response = await self._client.get("/agents")
|
acp_sdk/server/app.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
|
+
import asyncio
|
1
2
|
from collections.abc import AsyncGenerator
|
2
3
|
from concurrent.futures import ThreadPoolExecutor
|
3
4
|
from contextlib import asynccontextmanager
|
4
|
-
from datetime import
|
5
|
+
from datetime import timedelta
|
5
6
|
from enum import Enum
|
6
7
|
|
7
|
-
from cachetools import TTLCache
|
8
8
|
from fastapi import Depends, FastAPI, HTTPException, status
|
9
9
|
from fastapi.applications import AppType, Lifespan
|
10
10
|
from fastapi.encoders import jsonable_encoder
|
11
|
+
from fastapi.middleware.cors import CORSMiddleware
|
11
12
|
from fastapi.responses import JSONResponse, StreamingResponse
|
12
|
-
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
13
13
|
|
14
14
|
from acp_sdk.models import (
|
15
15
|
Agent as AgentModel,
|
@@ -28,12 +28,11 @@ from acp_sdk.models import (
|
|
28
28
|
RunReadResponse,
|
29
29
|
RunResumeRequest,
|
30
30
|
RunResumeResponse,
|
31
|
-
SessionId,
|
32
31
|
)
|
33
32
|
from acp_sdk.models.errors import ACPError
|
33
|
+
from acp_sdk.models.models import AwaitResume, RunStatus
|
34
34
|
from acp_sdk.models.schemas import PingResponse
|
35
35
|
from acp_sdk.server.agent import Agent
|
36
|
-
from acp_sdk.server.bundle import RunBundle
|
37
36
|
from acp_sdk.server.errors import (
|
38
37
|
RequestValidationError,
|
39
38
|
StarletteHTTPException,
|
@@ -42,8 +41,10 @@ from acp_sdk.server.errors import (
|
|
42
41
|
http_exception_handler,
|
43
42
|
validation_exception_handler,
|
44
43
|
)
|
44
|
+
from acp_sdk.server.executor import CancelData, Executor, RunData
|
45
45
|
from acp_sdk.server.session import Session
|
46
|
-
from acp_sdk.server.
|
46
|
+
from acp_sdk.server.store import MemoryStore, Store
|
47
|
+
from acp_sdk.server.utils import stream_sse, wait_util_stop
|
47
48
|
|
48
49
|
|
49
50
|
class Headers(str, Enum):
|
@@ -52,8 +53,7 @@ class Headers(str, Enum):
|
|
52
53
|
|
53
54
|
def create_app(
|
54
55
|
*agents: Agent,
|
55
|
-
|
56
|
-
run_ttl: timedelta = timedelta(hours=1),
|
56
|
+
store: Store | None = None,
|
57
57
|
lifespan: Lifespan[AppType] | None = None,
|
58
58
|
dependencies: list[Depends] | None = None,
|
59
59
|
) -> FastAPI:
|
@@ -75,22 +75,37 @@ def create_app(
|
|
75
75
|
dependencies=dependencies,
|
76
76
|
)
|
77
77
|
|
78
|
-
|
78
|
+
app.add_middleware(
|
79
|
+
CORSMiddleware,
|
80
|
+
allow_origins=["https://agentcommunicationprotocol.dev"],
|
81
|
+
allow_methods=["*"],
|
82
|
+
allow_headers=["*"],
|
83
|
+
allow_credentials=True,
|
84
|
+
)
|
79
85
|
|
80
86
|
agents: dict[AgentName, Agent] = {agent.name: agent for agent in agents}
|
81
|
-
|
82
|
-
|
87
|
+
|
88
|
+
store = store or MemoryStore(limit=1000, ttl=timedelta(hours=1))
|
89
|
+
run_store = store.as_store(model=RunData, prefix="run_")
|
90
|
+
run_cancel_store = store.as_store(model=CancelData, prefix="run_cancel_")
|
91
|
+
run_resume_store = store.as_store(model=AwaitResume, prefix="run_resume_")
|
92
|
+
session_store = store.as_store(model=Session, prefix="session_")
|
83
93
|
|
84
94
|
app.exception_handler(ACPError)(acp_error_handler)
|
85
95
|
app.exception_handler(StarletteHTTPException)(http_exception_handler)
|
86
96
|
app.exception_handler(RequestValidationError)(validation_exception_handler)
|
87
97
|
app.exception_handler(Exception)(catch_all_exception_handler)
|
88
98
|
|
89
|
-
def
|
90
|
-
|
91
|
-
if not
|
99
|
+
async def find_run_data(run_id: RunId) -> RunData:
|
100
|
+
run_data = await run_store.get(run_id)
|
101
|
+
if not run_data:
|
92
102
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
93
|
-
|
103
|
+
if run_data.run.status.is_terminal:
|
104
|
+
return run_data
|
105
|
+
cancel_data = await run_cancel_store.get(run_data.key)
|
106
|
+
if cancel_data is not None:
|
107
|
+
run_data.run.status = RunStatus.CANCELLING
|
108
|
+
return run_data
|
94
109
|
|
95
110
|
def find_agent(agent_name: AgentName) -> Agent:
|
96
111
|
agent = agents.get(agent_name, None)
|
@@ -120,94 +135,111 @@ def create_app(
|
|
120
135
|
async def create_run(request: RunCreateRequest) -> RunCreateResponse:
|
121
136
|
agent = find_agent(request.agent_name)
|
122
137
|
|
123
|
-
session =
|
138
|
+
session = (
|
139
|
+
(await session_store.get(request.session_id)) or Session(id=request.session_id)
|
140
|
+
if request.session_id
|
141
|
+
else Session()
|
142
|
+
)
|
124
143
|
nonlocal executor
|
125
|
-
|
126
|
-
agent=agent,
|
144
|
+
run_data = RunData(
|
127
145
|
run=Run(agent_name=agent.name, session_id=session.id),
|
128
146
|
input=request.input,
|
129
|
-
history=list(session.history()),
|
130
|
-
executor=executor,
|
131
147
|
)
|
132
|
-
|
148
|
+
await run_store.set(run_data.key, run_data)
|
133
149
|
|
134
|
-
|
135
|
-
|
150
|
+
session.append(run_data.run.run_id)
|
151
|
+
await session_store.set(session.id, session)
|
136
152
|
|
137
|
-
headers = {Headers.RUN_ID: str(
|
153
|
+
headers = {Headers.RUN_ID: str(run_data.run.run_id)}
|
154
|
+
ready = asyncio.Event()
|
155
|
+
|
156
|
+
Executor(
|
157
|
+
agent=agent,
|
158
|
+
run_data=run_data,
|
159
|
+
history=await session.history(run_store),
|
160
|
+
run_store=run_store,
|
161
|
+
cancel_store=run_cancel_store,
|
162
|
+
resume_store=run_resume_store,
|
163
|
+
executor=executor,
|
164
|
+
).execute(wait=ready)
|
138
165
|
|
139
166
|
match request.mode:
|
140
167
|
case RunMode.STREAM:
|
141
168
|
return StreamingResponse(
|
142
|
-
stream_sse(
|
169
|
+
stream_sse(run_data, run_store, 0, ready=ready),
|
143
170
|
headers=headers,
|
144
171
|
media_type="text/event-stream",
|
145
172
|
)
|
146
173
|
case RunMode.SYNC:
|
147
|
-
await
|
174
|
+
await wait_util_stop(run_data, run_store, ready=ready)
|
148
175
|
return JSONResponse(
|
149
176
|
headers=headers,
|
150
|
-
content=jsonable_encoder(
|
177
|
+
content=jsonable_encoder(run_data.run),
|
151
178
|
)
|
152
179
|
case RunMode.ASYNC:
|
180
|
+
ready.set()
|
153
181
|
return JSONResponse(
|
154
182
|
status_code=status.HTTP_202_ACCEPTED,
|
155
183
|
headers=headers,
|
156
|
-
content=jsonable_encoder(
|
184
|
+
content=jsonable_encoder(run_data.run),
|
157
185
|
)
|
158
186
|
case _:
|
159
187
|
raise NotImplementedError()
|
160
188
|
|
161
189
|
@app.get("/runs/{run_id}")
|
162
190
|
async def read_run(run_id: RunId) -> RunReadResponse:
|
163
|
-
bundle =
|
191
|
+
bundle = await find_run_data(run_id)
|
164
192
|
return bundle.run
|
165
193
|
|
166
194
|
@app.get("/runs/{run_id}/events")
|
167
195
|
async def list_run_events(run_id: RunId) -> RunEventsListResponse:
|
168
|
-
bundle =
|
196
|
+
bundle = await find_run_data(run_id)
|
169
197
|
return RunEventsListResponse(events=bundle.events)
|
170
198
|
|
171
199
|
@app.post("/runs/{run_id}")
|
172
200
|
async def resume_run(run_id: RunId, request: RunResumeRequest) -> RunResumeResponse:
|
173
|
-
|
201
|
+
run_data = await find_run_data(run_id)
|
174
202
|
|
175
|
-
if
|
203
|
+
if run_data.run.await_request is None:
|
176
204
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Run {run_id} has no await request")
|
177
205
|
|
178
|
-
if
|
206
|
+
if run_data.run.await_request.type != request.await_resume.type:
|
179
207
|
raise HTTPException(
|
180
208
|
status_code=status.HTTP_403_FORBIDDEN,
|
181
|
-
detail=f"Run {run_id} is expecting resume of type {
|
209
|
+
detail=f"Run {run_id} is expecting resume of type {run_data.run.await_request.type}",
|
182
210
|
)
|
183
211
|
|
184
|
-
|
212
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
213
|
+
await run_store.set(run_data.key, run_data)
|
214
|
+
await run_resume_store.set(run_data.key, request.await_resume)
|
215
|
+
|
185
216
|
match request.mode:
|
186
217
|
case RunMode.STREAM:
|
187
218
|
return StreamingResponse(
|
188
|
-
stream_sse(
|
219
|
+
stream_sse(run_data, run_store, len(run_data.events)),
|
189
220
|
media_type="text/event-stream",
|
190
221
|
)
|
191
222
|
case RunMode.SYNC:
|
192
|
-
await
|
193
|
-
return
|
223
|
+
run_data = await wait_util_stop(run_data, run_store)
|
224
|
+
return run_data.run
|
194
225
|
case RunMode.ASYNC:
|
195
226
|
return JSONResponse(
|
196
227
|
status_code=status.HTTP_202_ACCEPTED,
|
197
|
-
content=jsonable_encoder(
|
228
|
+
content=jsonable_encoder(run_data.run),
|
198
229
|
)
|
199
230
|
case _:
|
200
231
|
raise NotImplementedError()
|
201
232
|
|
202
233
|
@app.post("/runs/{run_id}/cancel")
|
203
234
|
async def cancel_run(run_id: RunId) -> RunCancelResponse:
|
204
|
-
|
205
|
-
if
|
235
|
+
run_data = await find_run_data(run_id)
|
236
|
+
if run_data.run.status.is_terminal:
|
206
237
|
raise HTTPException(
|
207
238
|
status_code=status.HTTP_403_FORBIDDEN,
|
208
|
-
detail=f"Run in terminal status {
|
239
|
+
detail=f"Run in terminal status {run_data.run.status} can't be cancelled",
|
209
240
|
)
|
210
|
-
await
|
211
|
-
|
241
|
+
await run_cancel_store.set(run_data.key, CancelData())
|
242
|
+
run_data.run.status = RunStatus.CANCELLING
|
243
|
+
return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder(run_data.run))
|
212
244
|
|
213
245
|
return app
|
@@ -0,0 +1,198 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from collections.abc import AsyncIterator
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
5
|
+
from datetime import datetime, timezone
|
6
|
+
from typing import Self
|
7
|
+
|
8
|
+
from pydantic import BaseModel, ValidationError
|
9
|
+
|
10
|
+
from acp_sdk.instrumentation import get_tracer
|
11
|
+
from acp_sdk.models import (
|
12
|
+
ACPError,
|
13
|
+
AnyModel,
|
14
|
+
AwaitRequest,
|
15
|
+
AwaitResume,
|
16
|
+
Error,
|
17
|
+
ErrorCode,
|
18
|
+
Event,
|
19
|
+
GenericEvent,
|
20
|
+
Message,
|
21
|
+
MessageCompletedEvent,
|
22
|
+
MessageCreatedEvent,
|
23
|
+
MessagePart,
|
24
|
+
MessagePartEvent,
|
25
|
+
Run,
|
26
|
+
RunAwaitingEvent,
|
27
|
+
RunCancelledEvent,
|
28
|
+
RunCompletedEvent,
|
29
|
+
RunCreatedEvent,
|
30
|
+
RunFailedEvent,
|
31
|
+
RunInProgressEvent,
|
32
|
+
RunStatus,
|
33
|
+
)
|
34
|
+
from acp_sdk.server.agent import Agent
|
35
|
+
from acp_sdk.server.logging import logger
|
36
|
+
from acp_sdk.server.store import Store
|
37
|
+
|
38
|
+
|
39
|
+
class RunData(BaseModel):
|
40
|
+
run: Run
|
41
|
+
input: list[Message]
|
42
|
+
events: list[Event] = []
|
43
|
+
|
44
|
+
@property
|
45
|
+
def key(self) -> str:
|
46
|
+
return str(self.run.run_id)
|
47
|
+
|
48
|
+
async def watch(self, store: Store[Self], *, ready: asyncio.Event | None = None) -> AsyncIterator[Self]:
|
49
|
+
async for data in store.watch(self.key, ready=ready):
|
50
|
+
if data is None:
|
51
|
+
raise RuntimeError("Missing data")
|
52
|
+
yield data
|
53
|
+
if data.run.status.is_terminal:
|
54
|
+
break
|
55
|
+
|
56
|
+
|
57
|
+
class CancelData(BaseModel):
|
58
|
+
pass
|
59
|
+
|
60
|
+
|
61
|
+
class Executor:
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
*,
|
65
|
+
agent: Agent,
|
66
|
+
run_data: RunData,
|
67
|
+
history: list[Message],
|
68
|
+
executor: ThreadPoolExecutor,
|
69
|
+
run_store: Store[RunData],
|
70
|
+
cancel_store: Store[CancelData],
|
71
|
+
resume_store: Store[AwaitResume],
|
72
|
+
) -> None:
|
73
|
+
self.agent = agent
|
74
|
+
self.history = history
|
75
|
+
self.run_data = run_data
|
76
|
+
self.executor = executor
|
77
|
+
|
78
|
+
self.run_store = run_store
|
79
|
+
self.cancel_store = cancel_store
|
80
|
+
self.resume_store = resume_store
|
81
|
+
|
82
|
+
self.logger = logging.LoggerAdapter(logger, {"run_id": str(run_data.run.run_id)})
|
83
|
+
|
84
|
+
def execute(self, *, wait: asyncio.Event) -> None:
|
85
|
+
self.task = asyncio.create_task(self._execute(self.run_data, executor=self.executor, wait=wait))
|
86
|
+
self.watcher = asyncio.create_task(self._watch_for_cancellation())
|
87
|
+
|
88
|
+
async def _push(self) -> None:
|
89
|
+
await self.run_store.set(self.run_data.run.run_id, self.run_data)
|
90
|
+
|
91
|
+
async def _emit(self, event: Event) -> None:
|
92
|
+
freeze = event.model_copy(deep=True)
|
93
|
+
self.run_data.events.append(freeze)
|
94
|
+
await self._push()
|
95
|
+
|
96
|
+
async def _await(self) -> AwaitResume:
|
97
|
+
async for resume in self.resume_store.watch(self.run_data.key):
|
98
|
+
if resume is not None:
|
99
|
+
await self.resume_store.set(self.run_data.key, None)
|
100
|
+
return resume
|
101
|
+
|
102
|
+
async def _watch_for_cancellation(self) -> None:
|
103
|
+
while not self.task.done():
|
104
|
+
try:
|
105
|
+
async for data in self.cancel_store.watch(self.run_data.key):
|
106
|
+
if data is not None:
|
107
|
+
self.task.cancel()
|
108
|
+
except Exception:
|
109
|
+
logger.warning("Cancellation watcher failed, restarting")
|
110
|
+
|
111
|
+
async def _execute(self, run_data: RunData, *, executor: ThreadPoolExecutor, wait: asyncio.Event) -> None:
|
112
|
+
with get_tracer().start_as_current_span("run"):
|
113
|
+
in_message = False
|
114
|
+
|
115
|
+
async def flush_message() -> None:
|
116
|
+
nonlocal in_message
|
117
|
+
if in_message:
|
118
|
+
message = run_data.run.output[-1]
|
119
|
+
message.completed_at = datetime.now(timezone.utc)
|
120
|
+
await self._emit(MessageCompletedEvent(message=message))
|
121
|
+
in_message = False
|
122
|
+
|
123
|
+
try:
|
124
|
+
await wait.wait()
|
125
|
+
|
126
|
+
await self._emit(RunCreatedEvent(run=run_data.run))
|
127
|
+
|
128
|
+
generator = self.agent.execute(
|
129
|
+
input=self.history + run_data.input, session_id=run_data.run.session_id, executor=executor
|
130
|
+
)
|
131
|
+
self.logger.info("Run started")
|
132
|
+
|
133
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
134
|
+
await self._emit(RunInProgressEvent(run=run_data.run))
|
135
|
+
|
136
|
+
await_resume = None
|
137
|
+
while True:
|
138
|
+
next = await generator.asend(await_resume)
|
139
|
+
|
140
|
+
if isinstance(next, (MessagePart, str)):
|
141
|
+
if isinstance(next, str):
|
142
|
+
next = MessagePart(content=next)
|
143
|
+
if not in_message:
|
144
|
+
run_data.run.output.append(Message(parts=[], completed_at=None))
|
145
|
+
in_message = True
|
146
|
+
await self._emit(MessageCreatedEvent(message=run_data.run.output[-1]))
|
147
|
+
run_data.run.output[-1].parts.append(next)
|
148
|
+
await self._emit(MessagePartEvent(part=next))
|
149
|
+
elif isinstance(next, Message):
|
150
|
+
await flush_message()
|
151
|
+
run_data.run.output.append(next)
|
152
|
+
await self._emit(MessageCreatedEvent(message=next))
|
153
|
+
for part in next.parts:
|
154
|
+
await self._emit(MessagePartEvent(part=part))
|
155
|
+
await self._emit(MessageCompletedEvent(message=next))
|
156
|
+
elif isinstance(next, AwaitRequest):
|
157
|
+
run_data.run.await_request = next
|
158
|
+
run_data.run.status = RunStatus.AWAITING
|
159
|
+
await self._emit(RunAwaitingEvent(run=run_data.run))
|
160
|
+
self.logger.info("Run awaited")
|
161
|
+
await_resume = await self._await()
|
162
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
163
|
+
await self._emit(RunInProgressEvent(run=run_data.run))
|
164
|
+
self.logger.info("Run resumed")
|
165
|
+
elif isinstance(next, Error):
|
166
|
+
raise ACPError(error=next)
|
167
|
+
elif isinstance(next, BaseException):
|
168
|
+
raise next
|
169
|
+
elif next is None:
|
170
|
+
await flush_message()
|
171
|
+
elif isinstance(next, BaseModel):
|
172
|
+
await self._emit(GenericEvent(generic=AnyModel(**next.model_dump())))
|
173
|
+
else:
|
174
|
+
try:
|
175
|
+
generic = AnyModel.model_validate(next)
|
176
|
+
await self._emit(GenericEvent(generic=generic))
|
177
|
+
except ValidationError:
|
178
|
+
raise TypeError("Invalid yield")
|
179
|
+
except StopAsyncIteration:
|
180
|
+
await flush_message()
|
181
|
+
run_data.run.status = RunStatus.COMPLETED
|
182
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
183
|
+
await self._emit(RunCompletedEvent(run=run_data.run))
|
184
|
+
self.logger.info("Run completed")
|
185
|
+
except asyncio.CancelledError:
|
186
|
+
run_data.run.status = RunStatus.CANCELLED
|
187
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
188
|
+
await self._emit(RunCancelledEvent(run=run_data.run))
|
189
|
+
self.logger.info("Run cancelled")
|
190
|
+
except Exception as e:
|
191
|
+
if isinstance(e, ACPError):
|
192
|
+
run_data.run.error = e.error
|
193
|
+
else:
|
194
|
+
run_data.run.error = Error(code=ErrorCode.SERVER_ERROR, message=str(e))
|
195
|
+
run_data.run.status = RunStatus.FAILED
|
196
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
197
|
+
await self._emit(RunFailedEvent(run=run_data.run))
|
198
|
+
self.logger.exception("Run failed")
|
acp_sdk/server/server.py
CHANGED
@@ -2,7 +2,6 @@ import asyncio
|
|
2
2
|
import os
|
3
3
|
from collections.abc import AsyncGenerator, Awaitable
|
4
4
|
from contextlib import asynccontextmanager
|
5
|
-
from datetime import timedelta
|
6
5
|
from typing import Any, Callable
|
7
6
|
|
8
7
|
import requests
|
@@ -16,6 +15,7 @@ from acp_sdk.server.agent import agent as agent_decorator
|
|
16
15
|
from acp_sdk.server.app import create_app
|
17
16
|
from acp_sdk.server.logging import configure_logger as configure_logger_func
|
18
17
|
from acp_sdk.server.logging import logger
|
18
|
+
from acp_sdk.server.store import Store
|
19
19
|
from acp_sdk.server.telemetry import configure_telemetry as configure_telemetry_func
|
20
20
|
from acp_sdk.server.utils import async_request_with_retry
|
21
21
|
|
@@ -54,8 +54,7 @@ class Server:
|
|
54
54
|
configure_logger: bool = True,
|
55
55
|
configure_telemetry: bool = False,
|
56
56
|
self_registration: bool = True,
|
57
|
-
|
58
|
-
run_ttl: timedelta = timedelta(hours=1),
|
57
|
+
store: Store | None = None,
|
59
58
|
host: str = "127.0.0.1",
|
60
59
|
port: int = 8000,
|
61
60
|
uds: str | None = None,
|
@@ -118,13 +117,15 @@ class Server:
|
|
118
117
|
|
119
118
|
import uvicorn
|
120
119
|
|
120
|
+
app = create_app(*self.agents, lifespan=self.lifespan, store=store)
|
121
|
+
|
121
122
|
if configure_logger:
|
122
123
|
configure_logger_func()
|
123
124
|
if configure_telemetry:
|
124
|
-
configure_telemetry_func()
|
125
|
+
configure_telemetry_func(app)
|
125
126
|
|
126
127
|
config = uvicorn.Config(
|
127
|
-
|
128
|
+
app,
|
128
129
|
host,
|
129
130
|
port,
|
130
131
|
uds,
|
@@ -182,8 +183,7 @@ class Server:
|
|
182
183
|
configure_logger: bool = True,
|
183
184
|
configure_telemetry: bool = False,
|
184
185
|
self_registration: bool = True,
|
185
|
-
|
186
|
-
run_ttl: timedelta = timedelta(hours=1),
|
186
|
+
store: Store | None = None,
|
187
187
|
host: str = "127.0.0.1",
|
188
188
|
port: int = 8000,
|
189
189
|
uds: str | None = None,
|
@@ -241,8 +241,7 @@ class Server:
|
|
241
241
|
configure_logger=configure_logger,
|
242
242
|
configure_telemetry=configure_telemetry,
|
243
243
|
self_registration=self_registration,
|
244
|
-
|
245
|
-
run_ttl=run_ttl,
|
244
|
+
store=store,
|
246
245
|
host=host,
|
247
246
|
port=port,
|
248
247
|
uds=uds,
|
@@ -309,7 +308,7 @@ class Server:
|
|
309
308
|
|
310
309
|
async def _register_agent(self) -> None:
|
311
310
|
"""If not in PRODUCTION mode, register agent to the beeai platform and provide missing env variables"""
|
312
|
-
if os.getenv("PRODUCTION_MODE",
|
311
|
+
if os.getenv("PRODUCTION_MODE", "").lower() in ["true", "1"]:
|
313
312
|
logger.debug("Agent is not automatically registered in the production mode.")
|
314
313
|
return
|
315
314
|
|
acp_sdk/server/session.py
CHANGED
@@ -1,21 +1,24 @@
|
|
1
1
|
import uuid
|
2
|
-
from collections.abc import Iterator
|
3
2
|
|
4
|
-
from
|
5
|
-
from acp_sdk.models.models import RunStatus
|
6
|
-
from acp_sdk.server.bundle import RunBundle
|
3
|
+
from pydantic import BaseModel, Field
|
7
4
|
|
5
|
+
from acp_sdk.models import Message, RunId, RunStatus, SessionId
|
6
|
+
from acp_sdk.server.executor import RunData
|
7
|
+
from acp_sdk.server.store import Store
|
8
8
|
|
9
|
-
class Session:
|
10
|
-
def __init__(self, id: SessionId | None = None) -> None:
|
11
|
-
self.id: SessionId = id or uuid.uuid4()
|
12
|
-
self.bundles: list[RunBundle] = []
|
13
9
|
|
14
|
-
|
15
|
-
|
10
|
+
class Session(BaseModel):
|
11
|
+
id: SessionId = Field(default_factory=uuid.uuid4)
|
12
|
+
runs: list[RunId] = []
|
16
13
|
|
17
|
-
def
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
14
|
+
def append(self, run_id: RunId) -> None:
|
15
|
+
self.runs.append(run_id)
|
16
|
+
|
17
|
+
async def history(self, store: Store[RunData]) -> list[Message]:
|
18
|
+
history = []
|
19
|
+
for run_id in self.runs:
|
20
|
+
run_data = await store.get(run_id)
|
21
|
+
if run_data is not None and run_data.run.status == RunStatus.COMPLETED:
|
22
|
+
history.extend(run_data.input)
|
23
|
+
history.extend(run_data.run.output)
|
24
|
+
return history
|
@@ -0,0 +1,4 @@
|
|
1
|
+
from acp_sdk.server.store.memory_store import MemoryStore as MemoryStore
|
2
|
+
from acp_sdk.server.store.postgresql_store import PostgreSQLStore as PostgreSQLStore
|
3
|
+
from acp_sdk.server.store.redis_store import RedisStore as RedisStore
|
4
|
+
from acp_sdk.server.store.store import Store as Store
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Generic
|
5
|
+
|
6
|
+
from cachetools import TTLCache
|
7
|
+
|
8
|
+
from acp_sdk.server.store.store import Store, T
|
9
|
+
from acp_sdk.server.store.utils import Stringable
|
10
|
+
|
11
|
+
|
12
|
+
class MemoryStore(Store[T], Generic[T]):
|
13
|
+
def __init__(self, *, limit: int, ttl: int | None = None) -> None:
|
14
|
+
super().__init__()
|
15
|
+
self._cache: TTLCache[str, T] = TTLCache(maxsize=limit, ttl=ttl, timer=datetime.now)
|
16
|
+
self._event = asyncio.Event()
|
17
|
+
|
18
|
+
async def get(self, key: Stringable) -> T | None:
|
19
|
+
value = self._cache.get(str(key))
|
20
|
+
return value.model_copy(deep=True) if value else value
|
21
|
+
|
22
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
23
|
+
if value is None:
|
24
|
+
del self._cache[str(key)]
|
25
|
+
else:
|
26
|
+
self._cache[str(key)] = value.model_copy(deep=True)
|
27
|
+
self._event.set()
|
28
|
+
|
29
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
|
30
|
+
if ready:
|
31
|
+
ready.set()
|
32
|
+
while True:
|
33
|
+
await self._event.wait()
|
34
|
+
self._event.clear()
|
35
|
+
yield await self.get(key)
|
@@ -0,0 +1,69 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from typing import Generic
|
4
|
+
|
5
|
+
from psycopg import AsyncConnection
|
6
|
+
from psycopg.rows import dict_row
|
7
|
+
|
8
|
+
from acp_sdk.server.store.store import Store, StoreModel, T
|
9
|
+
from acp_sdk.server.store.utils import Stringable
|
10
|
+
|
11
|
+
|
12
|
+
class PostgreSQLStore(Store[T], Generic[T]):
|
13
|
+
def __init__(self, *, aconn: AsyncConnection, table: str = "acp_store", channel: str = "acp_update") -> None:
|
14
|
+
super().__init__()
|
15
|
+
self._aconn = aconn
|
16
|
+
self._table = table
|
17
|
+
self._channel = channel
|
18
|
+
|
19
|
+
async def get(self, key: Stringable) -> T | None:
|
20
|
+
await self._ensure_table()
|
21
|
+
async with self._aconn.cursor(row_factory=dict_row) as cur:
|
22
|
+
await cur.execute(f"SELECT value FROM {self._table} WHERE key = %s", (str(key),))
|
23
|
+
result = await cur.fetchone()
|
24
|
+
if result is None:
|
25
|
+
return None
|
26
|
+
return StoreModel.model_validate(result["value"])
|
27
|
+
|
28
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
29
|
+
await self._ensure_table()
|
30
|
+
async with self._aconn.cursor() as cur:
|
31
|
+
if value is None:
|
32
|
+
await cur.execute(
|
33
|
+
f"DELETE FROM {self._table} WHERE key = %s",
|
34
|
+
(str(key),),
|
35
|
+
)
|
36
|
+
else:
|
37
|
+
await cur.execute(
|
38
|
+
f"""
|
39
|
+
INSERT INTO {self._table} (key, value)
|
40
|
+
VALUES (%s, %s)
|
41
|
+
ON CONFLICT (key)
|
42
|
+
DO UPDATE SET value = EXCLUDED.value
|
43
|
+
""",
|
44
|
+
(str(key), value.model_dump_json()),
|
45
|
+
)
|
46
|
+
await cur.execute(f"NOTIFY {self._channel}, '{key!s}'") # NOTIFY appears not to accept params
|
47
|
+
await self._aconn.commit()
|
48
|
+
|
49
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
|
50
|
+
notify_conn = await AsyncConnection.connect(
|
51
|
+
conninfo=f"{self._aconn.info.dsn} password={self._aconn.info.password}", autocommit=True
|
52
|
+
)
|
53
|
+
async with notify_conn:
|
54
|
+
await notify_conn.execute(f"LISTEN {self._channel}")
|
55
|
+
if ready:
|
56
|
+
ready.set()
|
57
|
+
async for notify in notify_conn.notifies():
|
58
|
+
if notify.payload == str(key):
|
59
|
+
yield await self.get(key)
|
60
|
+
|
61
|
+
async def _ensure_table(self) -> None:
|
62
|
+
async with self._aconn.cursor() as cur:
|
63
|
+
await cur.execute(f"""
|
64
|
+
CREATE TABLE IF NOT EXISTS {self._table} (
|
65
|
+
key TEXT PRIMARY KEY,
|
66
|
+
value JSONB NOT NULL
|
67
|
+
)
|
68
|
+
""")
|
69
|
+
await self._aconn.commit()
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from typing import Generic
|
4
|
+
|
5
|
+
from redis.asyncio import Redis
|
6
|
+
|
7
|
+
from acp_sdk.server.store.store import Store, StoreModel, T
|
8
|
+
from acp_sdk.server.store.utils import Stringable
|
9
|
+
|
10
|
+
|
11
|
+
class RedisStore(Store[T], Generic[T]):
|
12
|
+
def __init__(self, *, redis: Redis) -> None:
|
13
|
+
super().__init__()
|
14
|
+
self._redis = redis
|
15
|
+
|
16
|
+
async def get(self, key: Stringable) -> T | None:
|
17
|
+
value = await self._redis.get(str(key))
|
18
|
+
return StoreModel.model_validate_json(value) if value else value
|
19
|
+
|
20
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
21
|
+
if value is None:
|
22
|
+
await self._redis.delete(str(key))
|
23
|
+
else:
|
24
|
+
await self._redis.set(name=str(key), value=value.model_dump_json())
|
25
|
+
|
26
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T]:
|
27
|
+
await self._redis.config_set("notify-keyspace-events", "KEA")
|
28
|
+
|
29
|
+
pubsub = self._redis.pubsub()
|
30
|
+
channel = f"__keyspace@0__:{key!s}"
|
31
|
+
await pubsub.subscribe(channel)
|
32
|
+
if ready:
|
33
|
+
ready.set()
|
34
|
+
try:
|
35
|
+
async for message in pubsub.listen():
|
36
|
+
if message["type"] == "message":
|
37
|
+
yield await self.get(key)
|
38
|
+
finally:
|
39
|
+
await pubsub.unsubscribe(channel)
|
40
|
+
await pubsub.close()
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import asyncio
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from collections.abc import AsyncIterator
|
4
|
+
from typing import Generic, TypeVar
|
5
|
+
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
7
|
+
|
8
|
+
from acp_sdk.server.store.utils import Stringable
|
9
|
+
|
10
|
+
|
11
|
+
class StoreModel(BaseModel):
|
12
|
+
model_config = ConfigDict(extra="allow")
|
13
|
+
|
14
|
+
|
15
|
+
T = TypeVar("T", bound=BaseModel)
|
16
|
+
U = TypeVar("U", bound=BaseModel)
|
17
|
+
|
18
|
+
|
19
|
+
class Store(Generic[T], ABC):
|
20
|
+
@abstractmethod
|
21
|
+
async def get(self, key: Stringable) -> T | None:
|
22
|
+
pass
|
23
|
+
|
24
|
+
@abstractmethod
|
25
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
|
30
|
+
pass
|
31
|
+
|
32
|
+
def as_store(self, model: type[U], prefix: Stringable = "") -> "Store[U]":
|
33
|
+
return StoreView(model=model, store=self, prefix=prefix)
|
34
|
+
|
35
|
+
|
36
|
+
class StoreView(Store[U], Generic[U]):
|
37
|
+
def __init__(self, *, model: type[U], store: Store[T], prefix: Stringable = "") -> None:
|
38
|
+
super().__init__()
|
39
|
+
self._model = model
|
40
|
+
self._store = store
|
41
|
+
self._prefix = prefix
|
42
|
+
|
43
|
+
async def get(self, key: Stringable) -> U | None:
|
44
|
+
value = await self._store.get(self._get_key(key))
|
45
|
+
return self._model.model_validate(value.model_dump()) if value else value
|
46
|
+
|
47
|
+
async def set(self, key: Stringable, value: U | None) -> None:
|
48
|
+
await self._store.set(self._get_key(key), value)
|
49
|
+
|
50
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[U | None]:
|
51
|
+
async for value in self._store.watch(self._get_key(key), ready=ready):
|
52
|
+
yield self._model.model_validate(value.model_dump()) if value else value
|
53
|
+
|
54
|
+
def _get_key(self, key: Stringable) -> str:
|
55
|
+
return f"{self._prefix!s}{key!s}"
|
acp_sdk/server/telemetry.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
import logging
|
2
2
|
|
3
|
+
from fastapi import FastAPI
|
3
4
|
from opentelemetry import metrics, trace
|
4
5
|
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
|
5
6
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
6
7
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
8
|
+
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
7
9
|
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
|
8
10
|
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
|
9
11
|
from opentelemetry.sdk.metrics import MeterProvider
|
@@ -22,8 +24,10 @@ from acp_sdk.version import __version__
|
|
22
24
|
root_logger = logging.getLogger()
|
23
25
|
|
24
26
|
|
25
|
-
def configure_telemetry() -> None:
|
26
|
-
"""Utility that configures opentelemetry with OTLP exporter"""
|
27
|
+
def configure_telemetry(app: FastAPI) -> None:
|
28
|
+
"""Utility that configures opentelemetry with OTLP exporter and FastAPI instrumentation"""
|
29
|
+
|
30
|
+
FastAPIInstrumentor.instrument_app(app)
|
27
31
|
|
28
32
|
resource = Resource(
|
29
33
|
attributes={
|
acp_sdk/server/utils.py
CHANGED
@@ -6,17 +6,41 @@ import httpx
|
|
6
6
|
import requests
|
7
7
|
from pydantic import BaseModel
|
8
8
|
|
9
|
-
from acp_sdk.
|
9
|
+
from acp_sdk.models import RunStatus
|
10
|
+
from acp_sdk.server.executor import RunData
|
10
11
|
from acp_sdk.server.logging import logger
|
12
|
+
from acp_sdk.server.store.store import Store
|
11
13
|
|
12
14
|
|
13
15
|
def encode_sse(model: BaseModel) -> str:
|
14
16
|
return f"data: {model.model_dump_json()}\n\n"
|
15
17
|
|
16
18
|
|
17
|
-
async def
|
18
|
-
|
19
|
-
|
19
|
+
async def watch_util_stop(
|
20
|
+
run_data: RunData, store: Store[RunData], *, ready: asyncio.Event | None = None
|
21
|
+
) -> AsyncGenerator[RunData]:
|
22
|
+
async for data in run_data.watch(store, ready=ready):
|
23
|
+
yield data
|
24
|
+
if data.run.status == RunStatus.AWAITING:
|
25
|
+
break
|
26
|
+
|
27
|
+
|
28
|
+
async def wait_util_stop(run_data: RunData, store: Store[RunData], *, ready: asyncio.Event | None = None) -> RunData:
|
29
|
+
data = run_data
|
30
|
+
async for latest_data in watch_util_stop(run_data, store, ready=ready):
|
31
|
+
data = latest_data
|
32
|
+
return data
|
33
|
+
|
34
|
+
|
35
|
+
async def stream_sse(
|
36
|
+
run_data: RunData, store: Store[RunData], idx: int, *, ready: asyncio.Event | None = None
|
37
|
+
) -> AsyncGenerator[str]:
|
38
|
+
next_event_idx = idx
|
39
|
+
async for data in watch_util_stop(run_data, store, ready=ready):
|
40
|
+
new_events = data.events[next_event_idx:]
|
41
|
+
next_event_idx = len(data.events)
|
42
|
+
for event in new_events:
|
43
|
+
yield encode_sse(event)
|
20
44
|
|
21
45
|
|
22
46
|
async def async_request_with_retry(
|
@@ -1,27 +1,29 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: acp-sdk
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: Agent Communication Protocol SDK
|
5
5
|
Author: IBM Corp.
|
6
6
|
Maintainer-email: Tomas Pilar <thomas7pilar@gmail.com>
|
7
7
|
License-Expression: Apache-2.0
|
8
8
|
Requires-Python: <4.0,>=3.11
|
9
|
-
Requires-Dist: cachetools>=5.5
|
10
|
-
Requires-Dist: fastapi[standard]>=0.115
|
11
|
-
Requires-Dist: httpx-sse>=0.4
|
12
|
-
Requires-Dist: httpx>=0.26
|
13
|
-
Requires-Dist: janus>=2.0
|
14
|
-
Requires-Dist: opentelemetry-api>=1.31
|
15
|
-
Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31
|
9
|
+
Requires-Dist: cachetools>=5.5
|
10
|
+
Requires-Dist: fastapi[standard]>=0.115
|
11
|
+
Requires-Dist: httpx-sse>=0.4
|
12
|
+
Requires-Dist: httpx>=0.26
|
13
|
+
Requires-Dist: janus>=2.0
|
14
|
+
Requires-Dist: opentelemetry-api>=1.31
|
15
|
+
Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31
|
16
16
|
Requires-Dist: opentelemetry-instrumentation-fastapi>=0.52b1
|
17
17
|
Requires-Dist: opentelemetry-instrumentation-httpx>=0.52b1
|
18
|
-
Requires-Dist: opentelemetry-sdk>=1.31
|
19
|
-
Requires-Dist:
|
18
|
+
Requires-Dist: opentelemetry-sdk>=1.31
|
19
|
+
Requires-Dist: psycopg[binary]>=3.2
|
20
|
+
Requires-Dist: pydantic>=2.0
|
21
|
+
Requires-Dist: redis>=6.1
|
20
22
|
Description-Content-Type: text/markdown
|
21
23
|
|
22
24
|
# Agent Communication Protocol SDK for Python
|
23
25
|
|
24
|
-
Agent Communication Protocol SDK for Python
|
26
|
+
Agent Communication Protocol SDK for Python helps developers to serve and consume agents over the Agent Communication Protocol.
|
25
27
|
|
26
28
|
## Prerequisites
|
27
29
|
|
@@ -3,7 +3,7 @@ acp_sdk/instrumentation.py,sha256=JqSyvILN3sGAfOZrmckQq4-M_4_5alyPn95DK0o5lfA,16
|
|
3
3
|
acp_sdk/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
acp_sdk/version.py,sha256=Niy83rgvigB4hL_rR-O4ySvI7dj6xnqkyOe_JTymi9s,73
|
5
5
|
acp_sdk/client/__init__.py,sha256=Bca1DORrswxzZsrR2aUFpATuNG2xNSmYvF1Z2WJaVbc,51
|
6
|
-
acp_sdk/client/client.py,sha256=
|
6
|
+
acp_sdk/client/client.py,sha256=7GVD7eVlaSU78O3puMlyoOTqKjTwZYaBKLHyYvbvr3I,8504
|
7
7
|
acp_sdk/client/types.py,sha256=_H6zYt-2OHOOYRtssRnbDIiwmgsl2-KIXc9lb-mJLFA,133
|
8
8
|
acp_sdk/client/utils.py,sha256=2jhJyrPJmVFRoDJh0q_JMqOMlC3IxCh-6HXed-PIZS8,924
|
9
9
|
acp_sdk/models/__init__.py,sha256=numSDBDT1QHx7n_Y3Deb5VOvKWcUBxbOEaMwQBSRHxc,151
|
@@ -12,16 +12,22 @@ acp_sdk/models/models.py,sha256=So5D9VjJxX2jdOF8BGc3f9D2ssbAyFEO2Rgr1KykI-M,7730
|
|
12
12
|
acp_sdk/models/schemas.py,sha256=_ah7_zHsQJGxDXvnzsBvASdRsQHVphFQ7Sum6A04iRw,759
|
13
13
|
acp_sdk/server/__init__.py,sha256=mxBBBFaZuMEUENRMLwp1XZkuLeT9QghcFmNvjnqvAAU,377
|
14
14
|
acp_sdk/server/agent.py,sha256=6VBKn_qVXqUl79G8T7grwhnuLMwr67d4UGagMGX1hMs,6586
|
15
|
-
acp_sdk/server/app.py,sha256=
|
16
|
-
acp_sdk/server/bundle.py,sha256=umD2GgDp17lUddu0adpp1zUcm1JJvDrDpIZ0uR-6VeY,7204
|
15
|
+
acp_sdk/server/app.py,sha256=XeMn8hfU2J2O2QxSvsH5q6lowxCfwS1-DTWz-gUVotA,8733
|
17
16
|
acp_sdk/server/context.py,sha256=MgnLV6qcDIhc_0BjW7r4Jj1tHts4ZuwpdTGIBnz2Mgo,1036
|
18
17
|
acp_sdk/server/errors.py,sha256=GSO8yYIqEeX8Y4Lz86ks35dMTHiQiXuOrLYYx0eXsbI,2110
|
18
|
+
acp_sdk/server/executor.py,sha256=YL0J9cVY1QZtdTeqwjJaKDpB_T6_sByHlHc52kgNAJo,7742
|
19
19
|
acp_sdk/server/logging.py,sha256=Oc8yZigCsuDnHHPsarRzu0RX3NKaLEgpELM2yovGKDI,411
|
20
|
-
acp_sdk/server/server.py,sha256=
|
21
|
-
acp_sdk/server/session.py,sha256=
|
22
|
-
acp_sdk/server/telemetry.py,sha256=
|
20
|
+
acp_sdk/server/server.py,sha256=lgpvokjd3f_dqEkyfLn3Nr4oKXNHPUkvV0uNimUmXU8,13497
|
21
|
+
acp_sdk/server/session.py,sha256=vGUVpKzUGefI1c7LeK08Bvd8zvJIRfsdJEt2KhYoEg0,764
|
22
|
+
acp_sdk/server/telemetry.py,sha256=lbB2ppijUcqbHUOn0e-15LGcVvT_qrMguq8qBokICac,2016
|
23
23
|
acp_sdk/server/types.py,sha256=gLb5wCkMYhmu2laj_ymK-TPfN9LSjRgKOP1H_893UzA,304
|
24
|
-
acp_sdk/server/utils.py,sha256=
|
25
|
-
acp_sdk
|
26
|
-
acp_sdk
|
27
|
-
acp_sdk
|
24
|
+
acp_sdk/server/utils.py,sha256=BYSn4Bd95Bn-oEH1W1yE_pWpYUOdtYPh-vMnou4nsdk,2721
|
25
|
+
acp_sdk/server/store/__init__.py,sha256=zzKic0byQTM86cyC2whwZeNP4prfy_HZrbSriTaV5j8,282
|
26
|
+
acp_sdk/server/store/memory_store.py,sha256=9hOoJiHdaDz-0hNHWhO_qZpp-I_pnTqYBIhQnlWilu4,1194
|
27
|
+
acp_sdk/server/store/postgresql_store.py,sha256=bHyAgf4vSn_07wGXXq6juFwm3JldYNOjU9RARcDSYQo,2717
|
28
|
+
acp_sdk/server/store/redis_store.py,sha256=IKXvDseOFMcoGjVYPOkOBhPnJAchy_RyeMayKLoVCGA,1378
|
29
|
+
acp_sdk/server/store/store.py,sha256=jGmYy9oiuVjhYYJY8QRo4g2J2Qyt1HLTmq_eHy4aI7c,1806
|
30
|
+
acp_sdk/server/store/utils.py,sha256=JumEOMs1h1uGlnHnUGeguee-srGzT7_Y2NVEYt01QuY,92
|
31
|
+
acp_sdk-0.10.0.dist-info/METADATA,sha256=wmPPGrdoeCoe7M8DM-r3qlLgpwVteHARkZpbKaVGk9I,1685
|
32
|
+
acp_sdk-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
33
|
+
acp_sdk-0.10.0.dist-info/RECORD,,
|
acp_sdk/server/bundle.py
DELETED
@@ -1,182 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import logging
|
3
|
-
from collections.abc import AsyncGenerator
|
4
|
-
from concurrent.futures import ThreadPoolExecutor
|
5
|
-
from datetime import datetime, timezone
|
6
|
-
|
7
|
-
from pydantic import BaseModel, ValidationError
|
8
|
-
|
9
|
-
from acp_sdk.instrumentation import get_tracer
|
10
|
-
from acp_sdk.models import (
|
11
|
-
ACPError,
|
12
|
-
AnyModel,
|
13
|
-
AwaitRequest,
|
14
|
-
AwaitResume,
|
15
|
-
Error,
|
16
|
-
ErrorCode,
|
17
|
-
Event,
|
18
|
-
GenericEvent,
|
19
|
-
Message,
|
20
|
-
MessageCompletedEvent,
|
21
|
-
MessageCreatedEvent,
|
22
|
-
MessagePart,
|
23
|
-
MessagePartEvent,
|
24
|
-
Run,
|
25
|
-
RunAwaitingEvent,
|
26
|
-
RunCancelledEvent,
|
27
|
-
RunCompletedEvent,
|
28
|
-
RunCreatedEvent,
|
29
|
-
RunFailedEvent,
|
30
|
-
RunInProgressEvent,
|
31
|
-
RunStatus,
|
32
|
-
)
|
33
|
-
from acp_sdk.server.agent import Agent
|
34
|
-
from acp_sdk.server.logging import logger
|
35
|
-
|
36
|
-
|
37
|
-
class RunBundle:
|
38
|
-
def __init__(
|
39
|
-
self, *, agent: Agent, run: Run, input: list[Message], history: list[Message], executor: ThreadPoolExecutor
|
40
|
-
) -> None:
|
41
|
-
self.agent = agent
|
42
|
-
self.run = run
|
43
|
-
self.input = input
|
44
|
-
self.history = history
|
45
|
-
|
46
|
-
self.stream_queue: asyncio.Queue[Event] = asyncio.Queue()
|
47
|
-
self.events: list[Event] = []
|
48
|
-
|
49
|
-
self.await_queue: asyncio.Queue[AwaitResume] = asyncio.Queue(maxsize=1)
|
50
|
-
self.await_or_terminate_event = asyncio.Event()
|
51
|
-
|
52
|
-
self.task = asyncio.create_task(self._execute(input, executor=executor))
|
53
|
-
|
54
|
-
async def stream(self) -> AsyncGenerator[Event]:
|
55
|
-
while True:
|
56
|
-
event = await self.stream_queue.get()
|
57
|
-
if event is None:
|
58
|
-
break
|
59
|
-
yield event
|
60
|
-
self.stream_queue.task_done()
|
61
|
-
|
62
|
-
async def emit(self, event: Event) -> None:
|
63
|
-
freeze = event.model_copy(deep=True)
|
64
|
-
self.events.append(freeze)
|
65
|
-
await self.stream_queue.put(freeze)
|
66
|
-
|
67
|
-
async def await_(self) -> AwaitResume:
|
68
|
-
await self.stream_queue.put(None)
|
69
|
-
self.await_queue.empty()
|
70
|
-
self.await_or_terminate_event.set()
|
71
|
-
self.await_or_terminate_event.clear()
|
72
|
-
resume = await self.await_queue.get()
|
73
|
-
self.await_queue.task_done()
|
74
|
-
return resume
|
75
|
-
|
76
|
-
async def resume(self, resume: AwaitResume) -> None:
|
77
|
-
self.stream_queue = asyncio.Queue()
|
78
|
-
await self.await_queue.put(resume)
|
79
|
-
self.run.status = RunStatus.IN_PROGRESS
|
80
|
-
self.run.await_request = None
|
81
|
-
|
82
|
-
async def cancel(self) -> None:
|
83
|
-
self.task.cancel()
|
84
|
-
self.run.status = RunStatus.CANCELLING
|
85
|
-
self.run.await_request = None
|
86
|
-
|
87
|
-
async def join(self) -> None:
|
88
|
-
await self.await_or_terminate_event.wait()
|
89
|
-
|
90
|
-
async def _execute(self, input: list[Message], *, executor: ThreadPoolExecutor) -> None:
|
91
|
-
with get_tracer().start_as_current_span("run"):
|
92
|
-
run_logger = logging.LoggerAdapter(logger, {"run_id": str(self.run.run_id)})
|
93
|
-
|
94
|
-
in_message = False
|
95
|
-
|
96
|
-
async def flush_message() -> None:
|
97
|
-
nonlocal in_message
|
98
|
-
if in_message:
|
99
|
-
message = self.run.output[-1]
|
100
|
-
message.completed_at = datetime.now(timezone.utc)
|
101
|
-
await self.emit(MessageCompletedEvent(message=message))
|
102
|
-
in_message = False
|
103
|
-
|
104
|
-
try:
|
105
|
-
await self.emit(RunCreatedEvent(run=self.run))
|
106
|
-
|
107
|
-
generator = self.agent.execute(
|
108
|
-
input=self.history + input, session_id=self.run.session_id, executor=executor
|
109
|
-
)
|
110
|
-
run_logger.info("Run started")
|
111
|
-
|
112
|
-
self.run.status = RunStatus.IN_PROGRESS
|
113
|
-
await self.emit(RunInProgressEvent(run=self.run))
|
114
|
-
|
115
|
-
await_resume = None
|
116
|
-
while True:
|
117
|
-
next = await generator.asend(await_resume)
|
118
|
-
|
119
|
-
if isinstance(next, (MessagePart, str)):
|
120
|
-
if isinstance(next, str):
|
121
|
-
next = MessagePart(content=next)
|
122
|
-
if not in_message:
|
123
|
-
self.run.output.append(Message(parts=[], completed_at=None))
|
124
|
-
in_message = True
|
125
|
-
await self.emit(MessageCreatedEvent(message=self.run.output[-1]))
|
126
|
-
self.run.output[-1].parts.append(next)
|
127
|
-
await self.emit(MessagePartEvent(part=next))
|
128
|
-
elif isinstance(next, Message):
|
129
|
-
await flush_message()
|
130
|
-
self.run.output.append(next)
|
131
|
-
await self.emit(MessageCreatedEvent(message=next))
|
132
|
-
for part in next.parts:
|
133
|
-
await self.emit(MessagePartEvent(part=part))
|
134
|
-
await self.emit(MessageCompletedEvent(message=next))
|
135
|
-
elif isinstance(next, AwaitRequest):
|
136
|
-
self.run.await_request = next
|
137
|
-
self.run.status = RunStatus.AWAITING
|
138
|
-
await self.emit(RunAwaitingEvent(run=self.run))
|
139
|
-
run_logger.info("Run awaited")
|
140
|
-
await_resume = await self.await_()
|
141
|
-
await self.emit(RunInProgressEvent(run=self.run))
|
142
|
-
run_logger.info("Run resumed")
|
143
|
-
elif isinstance(next, Error):
|
144
|
-
raise ACPError(error=next)
|
145
|
-
elif isinstance(next, ACPError):
|
146
|
-
raise next
|
147
|
-
elif next is None:
|
148
|
-
await flush_message()
|
149
|
-
elif isinstance(next, BaseModel):
|
150
|
-
await self.emit(GenericEvent(generic=AnyModel(**next.model_dump())))
|
151
|
-
else:
|
152
|
-
try:
|
153
|
-
generic = AnyModel.model_validate(next)
|
154
|
-
await self.emit(GenericEvent(generic=generic))
|
155
|
-
except ValidationError:
|
156
|
-
raise TypeError("Invalid yield")
|
157
|
-
except StopAsyncIteration:
|
158
|
-
await flush_message()
|
159
|
-
self.run.status = RunStatus.COMPLETED
|
160
|
-
self.run.finished_at = datetime.now(timezone.utc)
|
161
|
-
await self.emit(RunCompletedEvent(run=self.run))
|
162
|
-
run_logger.info("Run completed")
|
163
|
-
except asyncio.CancelledError:
|
164
|
-
self.run.status = RunStatus.CANCELLED
|
165
|
-
self.run.finished_at = datetime.now(timezone.utc)
|
166
|
-
await self.emit(RunCancelledEvent(run=self.run))
|
167
|
-
run_logger.info("Run cancelled")
|
168
|
-
except Exception as e:
|
169
|
-
if isinstance(e, ACPError):
|
170
|
-
self.run.error = e.error
|
171
|
-
else:
|
172
|
-
self.run.error = Error(code=ErrorCode.SERVER_ERROR, message=str(e))
|
173
|
-
self.run.status = RunStatus.FAILED
|
174
|
-
self.run.finished_at = datetime.now(timezone.utc)
|
175
|
-
await self.emit(RunFailedEvent(run=self.run))
|
176
|
-
run_logger.exception("Run failed")
|
177
|
-
raise
|
178
|
-
finally:
|
179
|
-
self.await_or_terminate_event.set()
|
180
|
-
await self.stream_queue.put(None)
|
181
|
-
if not self.task.done():
|
182
|
-
self.task.cancel()
|
File without changes
|