acp-sdk 0.9.1__tar.gz → 0.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/PKG-INFO +13 -11
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/README.md +1 -1
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/pyproject.toml +17 -11
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/pytest.ini +2 -1
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/app.py +77 -42
- acp_sdk-0.10.0/src/acp_sdk/server/executor.py +198 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/server.py +6 -9
- acp_sdk-0.10.0/src/acp_sdk/server/session.py +24 -0
- acp_sdk-0.10.0/src/acp_sdk/server/store/__init__.py +4 -0
- acp_sdk-0.10.0/src/acp_sdk/server/store/memory_store.py +35 -0
- acp_sdk-0.10.0/src/acp_sdk/server/store/postgresql_store.py +69 -0
- acp_sdk-0.10.0/src/acp_sdk/server/store/redis_store.py +40 -0
- acp_sdk-0.10.0/src/acp_sdk/server/store/store.py +55 -0
- acp_sdk-0.10.0/src/acp_sdk/server/store/utils.py +5 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/utils.py +28 -4
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/e2e/fixtures/server.py +59 -5
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/e2e/test_suites/test_runs.py +1 -31
- acp_sdk-0.9.1/src/acp_sdk/server/bundle.py +0 -182
- acp_sdk-0.9.1/src/acp_sdk/server/session.py +0 -21
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/.gitignore +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/.python-version +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/docs/.gitignore +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/docs/Makefile +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/docs/conf.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/docs/index.rst +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/docs/make.bat +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/client/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/client/client.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/client/types.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/client/utils.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/instrumentation.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/models/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/models/errors.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/models/models.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/models/schemas.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/py.typed +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/agent.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/context.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/errors.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/logging.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/telemetry.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/server/types.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/src/acp_sdk/version.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/conftest.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/e2e/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/e2e/config.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/e2e/fixtures/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/e2e/fixtures/client.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/e2e/test_suites/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/e2e/test_suites/test_discovery.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/unit/client/test_client.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/unit/client/test_utils.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/unit/models/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/unit/models/test_models.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/unit/server/__init__.py +0 -0
- {acp_sdk-0.9.1 → acp_sdk-0.10.0}/tests/unit/server/test_server.py +0 -0
@@ -1,27 +1,29 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: acp-sdk
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: Agent Communication Protocol SDK
|
5
5
|
Author: IBM Corp.
|
6
6
|
Maintainer-email: Tomas Pilar <thomas7pilar@gmail.com>
|
7
7
|
License-Expression: Apache-2.0
|
8
8
|
Requires-Python: <4.0,>=3.11
|
9
|
-
Requires-Dist: cachetools>=5.5
|
10
|
-
Requires-Dist: fastapi[standard]>=0.115
|
11
|
-
Requires-Dist: httpx-sse>=0.4
|
12
|
-
Requires-Dist: httpx>=0.26
|
13
|
-
Requires-Dist: janus>=2.0
|
14
|
-
Requires-Dist: opentelemetry-api>=1.31
|
15
|
-
Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31
|
9
|
+
Requires-Dist: cachetools>=5.5
|
10
|
+
Requires-Dist: fastapi[standard]>=0.115
|
11
|
+
Requires-Dist: httpx-sse>=0.4
|
12
|
+
Requires-Dist: httpx>=0.26
|
13
|
+
Requires-Dist: janus>=2.0
|
14
|
+
Requires-Dist: opentelemetry-api>=1.31
|
15
|
+
Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31
|
16
16
|
Requires-Dist: opentelemetry-instrumentation-fastapi>=0.52b1
|
17
17
|
Requires-Dist: opentelemetry-instrumentation-httpx>=0.52b1
|
18
|
-
Requires-Dist: opentelemetry-sdk>=1.31
|
19
|
-
Requires-Dist:
|
18
|
+
Requires-Dist: opentelemetry-sdk>=1.31
|
19
|
+
Requires-Dist: psycopg[binary]>=3.2
|
20
|
+
Requires-Dist: pydantic>=2.0
|
21
|
+
Requires-Dist: redis>=6.1
|
20
22
|
Description-Content-Type: text/markdown
|
21
23
|
|
22
24
|
# Agent Communication Protocol SDK for Python
|
23
25
|
|
24
|
-
Agent Communication Protocol SDK for Python
|
26
|
+
Agent Communication Protocol SDK for Python helps developers to serve and consume agents over the Agent Communication Protocol.
|
25
27
|
|
26
28
|
## Prerequisites
|
27
29
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# Agent Communication Protocol SDK for Python
|
2
2
|
|
3
|
-
Agent Communication Protocol SDK for Python
|
3
|
+
Agent Communication Protocol SDK for Python helps developers to serve and consume agents over the Agent Communication Protocol.
|
4
4
|
|
5
5
|
## Prerequisites
|
6
6
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "acp-sdk"
|
3
|
-
version = "0.
|
3
|
+
version = "0.10.0"
|
4
4
|
description = "Agent Communication Protocol SDK"
|
5
5
|
license = "Apache-2.0"
|
6
6
|
readme = "README.md"
|
@@ -8,17 +8,19 @@ authors = [{ name = "IBM Corp." }]
|
|
8
8
|
maintainers = [{ name = "Tomas Pilar", email = "thomas7pilar@gmail.com" }]
|
9
9
|
requires-python = ">=3.11, <4.0"
|
10
10
|
dependencies = [
|
11
|
-
"opentelemetry-api>=1.31
|
12
|
-
"pydantic>=2.0
|
13
|
-
"httpx>=0.26
|
14
|
-
"httpx-sse>=0.4
|
11
|
+
"opentelemetry-api>=1.31",
|
12
|
+
"pydantic>=2.0",
|
13
|
+
"httpx>=0.26",
|
14
|
+
"httpx-sse>=0.4",
|
15
15
|
"opentelemetry-instrumentation-httpx>=0.52b1",
|
16
|
-
"fastapi[standard]>=0.115
|
17
|
-
"opentelemetry-exporter-otlp-proto-http>=1.31
|
16
|
+
"fastapi[standard]>=0.115",
|
17
|
+
"opentelemetry-exporter-otlp-proto-http>=1.31",
|
18
18
|
"opentelemetry-instrumentation-fastapi>=0.52b1",
|
19
|
-
"opentelemetry-sdk>=1.31
|
20
|
-
"janus>=2.0
|
21
|
-
"cachetools>=5.5
|
19
|
+
"opentelemetry-sdk>=1.31",
|
20
|
+
"janus>=2.0",
|
21
|
+
"cachetools>=5.5",
|
22
|
+
"redis>=6.1",
|
23
|
+
"psycopg[binary]>=3.2",
|
22
24
|
]
|
23
25
|
|
24
26
|
[build-system]
|
@@ -26,4 +28,8 @@ requires = ["hatchling"]
|
|
26
28
|
build-backend = "hatchling.build"
|
27
29
|
|
28
30
|
[dependency-groups]
|
29
|
-
dev = [
|
31
|
+
dev = [
|
32
|
+
"pytest-httpx>=0.35.0",
|
33
|
+
"pytest-postgresql>=7.0.2",
|
34
|
+
"pytest-redis>=3.1.3",
|
35
|
+
]
|
@@ -1,13 +1,14 @@
|
|
1
|
+
import asyncio
|
1
2
|
from collections.abc import AsyncGenerator
|
2
3
|
from concurrent.futures import ThreadPoolExecutor
|
3
4
|
from contextlib import asynccontextmanager
|
4
|
-
from datetime import
|
5
|
+
from datetime import timedelta
|
5
6
|
from enum import Enum
|
6
7
|
|
7
|
-
from cachetools import TTLCache
|
8
8
|
from fastapi import Depends, FastAPI, HTTPException, status
|
9
9
|
from fastapi.applications import AppType, Lifespan
|
10
10
|
from fastapi.encoders import jsonable_encoder
|
11
|
+
from fastapi.middleware.cors import CORSMiddleware
|
11
12
|
from fastapi.responses import JSONResponse, StreamingResponse
|
12
13
|
|
13
14
|
from acp_sdk.models import (
|
@@ -27,12 +28,11 @@ from acp_sdk.models import (
|
|
27
28
|
RunReadResponse,
|
28
29
|
RunResumeRequest,
|
29
30
|
RunResumeResponse,
|
30
|
-
SessionId,
|
31
31
|
)
|
32
32
|
from acp_sdk.models.errors import ACPError
|
33
|
+
from acp_sdk.models.models import AwaitResume, RunStatus
|
33
34
|
from acp_sdk.models.schemas import PingResponse
|
34
35
|
from acp_sdk.server.agent import Agent
|
35
|
-
from acp_sdk.server.bundle import RunBundle
|
36
36
|
from acp_sdk.server.errors import (
|
37
37
|
RequestValidationError,
|
38
38
|
StarletteHTTPException,
|
@@ -41,8 +41,10 @@ from acp_sdk.server.errors import (
|
|
41
41
|
http_exception_handler,
|
42
42
|
validation_exception_handler,
|
43
43
|
)
|
44
|
+
from acp_sdk.server.executor import CancelData, Executor, RunData
|
44
45
|
from acp_sdk.server.session import Session
|
45
|
-
from acp_sdk.server.
|
46
|
+
from acp_sdk.server.store import MemoryStore, Store
|
47
|
+
from acp_sdk.server.utils import stream_sse, wait_util_stop
|
46
48
|
|
47
49
|
|
48
50
|
class Headers(str, Enum):
|
@@ -51,8 +53,7 @@ class Headers(str, Enum):
|
|
51
53
|
|
52
54
|
def create_app(
|
53
55
|
*agents: Agent,
|
54
|
-
|
55
|
-
run_ttl: timedelta = timedelta(hours=1),
|
56
|
+
store: Store | None = None,
|
56
57
|
lifespan: Lifespan[AppType] | None = None,
|
57
58
|
dependencies: list[Depends] | None = None,
|
58
59
|
) -> FastAPI:
|
@@ -74,20 +75,37 @@ def create_app(
|
|
74
75
|
dependencies=dependencies,
|
75
76
|
)
|
76
77
|
|
78
|
+
app.add_middleware(
|
79
|
+
CORSMiddleware,
|
80
|
+
allow_origins=["https://agentcommunicationprotocol.dev"],
|
81
|
+
allow_methods=["*"],
|
82
|
+
allow_headers=["*"],
|
83
|
+
allow_credentials=True,
|
84
|
+
)
|
85
|
+
|
77
86
|
agents: dict[AgentName, Agent] = {agent.name: agent for agent in agents}
|
78
|
-
|
79
|
-
|
87
|
+
|
88
|
+
store = store or MemoryStore(limit=1000, ttl=timedelta(hours=1))
|
89
|
+
run_store = store.as_store(model=RunData, prefix="run_")
|
90
|
+
run_cancel_store = store.as_store(model=CancelData, prefix="run_cancel_")
|
91
|
+
run_resume_store = store.as_store(model=AwaitResume, prefix="run_resume_")
|
92
|
+
session_store = store.as_store(model=Session, prefix="session_")
|
80
93
|
|
81
94
|
app.exception_handler(ACPError)(acp_error_handler)
|
82
95
|
app.exception_handler(StarletteHTTPException)(http_exception_handler)
|
83
96
|
app.exception_handler(RequestValidationError)(validation_exception_handler)
|
84
97
|
app.exception_handler(Exception)(catch_all_exception_handler)
|
85
98
|
|
86
|
-
def
|
87
|
-
|
88
|
-
if not
|
99
|
+
async def find_run_data(run_id: RunId) -> RunData:
|
100
|
+
run_data = await run_store.get(run_id)
|
101
|
+
if not run_data:
|
89
102
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
90
|
-
|
103
|
+
if run_data.run.status.is_terminal:
|
104
|
+
return run_data
|
105
|
+
cancel_data = await run_cancel_store.get(run_data.key)
|
106
|
+
if cancel_data is not None:
|
107
|
+
run_data.run.status = RunStatus.CANCELLING
|
108
|
+
return run_data
|
91
109
|
|
92
110
|
def find_agent(agent_name: AgentName) -> Agent:
|
93
111
|
agent = agents.get(agent_name, None)
|
@@ -117,94 +135,111 @@ def create_app(
|
|
117
135
|
async def create_run(request: RunCreateRequest) -> RunCreateResponse:
|
118
136
|
agent = find_agent(request.agent_name)
|
119
137
|
|
120
|
-
session =
|
138
|
+
session = (
|
139
|
+
(await session_store.get(request.session_id)) or Session(id=request.session_id)
|
140
|
+
if request.session_id
|
141
|
+
else Session()
|
142
|
+
)
|
121
143
|
nonlocal executor
|
122
|
-
|
123
|
-
agent=agent,
|
144
|
+
run_data = RunData(
|
124
145
|
run=Run(agent_name=agent.name, session_id=session.id),
|
125
146
|
input=request.input,
|
126
|
-
history=list(session.history()),
|
127
|
-
executor=executor,
|
128
147
|
)
|
129
|
-
|
148
|
+
await run_store.set(run_data.key, run_data)
|
130
149
|
|
131
|
-
|
132
|
-
|
150
|
+
session.append(run_data.run.run_id)
|
151
|
+
await session_store.set(session.id, session)
|
133
152
|
|
134
|
-
headers = {Headers.RUN_ID: str(
|
153
|
+
headers = {Headers.RUN_ID: str(run_data.run.run_id)}
|
154
|
+
ready = asyncio.Event()
|
155
|
+
|
156
|
+
Executor(
|
157
|
+
agent=agent,
|
158
|
+
run_data=run_data,
|
159
|
+
history=await session.history(run_store),
|
160
|
+
run_store=run_store,
|
161
|
+
cancel_store=run_cancel_store,
|
162
|
+
resume_store=run_resume_store,
|
163
|
+
executor=executor,
|
164
|
+
).execute(wait=ready)
|
135
165
|
|
136
166
|
match request.mode:
|
137
167
|
case RunMode.STREAM:
|
138
168
|
return StreamingResponse(
|
139
|
-
stream_sse(
|
169
|
+
stream_sse(run_data, run_store, 0, ready=ready),
|
140
170
|
headers=headers,
|
141
171
|
media_type="text/event-stream",
|
142
172
|
)
|
143
173
|
case RunMode.SYNC:
|
144
|
-
await
|
174
|
+
await wait_util_stop(run_data, run_store, ready=ready)
|
145
175
|
return JSONResponse(
|
146
176
|
headers=headers,
|
147
|
-
content=jsonable_encoder(
|
177
|
+
content=jsonable_encoder(run_data.run),
|
148
178
|
)
|
149
179
|
case RunMode.ASYNC:
|
180
|
+
ready.set()
|
150
181
|
return JSONResponse(
|
151
182
|
status_code=status.HTTP_202_ACCEPTED,
|
152
183
|
headers=headers,
|
153
|
-
content=jsonable_encoder(
|
184
|
+
content=jsonable_encoder(run_data.run),
|
154
185
|
)
|
155
186
|
case _:
|
156
187
|
raise NotImplementedError()
|
157
188
|
|
158
189
|
@app.get("/runs/{run_id}")
|
159
190
|
async def read_run(run_id: RunId) -> RunReadResponse:
|
160
|
-
bundle =
|
191
|
+
bundle = await find_run_data(run_id)
|
161
192
|
return bundle.run
|
162
193
|
|
163
194
|
@app.get("/runs/{run_id}/events")
|
164
195
|
async def list_run_events(run_id: RunId) -> RunEventsListResponse:
|
165
|
-
bundle =
|
196
|
+
bundle = await find_run_data(run_id)
|
166
197
|
return RunEventsListResponse(events=bundle.events)
|
167
198
|
|
168
199
|
@app.post("/runs/{run_id}")
|
169
200
|
async def resume_run(run_id: RunId, request: RunResumeRequest) -> RunResumeResponse:
|
170
|
-
|
201
|
+
run_data = await find_run_data(run_id)
|
171
202
|
|
172
|
-
if
|
203
|
+
if run_data.run.await_request is None:
|
173
204
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Run {run_id} has no await request")
|
174
205
|
|
175
|
-
if
|
206
|
+
if run_data.run.await_request.type != request.await_resume.type:
|
176
207
|
raise HTTPException(
|
177
208
|
status_code=status.HTTP_403_FORBIDDEN,
|
178
|
-
detail=f"Run {run_id} is expecting resume of type {
|
209
|
+
detail=f"Run {run_id} is expecting resume of type {run_data.run.await_request.type}",
|
179
210
|
)
|
180
211
|
|
181
|
-
|
212
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
213
|
+
await run_store.set(run_data.key, run_data)
|
214
|
+
await run_resume_store.set(run_data.key, request.await_resume)
|
215
|
+
|
182
216
|
match request.mode:
|
183
217
|
case RunMode.STREAM:
|
184
218
|
return StreamingResponse(
|
185
|
-
stream_sse(
|
219
|
+
stream_sse(run_data, run_store, len(run_data.events)),
|
186
220
|
media_type="text/event-stream",
|
187
221
|
)
|
188
222
|
case RunMode.SYNC:
|
189
|
-
await
|
190
|
-
return
|
223
|
+
run_data = await wait_util_stop(run_data, run_store)
|
224
|
+
return run_data.run
|
191
225
|
case RunMode.ASYNC:
|
192
226
|
return JSONResponse(
|
193
227
|
status_code=status.HTTP_202_ACCEPTED,
|
194
|
-
content=jsonable_encoder(
|
228
|
+
content=jsonable_encoder(run_data.run),
|
195
229
|
)
|
196
230
|
case _:
|
197
231
|
raise NotImplementedError()
|
198
232
|
|
199
233
|
@app.post("/runs/{run_id}/cancel")
|
200
234
|
async def cancel_run(run_id: RunId) -> RunCancelResponse:
|
201
|
-
|
202
|
-
if
|
235
|
+
run_data = await find_run_data(run_id)
|
236
|
+
if run_data.run.status.is_terminal:
|
203
237
|
raise HTTPException(
|
204
238
|
status_code=status.HTTP_403_FORBIDDEN,
|
205
|
-
detail=f"Run in terminal status {
|
239
|
+
detail=f"Run in terminal status {run_data.run.status} can't be cancelled",
|
206
240
|
)
|
207
|
-
await
|
208
|
-
|
241
|
+
await run_cancel_store.set(run_data.key, CancelData())
|
242
|
+
run_data.run.status = RunStatus.CANCELLING
|
243
|
+
return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder(run_data.run))
|
209
244
|
|
210
245
|
return app
|
@@ -0,0 +1,198 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from collections.abc import AsyncIterator
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
5
|
+
from datetime import datetime, timezone
|
6
|
+
from typing import Self
|
7
|
+
|
8
|
+
from pydantic import BaseModel, ValidationError
|
9
|
+
|
10
|
+
from acp_sdk.instrumentation import get_tracer
|
11
|
+
from acp_sdk.models import (
|
12
|
+
ACPError,
|
13
|
+
AnyModel,
|
14
|
+
AwaitRequest,
|
15
|
+
AwaitResume,
|
16
|
+
Error,
|
17
|
+
ErrorCode,
|
18
|
+
Event,
|
19
|
+
GenericEvent,
|
20
|
+
Message,
|
21
|
+
MessageCompletedEvent,
|
22
|
+
MessageCreatedEvent,
|
23
|
+
MessagePart,
|
24
|
+
MessagePartEvent,
|
25
|
+
Run,
|
26
|
+
RunAwaitingEvent,
|
27
|
+
RunCancelledEvent,
|
28
|
+
RunCompletedEvent,
|
29
|
+
RunCreatedEvent,
|
30
|
+
RunFailedEvent,
|
31
|
+
RunInProgressEvent,
|
32
|
+
RunStatus,
|
33
|
+
)
|
34
|
+
from acp_sdk.server.agent import Agent
|
35
|
+
from acp_sdk.server.logging import logger
|
36
|
+
from acp_sdk.server.store import Store
|
37
|
+
|
38
|
+
|
39
|
+
class RunData(BaseModel):
|
40
|
+
run: Run
|
41
|
+
input: list[Message]
|
42
|
+
events: list[Event] = []
|
43
|
+
|
44
|
+
@property
|
45
|
+
def key(self) -> str:
|
46
|
+
return str(self.run.run_id)
|
47
|
+
|
48
|
+
async def watch(self, store: Store[Self], *, ready: asyncio.Event | None = None) -> AsyncIterator[Self]:
|
49
|
+
async for data in store.watch(self.key, ready=ready):
|
50
|
+
if data is None:
|
51
|
+
raise RuntimeError("Missing data")
|
52
|
+
yield data
|
53
|
+
if data.run.status.is_terminal:
|
54
|
+
break
|
55
|
+
|
56
|
+
|
57
|
+
class CancelData(BaseModel):
|
58
|
+
pass
|
59
|
+
|
60
|
+
|
61
|
+
class Executor:
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
*,
|
65
|
+
agent: Agent,
|
66
|
+
run_data: RunData,
|
67
|
+
history: list[Message],
|
68
|
+
executor: ThreadPoolExecutor,
|
69
|
+
run_store: Store[RunData],
|
70
|
+
cancel_store: Store[CancelData],
|
71
|
+
resume_store: Store[AwaitResume],
|
72
|
+
) -> None:
|
73
|
+
self.agent = agent
|
74
|
+
self.history = history
|
75
|
+
self.run_data = run_data
|
76
|
+
self.executor = executor
|
77
|
+
|
78
|
+
self.run_store = run_store
|
79
|
+
self.cancel_store = cancel_store
|
80
|
+
self.resume_store = resume_store
|
81
|
+
|
82
|
+
self.logger = logging.LoggerAdapter(logger, {"run_id": str(run_data.run.run_id)})
|
83
|
+
|
84
|
+
def execute(self, *, wait: asyncio.Event) -> None:
|
85
|
+
self.task = asyncio.create_task(self._execute(self.run_data, executor=self.executor, wait=wait))
|
86
|
+
self.watcher = asyncio.create_task(self._watch_for_cancellation())
|
87
|
+
|
88
|
+
async def _push(self) -> None:
|
89
|
+
await self.run_store.set(self.run_data.run.run_id, self.run_data)
|
90
|
+
|
91
|
+
async def _emit(self, event: Event) -> None:
|
92
|
+
freeze = event.model_copy(deep=True)
|
93
|
+
self.run_data.events.append(freeze)
|
94
|
+
await self._push()
|
95
|
+
|
96
|
+
async def _await(self) -> AwaitResume:
|
97
|
+
async for resume in self.resume_store.watch(self.run_data.key):
|
98
|
+
if resume is not None:
|
99
|
+
await self.resume_store.set(self.run_data.key, None)
|
100
|
+
return resume
|
101
|
+
|
102
|
+
async def _watch_for_cancellation(self) -> None:
|
103
|
+
while not self.task.done():
|
104
|
+
try:
|
105
|
+
async for data in self.cancel_store.watch(self.run_data.key):
|
106
|
+
if data is not None:
|
107
|
+
self.task.cancel()
|
108
|
+
except Exception:
|
109
|
+
logger.warning("Cancellation watcher failed, restarting")
|
110
|
+
|
111
|
+
async def _execute(self, run_data: RunData, *, executor: ThreadPoolExecutor, wait: asyncio.Event) -> None:
|
112
|
+
with get_tracer().start_as_current_span("run"):
|
113
|
+
in_message = False
|
114
|
+
|
115
|
+
async def flush_message() -> None:
|
116
|
+
nonlocal in_message
|
117
|
+
if in_message:
|
118
|
+
message = run_data.run.output[-1]
|
119
|
+
message.completed_at = datetime.now(timezone.utc)
|
120
|
+
await self._emit(MessageCompletedEvent(message=message))
|
121
|
+
in_message = False
|
122
|
+
|
123
|
+
try:
|
124
|
+
await wait.wait()
|
125
|
+
|
126
|
+
await self._emit(RunCreatedEvent(run=run_data.run))
|
127
|
+
|
128
|
+
generator = self.agent.execute(
|
129
|
+
input=self.history + run_data.input, session_id=run_data.run.session_id, executor=executor
|
130
|
+
)
|
131
|
+
self.logger.info("Run started")
|
132
|
+
|
133
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
134
|
+
await self._emit(RunInProgressEvent(run=run_data.run))
|
135
|
+
|
136
|
+
await_resume = None
|
137
|
+
while True:
|
138
|
+
next = await generator.asend(await_resume)
|
139
|
+
|
140
|
+
if isinstance(next, (MessagePart, str)):
|
141
|
+
if isinstance(next, str):
|
142
|
+
next = MessagePart(content=next)
|
143
|
+
if not in_message:
|
144
|
+
run_data.run.output.append(Message(parts=[], completed_at=None))
|
145
|
+
in_message = True
|
146
|
+
await self._emit(MessageCreatedEvent(message=run_data.run.output[-1]))
|
147
|
+
run_data.run.output[-1].parts.append(next)
|
148
|
+
await self._emit(MessagePartEvent(part=next))
|
149
|
+
elif isinstance(next, Message):
|
150
|
+
await flush_message()
|
151
|
+
run_data.run.output.append(next)
|
152
|
+
await self._emit(MessageCreatedEvent(message=next))
|
153
|
+
for part in next.parts:
|
154
|
+
await self._emit(MessagePartEvent(part=part))
|
155
|
+
await self._emit(MessageCompletedEvent(message=next))
|
156
|
+
elif isinstance(next, AwaitRequest):
|
157
|
+
run_data.run.await_request = next
|
158
|
+
run_data.run.status = RunStatus.AWAITING
|
159
|
+
await self._emit(RunAwaitingEvent(run=run_data.run))
|
160
|
+
self.logger.info("Run awaited")
|
161
|
+
await_resume = await self._await()
|
162
|
+
run_data.run.status = RunStatus.IN_PROGRESS
|
163
|
+
await self._emit(RunInProgressEvent(run=run_data.run))
|
164
|
+
self.logger.info("Run resumed")
|
165
|
+
elif isinstance(next, Error):
|
166
|
+
raise ACPError(error=next)
|
167
|
+
elif isinstance(next, BaseException):
|
168
|
+
raise next
|
169
|
+
elif next is None:
|
170
|
+
await flush_message()
|
171
|
+
elif isinstance(next, BaseModel):
|
172
|
+
await self._emit(GenericEvent(generic=AnyModel(**next.model_dump())))
|
173
|
+
else:
|
174
|
+
try:
|
175
|
+
generic = AnyModel.model_validate(next)
|
176
|
+
await self._emit(GenericEvent(generic=generic))
|
177
|
+
except ValidationError:
|
178
|
+
raise TypeError("Invalid yield")
|
179
|
+
except StopAsyncIteration:
|
180
|
+
await flush_message()
|
181
|
+
run_data.run.status = RunStatus.COMPLETED
|
182
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
183
|
+
await self._emit(RunCompletedEvent(run=run_data.run))
|
184
|
+
self.logger.info("Run completed")
|
185
|
+
except asyncio.CancelledError:
|
186
|
+
run_data.run.status = RunStatus.CANCELLED
|
187
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
188
|
+
await self._emit(RunCancelledEvent(run=run_data.run))
|
189
|
+
self.logger.info("Run cancelled")
|
190
|
+
except Exception as e:
|
191
|
+
if isinstance(e, ACPError):
|
192
|
+
run_data.run.error = e.error
|
193
|
+
else:
|
194
|
+
run_data.run.error = Error(code=ErrorCode.SERVER_ERROR, message=str(e))
|
195
|
+
run_data.run.status = RunStatus.FAILED
|
196
|
+
run_data.run.finished_at = datetime.now(timezone.utc)
|
197
|
+
await self._emit(RunFailedEvent(run=run_data.run))
|
198
|
+
self.logger.exception("Run failed")
|
@@ -2,7 +2,6 @@ import asyncio
|
|
2
2
|
import os
|
3
3
|
from collections.abc import AsyncGenerator, Awaitable
|
4
4
|
from contextlib import asynccontextmanager
|
5
|
-
from datetime import timedelta
|
6
5
|
from typing import Any, Callable
|
7
6
|
|
8
7
|
import requests
|
@@ -16,6 +15,7 @@ from acp_sdk.server.agent import agent as agent_decorator
|
|
16
15
|
from acp_sdk.server.app import create_app
|
17
16
|
from acp_sdk.server.logging import configure_logger as configure_logger_func
|
18
17
|
from acp_sdk.server.logging import logger
|
18
|
+
from acp_sdk.server.store import Store
|
19
19
|
from acp_sdk.server.telemetry import configure_telemetry as configure_telemetry_func
|
20
20
|
from acp_sdk.server.utils import async_request_with_retry
|
21
21
|
|
@@ -54,8 +54,7 @@ class Server:
|
|
54
54
|
configure_logger: bool = True,
|
55
55
|
configure_telemetry: bool = False,
|
56
56
|
self_registration: bool = True,
|
57
|
-
|
58
|
-
run_ttl: timedelta = timedelta(hours=1),
|
57
|
+
store: Store | None = None,
|
59
58
|
host: str = "127.0.0.1",
|
60
59
|
port: int = 8000,
|
61
60
|
uds: str | None = None,
|
@@ -118,7 +117,7 @@ class Server:
|
|
118
117
|
|
119
118
|
import uvicorn
|
120
119
|
|
121
|
-
app = create_app(*self.agents, lifespan=self.lifespan,
|
120
|
+
app = create_app(*self.agents, lifespan=self.lifespan, store=store)
|
122
121
|
|
123
122
|
if configure_logger:
|
124
123
|
configure_logger_func()
|
@@ -184,8 +183,7 @@ class Server:
|
|
184
183
|
configure_logger: bool = True,
|
185
184
|
configure_telemetry: bool = False,
|
186
185
|
self_registration: bool = True,
|
187
|
-
|
188
|
-
run_ttl: timedelta = timedelta(hours=1),
|
186
|
+
store: Store | None = None,
|
189
187
|
host: str = "127.0.0.1",
|
190
188
|
port: int = 8000,
|
191
189
|
uds: str | None = None,
|
@@ -243,8 +241,7 @@ class Server:
|
|
243
241
|
configure_logger=configure_logger,
|
244
242
|
configure_telemetry=configure_telemetry,
|
245
243
|
self_registration=self_registration,
|
246
|
-
|
247
|
-
run_ttl=run_ttl,
|
244
|
+
store=store,
|
248
245
|
host=host,
|
249
246
|
port=port,
|
250
247
|
uds=uds,
|
@@ -311,7 +308,7 @@ class Server:
|
|
311
308
|
|
312
309
|
async def _register_agent(self) -> None:
|
313
310
|
"""If not in PRODUCTION mode, register agent to the beeai platform and provide missing env variables"""
|
314
|
-
if os.getenv("PRODUCTION_MODE",
|
311
|
+
if os.getenv("PRODUCTION_MODE", "").lower() in ["true", "1"]:
|
315
312
|
logger.debug("Agent is not automatically registered in the production mode.")
|
316
313
|
return
|
317
314
|
|
@@ -0,0 +1,24 @@
|
|
1
|
+
import uuid
|
2
|
+
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
from acp_sdk.models import Message, RunId, RunStatus, SessionId
|
6
|
+
from acp_sdk.server.executor import RunData
|
7
|
+
from acp_sdk.server.store import Store
|
8
|
+
|
9
|
+
|
10
|
+
class Session(BaseModel):
|
11
|
+
id: SessionId = Field(default_factory=uuid.uuid4)
|
12
|
+
runs: list[RunId] = []
|
13
|
+
|
14
|
+
def append(self, run_id: RunId) -> None:
|
15
|
+
self.runs.append(run_id)
|
16
|
+
|
17
|
+
async def history(self, store: Store[RunData]) -> list[Message]:
|
18
|
+
history = []
|
19
|
+
for run_id in self.runs:
|
20
|
+
run_data = await store.get(run_id)
|
21
|
+
if run_data is not None and run_data.run.status == RunStatus.COMPLETED:
|
22
|
+
history.extend(run_data.input)
|
23
|
+
history.extend(run_data.run.output)
|
24
|
+
return history
|
@@ -0,0 +1,4 @@
|
|
1
|
+
from acp_sdk.server.store.memory_store import MemoryStore as MemoryStore
|
2
|
+
from acp_sdk.server.store.postgresql_store import PostgreSQLStore as PostgreSQLStore
|
3
|
+
from acp_sdk.server.store.redis_store import RedisStore as RedisStore
|
4
|
+
from acp_sdk.server.store.store import Store as Store
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Generic
|
5
|
+
|
6
|
+
from cachetools import TTLCache
|
7
|
+
|
8
|
+
from acp_sdk.server.store.store import Store, T
|
9
|
+
from acp_sdk.server.store.utils import Stringable
|
10
|
+
|
11
|
+
|
12
|
+
class MemoryStore(Store[T], Generic[T]):
|
13
|
+
def __init__(self, *, limit: int, ttl: int | None = None) -> None:
|
14
|
+
super().__init__()
|
15
|
+
self._cache: TTLCache[str, T] = TTLCache(maxsize=limit, ttl=ttl, timer=datetime.now)
|
16
|
+
self._event = asyncio.Event()
|
17
|
+
|
18
|
+
async def get(self, key: Stringable) -> T | None:
|
19
|
+
value = self._cache.get(str(key))
|
20
|
+
return value.model_copy(deep=True) if value else value
|
21
|
+
|
22
|
+
async def set(self, key: Stringable, value: T | None) -> None:
|
23
|
+
if value is None:
|
24
|
+
del self._cache[str(key)]
|
25
|
+
else:
|
26
|
+
self._cache[str(key)] = value.model_copy(deep=True)
|
27
|
+
self._event.set()
|
28
|
+
|
29
|
+
async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
|
30
|
+
if ready:
|
31
|
+
ready.set()
|
32
|
+
while True:
|
33
|
+
await self._event.wait()
|
34
|
+
self._event.clear()
|
35
|
+
yield await self.get(key)
|