acp-sdk 0.9.1__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acp_sdk/server/app.py +77 -42
- acp_sdk/server/executor.py +198 -0
- acp_sdk/server/server.py +20 -11
- acp_sdk/server/session.py +18 -15
- acp_sdk/server/store/__init__.py +4 -0
- acp_sdk/server/store/memory_store.py +35 -0
- acp_sdk/server/store/postgresql_store.py +69 -0
- acp_sdk/server/store/redis_store.py +40 -0
- acp_sdk/server/store/store.py +55 -0
- acp_sdk/server/store/utils.py +5 -0
- acp_sdk/server/utils.py +28 -4
- {acp_sdk-0.9.1.dist-info → acp_sdk-0.10.1.dist-info}/METADATA +13 -11
- {acp_sdk-0.9.1.dist-info → acp_sdk-0.10.1.dist-info}/RECORD +14 -8
- acp_sdk/server/bundle.py +0 -182
- {acp_sdk-0.9.1.dist-info → acp_sdk-0.10.1.dist-info}/WHEEL +0 -0
acp_sdk/server/app.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1
|
+
import asyncio
|
1
2
|
from collections.abc import AsyncGenerator
|
2
3
|
from concurrent.futures import ThreadPoolExecutor
|
3
4
|
from contextlib import asynccontextmanager
|
4
|
-
from datetime import
|
5
|
+
from datetime import timedelta
|
5
6
|
from enum import Enum
|
6
7
|
|
7
|
-
from cachetools import TTLCache
|
8
8
|
from fastapi import Depends, FastAPI, HTTPException, status
|
9
9
|
from fastapi.applications import AppType, Lifespan
|
10
10
|
from fastapi.encoders import jsonable_encoder
|
11
|
+
from fastapi.middleware.cors import CORSMiddleware
|
11
12
|
from fastapi.responses import JSONResponse, StreamingResponse
|
12
13
|
|
13
14
|
from acp_sdk.models import (
|
@@ -27,12 +28,11 @@ from acp_sdk.models import (
|
|
27
28
|
RunReadResponse,
|
28
29
|
RunResumeRequest,
|
29
30
|
RunResumeResponse,
|
30
|
-
SessionId,
|
31
31
|
)
|
32
32
|
from acp_sdk.models.errors import ACPError
|
33
|
+
from acp_sdk.models.models import AwaitResume, RunStatus
|
33
34
|
from acp_sdk.models.schemas import PingResponse
|
34
35
|
from acp_sdk.server.agent import Agent
|
35
|
-
from acp_sdk.server.bundle import RunBundle
|
36
36
|
from acp_sdk.server.errors import (
|
37
37
|
RequestValidationError,
|
38
38
|
StarletteHTTPException,
|
@@ -41,8 +41,10 @@ from acp_sdk.server.errors import (
|
|
41
41
|
http_exception_handler,
|
42
42
|
validation_exception_handler,
|
43
43
|
)
|
44
|
+
from acp_sdk.server.executor import CancelData, Executor, RunData
|
44
45
|
from acp_sdk.server.session import Session
|
45
|
-
from acp_sdk.server.
|
46
|
+
from acp_sdk.server.store import MemoryStore, Store
|
47
|
+
from acp_sdk.server.utils import stream_sse, wait_util_stop
|
46
48
|
|
47
49
|
|
48
50
|
class Headers(str, Enum):
|
@@ -51,8 +53,7 @@ class Headers(str, Enum):
|
|
51
53
|
|
52
54
|
def create_app(
|
53
55
|
*agents: Agent,
|
54
|
-
|
55
|
-
run_ttl: timedelta = timedelta(hours=1),
|
56
|
+
store: Store | None = None,
|
56
57
|
lifespan: Lifespan[AppType] | None = None,
|
57
58
|
dependencies: list[Depends] | None = None,
|
58
59
|
) -> FastAPI:
|
@@ -74,20 +75,37 @@ def create_app(
|
|
74
75
|
dependencies=dependencies,
|
75
76
|
)
|
76
77
|
|
78
|
+
app.add_middleware(
|
79
|
+
CORSMiddleware,
|
80
|
+
allow_origins=["https://agentcommunicationprotocol.dev"],
|
81
|
+
allow_methods=["*"],
|
82
|
+
allow_headers=["*"],
|
83
|
+
allow_credentials=True,
|
84
|
+
)
|
85
|
+
|
77
86
|
agents: dict[AgentName, Agent] = {agent.name: agent for agent in agents}
|
78
|
-
|
79
|
-
|
87
|
+
|
88
|
+
store = store or MemoryStore(limit=1000, ttl=timedelta(hours=1))
|
89
|
+
run_store = store.as_store(model=RunData, prefix="run_")
|
90
|
+
run_cancel_store = store.as_store(model=CancelData, prefix="run_cancel_")
|
91
|
+
run_resume_store = store.as_store(model=AwaitResume, prefix="run_resume_")
|
92
|
+
session_store = store.as_store(model=Session, prefix="session_")
|
80
93
|
|
81
94
|
app.exception_handler(ACPError)(acp_error_handler)
|
82
95
|
app.exception_handler(StarletteHTTPException)(http_exception_handler)
|
83
96
|
app.exception_handler(RequestValidationError)(validation_exception_handler)
|
84
97
|
app.exception_handler(Exception)(catch_all_exception_handler)
|
85
98
|
|
86
|
-
def
|
87
|
-
|
88
|
-
if not
|
99
|
+
async def find_run_data(run_id: RunId) -> RunData:
|
100
|
+
run_data = await run_store.get(run_id)
|
101
|
+
if not run_data:
|
89
102
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
90
|
-
|
103
|
+
if run_data.run.status.is_terminal:
|
104
|
+
return run_data
|
105
|
+
cancel_data = await run_cancel_store.get(run_data.key)
|
106
|
+
if cancel_data is not None:
|
107
|
+
run_data.run.status = RunStatus.CANCELLING
|
108
|
+
return run_data
|
91
109
|
|
92
110
|
def find_agent(agent_name: AgentName) -> Agent:
|
93
111
|
agent = agents.get(agent_name, None)
|
@@ -117,94 +135,111 @@ def create_app(
|
|
117
135
|
async def create_run(request: RunCreateRequest) -> RunCreateResponse:
|
118
136
|
agent = find_agent(request.agent_name)
|
119
137
|
|
120
|
-
session =
|
138
|
+
session = (
|
139
|
+
(await session_store.get(request.session_id)) or Session(id=request.session_id)
|
140
|
+
if request.session_id
|
141
|
+
else Session()
|
142
|
+
)
|
121
143
|
nonlocal executor
|
122
|
-
|
123
|
-
agent=agent,
|
144
|
+
run_data = RunData(
|
124
145
|
run=Run(agent_name=agent.name, session_id=session.id),
|
125
146
|
input=request.input,
|
126
|
-
history=list(session.history()),
|
127
|
-
executor=executor,
|
128
147
|
)
|
129
|
-
|
148
|
+
await run_store.set(run_data.key, run_data)
|
130
149
|
|
131
|
-
|
132
|
-
|
150
|
+
session.append(run_data.run.run_id)
|
151
|
+
await session_store.set(session.id, session)
|
133
152
|
|
134
|
-
headers = {Headers.RUN_ID: str(
|
153
|
+
headers = {Headers.RUN_ID: str(run_data.run.run_id)}
|
154
|
+
ready = asyncio.Event()
|
155
|
+
|
156
|
+
Executor(
|
157
|
+
agent=agent,
|
158
|
+
run_data=run_data,
|
159
|
+
history=await session.history(run_store),
|
160
|
+
run_store=run_store,
|
161
|
+
cancel_store=run_cancel_store,
|
162
|
+
resume_store=run_resume_store,
|
163
|
+
executor=executor,
|
164
|
+
).execute(wait=ready)
|
135
165
|
|
136
166
|
match request.mode:
|
137
167
|
case RunMode.STREAM:
|
138
168
|
return StreamingResponse(
|
139
|
-
stream_sse(
|
169
|
+
stream_sse(run_data, run_store, 0, ready=ready),
|
140
170
|
headers=headers,
|
141
171
|
media_type="text/event-stream",
|
142
172
|
)
|
143
173
|
case RunMode.SYNC:
|
144
|
-
await
|
174
|
+
await wait_util_stop(run_data, run_store, ready=ready)
|
145
175
|
return JSONResponse(
|
146
176
|
headers=headers,
|
147
|
-
content=jsonable_encoder(
|
177
|
+
content=jsonable_encoder(run_data.run),
|
148
178
|
)
|
149
179
|
case RunMode.ASYNC:
|
180
|
+
ready.set()
|
150
181
|
return JSONResponse(
|
151
182
|
status_code=status.HTTP_202_ACCEPTED,
|
152
183
|
headers=headers,
|
153
|
-
content=jsonable_encoder(
|
184
|
+
content=jsonable_encoder(run_data.run),
|
154
185
|
)
|
155
186
|
case _:
|
156
187
|
raise NotImplementedError()
|
157
188
|
|
158
189
|
@app.get("/runs/{run_id}")
|
159
190
|
async def read_run(run_id: RunId) -> RunReadResponse:
|
160
|
-
bundle =
|
191
|
+
bundle = await find_run_data(run_id)
|
161
192
|
return bundle.run
|
162
193
|
|
163
194
|
@app.get("/runs/{run_id}/events")
|
164
195
|
async def list_run_events(run_id: RunId) -> RunEventsListResponse:
|
165
|
-
bundle =
|
196
|
+
bundle = await find_run_data(run_id)
|
166
197
|
return RunEventsListResponse(events=bundle.events)
|
167
198
|
|
168
199
|
@app.post("/runs/{run_id}")
|
169
200
|
async def resume_run(run_id: RunId, request: RunResumeRequest) -> RunResumeResponse:
|
170
|
-
|
201
|
+
run_data = await find_run_data(run_id)
|
171
202
|
|
172
|
-
if
|
203
|
+
if run_data.run.await_request is None:
|
173
204
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Run {run_id} has no await request")
|
174
205
|
|
175
|
-
if
|
206
|
+
if run_data.run.await_request.type != request.await_resume.type:
|
176
207
|
raise HTTPException(
|
177
208
|
status_code=status.HTTP_403_FORBIDDEN,
|
178
|
-
detail=f"Run {run_id} is expecting resume of type {
|
209
|
+
detail=f"Run {run_id} is expecting resume of type {run_data.run.await_request.type}",
|
179
210
|
)
|
180
211
|
|
181
|
-
|
212
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
213
|
+
await run_store.set(run_data.key, run_data)
|
214
|
+
await run_resume_store.set(run_data.key, request.await_resume)
|
215
|
+
|
182
216
|
match request.mode:
|
183
217
|
case RunMode.STREAM:
|
184
218
|
return StreamingResponse(
|
185
|
-
stream_sse(
|
219
|
+
stream_sse(run_data, run_store, len(run_data.events)),
|
186
220
|
media_type="text/event-stream",
|
187
221
|
)
|
188
222
|
case RunMode.SYNC:
|
189
|
-
await
|
190
|
-
return
|
223
|
+
run_data = await wait_util_stop(run_data, run_store)
|
224
|
+
return run_data.run
|
191
225
|
case RunMode.ASYNC:
|
192
226
|
return JSONResponse(
|
193
227
|
status_code=status.HTTP_202_ACCEPTED,
|
194
|
-
content=jsonable_encoder(
|
228
|
+
content=jsonable_encoder(run_data.run),
|
195
229
|
)
|
196
230
|
case _:
|
197
231
|
raise NotImplementedError()
|
198
232
|
|
199
233
|
@app.post("/runs/{run_id}/cancel")
|
200
234
|
async def cancel_run(run_id: RunId) -> RunCancelResponse:
|
201
|
-
|
202
|
-
if
|
235
|
+
run_data = await find_run_data(run_id)
|
236
|
+
if run_data.run.status.is_terminal:
|
203
237
|
raise HTTPException(
|
204
238
|
status_code=status.HTTP_403_FORBIDDEN,
|
205
|
-
detail=f"Run in terminal status {
|
239
|
+
detail=f"Run in terminal status {run_data.run.status} can't be cancelled",
|
206
240
|
)
|
207
|
-
await
|
208
|
-
|
241
|
+
await run_cancel_store.set(run_data.key, CancelData())
|
242
|
+
run_data.run.status = RunStatus.CANCELLING
|
243
|
+
return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder(run_data.run))
|
209
244
|
|
210
245
|
return app
|
@@ -0,0 +1,198 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from collections.abc import AsyncIterator
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
5
|
+
from datetime import datetime, timezone
|
6
|
+
from typing import Self
|
7
|
+
|
8
|
+
from pydantic import BaseModel, ValidationError
|
9
|
+
|
10
|
+
from acp_sdk.instrumentation import get_tracer
|
11
|
+
from acp_sdk.models import (
|
12
|
+
ACPError,
|
13
|
+
AnyModel,
|
14
|
+
AwaitRequest,
|
15
|
+
AwaitResume,
|
16
|
+
Error,
|
17
|
+
ErrorCode,
|
18
|
+
Event,
|
19
|
+
GenericEvent,
|
20
|
+
Message,
|
21
|
+
MessageCompletedEvent,
|
22
|
+
MessageCreatedEvent,
|
23
|
+
MessagePart,
|
24
|
+
MessagePartEvent,
|
25
|
+
Run,
|
26
|
+
RunAwaitingEvent,
|
27
|
+
RunCancelledEvent,
|
28
|
+
RunCompletedEvent,
|
29
|
+
RunCreatedEvent,
|
30
|
+
RunFailedEvent,
|
31
|
+
RunInProgressEvent,
|
32
|
+
RunStatus,
|
33
|
+
)
|
34
|
+
from acp_sdk.server.agent import Agent
|
35
|
+
from acp_sdk.server.logging import logger
|
36
|
+
from acp_sdk.server.store import Store
|
37
|
+
|
38
|
+
|
39
|
+
class RunData(BaseModel):
|
40
|
+
run: Run
|
41
|
+
input: list[Message]
|
42
|
+
events: list[Event] = []
|
43
|
+
|
44
|
+
@property
|
45
|
+
def key(self) -> str:
|
46
|
+
return str(self.run.run_id)
|
47
|
+
|
48
|
+
async def watch(self, store: Store[Self], *, ready: asyncio.Event | None = None) -> AsyncIterator[Self]:
|
49
|
+
async for data in store.watch(self.key, ready=ready):
|
50
|
+
if data is None:
|
51
|
+
raise RuntimeError("Missing data")
|
52
|
+
yield data
|
53
|
+
if data.run.status.is_terminal:
|
54
|
+
break
|
55
|
+
|
56
|
+
|
57
|
+
class CancelData(BaseModel):
|
58
|
+
pass
|
59
|
+
|
60
|
+
|
61
|
+
class Executor:
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
*,
|
65
|
+
agent: Agent,
|
66
|
+
run_data: RunData,
|
67
|
+
history: list[Message],
|
68
|
+
executor: ThreadPoolExecutor,
|
69
|
+
run_store: Store[RunData],
|
70
|
+
cancel_store: Store[CancelData],
|
71
|
+
resume_store: Store[AwaitResume],
|
72
|
+
) -> None:
|
73
|
+
self.agent = agent
|
74
|
+
self.history = history
|
75
|
+
self.run_data = run_data
|
76
|
+
self.executor = executor
|
77
|
+
|
78
|
+
self.run_store = run_store
|
79
|
+
self.cancel_store = cancel_store
|
80
|
+
self.resume_store = resume_store
|
81
|
+
|
82
|
+
self.logger = logging.LoggerAdapter(logger, {"run_id": str(run_data.run.run_id)})
|
83
|
+
|
84
|
+
def execute(self, *, wait: asyncio.Event) -> None:
|
85
|
+
self.task = asyncio.create_task(self._execute(self.run_data, executor=self.executor, wait=wait))
|
86
|
+
self.watcher = asyncio.create_task(self._watch_for_cancellation())
|
87
|
+
|
88
|
+
async def _push(self) -> None:
|
89
|
+
await self.run_store.set(self.run_data.run.run_id, self.run_data)
|
90
|
+
|
91
|
+
async def _emit(self, event: Event) -> None:
|
92
|
+
freeze = event.model_copy(deep=True)
|
93
|
+
self.run_data.events.append(freeze)
|
94
|
+
await self._push()
|
95
|
+
|
96
|
+
async def _await(self) -> AwaitResume:
|
97
|
+
async for resume in self.resume_store.watch(self.run_data.key):
|
98
|
+
if resume is not None:
|
99
|
+
await self.resume_store.set(self.run_data.key, None)
|
100
|
+
return resume
|
101
|
+
|
102
|
+
async def _watch_for_cancellation(self) -> None:
|
103
|
+
while not self.task.done():
|
104
|
+
try:
|
105
|
+
async for data in self.cancel_store.watch(self.run_data.key):
|
106
|
+
if data is not None:
|
107
|
+
self.task.cancel()
|
108
|
+
except Exception:
|
109
|
+
logger.warning("Cancellation watcher failed, restarting")
|
110
|
+
|
111
|
+
async def _execute(self, run_data: RunData, *, executor: ThreadPoolExecutor, wait: asyncio.Event) -> None:
|
112
|
+
with get_tracer().start_as_current_span("run"):
|
113
|
+
in_message = False
|
114
|
+
|
115
|
+
async def flush_message() -> None:
|
116
|
+
nonlocal in_message
|
117
|
+
if in_message:
|
118
|
+
message = run_data.run.output[-1]
|
119
|
+
message.completed_at = datetime.now(timezone.utc)
|
120
|
+
await self._emit(MessageCompletedEvent(message=message))
|
121
|
+
in_message = False
|
122
|
+
|
123
|
+
try:
|
124
|
+
await wait.wait()
|
125
|
+
|
126
|
+
await self._emit(RunCreatedEvent(run=run_data.run))
|
127
|
+
|
128
|
+
generator = self.agent.execute(
|
129
|
+
input=self.history + run_data.input, session_id=run_data.run.session_id, executor=executor
|
130
|
+
)
|
131
|
+
self.logger.info("Run started")
|
132
|
+
|
133
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
134
|
+
await self._emit(RunInProgressEvent(run=run_data.run))
|
135
|
+
|
136
|
+
await_resume = None
|
137
|
+
while True:
|
138
|
+
next = await generator.asend(await_resume)
|
139
|
+
|
140
|
+
if isinstance(next, (MessagePart, str)):
|
141
|
+
if isinstance(next, str):
|
142
|
+
next = MessagePart(content=next)
|
143
|
+
if not in_message:
|
144
|
+
run_data.run.output.append(Message(parts=[], completed_at=None))
|
145
|
+
in_message = True
|
146
|
+
await self._emit(MessageCreatedEvent(message=run_data.run.output[-1]))
|
147
|
+
run_data.run.output[-1].parts.append(next)
|
148
|
+
await self._emit(MessagePartEvent(part=next))
|
149
|
+
elif isinstance(next, Message):
|
150
|
+
await flush_message()
|
151
|
+
run_data.run.output.append(next)
|
152
|
+
await self._emit(MessageCreatedEvent(message=next))
|
153
|
+
for part in next.parts:
|
154
|
+
await self._emit(MessagePartEvent(part=part))
|
155
|
+
await self._emit(MessageCompletedEvent(message=next))
|
156
|
+
elif isinstance(next, AwaitRequest):
|
157
|
+
run_data.run.await_request = next
|
158
|
+
run_data.run.status = RunStatus.AWAITING
|
159
|
+
await self._emit(RunAwaitingEvent(run=run_data.run))
|
160
|
+
self.logger.info("Run awaited")
|
161
|
+
await_resume = await self._await()
|
162
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
163
|
+
await self._emit(RunInProgressEvent(run=run_data.run))
|
164
|
+
self.logger.info("Run resumed")
|
165
|
+
elif isinstance(next, Error):
|
166
|
+
raise ACPError(error=next)
|
167
|
+
elif isinstance(next, BaseException):
|
168
|
+
raise next
|
169
|
+
elif next is None:
|
170
|
+
await flush_message()
|
171
|
+
elif isinstance(next, BaseModel):
|
172
|
+
await self._emit(GenericEvent(generic=AnyModel(**next.model_dump())))
|
173
|
+
else:
|
174
|
+
try:
|
175
|
+
generic = AnyModel.model_validate(next)
|
176
|
+
await self._emit(GenericEvent(generic=generic))
|
177
|
+
except ValidationError:
|
178
|
+
raise TypeError("Invalid yield")
|
179
|
+
except StopAsyncIteration:
|
180
|
+
await flush_message()
|
181
|
+
run_data.run.status = RunStatus.COMPLETED
|
182
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
183
|
+
await self._emit(RunCompletedEvent(run=run_data.run))
|
184
|
+
self.logger.info("Run completed")
|
185
|
+
except asyncio.CancelledError:
|
186
|
+
run_data.run.status = RunStatus.CANCELLED
|
187
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
188
|
+
await self._emit(RunCancelledEvent(run=run_data.run))
|
189
|
+
self.logger.info("Run cancelled")
|
190
|
+
except Exception as e:
|
191
|
+
if isinstance(e, ACPError):
|
192
|
+
run_data.run.error = e.error
|
193
|
+
else:
|
194
|
+
run_data.run.error = Error(code=ErrorCode.SERVER_ERROR, message=str(e))
|
195
|
+
run_data.run.status = RunStatus.FAILED
|
196
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
197
|
+
await self._emit(RunFailedEvent(run=run_data.run))
|
198
|
+
self.logger.exception("Run failed")
|
acp_sdk/server/server.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import asyncio
|
2
2
|
import os
|
3
|
+
import re
|
3
4
|
from collections.abc import AsyncGenerator, Awaitable
|
4
5
|
from contextlib import asynccontextmanager
|
5
|
-
from datetime import timedelta
|
6
6
|
from typing import Any, Callable
|
7
7
|
|
8
8
|
import requests
|
@@ -16,6 +16,7 @@ from acp_sdk.server.agent import agent as agent_decorator
|
|
16
16
|
from acp_sdk.server.app import create_app
|
17
17
|
from acp_sdk.server.logging import configure_logger as configure_logger_func
|
18
18
|
from acp_sdk.server.logging import logger
|
19
|
+
from acp_sdk.server.store import Store
|
19
20
|
from acp_sdk.server.telemetry import configure_telemetry as configure_telemetry_func
|
20
21
|
from acp_sdk.server.utils import async_request_with_retry
|
21
22
|
|
@@ -54,8 +55,7 @@ class Server:
|
|
54
55
|
configure_logger: bool = True,
|
55
56
|
configure_telemetry: bool = False,
|
56
57
|
self_registration: bool = True,
|
57
|
-
|
58
|
-
run_ttl: timedelta = timedelta(hours=1),
|
58
|
+
store: Store | None = None,
|
59
59
|
host: str = "127.0.0.1",
|
60
60
|
port: int = 8000,
|
61
61
|
uds: str | None = None,
|
@@ -118,7 +118,7 @@ class Server:
|
|
118
118
|
|
119
119
|
import uvicorn
|
120
120
|
|
121
|
-
app = create_app(*self.agents, lifespan=self.lifespan,
|
121
|
+
app = create_app(*self.agents, lifespan=self.lifespan, store=store)
|
122
122
|
|
123
123
|
if configure_logger:
|
124
124
|
configure_logger_func()
|
@@ -184,8 +184,7 @@ class Server:
|
|
184
184
|
configure_logger: bool = True,
|
185
185
|
configure_telemetry: bool = False,
|
186
186
|
self_registration: bool = True,
|
187
|
-
|
188
|
-
run_ttl: timedelta = timedelta(hours=1),
|
187
|
+
store: Store | None = None,
|
189
188
|
host: str = "127.0.0.1",
|
190
189
|
port: int = 8000,
|
191
190
|
uds: str | None = None,
|
@@ -243,8 +242,7 @@ class Server:
|
|
243
242
|
configure_logger=configure_logger,
|
244
243
|
configure_telemetry=configure_telemetry,
|
245
244
|
self_registration=self_registration,
|
246
|
-
|
247
|
-
run_ttl=run_ttl,
|
245
|
+
store=store,
|
248
246
|
host=host,
|
249
247
|
port=port,
|
250
248
|
uds=uds,
|
@@ -311,27 +309,38 @@ class Server:
|
|
311
309
|
|
312
310
|
async def _register_agent(self) -> None:
|
313
311
|
"""If not in PRODUCTION mode, register agent to the beeai platform and provide missing env variables"""
|
314
|
-
if os.getenv("PRODUCTION_MODE",
|
312
|
+
if os.getenv("PRODUCTION_MODE", "").lower() in ["true", "1"]:
|
315
313
|
logger.debug("Agent is not automatically registered in the production mode.")
|
316
314
|
return
|
317
315
|
|
318
316
|
url = os.getenv("PLATFORM_URL", "http://127.0.0.1:8333")
|
317
|
+
host = re.sub(r"localhost|127\.0\.0\.1", "host.docker.internal", self.server.config.host)
|
319
318
|
request_data = {
|
320
|
-
"location": f"http://{
|
319
|
+
"location": f"http://{host}:{self.server.config.port}",
|
321
320
|
}
|
321
|
+
await async_request_with_retry(lambda client, data=request_data: client.get(f"{url}/api/v1/providers"))
|
322
322
|
try:
|
323
323
|
await async_request_with_retry(
|
324
|
-
lambda client, data=request_data: client.post(
|
324
|
+
lambda client, data=request_data: client.post(
|
325
|
+
f"{url}/api/v1/providers", json=data, params={"auto_remove": True}
|
326
|
+
)
|
325
327
|
)
|
326
328
|
logger.info("Agent registered to the beeai server.")
|
327
329
|
|
328
330
|
# check missing env keyes
|
329
331
|
envs_request = await async_request_with_retry(lambda client: client.get(f"{url}/api/v1/variables"))
|
330
332
|
envs = envs_request.get("env")
|
333
|
+
os.environ["LLM_MODEL"] = "dummy"
|
334
|
+
os.environ["LLM_API_KEY"] = "dummy"
|
335
|
+
os.environ["LLM_API_BASE"] = f"{url.rstrip('/')}/api/v1/llm"
|
336
|
+
|
331
337
|
for agent in self.agents:
|
332
338
|
# register all available envs
|
333
339
|
missing_keyes = []
|
334
340
|
for env in agent.metadata.model_dump().get("env", []):
|
341
|
+
# Those envs are set to use LLM gateway from platform server
|
342
|
+
if env["name"] in {"LLM_MODEL", "LLM_API_KEY", "LLM_API_BASE"}:
|
343
|
+
continue
|
335
344
|
server_env = envs.get(env.get("name"))
|
336
345
|
if server_env:
|
337
346
|
logger.debug(f"Env variable {env['name']} = '{server_env}' added dynamically")
|
acp_sdk/server/session.py
CHANGED
@@ -1,21 +1,24 @@
|
|
1
1
|
import uuid
|
2
|
-
from collections.abc import Iterator
|
3
2
|
|
4
|
-
from
|
5
|
-
from acp_sdk.models.models import RunStatus
|
6
|
-
from acp_sdk.server.bundle import RunBundle
|
3
|
+
from pydantic import BaseModel, Field
|
7
4
|
|
5
|
+
from acp_sdk.models import Message, RunId, RunStatus, SessionId
|
6
|
+
from acp_sdk.server.executor import RunData
|
7
|
+
from acp_sdk.server.store import Store
|
8
8
|
|
9
|
-
class Session:
|
10
|
-
def __init__(self, id: SessionId | None = None) -> None:
|
11
|
-
self.id: SessionId = id or uuid.uuid4()
|
12
|
-
self.bundles: list[RunBundle] = []
|
13
9
|
|
14
|
-
|
15
|
-
|
10
|
+
class Session(BaseModel):
|
11
|
+
id: SessionId = Field(default_factory=uuid.uuid4)
|
12
|
+
runs: list[RunId] = []
|
16
13
|
|
17
|
-
def
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
14
|
+
def append(self, run_id: RunId) -> None:
|
15
|
+
self.runs.append(run_id)
|
16
|
+
|
17
|
+
async def history(self, store: Store[RunData]) -> list[Message]:
|
18
|
+
history = []
|
19
|
+
for run_id in self.runs:
|
20
|
+
run_data = await store.get(run_id)
|
21
|
+
if run_data is not None and run_data.run.status == RunStatus.COMPLETED:
|
22
|
+
history.extend(run_data.input)
|
23
|
+
history.extend(run_data.run.output)
|
24
|
+
return history
|
@@ -0,0 +1,4 @@
|
|
1
|
+
from acp_sdk.server.store.memory_store import MemoryStore as MemoryStore
|
2
|
+
from acp_sdk.server.store.postgresql_store import PostgreSQLStore as PostgreSQLStore
|
3
|
+
from acp_sdk.server.store.redis_store import RedisStore as RedisStore
|
4
|
+
from acp_sdk.server.store.store import Store as Store
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Generic
|
5
|
+
|
6
|
+
from cachetools import TTLCache
|
7
|
+
|
8
|
+
from acp_sdk.server.store.store import Store, T
|
9
|
+
from acp_sdk.server.store.utils import Stringable
|
10
|
+
|
11
|
+
|
12
|
+
class MemoryStore(Store[T], Generic[T]):
|
13
|
+
def __init__(self, *, limit: int, ttl: int | None = None) -> None:
|
14
|
+
super().__init__()
|
15
|
+
self._cache: TTLCache[str, T] = TTLCache(maxsize=limit, ttl=ttl, timer=datetime.now)
|
16
|
+
self._event = asyncio.Event()
|
17
|
+
|
18
|
+
async def get(self, key: Stringable) -> T | None:
|
19
|
+
value = self._cache.get(str(key))
|
20
|
+
return value.model_copy(deep=True) if value else value
|
21
|
+
|
22
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
23
|
+
if value is None:
|
24
|
+
del self._cache[str(key)]
|
25
|
+
else:
|
26
|
+
self._cache[str(key)] = value.model_copy(deep=True)
|
27
|
+
self._event.set()
|
28
|
+
|
29
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
|
30
|
+
if ready:
|
31
|
+
ready.set()
|
32
|
+
while True:
|
33
|
+
await self._event.wait()
|
34
|
+
self._event.clear()
|
35
|
+
yield await self.get(key)
|
@@ -0,0 +1,69 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from typing import Generic
|
4
|
+
|
5
|
+
from psycopg import AsyncConnection
|
6
|
+
from psycopg.rows import dict_row
|
7
|
+
|
8
|
+
from acp_sdk.server.store.store import Store, StoreModel, T
|
9
|
+
from acp_sdk.server.store.utils import Stringable
|
10
|
+
|
11
|
+
|
12
|
+
class PostgreSQLStore(Store[T], Generic[T]):
|
13
|
+
def __init__(self, *, aconn: AsyncConnection, table: str = "acp_store", channel: str = "acp_update") -> None:
|
14
|
+
super().__init__()
|
15
|
+
self._aconn = aconn
|
16
|
+
self._table = table
|
17
|
+
self._channel = channel
|
18
|
+
|
19
|
+
async def get(self, key: Stringable) -> T | None:
|
20
|
+
await self._ensure_table()
|
21
|
+
async with self._aconn.cursor(row_factory=dict_row) as cur:
|
22
|
+
await cur.execute(f"SELECT value FROM {self._table} WHERE key = %s", (str(key),))
|
23
|
+
result = await cur.fetchone()
|
24
|
+
if result is None:
|
25
|
+
return None
|
26
|
+
return StoreModel.model_validate(result["value"])
|
27
|
+
|
28
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
29
|
+
await self._ensure_table()
|
30
|
+
async with self._aconn.cursor() as cur:
|
31
|
+
if value is None:
|
32
|
+
await cur.execute(
|
33
|
+
f"DELETE FROM {self._table} WHERE key = %s",
|
34
|
+
(str(key),),
|
35
|
+
)
|
36
|
+
else:
|
37
|
+
await cur.execute(
|
38
|
+
f"""
|
39
|
+
INSERT INTO {self._table} (key, value)
|
40
|
+
VALUES (%s, %s)
|
41
|
+
ON CONFLICT (key)
|
42
|
+
DO UPDATE SET value = EXCLUDED.value
|
43
|
+
""",
|
44
|
+
(str(key), value.model_dump_json()),
|
45
|
+
)
|
46
|
+
await cur.execute(f"NOTIFY {self._channel}, '{key!s}'") # NOTIFY appears not to accept params
|
47
|
+
await self._aconn.commit()
|
48
|
+
|
49
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
|
50
|
+
notify_conn = await AsyncConnection.connect(
|
51
|
+
conninfo=f"{self._aconn.info.dsn} password={self._aconn.info.password}", autocommit=True
|
52
|
+
)
|
53
|
+
async with notify_conn:
|
54
|
+
await notify_conn.execute(f"LISTEN {self._channel}")
|
55
|
+
if ready:
|
56
|
+
ready.set()
|
57
|
+
async for notify in notify_conn.notifies():
|
58
|
+
if notify.payload == str(key):
|
59
|
+
yield await self.get(key)
|
60
|
+
|
61
|
+
async def _ensure_table(self) -> None:
|
62
|
+
async with self._aconn.cursor() as cur:
|
63
|
+
await cur.execute(f"""
|
64
|
+
CREATE TABLE IF NOT EXISTS {self._table} (
|
65
|
+
key TEXT PRIMARY KEY,
|
66
|
+
value JSONB NOT NULL
|
67
|
+
)
|
68
|
+
""")
|
69
|
+
await self._aconn.commit()
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from typing import Generic
|
4
|
+
|
5
|
+
from redis.asyncio import Redis
|
6
|
+
|
7
|
+
from acp_sdk.server.store.store import Store, StoreModel, T
|
8
|
+
from acp_sdk.server.store.utils import Stringable
|
9
|
+
|
10
|
+
|
11
|
+
class RedisStore(Store[T], Generic[T]):
|
12
|
+
def __init__(self, *, redis: Redis) -> None:
|
13
|
+
super().__init__()
|
14
|
+
self._redis = redis
|
15
|
+
|
16
|
+
async def get(self, key: Stringable) -> T | None:
|
17
|
+
value = await self._redis.get(str(key))
|
18
|
+
return StoreModel.model_validate_json(value) if value else value
|
19
|
+
|
20
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
21
|
+
if value is None:
|
22
|
+
await self._redis.delete(str(key))
|
23
|
+
else:
|
24
|
+
await self._redis.set(name=str(key), value=value.model_dump_json())
|
25
|
+
|
26
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T]:
|
27
|
+
await self._redis.config_set("notify-keyspace-events", "KEA")
|
28
|
+
|
29
|
+
pubsub = self._redis.pubsub()
|
30
|
+
channel = f"__keyspace@0__:{key!s}"
|
31
|
+
await pubsub.subscribe(channel)
|
32
|
+
if ready:
|
33
|
+
ready.set()
|
34
|
+
try:
|
35
|
+
async for message in pubsub.listen():
|
36
|
+
if message["type"] == "message":
|
37
|
+
yield await self.get(key)
|
38
|
+
finally:
|
39
|
+
await pubsub.unsubscribe(channel)
|
40
|
+
await pubsub.close()
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import asyncio
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from collections.abc import AsyncIterator
|
4
|
+
from typing import Generic, TypeVar
|
5
|
+
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
7
|
+
|
8
|
+
from acp_sdk.server.store.utils import Stringable
|
9
|
+
|
10
|
+
|
11
|
+
class StoreModel(BaseModel):
|
12
|
+
model_config = ConfigDict(extra="allow")
|
13
|
+
|
14
|
+
|
15
|
+
T = TypeVar("T", bound=BaseModel)
|
16
|
+
U = TypeVar("U", bound=BaseModel)
|
17
|
+
|
18
|
+
|
19
|
+
class Store(Generic[T], ABC):
|
20
|
+
@abstractmethod
|
21
|
+
async def get(self, key: Stringable) -> T | None:
|
22
|
+
pass
|
23
|
+
|
24
|
+
@abstractmethod
|
25
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
|
30
|
+
pass
|
31
|
+
|
32
|
+
def as_store(self, model: type[U], prefix: Stringable = "") -> "Store[U]":
|
33
|
+
return StoreView(model=model, store=self, prefix=prefix)
|
34
|
+
|
35
|
+
|
36
|
+
class StoreView(Store[U], Generic[U]):
|
37
|
+
def __init__(self, *, model: type[U], store: Store[T], prefix: Stringable = "") -> None:
|
38
|
+
super().__init__()
|
39
|
+
self._model = model
|
40
|
+
self._store = store
|
41
|
+
self._prefix = prefix
|
42
|
+
|
43
|
+
async def get(self, key: Stringable) -> U | None:
|
44
|
+
value = await self._store.get(self._get_key(key))
|
45
|
+
return self._model.model_validate(value.model_dump()) if value else value
|
46
|
+
|
47
|
+
async def set(self, key: Stringable, value: U | None) -> None:
|
48
|
+
await self._store.set(self._get_key(key), value)
|
49
|
+
|
50
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[U | None]:
|
51
|
+
async for value in self._store.watch(self._get_key(key), ready=ready):
|
52
|
+
yield self._model.model_validate(value.model_dump()) if value else value
|
53
|
+
|
54
|
+
def _get_key(self, key: Stringable) -> str:
|
55
|
+
return f"{self._prefix!s}{key!s}"
|
acp_sdk/server/utils.py
CHANGED
@@ -6,17 +6,41 @@ import httpx
|
|
6
6
|
import requests
|
7
7
|
from pydantic import BaseModel
|
8
8
|
|
9
|
-
from acp_sdk.
|
9
|
+
from acp_sdk.models import RunStatus
|
10
|
+
from acp_sdk.server.executor import RunData
|
10
11
|
from acp_sdk.server.logging import logger
|
12
|
+
from acp_sdk.server.store.store import Store
|
11
13
|
|
12
14
|
|
13
15
|
def encode_sse(model: BaseModel) -> str:
|
14
16
|
return f"data: {model.model_dump_json()}\n\n"
|
15
17
|
|
16
18
|
|
17
|
-
async def
|
18
|
-
|
19
|
-
|
19
|
+
async def watch_util_stop(
|
20
|
+
run_data: RunData, store: Store[RunData], *, ready: asyncio.Event | None = None
|
21
|
+
) -> AsyncGenerator[RunData]:
|
22
|
+
async for data in run_data.watch(store, ready=ready):
|
23
|
+
yield data
|
24
|
+
if data.run.status == RunStatus.AWAITING:
|
25
|
+
break
|
26
|
+
|
27
|
+
|
28
|
+
async def wait_util_stop(run_data: RunData, store: Store[RunData], *, ready: asyncio.Event | None = None) -> RunData:
|
29
|
+
data = run_data
|
30
|
+
async for latest_data in watch_util_stop(run_data, store, ready=ready):
|
31
|
+
data = latest_data
|
32
|
+
return data
|
33
|
+
|
34
|
+
|
35
|
+
async def stream_sse(
|
36
|
+
run_data: RunData, store: Store[RunData], idx: int, *, ready: asyncio.Event | None = None
|
37
|
+
) -> AsyncGenerator[str]:
|
38
|
+
next_event_idx = idx
|
39
|
+
async for data in watch_util_stop(run_data, store, ready=ready):
|
40
|
+
new_events = data.events[next_event_idx:]
|
41
|
+
next_event_idx = len(data.events)
|
42
|
+
for event in new_events:
|
43
|
+
yield encode_sse(event)
|
20
44
|
|
21
45
|
|
22
46
|
async def async_request_with_retry(
|
@@ -1,27 +1,29 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: acp-sdk
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.1
|
4
4
|
Summary: Agent Communication Protocol SDK
|
5
5
|
Author: IBM Corp.
|
6
6
|
Maintainer-email: Tomas Pilar <thomas7pilar@gmail.com>
|
7
7
|
License-Expression: Apache-2.0
|
8
8
|
Requires-Python: <4.0,>=3.11
|
9
|
-
Requires-Dist: cachetools>=5.5
|
10
|
-
Requires-Dist: fastapi[standard]>=0.115
|
11
|
-
Requires-Dist: httpx-sse>=0.4
|
12
|
-
Requires-Dist: httpx>=0.26
|
13
|
-
Requires-Dist: janus>=2.0
|
14
|
-
Requires-Dist: opentelemetry-api>=1.31
|
15
|
-
Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31
|
9
|
+
Requires-Dist: cachetools>=5.5
|
10
|
+
Requires-Dist: fastapi[standard]>=0.115
|
11
|
+
Requires-Dist: httpx-sse>=0.4
|
12
|
+
Requires-Dist: httpx>=0.26
|
13
|
+
Requires-Dist: janus>=2.0
|
14
|
+
Requires-Dist: opentelemetry-api>=1.31
|
15
|
+
Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31
|
16
16
|
Requires-Dist: opentelemetry-instrumentation-fastapi>=0.52b1
|
17
17
|
Requires-Dist: opentelemetry-instrumentation-httpx>=0.52b1
|
18
|
-
Requires-Dist: opentelemetry-sdk>=1.31
|
19
|
-
Requires-Dist:
|
18
|
+
Requires-Dist: opentelemetry-sdk>=1.31
|
19
|
+
Requires-Dist: psycopg[binary]>=3.2
|
20
|
+
Requires-Dist: pydantic>=2.0
|
21
|
+
Requires-Dist: redis>=6.1
|
20
22
|
Description-Content-Type: text/markdown
|
21
23
|
|
22
24
|
# Agent Communication Protocol SDK for Python
|
23
25
|
|
24
|
-
Agent Communication Protocol SDK for Python
|
26
|
+
Agent Communication Protocol SDK for Python helps developers to serve and consume agents over the Agent Communication Protocol.
|
25
27
|
|
26
28
|
## Prerequisites
|
27
29
|
|
@@ -12,16 +12,22 @@ acp_sdk/models/models.py,sha256=So5D9VjJxX2jdOF8BGc3f9D2ssbAyFEO2Rgr1KykI-M,7730
|
|
12
12
|
acp_sdk/models/schemas.py,sha256=_ah7_zHsQJGxDXvnzsBvASdRsQHVphFQ7Sum6A04iRw,759
|
13
13
|
acp_sdk/server/__init__.py,sha256=mxBBBFaZuMEUENRMLwp1XZkuLeT9QghcFmNvjnqvAAU,377
|
14
14
|
acp_sdk/server/agent.py,sha256=6VBKn_qVXqUl79G8T7grwhnuLMwr67d4UGagMGX1hMs,6586
|
15
|
-
acp_sdk/server/app.py,sha256=
|
16
|
-
acp_sdk/server/bundle.py,sha256=umD2GgDp17lUddu0adpp1zUcm1JJvDrDpIZ0uR-6VeY,7204
|
15
|
+
acp_sdk/server/app.py,sha256=XeMn8hfU2J2O2QxSvsH5q6lowxCfwS1-DTWz-gUVotA,8733
|
17
16
|
acp_sdk/server/context.py,sha256=MgnLV6qcDIhc_0BjW7r4Jj1tHts4ZuwpdTGIBnz2Mgo,1036
|
18
17
|
acp_sdk/server/errors.py,sha256=GSO8yYIqEeX8Y4Lz86ks35dMTHiQiXuOrLYYx0eXsbI,2110
|
18
|
+
acp_sdk/server/executor.py,sha256=YL0J9cVY1QZtdTeqwjJaKDpB_T6_sByHlHc52kgNAJo,7742
|
19
19
|
acp_sdk/server/logging.py,sha256=Oc8yZigCsuDnHHPsarRzu0RX3NKaLEgpELM2yovGKDI,411
|
20
|
-
acp_sdk/server/server.py,sha256=
|
21
|
-
acp_sdk/server/session.py,sha256=
|
20
|
+
acp_sdk/server/server.py,sha256=slePybTpocXD8k9HxOvxCiyvigiDUdZmkz3kiI-dCFU,14113
|
21
|
+
acp_sdk/server/session.py,sha256=vGUVpKzUGefI1c7LeK08Bvd8zvJIRfsdJEt2KhYoEg0,764
|
22
22
|
acp_sdk/server/telemetry.py,sha256=lbB2ppijUcqbHUOn0e-15LGcVvT_qrMguq8qBokICac,2016
|
23
23
|
acp_sdk/server/types.py,sha256=gLb5wCkMYhmu2laj_ymK-TPfN9LSjRgKOP1H_893UzA,304
|
24
|
-
acp_sdk/server/utils.py,sha256=
|
25
|
-
acp_sdk
|
26
|
-
acp_sdk
|
27
|
-
acp_sdk
|
24
|
+
acp_sdk/server/utils.py,sha256=BYSn4Bd95Bn-oEH1W1yE_pWpYUOdtYPh-vMnou4nsdk,2721
|
25
|
+
acp_sdk/server/store/__init__.py,sha256=zzKic0byQTM86cyC2whwZeNP4prfy_HZrbSriTaV5j8,282
|
26
|
+
acp_sdk/server/store/memory_store.py,sha256=9hOoJiHdaDz-0hNHWhO_qZpp-I_pnTqYBIhQnlWilu4,1194
|
27
|
+
acp_sdk/server/store/postgresql_store.py,sha256=bHyAgf4vSn_07wGXXq6juFwm3JldYNOjU9RARcDSYQo,2717
|
28
|
+
acp_sdk/server/store/redis_store.py,sha256=IKXvDseOFMcoGjVYPOkOBhPnJAchy_RyeMayKLoVCGA,1378
|
29
|
+
acp_sdk/server/store/store.py,sha256=jGmYy9oiuVjhYYJY8QRo4g2J2Qyt1HLTmq_eHy4aI7c,1806
|
30
|
+
acp_sdk/server/store/utils.py,sha256=JumEOMs1h1uGlnHnUGeguee-srGzT7_Y2NVEYt01QuY,92
|
31
|
+
acp_sdk-0.10.1.dist-info/METADATA,sha256=2CkcWAS4e-6bIeoFsv2eyxLcbeTSgy15gq3VCP1dxbE,1685
|
32
|
+
acp_sdk-0.10.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
33
|
+
acp_sdk-0.10.1.dist-info/RECORD,,
|
acp_sdk/server/bundle.py
DELETED
@@ -1,182 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import logging
|
3
|
-
from collections.abc import AsyncGenerator
|
4
|
-
from concurrent.futures import ThreadPoolExecutor
|
5
|
-
from datetime import datetime, timezone
|
6
|
-
|
7
|
-
from pydantic import BaseModel, ValidationError
|
8
|
-
|
9
|
-
from acp_sdk.instrumentation import get_tracer
|
10
|
-
from acp_sdk.models import (
|
11
|
-
ACPError,
|
12
|
-
AnyModel,
|
13
|
-
AwaitRequest,
|
14
|
-
AwaitResume,
|
15
|
-
Error,
|
16
|
-
ErrorCode,
|
17
|
-
Event,
|
18
|
-
GenericEvent,
|
19
|
-
Message,
|
20
|
-
MessageCompletedEvent,
|
21
|
-
MessageCreatedEvent,
|
22
|
-
MessagePart,
|
23
|
-
MessagePartEvent,
|
24
|
-
Run,
|
25
|
-
RunAwaitingEvent,
|
26
|
-
RunCancelledEvent,
|
27
|
-
RunCompletedEvent,
|
28
|
-
RunCreatedEvent,
|
29
|
-
RunFailedEvent,
|
30
|
-
RunInProgressEvent,
|
31
|
-
RunStatus,
|
32
|
-
)
|
33
|
-
from acp_sdk.server.agent import Agent
|
34
|
-
from acp_sdk.server.logging import logger
|
35
|
-
|
36
|
-
|
37
|
-
class RunBundle:
|
38
|
-
def __init__(
|
39
|
-
self, *, agent: Agent, run: Run, input: list[Message], history: list[Message], executor: ThreadPoolExecutor
|
40
|
-
) -> None:
|
41
|
-
self.agent = agent
|
42
|
-
self.run = run
|
43
|
-
self.input = input
|
44
|
-
self.history = history
|
45
|
-
|
46
|
-
self.stream_queue: asyncio.Queue[Event] = asyncio.Queue()
|
47
|
-
self.events: list[Event] = []
|
48
|
-
|
49
|
-
self.await_queue: asyncio.Queue[AwaitResume] = asyncio.Queue(maxsize=1)
|
50
|
-
self.await_or_terminate_event = asyncio.Event()
|
51
|
-
|
52
|
-
self.task = asyncio.create_task(self._execute(input, executor=executor))
|
53
|
-
|
54
|
-
async def stream(self) -> AsyncGenerator[Event]:
|
55
|
-
while True:
|
56
|
-
event = await self.stream_queue.get()
|
57
|
-
if event is None:
|
58
|
-
break
|
59
|
-
yield event
|
60
|
-
self.stream_queue.task_done()
|
61
|
-
|
62
|
-
async def emit(self, event: Event) -> None:
|
63
|
-
freeze = event.model_copy(deep=True)
|
64
|
-
self.events.append(freeze)
|
65
|
-
await self.stream_queue.put(freeze)
|
66
|
-
|
67
|
-
async def await_(self) -> AwaitResume:
|
68
|
-
await self.stream_queue.put(None)
|
69
|
-
self.await_queue.empty()
|
70
|
-
self.await_or_terminate_event.set()
|
71
|
-
self.await_or_terminate_event.clear()
|
72
|
-
resume = await self.await_queue.get()
|
73
|
-
self.await_queue.task_done()
|
74
|
-
return resume
|
75
|
-
|
76
|
-
async def resume(self, resume: AwaitResume) -> None:
|
77
|
-
self.stream_queue = asyncio.Queue()
|
78
|
-
await self.await_queue.put(resume)
|
79
|
-
self.run.status = RunStatus.IN_PROGRESS
|
80
|
-
self.run.await_request = None
|
81
|
-
|
82
|
-
async def cancel(self) -> None:
|
83
|
-
self.task.cancel()
|
84
|
-
self.run.status = RunStatus.CANCELLING
|
85
|
-
self.run.await_request = None
|
86
|
-
|
87
|
-
async def join(self) -> None:
|
88
|
-
await self.await_or_terminate_event.wait()
|
89
|
-
|
90
|
-
async def _execute(self, input: list[Message], *, executor: ThreadPoolExecutor) -> None:
|
91
|
-
with get_tracer().start_as_current_span("run"):
|
92
|
-
run_logger = logging.LoggerAdapter(logger, {"run_id": str(self.run.run_id)})
|
93
|
-
|
94
|
-
in_message = False
|
95
|
-
|
96
|
-
async def flush_message() -> None:
|
97
|
-
nonlocal in_message
|
98
|
-
if in_message:
|
99
|
-
message = self.run.output[-1]
|
100
|
-
message.completed_at = datetime.now(timezone.utc)
|
101
|
-
await self.emit(MessageCompletedEvent(message=message))
|
102
|
-
in_message = False
|
103
|
-
|
104
|
-
try:
|
105
|
-
await self.emit(RunCreatedEvent(run=self.run))
|
106
|
-
|
107
|
-
generator = self.agent.execute(
|
108
|
-
input=self.history + input, session_id=self.run.session_id, executor=executor
|
109
|
-
)
|
110
|
-
run_logger.info("Run started")
|
111
|
-
|
112
|
-
self.run.status = RunStatus.IN_PROGRESS
|
113
|
-
await self.emit(RunInProgressEvent(run=self.run))
|
114
|
-
|
115
|
-
await_resume = None
|
116
|
-
while True:
|
117
|
-
next = await generator.asend(await_resume)
|
118
|
-
|
119
|
-
if isinstance(next, (MessagePart, str)):
|
120
|
-
if isinstance(next, str):
|
121
|
-
next = MessagePart(content=next)
|
122
|
-
if not in_message:
|
123
|
-
self.run.output.append(Message(parts=[], completed_at=None))
|
124
|
-
in_message = True
|
125
|
-
await self.emit(MessageCreatedEvent(message=self.run.output[-1]))
|
126
|
-
self.run.output[-1].parts.append(next)
|
127
|
-
await self.emit(MessagePartEvent(part=next))
|
128
|
-
elif isinstance(next, Message):
|
129
|
-
await flush_message()
|
130
|
-
self.run.output.append(next)
|
131
|
-
await self.emit(MessageCreatedEvent(message=next))
|
132
|
-
for part in next.parts:
|
133
|
-
await self.emit(MessagePartEvent(part=part))
|
134
|
-
await self.emit(MessageCompletedEvent(message=next))
|
135
|
-
elif isinstance(next, AwaitRequest):
|
136
|
-
self.run.await_request = next
|
137
|
-
self.run.status = RunStatus.AWAITING
|
138
|
-
await self.emit(RunAwaitingEvent(run=self.run))
|
139
|
-
run_logger.info("Run awaited")
|
140
|
-
await_resume = await self.await_()
|
141
|
-
await self.emit(RunInProgressEvent(run=self.run))
|
142
|
-
run_logger.info("Run resumed")
|
143
|
-
elif isinstance(next, Error):
|
144
|
-
raise ACPError(error=next)
|
145
|
-
elif isinstance(next, ACPError):
|
146
|
-
raise next
|
147
|
-
elif next is None:
|
148
|
-
await flush_message()
|
149
|
-
elif isinstance(next, BaseModel):
|
150
|
-
await self.emit(GenericEvent(generic=AnyModel(**next.model_dump())))
|
151
|
-
else:
|
152
|
-
try:
|
153
|
-
generic = AnyModel.model_validate(next)
|
154
|
-
await self.emit(GenericEvent(generic=generic))
|
155
|
-
except ValidationError:
|
156
|
-
raise TypeError("Invalid yield")
|
157
|
-
except StopAsyncIteration:
|
158
|
-
await flush_message()
|
159
|
-
self.run.status = RunStatus.COMPLETED
|
160
|
-
self.run.finished_at = datetime.now(timezone.utc)
|
161
|
-
await self.emit(RunCompletedEvent(run=self.run))
|
162
|
-
run_logger.info("Run completed")
|
163
|
-
except asyncio.CancelledError:
|
164
|
-
self.run.status = RunStatus.CANCELLED
|
165
|
-
self.run.finished_at = datetime.now(timezone.utc)
|
166
|
-
await self.emit(RunCancelledEvent(run=self.run))
|
167
|
-
run_logger.info("Run cancelled")
|
168
|
-
except Exception as e:
|
169
|
-
if isinstance(e, ACPError):
|
170
|
-
self.run.error = e.error
|
171
|
-
else:
|
172
|
-
self.run.error = Error(code=ErrorCode.SERVER_ERROR, message=str(e))
|
173
|
-
self.run.status = RunStatus.FAILED
|
174
|
-
self.run.finished_at = datetime.now(timezone.utc)
|
175
|
-
await self.emit(RunFailedEvent(run=self.run))
|
176
|
-
run_logger.exception("Run failed")
|
177
|
-
raise
|
178
|
-
finally:
|
179
|
-
self.await_or_terminate_event.set()
|
180
|
-
await self.stream_queue.put(None)
|
181
|
-
if not self.task.done():
|
182
|
-
self.task.cancel()
|
File without changes
|