acp-sdk 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
acp_sdk/client/client.py CHANGED
@@ -8,7 +8,6 @@ from typing import Self
8
8
 
9
9
  import httpx
10
10
  from httpx_sse import EventSource, aconnect_sse
11
- from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
12
11
  from pydantic import TypeAdapter
13
12
 
14
13
  from acp_sdk.client.types import Input
@@ -44,7 +43,6 @@ class Client:
44
43
  *,
45
44
  session_id: SessionId | None = None,
46
45
  client: httpx.AsyncClient | None = None,
47
- instrument: bool = True,
48
46
  auth: httpx._types.AuthTypes | None = None,
49
47
  params: httpx._types.QueryParamTypes | None = None,
50
48
  headers: httpx._types.HeaderTypes | None = None,
@@ -85,8 +83,6 @@ class Client:
85
83
  transport=transport,
86
84
  trust_env=trust_env,
87
85
  )
88
- if instrument:
89
- HTTPXClientInstrumentor.instrument_client(self._client)
90
86
 
91
87
  @property
92
88
  def client(self) -> httpx.AsyncClient:
@@ -108,7 +104,7 @@ class Client:
108
104
  async def session(self, session_id: SessionId | None = None) -> AsyncGenerator[Self]:
109
105
  session_id = session_id or uuid.uuid4()
110
106
  with get_tracer().start_as_current_span("session", attributes={"acp.session": str(session_id)}):
111
- yield Client(client=self._client, session_id=session_id, instrument=False)
107
+ yield Client(client=self._client, session_id=session_id)
112
108
 
113
109
  async def agents(self) -> AsyncIterator[Agent]:
114
110
  response = await self._client.get("/agents")
acp_sdk/server/app.py CHANGED
@@ -1,15 +1,15 @@
1
+ import asyncio
1
2
  from collections.abc import AsyncGenerator
2
3
  from concurrent.futures import ThreadPoolExecutor
3
4
  from contextlib import asynccontextmanager
4
- from datetime import datetime, timedelta
5
+ from datetime import timedelta
5
6
  from enum import Enum
6
7
 
7
- from cachetools import TTLCache
8
8
  from fastapi import Depends, FastAPI, HTTPException, status
9
9
  from fastapi.applications import AppType, Lifespan
10
10
  from fastapi.encoders import jsonable_encoder
11
+ from fastapi.middleware.cors import CORSMiddleware
11
12
  from fastapi.responses import JSONResponse, StreamingResponse
12
- from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
13
13
 
14
14
  from acp_sdk.models import (
15
15
  Agent as AgentModel,
@@ -28,12 +28,11 @@ from acp_sdk.models import (
28
28
  RunReadResponse,
29
29
  RunResumeRequest,
30
30
  RunResumeResponse,
31
- SessionId,
32
31
  )
33
32
  from acp_sdk.models.errors import ACPError
33
+ from acp_sdk.models.models import AwaitResume, RunStatus
34
34
  from acp_sdk.models.schemas import PingResponse
35
35
  from acp_sdk.server.agent import Agent
36
- from acp_sdk.server.bundle import RunBundle
37
36
  from acp_sdk.server.errors import (
38
37
  RequestValidationError,
39
38
  StarletteHTTPException,
@@ -42,8 +41,10 @@ from acp_sdk.server.errors import (
42
41
  http_exception_handler,
43
42
  validation_exception_handler,
44
43
  )
44
+ from acp_sdk.server.executor import CancelData, Executor, RunData
45
45
  from acp_sdk.server.session import Session
46
- from acp_sdk.server.utils import stream_sse
46
+ from acp_sdk.server.store import MemoryStore, Store
47
+ from acp_sdk.server.utils import stream_sse, wait_util_stop
47
48
 
48
49
 
49
50
  class Headers(str, Enum):
@@ -52,8 +53,7 @@ class Headers(str, Enum):
52
53
 
53
54
  def create_app(
54
55
  *agents: Agent,
55
- run_limit: int = 1000,
56
- run_ttl: timedelta = timedelta(hours=1),
56
+ store: Store | None = None,
57
57
  lifespan: Lifespan[AppType] | None = None,
58
58
  dependencies: list[Depends] | None = None,
59
59
  ) -> FastAPI:
@@ -75,22 +75,37 @@ def create_app(
75
75
  dependencies=dependencies,
76
76
  )
77
77
 
78
- FastAPIInstrumentor.instrument_app(app)
78
+ app.add_middleware(
79
+ CORSMiddleware,
80
+ allow_origins=["https://agentcommunicationprotocol.dev"],
81
+ allow_methods=["*"],
82
+ allow_headers=["*"],
83
+ allow_credentials=True,
84
+ )
79
85
 
80
86
  agents: dict[AgentName, Agent] = {agent.name: agent for agent in agents}
81
- runs: TTLCache[RunId, RunBundle] = TTLCache(maxsize=run_limit, ttl=run_ttl, timer=datetime.now)
82
- sessions: TTLCache[SessionId, Session] = TTLCache(maxsize=run_limit, ttl=run_ttl, timer=datetime.now)
87
+
88
+ store = store or MemoryStore(limit=1000, ttl=timedelta(hours=1))
89
+ run_store = store.as_store(model=RunData, prefix="run_")
90
+ run_cancel_store = store.as_store(model=CancelData, prefix="run_cancel_")
91
+ run_resume_store = store.as_store(model=AwaitResume, prefix="run_resume_")
92
+ session_store = store.as_store(model=Session, prefix="session_")
83
93
 
84
94
  app.exception_handler(ACPError)(acp_error_handler)
85
95
  app.exception_handler(StarletteHTTPException)(http_exception_handler)
86
96
  app.exception_handler(RequestValidationError)(validation_exception_handler)
87
97
  app.exception_handler(Exception)(catch_all_exception_handler)
88
98
 
89
- def find_run_bundle(run_id: RunId) -> RunBundle:
90
- bundle = runs.get(run_id)
91
- if not bundle:
99
+ async def find_run_data(run_id: RunId) -> RunData:
100
+ run_data = await run_store.get(run_id)
101
+ if not run_data:
92
102
  raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
93
- return bundle
103
+ if run_data.run.status.is_terminal:
104
+ return run_data
105
+ cancel_data = await run_cancel_store.get(run_data.key)
106
+ if cancel_data is not None:
107
+ run_data.run.status = RunStatus.CANCELLING
108
+ return run_data
94
109
 
95
110
  def find_agent(agent_name: AgentName) -> Agent:
96
111
  agent = agents.get(agent_name, None)
@@ -120,94 +135,111 @@ def create_app(
120
135
  async def create_run(request: RunCreateRequest) -> RunCreateResponse:
121
136
  agent = find_agent(request.agent_name)
122
137
 
123
- session = sessions.get(request.session_id, Session(id=request.session_id)) if request.session_id else Session()
138
+ session = (
139
+ (await session_store.get(request.session_id)) or Session(id=request.session_id)
140
+ if request.session_id
141
+ else Session()
142
+ )
124
143
  nonlocal executor
125
- bundle = RunBundle(
126
- agent=agent,
144
+ run_data = RunData(
127
145
  run=Run(agent_name=agent.name, session_id=session.id),
128
146
  input=request.input,
129
- history=list(session.history()),
130
- executor=executor,
131
147
  )
132
- session.append(bundle)
148
+ await run_store.set(run_data.key, run_data)
133
149
 
134
- runs[bundle.run.run_id] = bundle
135
- sessions[session.id] = session
150
+ session.append(run_data.run.run_id)
151
+ await session_store.set(session.id, session)
136
152
 
137
- headers = {Headers.RUN_ID: str(bundle.run.run_id)}
153
+ headers = {Headers.RUN_ID: str(run_data.run.run_id)}
154
+ ready = asyncio.Event()
155
+
156
+ Executor(
157
+ agent=agent,
158
+ run_data=run_data,
159
+ history=await session.history(run_store),
160
+ run_store=run_store,
161
+ cancel_store=run_cancel_store,
162
+ resume_store=run_resume_store,
163
+ executor=executor,
164
+ ).execute(wait=ready)
138
165
 
139
166
  match request.mode:
140
167
  case RunMode.STREAM:
141
168
  return StreamingResponse(
142
- stream_sse(bundle),
169
+ stream_sse(run_data, run_store, 0, ready=ready),
143
170
  headers=headers,
144
171
  media_type="text/event-stream",
145
172
  )
146
173
  case RunMode.SYNC:
147
- await bundle.join()
174
+ await wait_util_stop(run_data, run_store, ready=ready)
148
175
  return JSONResponse(
149
176
  headers=headers,
150
- content=jsonable_encoder(bundle.run),
177
+ content=jsonable_encoder(run_data.run),
151
178
  )
152
179
  case RunMode.ASYNC:
180
+ ready.set()
153
181
  return JSONResponse(
154
182
  status_code=status.HTTP_202_ACCEPTED,
155
183
  headers=headers,
156
- content=jsonable_encoder(bundle.run),
184
+ content=jsonable_encoder(run_data.run),
157
185
  )
158
186
  case _:
159
187
  raise NotImplementedError()
160
188
 
161
189
  @app.get("/runs/{run_id}")
162
190
  async def read_run(run_id: RunId) -> RunReadResponse:
163
- bundle = find_run_bundle(run_id)
191
+ bundle = await find_run_data(run_id)
164
192
  return bundle.run
165
193
 
166
194
  @app.get("/runs/{run_id}/events")
167
195
  async def list_run_events(run_id: RunId) -> RunEventsListResponse:
168
- bundle = find_run_bundle(run_id)
196
+ bundle = await find_run_data(run_id)
169
197
  return RunEventsListResponse(events=bundle.events)
170
198
 
171
199
  @app.post("/runs/{run_id}")
172
200
  async def resume_run(run_id: RunId, request: RunResumeRequest) -> RunResumeResponse:
173
- bundle = find_run_bundle(run_id)
201
+ run_data = await find_run_data(run_id)
174
202
 
175
- if bundle.run.await_request is None:
203
+ if run_data.run.await_request is None:
176
204
  raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Run {run_id} has no await request")
177
205
 
178
- if bundle.run.await_request.type != request.await_resume.type:
206
+ if run_data.run.await_request.type != request.await_resume.type:
179
207
  raise HTTPException(
180
208
  status_code=status.HTTP_403_FORBIDDEN,
181
- detail=f"Run {run_id} is expecting resume of type {bundle.run.await_request.type}",
209
+ detail=f"Run {run_id} is expecting resume of type {run_data.run.await_request.type}",
182
210
  )
183
211
 
184
- await bundle.resume(request.await_resume)
212
+ run_data.run.status = RunStatus.IN_PROGRESS
213
+ await run_store.set(run_data.key, run_data)
214
+ await run_resume_store.set(run_data.key, request.await_resume)
215
+
185
216
  match request.mode:
186
217
  case RunMode.STREAM:
187
218
  return StreamingResponse(
188
- stream_sse(bundle),
219
+ stream_sse(run_data, run_store, len(run_data.events)),
189
220
  media_type="text/event-stream",
190
221
  )
191
222
  case RunMode.SYNC:
192
- await bundle.join()
193
- return bundle.run
223
+ run_data = await wait_util_stop(run_data, run_store)
224
+ return run_data.run
194
225
  case RunMode.ASYNC:
195
226
  return JSONResponse(
196
227
  status_code=status.HTTP_202_ACCEPTED,
197
- content=jsonable_encoder(bundle.run),
228
+ content=jsonable_encoder(run_data.run),
198
229
  )
199
230
  case _:
200
231
  raise NotImplementedError()
201
232
 
202
233
  @app.post("/runs/{run_id}/cancel")
203
234
  async def cancel_run(run_id: RunId) -> RunCancelResponse:
204
- bundle = find_run_bundle(run_id)
205
- if bundle.run.status.is_terminal:
235
+ run_data = await find_run_data(run_id)
236
+ if run_data.run.status.is_terminal:
206
237
  raise HTTPException(
207
238
  status_code=status.HTTP_403_FORBIDDEN,
208
- detail=f"Run in terminal status {bundle.run.status} can't be cancelled",
239
+ detail=f"Run in terminal status {run_data.run.status} can't be cancelled",
209
240
  )
210
- await bundle.cancel()
211
- return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder(bundle.run))
241
+ await run_cancel_store.set(run_data.key, CancelData())
242
+ run_data.run.status = RunStatus.CANCELLING
243
+ return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder(run_data.run))
212
244
 
213
245
  return app
@@ -0,0 +1,198 @@
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import AsyncIterator
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from datetime import datetime, timezone
6
+ from typing import Self
7
+
8
+ from pydantic import BaseModel, ValidationError
9
+
10
+ from acp_sdk.instrumentation import get_tracer
11
+ from acp_sdk.models import (
12
+ ACPError,
13
+ AnyModel,
14
+ AwaitRequest,
15
+ AwaitResume,
16
+ Error,
17
+ ErrorCode,
18
+ Event,
19
+ GenericEvent,
20
+ Message,
21
+ MessageCompletedEvent,
22
+ MessageCreatedEvent,
23
+ MessagePart,
24
+ MessagePartEvent,
25
+ Run,
26
+ RunAwaitingEvent,
27
+ RunCancelledEvent,
28
+ RunCompletedEvent,
29
+ RunCreatedEvent,
30
+ RunFailedEvent,
31
+ RunInProgressEvent,
32
+ RunStatus,
33
+ )
34
+ from acp_sdk.server.agent import Agent
35
+ from acp_sdk.server.logging import logger
36
+ from acp_sdk.server.store import Store
37
+
38
+
39
+ class RunData(BaseModel):
40
+ run: Run
41
+ input: list[Message]
42
+ events: list[Event] = []
43
+
44
+ @property
45
+ def key(self) -> str:
46
+ return str(self.run.run_id)
47
+
48
+ async def watch(self, store: Store[Self], *, ready: asyncio.Event | None = None) -> AsyncIterator[Self]:
49
+ async for data in store.watch(self.key, ready=ready):
50
+ if data is None:
51
+ raise RuntimeError("Missing data")
52
+ yield data
53
+ if data.run.status.is_terminal:
54
+ break
55
+
56
+
57
+ class CancelData(BaseModel):
58
+ pass
59
+
60
+
61
+ class Executor:
62
+ def __init__(
63
+ self,
64
+ *,
65
+ agent: Agent,
66
+ run_data: RunData,
67
+ history: list[Message],
68
+ executor: ThreadPoolExecutor,
69
+ run_store: Store[RunData],
70
+ cancel_store: Store[CancelData],
71
+ resume_store: Store[AwaitResume],
72
+ ) -> None:
73
+ self.agent = agent
74
+ self.history = history
75
+ self.run_data = run_data
76
+ self.executor = executor
77
+
78
+ self.run_store = run_store
79
+ self.cancel_store = cancel_store
80
+ self.resume_store = resume_store
81
+
82
+ self.logger = logging.LoggerAdapter(logger, {"run_id": str(run_data.run.run_id)})
83
+
84
+ def execute(self, *, wait: asyncio.Event) -> None:
85
+ self.task = asyncio.create_task(self._execute(self.run_data, executor=self.executor, wait=wait))
86
+ self.watcher = asyncio.create_task(self._watch_for_cancellation())
87
+
88
+ async def _push(self) -> None:
89
+ await self.run_store.set(self.run_data.run.run_id, self.run_data)
90
+
91
+ async def _emit(self, event: Event) -> None:
92
+ freeze = event.model_copy(deep=True)
93
+ self.run_data.events.append(freeze)
94
+ await self._push()
95
+
96
+ async def _await(self) -> AwaitResume:
97
+ async for resume in self.resume_store.watch(self.run_data.key):
98
+ if resume is not None:
99
+ await self.resume_store.set(self.run_data.key, None)
100
+ return resume
101
+
102
+ async def _watch_for_cancellation(self) -> None:
103
+ while not self.task.done():
104
+ try:
105
+ async for data in self.cancel_store.watch(self.run_data.key):
106
+ if data is not None:
107
+ self.task.cancel()
108
+ except Exception:
109
+ logger.warning("Cancellation watcher failed, restarting")
110
+
111
+ async def _execute(self, run_data: RunData, *, executor: ThreadPoolExecutor, wait: asyncio.Event) -> None:
112
+ with get_tracer().start_as_current_span("run"):
113
+ in_message = False
114
+
115
+ async def flush_message() -> None:
116
+ nonlocal in_message
117
+ if in_message:
118
+ message = run_data.run.output[-1]
119
+ message.completed_at = datetime.now(timezone.utc)
120
+ await self._emit(MessageCompletedEvent(message=message))
121
+ in_message = False
122
+
123
+ try:
124
+ await wait.wait()
125
+
126
+ await self._emit(RunCreatedEvent(run=run_data.run))
127
+
128
+ generator = self.agent.execute(
129
+ input=self.history + run_data.input, session_id=run_data.run.session_id, executor=executor
130
+ )
131
+ self.logger.info("Run started")
132
+
133
+ run_data.run.status = RunStatus.IN_PROGRESS
134
+ await self._emit(RunInProgressEvent(run=run_data.run))
135
+
136
+ await_resume = None
137
+ while True:
138
+ next = await generator.asend(await_resume)
139
+
140
+ if isinstance(next, (MessagePart, str)):
141
+ if isinstance(next, str):
142
+ next = MessagePart(content=next)
143
+ if not in_message:
144
+ run_data.run.output.append(Message(parts=[], completed_at=None))
145
+ in_message = True
146
+ await self._emit(MessageCreatedEvent(message=run_data.run.output[-1]))
147
+ run_data.run.output[-1].parts.append(next)
148
+ await self._emit(MessagePartEvent(part=next))
149
+ elif isinstance(next, Message):
150
+ await flush_message()
151
+ run_data.run.output.append(next)
152
+ await self._emit(MessageCreatedEvent(message=next))
153
+ for part in next.parts:
154
+ await self._emit(MessagePartEvent(part=part))
155
+ await self._emit(MessageCompletedEvent(message=next))
156
+ elif isinstance(next, AwaitRequest):
157
+ run_data.run.await_request = next
158
+ run_data.run.status = RunStatus.AWAITING
159
+ await self._emit(RunAwaitingEvent(run=run_data.run))
160
+ self.logger.info("Run awaited")
161
+ await_resume = await self._await()
162
+ run_data.run.status = RunStatus.IN_PROGRESS
163
+ await self._emit(RunInProgressEvent(run=run_data.run))
164
+ self.logger.info("Run resumed")
165
+ elif isinstance(next, Error):
166
+ raise ACPError(error=next)
167
+ elif isinstance(next, BaseException):
168
+ raise next
169
+ elif next is None:
170
+ await flush_message()
171
+ elif isinstance(next, BaseModel):
172
+ await self._emit(GenericEvent(generic=AnyModel(**next.model_dump())))
173
+ else:
174
+ try:
175
+ generic = AnyModel.model_validate(next)
176
+ await self._emit(GenericEvent(generic=generic))
177
+ except ValidationError:
178
+ raise TypeError("Invalid yield")
179
+ except StopAsyncIteration:
180
+ await flush_message()
181
+ run_data.run.status = RunStatus.COMPLETED
182
+ run_data.run.finished_at = datetime.now(timezone.utc)
183
+ await self._emit(RunCompletedEvent(run=run_data.run))
184
+ self.logger.info("Run completed")
185
+ except asyncio.CancelledError:
186
+ run_data.run.status = RunStatus.CANCELLED
187
+ run_data.run.finished_at = datetime.now(timezone.utc)
188
+ await self._emit(RunCancelledEvent(run=run_data.run))
189
+ self.logger.info("Run cancelled")
190
+ except Exception as e:
191
+ if isinstance(e, ACPError):
192
+ run_data.run.error = e.error
193
+ else:
194
+ run_data.run.error = Error(code=ErrorCode.SERVER_ERROR, message=str(e))
195
+ run_data.run.status = RunStatus.FAILED
196
+ run_data.run.finished_at = datetime.now(timezone.utc)
197
+ await self._emit(RunFailedEvent(run=run_data.run))
198
+ self.logger.exception("Run failed")
acp_sdk/server/server.py CHANGED
@@ -2,7 +2,6 @@ import asyncio
2
2
  import os
3
3
  from collections.abc import AsyncGenerator, Awaitable
4
4
  from contextlib import asynccontextmanager
5
- from datetime import timedelta
6
5
  from typing import Any, Callable
7
6
 
8
7
  import requests
@@ -16,6 +15,7 @@ from acp_sdk.server.agent import agent as agent_decorator
16
15
  from acp_sdk.server.app import create_app
17
16
  from acp_sdk.server.logging import configure_logger as configure_logger_func
18
17
  from acp_sdk.server.logging import logger
18
+ from acp_sdk.server.store import Store
19
19
  from acp_sdk.server.telemetry import configure_telemetry as configure_telemetry_func
20
20
  from acp_sdk.server.utils import async_request_with_retry
21
21
 
@@ -54,8 +54,7 @@ class Server:
54
54
  configure_logger: bool = True,
55
55
  configure_telemetry: bool = False,
56
56
  self_registration: bool = True,
57
- run_limit: int = 1000,
58
- run_ttl: timedelta = timedelta(hours=1),
57
+ store: Store | None = None,
59
58
  host: str = "127.0.0.1",
60
59
  port: int = 8000,
61
60
  uds: str | None = None,
@@ -118,13 +117,15 @@ class Server:
118
117
 
119
118
  import uvicorn
120
119
 
120
+ app = create_app(*self.agents, lifespan=self.lifespan, store=store)
121
+
121
122
  if configure_logger:
122
123
  configure_logger_func()
123
124
  if configure_telemetry:
124
- configure_telemetry_func()
125
+ configure_telemetry_func(app)
125
126
 
126
127
  config = uvicorn.Config(
127
- create_app(*self.agents, lifespan=self.lifespan, run_limit=run_limit, run_ttl=run_ttl),
128
+ app,
128
129
  host,
129
130
  port,
130
131
  uds,
@@ -182,8 +183,7 @@ class Server:
182
183
  configure_logger: bool = True,
183
184
  configure_telemetry: bool = False,
184
185
  self_registration: bool = True,
185
- run_limit: int = 1000,
186
- run_ttl: timedelta = timedelta(hours=1),
186
+ store: Store | None = None,
187
187
  host: str = "127.0.0.1",
188
188
  port: int = 8000,
189
189
  uds: str | None = None,
@@ -241,8 +241,7 @@ class Server:
241
241
  configure_logger=configure_logger,
242
242
  configure_telemetry=configure_telemetry,
243
243
  self_registration=self_registration,
244
- run_limit=run_limit,
245
- run_ttl=run_ttl,
244
+ store=store,
246
245
  host=host,
247
246
  port=port,
248
247
  uds=uds,
@@ -309,7 +308,7 @@ class Server:
309
308
 
310
309
  async def _register_agent(self) -> None:
311
310
  """If not in PRODUCTION mode, register agent to the beeai platform and provide missing env variables"""
312
- if os.getenv("PRODUCTION_MODE", False):
311
+ if os.getenv("PRODUCTION_MODE", "").lower() in ["true", "1"]:
313
312
  logger.debug("Agent is not automatically registered in the production mode.")
314
313
  return
315
314
 
acp_sdk/server/session.py CHANGED
@@ -1,21 +1,24 @@
1
1
  import uuid
2
- from collections.abc import Iterator
3
2
 
4
- from acp_sdk.models import Message, SessionId
5
- from acp_sdk.models.models import RunStatus
6
- from acp_sdk.server.bundle import RunBundle
3
+ from pydantic import BaseModel, Field
7
4
 
5
+ from acp_sdk.models import Message, RunId, RunStatus, SessionId
6
+ from acp_sdk.server.executor import RunData
7
+ from acp_sdk.server.store import Store
8
8
 
9
- class Session:
10
- def __init__(self, id: SessionId | None = None) -> None:
11
- self.id: SessionId = id or uuid.uuid4()
12
- self.bundles: list[RunBundle] = []
13
9
 
14
- def append(self, bundle: RunBundle) -> None:
15
- self.bundles.append(bundle)
10
+ class Session(BaseModel):
11
+ id: SessionId = Field(default_factory=uuid.uuid4)
12
+ runs: list[RunId] = []
16
13
 
17
- def history(self) -> Iterator[Message]:
18
- for bundle in self.bundles:
19
- if bundle.run.status == RunStatus.COMPLETED:
20
- yield from bundle.input
21
- yield from bundle.run.output
14
+ def append(self, run_id: RunId) -> None:
15
+ self.runs.append(run_id)
16
+
17
+ async def history(self, store: Store[RunData]) -> list[Message]:
18
+ history = []
19
+ for run_id in self.runs:
20
+ run_data = await store.get(run_id)
21
+ if run_data is not None and run_data.run.status == RunStatus.COMPLETED:
22
+ history.extend(run_data.input)
23
+ history.extend(run_data.run.output)
24
+ return history
@@ -0,0 +1,4 @@
1
+ from acp_sdk.server.store.memory_store import MemoryStore as MemoryStore
2
+ from acp_sdk.server.store.postgresql_store import PostgreSQLStore as PostgreSQLStore
3
+ from acp_sdk.server.store.redis_store import RedisStore as RedisStore
4
+ from acp_sdk.server.store.store import Store as Store
@@ -0,0 +1,35 @@
1
+ import asyncio
2
+ from collections.abc import AsyncIterator
3
+ from datetime import datetime
4
+ from typing import Generic
5
+
6
+ from cachetools import TTLCache
7
+
8
+ from acp_sdk.server.store.store import Store, T
9
+ from acp_sdk.server.store.utils import Stringable
10
+
11
+
12
+ class MemoryStore(Store[T], Generic[T]):
13
+ def __init__(self, *, limit: int, ttl: int | None = None) -> None:
14
+ super().__init__()
15
+ self._cache: TTLCache[str, T] = TTLCache(maxsize=limit, ttl=ttl, timer=datetime.now)
16
+ self._event = asyncio.Event()
17
+
18
+ async def get(self, key: Stringable) -> T | None:
19
+ value = self._cache.get(str(key))
20
+ return value.model_copy(deep=True) if value else value
21
+
22
+ async def set(self, key: Stringable, value: T | None) -> None:
23
+ if value is None:
24
+ del self._cache[str(key)]
25
+ else:
26
+ self._cache[str(key)] = value.model_copy(deep=True)
27
+ self._event.set()
28
+
29
+ async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
30
+ if ready:
31
+ ready.set()
32
+ while True:
33
+ await self._event.wait()
34
+ self._event.clear()
35
+ yield await self.get(key)
@@ -0,0 +1,69 @@
1
+ import asyncio
2
+ from collections.abc import AsyncIterator
3
+ from typing import Generic
4
+
5
+ from psycopg import AsyncConnection
6
+ from psycopg.rows import dict_row
7
+
8
+ from acp_sdk.server.store.store import Store, StoreModel, T
9
+ from acp_sdk.server.store.utils import Stringable
10
+
11
+
12
+ class PostgreSQLStore(Store[T], Generic[T]):
13
+ def __init__(self, *, aconn: AsyncConnection, table: str = "acp_store", channel: str = "acp_update") -> None:
14
+ super().__init__()
15
+ self._aconn = aconn
16
+ self._table = table
17
+ self._channel = channel
18
+
19
+ async def get(self, key: Stringable) -> T | None:
20
+ await self._ensure_table()
21
+ async with self._aconn.cursor(row_factory=dict_row) as cur:
22
+ await cur.execute(f"SELECT value FROM {self._table} WHERE key = %s", (str(key),))
23
+ result = await cur.fetchone()
24
+ if result is None:
25
+ return None
26
+ return StoreModel.model_validate(result["value"])
27
+
28
+ async def set(self, key: Stringable, value: T | None) -> None:
29
+ await self._ensure_table()
30
+ async with self._aconn.cursor() as cur:
31
+ if value is None:
32
+ await cur.execute(
33
+ f"DELETE FROM {self._table} WHERE key = %s",
34
+ (str(key),),
35
+ )
36
+ else:
37
+ await cur.execute(
38
+ f"""
39
+ INSERT INTO {self._table} (key, value)
40
+ VALUES (%s, %s)
41
+ ON CONFLICT (key)
42
+ DO UPDATE SET value = EXCLUDED.value
43
+ """,
44
+ (str(key), value.model_dump_json()),
45
+ )
46
+ await cur.execute(f"NOTIFY {self._channel}, '{key!s}'") # NOTIFY appears not to accept params
47
+ await self._aconn.commit()
48
+
49
+ async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
50
+ notify_conn = await AsyncConnection.connect(
51
+ conninfo=f"{self._aconn.info.dsn} password={self._aconn.info.password}", autocommit=True
52
+ )
53
+ async with notify_conn:
54
+ await notify_conn.execute(f"LISTEN {self._channel}")
55
+ if ready:
56
+ ready.set()
57
+ async for notify in notify_conn.notifies():
58
+ if notify.payload == str(key):
59
+ yield await self.get(key)
60
+
61
+ async def _ensure_table(self) -> None:
62
+ async with self._aconn.cursor() as cur:
63
+ await cur.execute(f"""
64
+ CREATE TABLE IF NOT EXISTS {self._table} (
65
+ key TEXT PRIMARY KEY,
66
+ value JSONB NOT NULL
67
+ )
68
+ """)
69
+ await self._aconn.commit()
@@ -0,0 +1,40 @@
1
+ import asyncio
2
+ from collections.abc import AsyncIterator
3
+ from typing import Generic
4
+
5
+ from redis.asyncio import Redis
6
+
7
+ from acp_sdk.server.store.store import Store, StoreModel, T
8
+ from acp_sdk.server.store.utils import Stringable
9
+
10
+
11
+ class RedisStore(Store[T], Generic[T]):
12
+ def __init__(self, *, redis: Redis) -> None:
13
+ super().__init__()
14
+ self._redis = redis
15
+
16
+ async def get(self, key: Stringable) -> T | None:
17
+ value = await self._redis.get(str(key))
18
+ return StoreModel.model_validate_json(value) if value else value
19
+
20
+ async def set(self, key: Stringable, value: T | None) -> None:
21
+ if value is None:
22
+ await self._redis.delete(str(key))
23
+ else:
24
+ await self._redis.set(name=str(key), value=value.model_dump_json())
25
+
26
+ async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T]:
27
+ await self._redis.config_set("notify-keyspace-events", "KEA")
28
+
29
+ pubsub = self._redis.pubsub()
30
+ channel = f"__keyspace@0__:{key!s}"
31
+ await pubsub.subscribe(channel)
32
+ if ready:
33
+ ready.set()
34
+ try:
35
+ async for message in pubsub.listen():
36
+ if message["type"] == "message":
37
+ yield await self.get(key)
38
+ finally:
39
+ await pubsub.unsubscribe(channel)
40
+ await pubsub.close()
@@ -0,0 +1,55 @@
1
+ import asyncio
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import AsyncIterator
4
+ from typing import Generic, TypeVar
5
+
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+ from acp_sdk.server.store.utils import Stringable
9
+
10
+
11
+ class StoreModel(BaseModel):
12
+ model_config = ConfigDict(extra="allow")
13
+
14
+
15
+ T = TypeVar("T", bound=BaseModel)
16
+ U = TypeVar("U", bound=BaseModel)
17
+
18
+
19
+ class Store(Generic[T], ABC):
20
+ @abstractmethod
21
+ async def get(self, key: Stringable) -> T | None:
22
+ pass
23
+
24
+ @abstractmethod
25
+ async def set(self, key: Stringable, value: T | None) -> None:
26
+ pass
27
+
28
+ @abstractmethod
29
+ def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[T | None]:
30
+ pass
31
+
32
+ def as_store(self, model: type[U], prefix: Stringable = "") -> "Store[U]":
33
+ return StoreView(model=model, store=self, prefix=prefix)
34
+
35
+
36
+ class StoreView(Store[U], Generic[U]):
37
+ def __init__(self, *, model: type[U], store: Store[T], prefix: Stringable = "") -> None:
38
+ super().__init__()
39
+ self._model = model
40
+ self._store = store
41
+ self._prefix = prefix
42
+
43
+ async def get(self, key: Stringable) -> U | None:
44
+ value = await self._store.get(self._get_key(key))
45
+ return self._model.model_validate(value.model_dump()) if value else value
46
+
47
+ async def set(self, key: Stringable, value: U | None) -> None:
48
+ await self._store.set(self._get_key(key), value)
49
+
50
+ async def watch(self, key: Stringable, *, ready: asyncio.Event | None = None) -> AsyncIterator[U | None]:
51
+ async for value in self._store.watch(self._get_key(key), ready=ready):
52
+ yield self._model.model_validate(value.model_dump()) if value else value
53
+
54
+ def _get_key(self, key: Stringable) -> str:
55
+ return f"{self._prefix!s}{key!s}"
@@ -0,0 +1,5 @@
1
+ from typing import Protocol
2
+
3
+
4
+ class Stringable(Protocol):
5
+ def __str__(self) -> str: ...
@@ -1,9 +1,11 @@
1
1
  import logging
2
2
 
3
+ from fastapi import FastAPI
3
4
  from opentelemetry import metrics, trace
4
5
  from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
5
6
  from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
6
7
  from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
8
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
7
9
  from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
8
10
  from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
9
11
  from opentelemetry.sdk.metrics import MeterProvider
@@ -22,8 +24,10 @@ from acp_sdk.version import __version__
22
24
  root_logger = logging.getLogger()
23
25
 
24
26
 
25
- def configure_telemetry() -> None:
26
- """Utility that configures opentelemetry with OTLP exporter"""
27
+ def configure_telemetry(app: FastAPI) -> None:
28
+ """Utility that configures opentelemetry with OTLP exporter and FastAPI instrumentation"""
29
+
30
+ FastAPIInstrumentor.instrument_app(app)
27
31
 
28
32
  resource = Resource(
29
33
  attributes={
acp_sdk/server/utils.py CHANGED
@@ -6,17 +6,41 @@ import httpx
6
6
  import requests
7
7
  from pydantic import BaseModel
8
8
 
9
- from acp_sdk.server.bundle import RunBundle
9
+ from acp_sdk.models import RunStatus
10
+ from acp_sdk.server.executor import RunData
10
11
  from acp_sdk.server.logging import logger
12
+ from acp_sdk.server.store.store import Store
11
13
 
12
14
 
13
15
  def encode_sse(model: BaseModel) -> str:
14
16
  return f"data: {model.model_dump_json()}\n\n"
15
17
 
16
18
 
17
- async def stream_sse(bundle: RunBundle) -> AsyncGenerator[str]:
18
- async for event in bundle.stream():
19
- yield encode_sse(event)
19
+ async def watch_util_stop(
20
+ run_data: RunData, store: Store[RunData], *, ready: asyncio.Event | None = None
21
+ ) -> AsyncGenerator[RunData]:
22
+ async for data in run_data.watch(store, ready=ready):
23
+ yield data
24
+ if data.run.status == RunStatus.AWAITING:
25
+ break
26
+
27
+
28
+ async def wait_util_stop(run_data: RunData, store: Store[RunData], *, ready: asyncio.Event | None = None) -> RunData:
29
+ data = run_data
30
+ async for latest_data in watch_util_stop(run_data, store, ready=ready):
31
+ data = latest_data
32
+ return data
33
+
34
+
35
+ async def stream_sse(
36
+ run_data: RunData, store: Store[RunData], idx: int, *, ready: asyncio.Event | None = None
37
+ ) -> AsyncGenerator[str]:
38
+ next_event_idx = idx
39
+ async for data in watch_util_stop(run_data, store, ready=ready):
40
+ new_events = data.events[next_event_idx:]
41
+ next_event_idx = len(data.events)
42
+ for event in new_events:
43
+ yield encode_sse(event)
20
44
 
21
45
 
22
46
  async def async_request_with_retry(
@@ -1,27 +1,29 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: acp-sdk
3
- Version: 0.9.0
3
+ Version: 0.10.0
4
4
  Summary: Agent Communication Protocol SDK
5
5
  Author: IBM Corp.
6
6
  Maintainer-email: Tomas Pilar <thomas7pilar@gmail.com>
7
7
  License-Expression: Apache-2.0
8
8
  Requires-Python: <4.0,>=3.11
9
- Requires-Dist: cachetools>=5.5.2
10
- Requires-Dist: fastapi[standard]>=0.115.8
11
- Requires-Dist: httpx-sse>=0.4.0
12
- Requires-Dist: httpx>=0.26.0
13
- Requires-Dist: janus>=2.0.0
14
- Requires-Dist: opentelemetry-api>=1.31.1
15
- Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31.1
9
+ Requires-Dist: cachetools>=5.5
10
+ Requires-Dist: fastapi[standard]>=0.115
11
+ Requires-Dist: httpx-sse>=0.4
12
+ Requires-Dist: httpx>=0.26
13
+ Requires-Dist: janus>=2.0
14
+ Requires-Dist: opentelemetry-api>=1.31
15
+ Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.31
16
16
  Requires-Dist: opentelemetry-instrumentation-fastapi>=0.52b1
17
17
  Requires-Dist: opentelemetry-instrumentation-httpx>=0.52b1
18
- Requires-Dist: opentelemetry-sdk>=1.31.1
19
- Requires-Dist: pydantic>=2.0.0
18
+ Requires-Dist: opentelemetry-sdk>=1.31
19
+ Requires-Dist: psycopg[binary]>=3.2
20
+ Requires-Dist: pydantic>=2.0
21
+ Requires-Dist: redis>=6.1
20
22
  Description-Content-Type: text/markdown
21
23
 
22
24
  # Agent Communication Protocol SDK for Python
23
25
 
24
- Agent Communication Protocol SDK for Python provides allows developers to serve and consume agents over the Agent Communication Protocol.
26
+ Agent Communication Protocol SDK for Python helps developers to serve and consume agents over the Agent Communication Protocol.
25
27
 
26
28
  ## Prerequisites
27
29
 
@@ -3,7 +3,7 @@ acp_sdk/instrumentation.py,sha256=JqSyvILN3sGAfOZrmckQq4-M_4_5alyPn95DK0o5lfA,16
3
3
  acp_sdk/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  acp_sdk/version.py,sha256=Niy83rgvigB4hL_rR-O4ySvI7dj6xnqkyOe_JTymi9s,73
5
5
  acp_sdk/client/__init__.py,sha256=Bca1DORrswxzZsrR2aUFpATuNG2xNSmYvF1Z2WJaVbc,51
6
- acp_sdk/client/client.py,sha256=xQWBVJZuT7JmPS3dw6W1CidT9Uqi-XqZ6R-Xj04IefA,8718
6
+ acp_sdk/client/client.py,sha256=7GVD7eVlaSU78O3puMlyoOTqKjTwZYaBKLHyYvbvr3I,8504
7
7
  acp_sdk/client/types.py,sha256=_H6zYt-2OHOOYRtssRnbDIiwmgsl2-KIXc9lb-mJLFA,133
8
8
  acp_sdk/client/utils.py,sha256=2jhJyrPJmVFRoDJh0q_JMqOMlC3IxCh-6HXed-PIZS8,924
9
9
  acp_sdk/models/__init__.py,sha256=numSDBDT1QHx7n_Y3Deb5VOvKWcUBxbOEaMwQBSRHxc,151
@@ -12,16 +12,22 @@ acp_sdk/models/models.py,sha256=So5D9VjJxX2jdOF8BGc3f9D2ssbAyFEO2Rgr1KykI-M,7730
12
12
  acp_sdk/models/schemas.py,sha256=_ah7_zHsQJGxDXvnzsBvASdRsQHVphFQ7Sum6A04iRw,759
13
13
  acp_sdk/server/__init__.py,sha256=mxBBBFaZuMEUENRMLwp1XZkuLeT9QghcFmNvjnqvAAU,377
14
14
  acp_sdk/server/agent.py,sha256=6VBKn_qVXqUl79G8T7grwhnuLMwr67d4UGagMGX1hMs,6586
15
- acp_sdk/server/app.py,sha256=VR8UPR08SXebABvfvgdnjOR-mqV_Df8iSWdG_yG4wvg,7372
16
- acp_sdk/server/bundle.py,sha256=umD2GgDp17lUddu0adpp1zUcm1JJvDrDpIZ0uR-6VeY,7204
15
+ acp_sdk/server/app.py,sha256=XeMn8hfU2J2O2QxSvsH5q6lowxCfwS1-DTWz-gUVotA,8733
17
16
  acp_sdk/server/context.py,sha256=MgnLV6qcDIhc_0BjW7r4Jj1tHts4ZuwpdTGIBnz2Mgo,1036
18
17
  acp_sdk/server/errors.py,sha256=GSO8yYIqEeX8Y4Lz86ks35dMTHiQiXuOrLYYx0eXsbI,2110
18
+ acp_sdk/server/executor.py,sha256=YL0J9cVY1QZtdTeqwjJaKDpB_T6_sByHlHc52kgNAJo,7742
19
19
  acp_sdk/server/logging.py,sha256=Oc8yZigCsuDnHHPsarRzu0RX3NKaLEgpELM2yovGKDI,411
20
- acp_sdk/server/server.py,sha256=BPFs5vFAvSL2Xq-446gJHuWmvr3xUdqMQu4OXD4JIU8,13599
21
- acp_sdk/server/session.py,sha256=ekz1o6Sy1tQZlpaoS_VgbvFuUQh2qpiHG71mvBdvhgc,662
22
- acp_sdk/server/telemetry.py,sha256=1BUxNg-xL_Vqgs27PDWNc3HikrQW2lidAtT_FKlp_Qk,1833
20
+ acp_sdk/server/server.py,sha256=lgpvokjd3f_dqEkyfLn3Nr4oKXNHPUkvV0uNimUmXU8,13497
21
+ acp_sdk/server/session.py,sha256=vGUVpKzUGefI1c7LeK08Bvd8zvJIRfsdJEt2KhYoEg0,764
22
+ acp_sdk/server/telemetry.py,sha256=lbB2ppijUcqbHUOn0e-15LGcVvT_qrMguq8qBokICac,2016
23
23
  acp_sdk/server/types.py,sha256=gLb5wCkMYhmu2laj_ymK-TPfN9LSjRgKOP1H_893UzA,304
24
- acp_sdk/server/utils.py,sha256=BhZKBNaLgczX6aYjxYva-6VI1bKmHtYQ5YDA5LrwF50,1831
25
- acp_sdk-0.9.0.dist-info/METADATA,sha256=p92B6qlLrhgAxulgcoBn7GBJP-TaQZ8MjsMgyO-L6EY,1650
26
- acp_sdk-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
- acp_sdk-0.9.0.dist-info/RECORD,,
24
+ acp_sdk/server/utils.py,sha256=BYSn4Bd95Bn-oEH1W1yE_pWpYUOdtYPh-vMnou4nsdk,2721
25
+ acp_sdk/server/store/__init__.py,sha256=zzKic0byQTM86cyC2whwZeNP4prfy_HZrbSriTaV5j8,282
26
+ acp_sdk/server/store/memory_store.py,sha256=9hOoJiHdaDz-0hNHWhO_qZpp-I_pnTqYBIhQnlWilu4,1194
27
+ acp_sdk/server/store/postgresql_store.py,sha256=bHyAgf4vSn_07wGXXq6juFwm3JldYNOjU9RARcDSYQo,2717
28
+ acp_sdk/server/store/redis_store.py,sha256=IKXvDseOFMcoGjVYPOkOBhPnJAchy_RyeMayKLoVCGA,1378
29
+ acp_sdk/server/store/store.py,sha256=jGmYy9oiuVjhYYJY8QRo4g2J2Qyt1HLTmq_eHy4aI7c,1806
30
+ acp_sdk/server/store/utils.py,sha256=JumEOMs1h1uGlnHnUGeguee-srGzT7_Y2NVEYt01QuY,92
31
+ acp_sdk-0.10.0.dist-info/METADATA,sha256=wmPPGrdoeCoe7M8DM-r3qlLgpwVteHARkZpbKaVGk9I,1685
32
+ acp_sdk-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
33
+ acp_sdk-0.10.0.dist-info/RECORD,,
acp_sdk/server/bundle.py DELETED
@@ -1,182 +0,0 @@
1
- import asyncio
2
- import logging
3
- from collections.abc import AsyncGenerator
4
- from concurrent.futures import ThreadPoolExecutor
5
- from datetime import datetime, timezone
6
-
7
- from pydantic import BaseModel, ValidationError
8
-
9
- from acp_sdk.instrumentation import get_tracer
10
- from acp_sdk.models import (
11
- ACPError,
12
- AnyModel,
13
- AwaitRequest,
14
- AwaitResume,
15
- Error,
16
- ErrorCode,
17
- Event,
18
- GenericEvent,
19
- Message,
20
- MessageCompletedEvent,
21
- MessageCreatedEvent,
22
- MessagePart,
23
- MessagePartEvent,
24
- Run,
25
- RunAwaitingEvent,
26
- RunCancelledEvent,
27
- RunCompletedEvent,
28
- RunCreatedEvent,
29
- RunFailedEvent,
30
- RunInProgressEvent,
31
- RunStatus,
32
- )
33
- from acp_sdk.server.agent import Agent
34
- from acp_sdk.server.logging import logger
35
-
36
-
37
- class RunBundle:
38
- def __init__(
39
- self, *, agent: Agent, run: Run, input: list[Message], history: list[Message], executor: ThreadPoolExecutor
40
- ) -> None:
41
- self.agent = agent
42
- self.run = run
43
- self.input = input
44
- self.history = history
45
-
46
- self.stream_queue: asyncio.Queue[Event] = asyncio.Queue()
47
- self.events: list[Event] = []
48
-
49
- self.await_queue: asyncio.Queue[AwaitResume] = asyncio.Queue(maxsize=1)
50
- self.await_or_terminate_event = asyncio.Event()
51
-
52
- self.task = asyncio.create_task(self._execute(input, executor=executor))
53
-
54
- async def stream(self) -> AsyncGenerator[Event]:
55
- while True:
56
- event = await self.stream_queue.get()
57
- if event is None:
58
- break
59
- yield event
60
- self.stream_queue.task_done()
61
-
62
- async def emit(self, event: Event) -> None:
63
- freeze = event.model_copy(deep=True)
64
- self.events.append(freeze)
65
- await self.stream_queue.put(freeze)
66
-
67
- async def await_(self) -> AwaitResume:
68
- await self.stream_queue.put(None)
69
- self.await_queue.empty()
70
- self.await_or_terminate_event.set()
71
- self.await_or_terminate_event.clear()
72
- resume = await self.await_queue.get()
73
- self.await_queue.task_done()
74
- return resume
75
-
76
- async def resume(self, resume: AwaitResume) -> None:
77
- self.stream_queue = asyncio.Queue()
78
- await self.await_queue.put(resume)
79
- self.run.status = RunStatus.IN_PROGRESS
80
- self.run.await_request = None
81
-
82
- async def cancel(self) -> None:
83
- self.task.cancel()
84
- self.run.status = RunStatus.CANCELLING
85
- self.run.await_request = None
86
-
87
- async def join(self) -> None:
88
- await self.await_or_terminate_event.wait()
89
-
90
- async def _execute(self, input: list[Message], *, executor: ThreadPoolExecutor) -> None:
91
- with get_tracer().start_as_current_span("run"):
92
- run_logger = logging.LoggerAdapter(logger, {"run_id": str(self.run.run_id)})
93
-
94
- in_message = False
95
-
96
- async def flush_message() -> None:
97
- nonlocal in_message
98
- if in_message:
99
- message = self.run.output[-1]
100
- message.completed_at = datetime.now(timezone.utc)
101
- await self.emit(MessageCompletedEvent(message=message))
102
- in_message = False
103
-
104
- try:
105
- await self.emit(RunCreatedEvent(run=self.run))
106
-
107
- generator = self.agent.execute(
108
- input=self.history + input, session_id=self.run.session_id, executor=executor
109
- )
110
- run_logger.info("Run started")
111
-
112
- self.run.status = RunStatus.IN_PROGRESS
113
- await self.emit(RunInProgressEvent(run=self.run))
114
-
115
- await_resume = None
116
- while True:
117
- next = await generator.asend(await_resume)
118
-
119
- if isinstance(next, (MessagePart, str)):
120
- if isinstance(next, str):
121
- next = MessagePart(content=next)
122
- if not in_message:
123
- self.run.output.append(Message(parts=[], completed_at=None))
124
- in_message = True
125
- await self.emit(MessageCreatedEvent(message=self.run.output[-1]))
126
- self.run.output[-1].parts.append(next)
127
- await self.emit(MessagePartEvent(part=next))
128
- elif isinstance(next, Message):
129
- await flush_message()
130
- self.run.output.append(next)
131
- await self.emit(MessageCreatedEvent(message=next))
132
- for part in next.parts:
133
- await self.emit(MessagePartEvent(part=part))
134
- await self.emit(MessageCompletedEvent(message=next))
135
- elif isinstance(next, AwaitRequest):
136
- self.run.await_request = next
137
- self.run.status = RunStatus.AWAITING
138
- await self.emit(RunAwaitingEvent(run=self.run))
139
- run_logger.info("Run awaited")
140
- await_resume = await self.await_()
141
- await self.emit(RunInProgressEvent(run=self.run))
142
- run_logger.info("Run resumed")
143
- elif isinstance(next, Error):
144
- raise ACPError(error=next)
145
- elif isinstance(next, ACPError):
146
- raise next
147
- elif next is None:
148
- await flush_message()
149
- elif isinstance(next, BaseModel):
150
- await self.emit(GenericEvent(generic=AnyModel(**next.model_dump())))
151
- else:
152
- try:
153
- generic = AnyModel.model_validate(next)
154
- await self.emit(GenericEvent(generic=generic))
155
- except ValidationError:
156
- raise TypeError("Invalid yield")
157
- except StopAsyncIteration:
158
- await flush_message()
159
- self.run.status = RunStatus.COMPLETED
160
- self.run.finished_at = datetime.now(timezone.utc)
161
- await self.emit(RunCompletedEvent(run=self.run))
162
- run_logger.info("Run completed")
163
- except asyncio.CancelledError:
164
- self.run.status = RunStatus.CANCELLED
165
- self.run.finished_at = datetime.now(timezone.utc)
166
- await self.emit(RunCancelledEvent(run=self.run))
167
- run_logger.info("Run cancelled")
168
- except Exception as e:
169
- if isinstance(e, ACPError):
170
- self.run.error = e.error
171
- else:
172
- self.run.error = Error(code=ErrorCode.SERVER_ERROR, message=str(e))
173
- self.run.status = RunStatus.FAILED
174
- self.run.finished_at = datetime.now(timezone.utc)
175
- await self.emit(RunFailedEvent(run=self.run))
176
- run_logger.exception("Run failed")
177
- raise
178
- finally:
179
- self.await_or_terminate_event.set()
180
- await self.stream_queue.put(None)
181
- if not self.task.done():
182
- self.task.cancel()