acp-sdk 0.10.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acp_sdk/client/client.py +133 -42
- acp_sdk/models/__init__.py +1 -0
- acp_sdk/models/models.py +82 -10
- acp_sdk/models/platform.py +22 -0
- acp_sdk/models/schemas.py +18 -3
- acp_sdk/models/types.py +9 -0
- acp_sdk/server/__init__.py +5 -1
- acp_sdk/server/agent.py +8 -84
- acp_sdk/server/app.py +87 -29
- acp_sdk/server/context.py +11 -3
- acp_sdk/server/executor.py +139 -14
- acp_sdk/server/server.py +17 -4
- acp_sdk/server/store/memory_store.py +4 -4
- acp_sdk/server/utils.py +1 -1
- acp_sdk/shared/__init__.py +2 -0
- acp_sdk/shared/resources.py +46 -0
- {acp_sdk-0.10.1.dist-info → acp_sdk-0.12.0.dist-info}/METADATA +2 -1
- acp_sdk-0.12.0.dist-info/RECORD +36 -0
- acp_sdk/server/session.py +0 -24
- acp_sdk-0.10.1.dist-info/RECORD +0 -33
- {acp_sdk-0.10.1.dist-info → acp_sdk-0.12.0.dist-info}/WHEEL +0 -0
acp_sdk/server/agent.py
CHANGED
@@ -1,23 +1,14 @@
|
|
1
1
|
import abc
|
2
|
-
import asyncio
|
3
2
|
import inspect
|
4
3
|
from collections.abc import AsyncGenerator, Coroutine, Generator
|
5
|
-
from concurrent.futures import ThreadPoolExecutor
|
6
4
|
from typing import Callable
|
7
5
|
|
8
|
-
import
|
9
|
-
|
10
|
-
from acp_sdk.models import (
|
11
|
-
AgentName,
|
12
|
-
Message,
|
13
|
-
SessionId,
|
14
|
-
)
|
15
|
-
from acp_sdk.models.models import Metadata
|
6
|
+
from acp_sdk.models import AgentName, Message, Metadata
|
16
7
|
from acp_sdk.server.context import Context
|
17
8
|
from acp_sdk.server.types import RunYield, RunYieldResume
|
18
9
|
|
19
10
|
|
20
|
-
class
|
11
|
+
class AgentManifest(abc.ABC):
|
21
12
|
@property
|
22
13
|
def name(self) -> AgentName:
|
23
14
|
return self.__class__.__name__
|
@@ -38,75 +29,8 @@ class Agent(abc.ABC):
|
|
38
29
|
):
|
39
30
|
pass
|
40
31
|
|
41
|
-
|
42
|
-
|
43
|
-
) -> AsyncGenerator[RunYield, RunYieldResume]:
|
44
|
-
yield_queue: janus.Queue[RunYield] = janus.Queue()
|
45
|
-
yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue()
|
46
|
-
|
47
|
-
context = Context(
|
48
|
-
session_id=session_id, executor=executor, yield_queue=yield_queue, yield_resume_queue=yield_resume_queue
|
49
|
-
)
|
50
|
-
|
51
|
-
if inspect.isasyncgenfunction(self.run):
|
52
|
-
run = asyncio.create_task(self._run_async_gen(input, context))
|
53
|
-
elif inspect.iscoroutinefunction(self.run):
|
54
|
-
run = asyncio.create_task(self._run_coro(input, context))
|
55
|
-
elif inspect.isgeneratorfunction(self.run):
|
56
|
-
run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, input, context)
|
57
|
-
else:
|
58
|
-
run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context)
|
59
|
-
|
60
|
-
try:
|
61
|
-
while not run.done() or yield_queue.async_q.qsize() > 0:
|
62
|
-
value = yield await yield_queue.async_q.get()
|
63
|
-
if isinstance(value, Exception):
|
64
|
-
raise value
|
65
|
-
await yield_resume_queue.async_q.put(value)
|
66
|
-
except janus.AsyncQueueShutDown:
|
67
|
-
pass
|
68
|
-
|
69
|
-
async def _run_async_gen(self, input: list[Message], context: Context) -> None:
|
70
|
-
try:
|
71
|
-
gen: AsyncGenerator[RunYield, RunYieldResume] = self.run(input, context)
|
72
|
-
value = None
|
73
|
-
while True:
|
74
|
-
value = await context.yield_async(await gen.asend(value))
|
75
|
-
except StopAsyncIteration:
|
76
|
-
pass
|
77
|
-
except Exception as e:
|
78
|
-
await context.yield_async(e)
|
79
|
-
finally:
|
80
|
-
context.shutdown()
|
81
|
-
|
82
|
-
async def _run_coro(self, input: list[Message], context: Context) -> None:
|
83
|
-
try:
|
84
|
-
await context.yield_async(await self.run(input, context))
|
85
|
-
except Exception as e:
|
86
|
-
await context.yield_async(e)
|
87
|
-
finally:
|
88
|
-
context.shutdown()
|
89
|
-
|
90
|
-
def _run_gen(self, input: list[Message], context: Context) -> None:
|
91
|
-
try:
|
92
|
-
gen: Generator[RunYield, RunYieldResume] = self.run(input, context)
|
93
|
-
value = None
|
94
|
-
while True:
|
95
|
-
value = context.yield_sync(gen.send(value))
|
96
|
-
except StopIteration:
|
97
|
-
pass
|
98
|
-
except Exception as e:
|
99
|
-
context.yield_sync(e)
|
100
|
-
finally:
|
101
|
-
context.shutdown()
|
102
|
-
|
103
|
-
def _run_func(self, input: list[Message], context: Context) -> None:
|
104
|
-
try:
|
105
|
-
context.yield_sync(self.run(input, context))
|
106
|
-
except Exception as e:
|
107
|
-
context.yield_sync(e)
|
108
|
-
finally:
|
109
|
-
context.shutdown()
|
32
|
+
|
33
|
+
Agent = AgentManifest
|
110
34
|
|
111
35
|
|
112
36
|
def agent(
|
@@ -114,10 +38,10 @@ def agent(
|
|
114
38
|
description: str | None = None,
|
115
39
|
*,
|
116
40
|
metadata: Metadata | None = None,
|
117
|
-
) -> Callable[[Callable],
|
41
|
+
) -> Callable[[Callable], AgentManifest]:
|
118
42
|
"""Decorator to create an agent."""
|
119
43
|
|
120
|
-
def decorator(fn: Callable) ->
|
44
|
+
def decorator(fn: Callable) -> AgentManifest:
|
121
45
|
signature = inspect.signature(fn)
|
122
46
|
parameters = list(signature.parameters.values())
|
123
47
|
|
@@ -130,7 +54,7 @@ def agent(
|
|
130
54
|
|
131
55
|
has_context_param = len(parameters) == 2
|
132
56
|
|
133
|
-
class DecoratorAgentBase(
|
57
|
+
class DecoratorAgentBase(AgentManifest):
|
134
58
|
@property
|
135
59
|
def name(self) -> str:
|
136
60
|
return name or fn.__name__
|
@@ -143,7 +67,7 @@ def agent(
|
|
143
67
|
def metadata(self) -> Metadata:
|
144
68
|
return metadata or Metadata()
|
145
69
|
|
146
|
-
agent:
|
70
|
+
agent: AgentManifest
|
147
71
|
if inspect.isasyncgenfunction(fn):
|
148
72
|
|
149
73
|
class AsyncGenDecoratorAgent(DecoratorAgentBase):
|
acp_sdk/server/app.py
CHANGED
@@ -5,19 +5,24 @@ from contextlib import asynccontextmanager
|
|
5
5
|
from datetime import timedelta
|
6
6
|
from enum import Enum
|
7
7
|
|
8
|
-
|
8
|
+
import httpx
|
9
|
+
import obstore.store
|
10
|
+
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
9
11
|
from fastapi.applications import AppType, Lifespan
|
10
12
|
from fastapi.encoders import jsonable_encoder
|
11
13
|
from fastapi.middleware.cors import CORSMiddleware
|
12
14
|
from fastapi.responses import JSONResponse, StreamingResponse
|
15
|
+
from obstore.exceptions import NotFoundError
|
13
16
|
|
14
17
|
from acp_sdk.models import (
|
15
|
-
|
16
|
-
)
|
17
|
-
from acp_sdk.models import (
|
18
|
+
ACPError,
|
18
19
|
AgentName,
|
19
20
|
AgentReadResponse,
|
20
21
|
AgentsListResponse,
|
22
|
+
AwaitResume,
|
23
|
+
PingResponse,
|
24
|
+
ResourceId,
|
25
|
+
ResourceUrl,
|
21
26
|
Run,
|
22
27
|
RunCancelResponse,
|
23
28
|
RunCreateRequest,
|
@@ -28,11 +33,15 @@ from acp_sdk.models import (
|
|
28
33
|
RunReadResponse,
|
29
34
|
RunResumeRequest,
|
30
35
|
RunResumeResponse,
|
36
|
+
RunStatus,
|
37
|
+
Session,
|
38
|
+
SessionId,
|
39
|
+
SessionReadResponse,
|
40
|
+
)
|
41
|
+
from acp_sdk.models import (
|
42
|
+
AgentManifest as AgentModel,
|
31
43
|
)
|
32
|
-
from acp_sdk.
|
33
|
-
from acp_sdk.models.models import AwaitResume, RunStatus
|
34
|
-
from acp_sdk.models.schemas import PingResponse
|
35
|
-
from acp_sdk.server.agent import Agent
|
44
|
+
from acp_sdk.server.agent import AgentManifest
|
36
45
|
from acp_sdk.server.errors import (
|
37
46
|
RequestValidationError,
|
38
47
|
StarletteHTTPException,
|
@@ -42,9 +51,9 @@ from acp_sdk.server.errors import (
|
|
42
51
|
validation_exception_handler,
|
43
52
|
)
|
44
53
|
from acp_sdk.server.executor import CancelData, Executor, RunData
|
45
|
-
from acp_sdk.server.session import Session
|
46
54
|
from acp_sdk.server.store import MemoryStore, Store
|
47
55
|
from acp_sdk.server.utils import stream_sse, wait_util_stop
|
56
|
+
from acp_sdk.shared import ResourceLoader, ResourceStore
|
48
57
|
|
49
58
|
|
50
59
|
class Headers(str, Enum):
|
@@ -52,23 +61,34 @@ class Headers(str, Enum):
|
|
52
61
|
|
53
62
|
|
54
63
|
def create_app(
|
55
|
-
*agents:
|
64
|
+
*agents: AgentManifest,
|
56
65
|
store: Store | None = None,
|
66
|
+
resource_store: ResourceStore | None = None,
|
67
|
+
resource_loader: ResourceLoader | None = None,
|
68
|
+
forward_resources: bool = True,
|
57
69
|
lifespan: Lifespan[AppType] | None = None,
|
58
70
|
dependencies: list[Depends] | None = None,
|
59
71
|
) -> FastAPI:
|
72
|
+
if not forward_resources and (
|
73
|
+
resource_store is None
|
74
|
+
or isinstance(resource_store._store, (obstore.store.MemoryStore, obstore.store.LocalStore))
|
75
|
+
):
|
76
|
+
raise ValueError("Resource forwarding must be enabled when resource store does not support HTTP URLs")
|
77
|
+
|
60
78
|
executor: ThreadPoolExecutor
|
79
|
+
client = httpx.AsyncClient()
|
61
80
|
|
62
81
|
@asynccontextmanager
|
63
82
|
async def internal_lifespan(app: FastAPI) -> AsyncGenerator[None]:
|
64
83
|
nonlocal executor
|
65
|
-
with
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
84
|
+
async with client:
|
85
|
+
with ThreadPoolExecutor() as exec:
|
86
|
+
executor = exec
|
87
|
+
if not lifespan:
|
88
|
+
yield None
|
89
|
+
else:
|
90
|
+
async with lifespan(app) as state:
|
91
|
+
yield state
|
72
92
|
|
73
93
|
app = FastAPI(
|
74
94
|
lifespan=internal_lifespan,
|
@@ -83,7 +103,7 @@ def create_app(
|
|
83
103
|
allow_credentials=True,
|
84
104
|
)
|
85
105
|
|
86
|
-
agents: dict[AgentName,
|
106
|
+
agents: dict[AgentName, AgentManifest] = {agent.name: agent for agent in agents}
|
87
107
|
|
88
108
|
store = store or MemoryStore(limit=1000, ttl=timedelta(hours=1))
|
89
109
|
run_store = store.as_store(model=RunData, prefix="run_")
|
@@ -91,6 +111,9 @@ def create_app(
|
|
91
111
|
run_resume_store = store.as_store(model=AwaitResume, prefix="run_resume_")
|
92
112
|
session_store = store.as_store(model=Session, prefix="session_")
|
93
113
|
|
114
|
+
resource_loader = resource_loader or ResourceLoader(client=client)
|
115
|
+
resource_store = resource_store or ResourceStore(store=obstore.store.MemoryStore())
|
116
|
+
|
94
117
|
app.exception_handler(ACPError)(acp_error_handler)
|
95
118
|
app.exception_handler(StarletteHTTPException)(http_exception_handler)
|
96
119
|
app.exception_handler(RequestValidationError)(validation_exception_handler)
|
@@ -107,7 +130,7 @@ def create_app(
|
|
107
130
|
run_data.run.status = RunStatus.CANCELLING
|
108
131
|
return run_data
|
109
132
|
|
110
|
-
def find_agent(agent_name: AgentName) ->
|
133
|
+
def find_agent(agent_name: AgentName) -> AgentManifest:
|
111
134
|
agent = agents.get(agent_name, None)
|
112
135
|
if not agent:
|
113
136
|
raise HTTPException(status_code=404, detail=f"Agent {agent_name} not found")
|
@@ -132,36 +155,54 @@ def create_app(
|
|
132
155
|
return PingResponse()
|
133
156
|
|
134
157
|
@app.post("/runs")
|
135
|
-
async def create_run(request: RunCreateRequest) -> RunCreateResponse:
|
158
|
+
async def create_run(request: RunCreateRequest, req: Request) -> RunCreateResponse:
|
136
159
|
agent = find_agent(request.agent_name)
|
137
160
|
|
138
|
-
session
|
139
|
-
|
161
|
+
if request.session_id and request.session and request.session_id != request.session.id:
|
162
|
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Session ID mismatch")
|
163
|
+
|
164
|
+
session = request.session or (
|
165
|
+
(
|
166
|
+
await session_store.get(request.session_id)
|
167
|
+
or Session(id=request.session_id, loader=resource_loader, store=resource_store)
|
168
|
+
)
|
140
169
|
if request.session_id
|
141
|
-
else Session()
|
170
|
+
else Session(loader=resource_loader, store=resource_store)
|
142
171
|
)
|
172
|
+
|
143
173
|
nonlocal executor
|
144
174
|
run_data = RunData(
|
145
|
-
run=Run(
|
146
|
-
|
175
|
+
run=Run(
|
176
|
+
agent_name=agent.name,
|
177
|
+
session_id=session.id,
|
178
|
+
)
|
147
179
|
)
|
148
180
|
await run_store.set(run_data.key, run_data)
|
149
|
-
|
150
|
-
session.append(run_data.run.run_id)
|
151
181
|
await session_store.set(session.id, session)
|
152
182
|
|
153
183
|
headers = {Headers.RUN_ID: str(run_data.run.run_id)}
|
154
184
|
ready = asyncio.Event()
|
155
185
|
|
186
|
+
async def create_resource_url(id: ResourceId) -> ResourceUrl:
|
187
|
+
if forward_resources:
|
188
|
+
return ResourceUrl(url=str(req.url_for("get_resource", resource_id=id)))
|
189
|
+
else:
|
190
|
+
return await resource_store.url(id)
|
191
|
+
|
156
192
|
Executor(
|
157
193
|
agent=agent,
|
158
194
|
run_data=run_data,
|
159
|
-
|
195
|
+
session=session,
|
196
|
+
session_store=session_store,
|
160
197
|
run_store=run_store,
|
161
198
|
cancel_store=run_cancel_store,
|
162
199
|
resume_store=run_resume_store,
|
163
200
|
executor=executor,
|
164
|
-
|
201
|
+
request=req,
|
202
|
+
resource_store=resource_store,
|
203
|
+
resource_loader=resource_loader,
|
204
|
+
create_resource_url=create_resource_url,
|
205
|
+
).execute(request.input, wait=ready)
|
165
206
|
|
166
207
|
match request.mode:
|
167
208
|
case RunMode.STREAM:
|
@@ -242,4 +283,21 @@ def create_app(
|
|
242
283
|
run_data.run.status = RunStatus.CANCELLING
|
243
284
|
return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder(run_data.run))
|
244
285
|
|
286
|
+
@app.get("/sessions/{session_id}")
|
287
|
+
async def read_session(session_id: SessionId) -> SessionReadResponse:
|
288
|
+
session = await session_store.get(session_id)
|
289
|
+
if not session:
|
290
|
+
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
291
|
+
return session
|
292
|
+
|
293
|
+
if forward_resources:
|
294
|
+
|
295
|
+
@app.get("/resources/{resource_id}", name="get_resource")
|
296
|
+
async def read_resource(resource_id: ResourceId) -> StreamingResponse:
|
297
|
+
try:
|
298
|
+
result = await resource_store.load(resource_id)
|
299
|
+
return StreamingResponse(result)
|
300
|
+
except NotFoundError:
|
301
|
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Resource {resource_id} not found")
|
302
|
+
|
245
303
|
return app
|
acp_sdk/server/context.py
CHANGED
@@ -1,22 +1,30 @@
|
|
1
1
|
from concurrent.futures import ThreadPoolExecutor
|
2
2
|
|
3
3
|
import janus
|
4
|
+
from fastapi import Request
|
4
5
|
|
5
|
-
from acp_sdk.models import
|
6
|
+
from acp_sdk.models import Session
|
6
7
|
from acp_sdk.server.types import RunYield, RunYieldResume
|
8
|
+
from acp_sdk.shared import ResourceLoader, ResourceStore
|
7
9
|
|
8
10
|
|
9
11
|
class Context:
|
10
12
|
def __init__(
|
11
13
|
self,
|
12
14
|
*,
|
13
|
-
|
15
|
+
session: Session,
|
16
|
+
store: ResourceStore,
|
17
|
+
loader: ResourceLoader,
|
14
18
|
executor: ThreadPoolExecutor,
|
19
|
+
request: Request,
|
15
20
|
yield_queue: janus.Queue[RunYield],
|
16
21
|
yield_resume_queue: janus.Queue[RunYieldResume],
|
17
22
|
) -> None:
|
18
|
-
self.
|
23
|
+
self.session = session
|
24
|
+
self.storage = store
|
25
|
+
self.loader = loader
|
19
26
|
self.executor = executor
|
27
|
+
self.request = request
|
20
28
|
self._yield_queue = yield_queue
|
21
29
|
self._yield_resume_queue = yield_resume_queue
|
22
30
|
|
acp_sdk/server/executor.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1
1
|
import asyncio
|
2
|
+
import inspect
|
2
3
|
import logging
|
3
|
-
|
4
|
+
import uuid
|
5
|
+
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Generator
|
4
6
|
from concurrent.futures import ThreadPoolExecutor
|
5
7
|
from datetime import datetime, timezone
|
6
|
-
from typing import Self
|
8
|
+
from typing import Callable, Self
|
7
9
|
|
10
|
+
import janus
|
11
|
+
from fastapi import Request
|
8
12
|
from pydantic import BaseModel, ValidationError
|
9
13
|
|
10
14
|
from acp_sdk.instrumentation import get_tracer
|
@@ -22,6 +26,8 @@ from acp_sdk.models import (
|
|
22
26
|
MessageCreatedEvent,
|
23
27
|
MessagePart,
|
24
28
|
MessagePartEvent,
|
29
|
+
ResourceId,
|
30
|
+
ResourceUrl,
|
25
31
|
Run,
|
26
32
|
RunAwaitingEvent,
|
27
33
|
RunCancelledEvent,
|
@@ -30,15 +36,18 @@ from acp_sdk.models import (
|
|
30
36
|
RunFailedEvent,
|
31
37
|
RunInProgressEvent,
|
32
38
|
RunStatus,
|
39
|
+
Session,
|
33
40
|
)
|
34
|
-
from acp_sdk.server.agent import
|
41
|
+
from acp_sdk.server.agent import AgentManifest
|
42
|
+
from acp_sdk.server.context import Context
|
35
43
|
from acp_sdk.server.logging import logger
|
36
44
|
from acp_sdk.server.store import Store
|
45
|
+
from acp_sdk.server.types import RunYield, RunYieldResume
|
46
|
+
from acp_sdk.shared import ResourceLoader, ResourceStore
|
37
47
|
|
38
48
|
|
39
49
|
class RunData(BaseModel):
|
40
50
|
run: Run
|
41
|
-
input: list[Message]
|
42
51
|
events: list[Event] = []
|
43
52
|
|
44
53
|
@property
|
@@ -62,27 +71,38 @@ class Executor:
|
|
62
71
|
def __init__(
|
63
72
|
self,
|
64
73
|
*,
|
65
|
-
agent:
|
74
|
+
agent: AgentManifest,
|
66
75
|
run_data: RunData,
|
67
|
-
|
76
|
+
session: Session,
|
68
77
|
executor: ThreadPoolExecutor,
|
78
|
+
request: Request,
|
69
79
|
run_store: Store[RunData],
|
70
80
|
cancel_store: Store[CancelData],
|
71
81
|
resume_store: Store[AwaitResume],
|
82
|
+
session_store: Store[Session],
|
83
|
+
resource_store: ResourceStore,
|
84
|
+
resource_loader: ResourceLoader,
|
85
|
+
create_resource_url: Callable[[ResourceId], Awaitable[ResourceUrl]],
|
72
86
|
) -> None:
|
73
87
|
self.agent = agent
|
74
|
-
self.
|
88
|
+
self.session = session
|
75
89
|
self.run_data = run_data
|
76
90
|
self.executor = executor
|
91
|
+
self.request = request
|
77
92
|
|
78
93
|
self.run_store = run_store
|
79
94
|
self.cancel_store = cancel_store
|
80
95
|
self.resume_store = resume_store
|
96
|
+
self.session_store = session_store
|
97
|
+
self.resource_store = resource_store
|
98
|
+
self.resource_loader = resource_loader
|
99
|
+
|
100
|
+
self.create_resource_url = create_resource_url
|
81
101
|
|
82
102
|
self.logger = logging.LoggerAdapter(logger, {"run_id": str(run_data.run.run_id)})
|
83
103
|
|
84
|
-
def execute(self, *, wait: asyncio.Event) -> None:
|
85
|
-
self.task = asyncio.create_task(self._execute(
|
104
|
+
def execute(self, input: list[Message], *, wait: asyncio.Event) -> None:
|
105
|
+
self.task = asyncio.create_task(self._execute(input=input, executor=self.executor, wait=wait))
|
86
106
|
self.watcher = asyncio.create_task(self._watch_for_cancellation())
|
87
107
|
|
88
108
|
async def _push(self) -> None:
|
@@ -108,7 +128,16 @@ class Executor:
|
|
108
128
|
except Exception:
|
109
129
|
logger.warning("Cancellation watcher failed, restarting")
|
110
130
|
|
111
|
-
async def
|
131
|
+
async def _record_session(self, history: list[Message]) -> None:
|
132
|
+
for message in history:
|
133
|
+
id = uuid.uuid4()
|
134
|
+
url = await self.create_resource_url(id)
|
135
|
+
await self.resource_store.store(id, message.model_dump_json().encode())
|
136
|
+
self.session.history.append(url)
|
137
|
+
await self.session_store.set(self.session.id, self.session)
|
138
|
+
|
139
|
+
async def _execute(self, input: list[Message], *, executor: ThreadPoolExecutor, wait: asyncio.Event) -> None:
|
140
|
+
run_data = self.run_data
|
112
141
|
with get_tracer().start_as_current_span("run"):
|
113
142
|
in_message = False
|
114
143
|
|
@@ -118,15 +147,22 @@ class Executor:
|
|
118
147
|
message = run_data.run.output[-1]
|
119
148
|
message.completed_at = datetime.now(timezone.utc)
|
120
149
|
await self._emit(MessageCompletedEvent(message=message))
|
150
|
+
session_history.append(message)
|
121
151
|
in_message = False
|
122
152
|
|
153
|
+
session_history = input.copy()
|
123
154
|
try:
|
124
155
|
await wait.wait()
|
125
156
|
|
126
157
|
await self._emit(RunCreatedEvent(run=run_data.run))
|
127
158
|
|
128
|
-
generator = self.
|
129
|
-
input=
|
159
|
+
generator = self._execute_agent(
|
160
|
+
input=input,
|
161
|
+
session=self.session,
|
162
|
+
storage=self.resource_store,
|
163
|
+
loader=self.resource_loader,
|
164
|
+
executor=executor,
|
165
|
+
request=self.request,
|
130
166
|
)
|
131
167
|
self.logger.info("Run started")
|
132
168
|
|
@@ -141,18 +177,21 @@ class Executor:
|
|
141
177
|
if isinstance(next, str):
|
142
178
|
next = MessagePart(content=next)
|
143
179
|
if not in_message:
|
144
|
-
run_data.run.output.append(
|
180
|
+
run_data.run.output.append(
|
181
|
+
Message(role=f"agent/{self.agent.name}", parts=[], completed_at=None)
|
182
|
+
)
|
145
183
|
in_message = True
|
146
184
|
await self._emit(MessageCreatedEvent(message=run_data.run.output[-1]))
|
147
185
|
run_data.run.output[-1].parts.append(next)
|
148
186
|
await self._emit(MessagePartEvent(part=next))
|
149
187
|
elif isinstance(next, Message):
|
150
188
|
await flush_message()
|
151
|
-
run_data.run.output.append(next)
|
189
|
+
run_data.run.output.append(next.model_copy(update={"role": f"agent/{self.agent.name}"}))
|
152
190
|
await self._emit(MessageCreatedEvent(message=next))
|
153
191
|
for part in next.parts:
|
154
192
|
await self._emit(MessagePartEvent(part=part))
|
155
193
|
await self._emit(MessageCompletedEvent(message=next))
|
194
|
+
session_history.append(next)
|
156
195
|
elif isinstance(next, AwaitRequest):
|
157
196
|
run_data.run.await_request = next
|
158
197
|
run_data.run.status = RunStatus.AWAITING
|
@@ -180,6 +219,10 @@ class Executor:
|
|
180
219
|
await flush_message()
|
181
220
|
run_data.run.status = RunStatus.COMPLETED
|
182
221
|
run_data.run.finished_at = datetime.now(timezone.utc)
|
222
|
+
try:
|
223
|
+
await self._record_session(session_history)
|
224
|
+
except Exception as e:
|
225
|
+
self.logger.warning(f"Failed to record session: {e}")
|
183
226
|
await self._emit(RunCompletedEvent(run=run_data.run))
|
184
227
|
self.logger.info("Run completed")
|
185
228
|
except asyncio.CancelledError:
|
@@ -196,3 +239,85 @@ class Executor:
|
|
196
239
|
run_data.run.finished_at = datetime.now(timezone.utc)
|
197
240
|
await self._emit(RunFailedEvent(run=run_data.run))
|
198
241
|
self.logger.exception("Run failed")
|
242
|
+
|
243
|
+
async def _execute_agent(
|
244
|
+
self,
|
245
|
+
input: list[Message],
|
246
|
+
session: Session,
|
247
|
+
storage: ResourceStore,
|
248
|
+
loader: ResourceLoader,
|
249
|
+
executor: ThreadPoolExecutor,
|
250
|
+
request: Request,
|
251
|
+
) -> AsyncGenerator[RunYield, RunYieldResume]:
|
252
|
+
yield_queue: janus.Queue[RunYield] = janus.Queue()
|
253
|
+
yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue()
|
254
|
+
|
255
|
+
context = Context(
|
256
|
+
session=session,
|
257
|
+
store=storage,
|
258
|
+
loader=loader,
|
259
|
+
executor=executor,
|
260
|
+
request=request,
|
261
|
+
yield_queue=yield_queue,
|
262
|
+
yield_resume_queue=yield_resume_queue,
|
263
|
+
)
|
264
|
+
|
265
|
+
if inspect.isasyncgenfunction(self.agent.run):
|
266
|
+
run = asyncio.create_task(self._run_async_gen(input, context))
|
267
|
+
elif inspect.iscoroutinefunction(self.agent.run):
|
268
|
+
run = asyncio.create_task(self._run_coro(input, context))
|
269
|
+
elif inspect.isgeneratorfunction(self.agent.run):
|
270
|
+
run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, input, context)
|
271
|
+
else:
|
272
|
+
run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context)
|
273
|
+
|
274
|
+
try:
|
275
|
+
while not run.done() or yield_queue.async_q.qsize() > 0:
|
276
|
+
value = yield await yield_queue.async_q.get()
|
277
|
+
if isinstance(value, Exception):
|
278
|
+
raise value
|
279
|
+
await yield_resume_queue.async_q.put(value)
|
280
|
+
except janus.AsyncQueueShutDown:
|
281
|
+
pass
|
282
|
+
|
283
|
+
async def _run_async_gen(self, input: list[Message], context: Context) -> None:
|
284
|
+
try:
|
285
|
+
gen: AsyncGenerator[RunYield, RunYieldResume] = self.agent.run(input, context)
|
286
|
+
value = None
|
287
|
+
while True:
|
288
|
+
value = await context.yield_async(await gen.asend(value))
|
289
|
+
except StopAsyncIteration:
|
290
|
+
pass
|
291
|
+
except Exception as e:
|
292
|
+
await context.yield_async(e)
|
293
|
+
finally:
|
294
|
+
context.shutdown()
|
295
|
+
|
296
|
+
async def _run_coro(self, input: list[Message], context: Context) -> None:
|
297
|
+
try:
|
298
|
+
await context.yield_async(await self.agent.run(input, context))
|
299
|
+
except Exception as e:
|
300
|
+
await context.yield_async(e)
|
301
|
+
finally:
|
302
|
+
context.shutdown()
|
303
|
+
|
304
|
+
def _run_gen(self, input: list[Message], context: Context) -> None:
|
305
|
+
try:
|
306
|
+
gen: Generator[RunYield, RunYieldResume] = self.agent.run(input, context)
|
307
|
+
value = None
|
308
|
+
while True:
|
309
|
+
value = context.yield_sync(gen.send(value))
|
310
|
+
except StopIteration:
|
311
|
+
pass
|
312
|
+
except Exception as e:
|
313
|
+
context.yield_sync(e)
|
314
|
+
finally:
|
315
|
+
context.shutdown()
|
316
|
+
|
317
|
+
def _run_func(self, input: list[Message], context: Context) -> None:
|
318
|
+
try:
|
319
|
+
context.yield_sync(self.agent.run(input, context))
|
320
|
+
except Exception as e:
|
321
|
+
context.yield_sync(e)
|
322
|
+
finally:
|
323
|
+
context.shutdown()
|