acp-sdk 0.0.6__py3-none-any.whl → 1.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. acp_sdk/client/__init__.py +1 -0
  2. acp_sdk/client/client.py +135 -0
  3. acp_sdk/models.py +219 -0
  4. acp_sdk/server/__init__.py +2 -0
  5. acp_sdk/server/agent.py +32 -0
  6. acp_sdk/server/bundle.py +133 -0
  7. acp_sdk/server/context.py +6 -0
  8. acp_sdk/server/server.py +137 -0
  9. acp_sdk/server/telemetry.py +45 -0
  10. acp_sdk/server/utils.py +12 -0
  11. acp_sdk-1.0.0rc1.dist-info/METADATA +53 -0
  12. acp_sdk-1.0.0rc1.dist-info/RECORD +15 -0
  13. acp/__init__.py +0 -138
  14. acp/cli/__init__.py +0 -6
  15. acp/cli/claude.py +0 -139
  16. acp/cli/cli.py +0 -471
  17. acp/client/__main__.py +0 -79
  18. acp/client/session.py +0 -372
  19. acp/client/sse.py +0 -145
  20. acp/client/stdio.py +0 -153
  21. acp/server/__init__.py +0 -3
  22. acp/server/__main__.py +0 -50
  23. acp/server/highlevel/__init__.py +0 -9
  24. acp/server/highlevel/agents/__init__.py +0 -5
  25. acp/server/highlevel/agents/agent_manager.py +0 -110
  26. acp/server/highlevel/agents/base.py +0 -20
  27. acp/server/highlevel/agents/templates.py +0 -21
  28. acp/server/highlevel/context.py +0 -185
  29. acp/server/highlevel/exceptions.py +0 -25
  30. acp/server/highlevel/prompts/__init__.py +0 -4
  31. acp/server/highlevel/prompts/base.py +0 -167
  32. acp/server/highlevel/prompts/manager.py +0 -50
  33. acp/server/highlevel/prompts/prompt_manager.py +0 -33
  34. acp/server/highlevel/resources/__init__.py +0 -23
  35. acp/server/highlevel/resources/base.py +0 -48
  36. acp/server/highlevel/resources/resource_manager.py +0 -94
  37. acp/server/highlevel/resources/templates.py +0 -80
  38. acp/server/highlevel/resources/types.py +0 -185
  39. acp/server/highlevel/server.py +0 -705
  40. acp/server/highlevel/tools/__init__.py +0 -4
  41. acp/server/highlevel/tools/base.py +0 -83
  42. acp/server/highlevel/tools/tool_manager.py +0 -53
  43. acp/server/highlevel/utilities/__init__.py +0 -1
  44. acp/server/highlevel/utilities/func_metadata.py +0 -210
  45. acp/server/highlevel/utilities/logging.py +0 -43
  46. acp/server/highlevel/utilities/types.py +0 -54
  47. acp/server/lowlevel/__init__.py +0 -3
  48. acp/server/lowlevel/helper_types.py +0 -9
  49. acp/server/lowlevel/server.py +0 -643
  50. acp/server/models.py +0 -17
  51. acp/server/session.py +0 -315
  52. acp/server/sse.py +0 -175
  53. acp/server/stdio.py +0 -83
  54. acp/server/websocket.py +0 -61
  55. acp/shared/__init__.py +0 -0
  56. acp/shared/context.py +0 -14
  57. acp/shared/exceptions.py +0 -14
  58. acp/shared/memory.py +0 -87
  59. acp/shared/progress.py +0 -40
  60. acp/shared/session.py +0 -413
  61. acp/shared/version.py +0 -3
  62. acp/types.py +0 -1258
  63. acp_sdk-0.0.6.dist-info/METADATA +0 -46
  64. acp_sdk-0.0.6.dist-info/RECORD +0 -57
  65. acp_sdk-0.0.6.dist-info/entry_points.txt +0 -2
  66. acp_sdk-0.0.6.dist-info/licenses/LICENSE +0 -22
  67. {acp/client → acp_sdk}/__init__.py +0 -0
  68. {acp → acp_sdk}/py.typed +0 -0
  69. {acp_sdk-0.0.6.dist-info → acp_sdk-1.0.0rc1.dist-info}/WHEEL +0 -0
@@ -0,0 +1 @@
1
+ from acp_sdk.client.client import Client
@@ -0,0 +1,135 @@
1
+ from types import TracebackType
2
+ from typing import AsyncIterator
3
+
4
+ import httpx
5
+ from httpx_sse import EventSource, aconnect_sse
6
+
7
+ from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
8
+
9
+ from acp_sdk.models import (
10
+ Agent,
11
+ AgentName,
12
+ AgentReadResponse,
13
+ AgentsListResponse,
14
+ AwaitResume,
15
+ Message,
16
+ RunCancelResponse,
17
+ RunCreateRequest,
18
+ Run,
19
+ RunCreateResponse,
20
+ RunEvent,
21
+ RunId,
22
+ RunMode,
23
+ RunResumeRequest,
24
+ RunResumeResponse,
25
+ )
26
+ from pydantic import TypeAdapter
27
+
28
+
29
+ class Client:
30
+ def __init__(
31
+ self, *, base_url: httpx.URL | str = "", client: httpx.AsyncClient | None = None
32
+ ):
33
+ self.base_url = base_url
34
+
35
+ self._client = self._init_client(client)
36
+
37
+ def _init_client(self, client: httpx.AsyncClient | None = None):
38
+ client = client or httpx.AsyncClient(base_url=self.base_url)
39
+ HTTPXClientInstrumentor.instrument_client(client)
40
+ return client
41
+
42
+ async def __aenter__(self):
43
+ await self._client.__aenter__()
44
+ return self
45
+
46
+ async def __aexit__(
47
+ self,
48
+ exc_type: type[BaseException] | None = None,
49
+ exc_value: BaseException | None = None,
50
+ traceback: TracebackType | None = None,
51
+ ):
52
+ await self._client.__aexit__(exc_type, exc_value, traceback)
53
+
54
+ async def agents(self) -> AsyncIterator[Agent]:
55
+ response = await self._client.get("/agents")
56
+ for agent in AgentsListResponse.model_validate(response.json()).agents:
57
+ yield agent
58
+
59
+ async def agent(self, *, name: AgentName) -> Agent:
60
+ response = await self._client.get(f"/agents/{name}")
61
+ return AgentReadResponse.model_validate(response.json())
62
+
63
+ async def run_sync(self, *, agent: AgentName, input: Message) -> Run:
64
+ response = await self._client.post(
65
+ "/runs",
66
+ json=RunCreateRequest(
67
+ agent_name=agent, input=input, mode=RunMode.SYNC
68
+ ).model_dump(),
69
+ )
70
+ return RunCreateResponse.model_validate(response.json())
71
+
72
+ async def run_async(self, *, agent: AgentName, input: Message) -> Run:
73
+ response = await self._client.post(
74
+ "/runs",
75
+ json=RunCreateRequest(
76
+ agent_name=agent, input=input, mode=RunMode.ASYNC
77
+ ).model_dump(),
78
+ )
79
+ return RunCreateResponse.model_validate(response.json())
80
+
81
+ async def run_stream(
82
+ self, *, agent: AgentName, input: Message
83
+ ) -> AsyncIterator[RunEvent]:
84
+ async with aconnect_sse(
85
+ self._client,
86
+ "POST",
87
+ "/runs",
88
+ json=RunCreateRequest(
89
+ agent_name=agent, input=input, mode=RunMode.STREAM
90
+ ).model_dump(),
91
+ ) as event_source:
92
+ async for event in self._validate_stream(event_source):
93
+ yield event
94
+
95
+ async def run_status(self, *, run_id: RunId) -> Run:
96
+ response = await self._client.get(f"/runs/{run_id}")
97
+ return Run.model_validate(response.json())
98
+
99
+ async def run_cancel(self, *, run_id: RunId) -> Run:
100
+ response = await self._client.post(f"/runs/{run_id}/cancel")
101
+ return RunCancelResponse.model_validate(response.json())
102
+
103
+ async def run_resume_sync(self, *, run_id: RunId, await_: AwaitResume) -> Run:
104
+ response = await self._client.post(
105
+ f"/runs/{run_id}",
106
+ json=RunResumeRequest(await_=await_, mode=RunMode.SYNC).model_dump(),
107
+ )
108
+ return RunResumeResponse.model_validate(response.json())
109
+
110
+ async def run_resume_async(self, *, run_id: RunId, await_: AwaitResume) -> Run:
111
+ response = await self._client.post(
112
+ f"/runs/{run_id}",
113
+ json=RunResumeRequest(await_=await_, mode=RunMode.ASYNC).model_dump(),
114
+ )
115
+ return RunResumeResponse.model_validate(response.json())
116
+
117
+ async def run_resume_stream(
118
+ self, *, run_id: RunId, await_: AwaitResume
119
+ ) -> AsyncIterator[RunEvent]:
120
+ async with aconnect_sse(
121
+ self._client,
122
+ "POST",
123
+ f"/runs/{run_id}",
124
+ json=RunResumeRequest(await_=await_, mode=RunMode.STREAM).model_dump(),
125
+ ) as event_source:
126
+ async for event in self._validate_stream(event_source):
127
+ yield event
128
+
129
+ async def _validate_stream(
130
+ self,
131
+ event_source: EventSource,
132
+ ) -> AsyncIterator[RunEvent]:
133
+ async for event in event_source.aiter_sse():
134
+ event = TypeAdapter(RunEvent).validate_json(event.data)
135
+ yield event
acp_sdk/models.py ADDED
@@ -0,0 +1,219 @@
1
+ from enum import Enum
2
+ from typing import Annotated, Literal, Union
3
+ import uuid
4
+
5
+ from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel
6
+
7
+
8
+ class ACPError(BaseModel):
9
+ code: str
10
+ message: str
11
+
12
+
13
+ class AnyModel(BaseModel):
14
+ model_config = ConfigDict(extra="allow")
15
+
16
+
17
+ class MessagePartBase(BaseModel):
18
+ type: Literal["text", "image", "artifact"]
19
+
20
+
21
+ class TextMessagePart(MessagePartBase):
22
+ type: Literal["text"] = "text"
23
+ content: str
24
+
25
+
26
+ class ImageMessagePart(MessagePartBase):
27
+ type: Literal["image"] = "image"
28
+ content_url: AnyUrl
29
+
30
+
31
+ class ArtifactMessagePart(MessagePartBase):
32
+ type: Literal["artifact"] = "artifact"
33
+ name: str
34
+ content_url: AnyUrl
35
+
36
+
37
+ MessagePart = Union[TextMessagePart, ImageMessagePart, ArtifactMessagePart]
38
+
39
+
40
+ class Message(RootModel):
41
+ root: list[MessagePart]
42
+
43
+ def __init__(self, *items: MessagePart):
44
+ super().__init__(root=list(items))
45
+
46
+ def __iter__(self):
47
+ return iter(self.root)
48
+
49
+ def __getitem__(self, item):
50
+ return self.root[item]
51
+
52
+ def __add__(self, other: "Message") -> "Message":
53
+ if not isinstance(other, Message):
54
+ raise TypeError(f"Cannot concatenate Message with {type(other).__name__}")
55
+ return Message(*(self.root + other.root))
56
+
57
+ def __str__(self):
58
+ return "".join(
59
+ str(part) for part in self.root if isinstance(part, TextMessagePart)
60
+ )
61
+
62
+
63
+ AgentName = str
64
+ SessionId = str
65
+ RunId = str
66
+
67
+
68
+ class RunMode(str, Enum):
69
+ SYNC = "sync"
70
+ ASYNC = "async"
71
+ STREAM = "stream"
72
+
73
+
74
+ class RunStatus(str, Enum):
75
+ CREATED = "created"
76
+ IN_PROGRESS = "in-progress"
77
+ AWAITING = "awaiting"
78
+ CANCELLING = "cancelling"
79
+ CANCELLED = "cancelled"
80
+ COMPLETED = "completed"
81
+ FAILED = "failed"
82
+
83
+ @property
84
+ def is_terminal(self) -> bool:
85
+ terminal_states = {RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.CANCELLED}
86
+ return self in terminal_states
87
+
88
+
89
+ class Await(BaseModel):
90
+ type: Literal["placeholder"] = "placeholder"
91
+
92
+
93
+ class AwaitResume(BaseModel):
94
+ pass
95
+
96
+
97
+ class Run(BaseModel):
98
+ run_id: RunId = str(uuid.uuid4())
99
+ agent_name: AgentName
100
+ session_id: SessionId | None = None
101
+ status: RunStatus = RunStatus.CREATED
102
+ await_: Await | None = Field(None, alias="await")
103
+ output: Message | None = None
104
+ error: ACPError | None = None
105
+
106
+ model_config = ConfigDict(populate_by_name=True)
107
+
108
+ def model_dump_json(
109
+ self,
110
+ **kwargs,
111
+ ):
112
+ return super().model_dump_json(
113
+ by_alias=True,
114
+ **kwargs,
115
+ )
116
+
117
+
118
+ class MessageEvent(BaseModel):
119
+ type: Literal["message"] = "message"
120
+ message: Message
121
+
122
+
123
+ class AwaitEvent(BaseModel):
124
+ type: Literal["await"] = "await"
125
+ await_: Await | None = Field(alias="await")
126
+
127
+ model_config = ConfigDict(populate_by_name=True)
128
+
129
+ def model_dump_json(
130
+ self,
131
+ **kwargs,
132
+ ):
133
+ return super().model_dump_json(
134
+ by_alias=True,
135
+ **kwargs,
136
+ )
137
+
138
+
139
+ class GenericEvent(BaseModel):
140
+ type: Literal["generic"] = "generic"
141
+ generic: AnyModel
142
+
143
+
144
+ class CreatedEvent(BaseModel):
145
+ type: Literal["created"] = "created"
146
+ run: Run
147
+
148
+
149
+ class InProgressEvent(BaseModel):
150
+ type: Literal["in-progress"] = "in-progress"
151
+ run: Run
152
+
153
+
154
+ class FailedEvent(BaseModel):
155
+ type: Literal["failed"] = "failed"
156
+ run: Run
157
+
158
+
159
+ class CancelledEvent(BaseModel):
160
+ type: Literal["cancelled"] = "cancelled"
161
+ run: Run
162
+
163
+
164
+ class CompletedEvent(BaseModel):
165
+ type: Literal["completed"] = "completed"
166
+ run: Run
167
+
168
+
169
+ RunEvent = Union[
170
+ CreatedEvent,
171
+ InProgressEvent,
172
+ MessageEvent,
173
+ AwaitEvent,
174
+ GenericEvent,
175
+ CancelledEvent,
176
+ FailedEvent,
177
+ CompletedEvent,
178
+ ]
179
+
180
+
181
+ class RunCreateRequest(BaseModel):
182
+ agent_name: AgentName
183
+ session_id: SessionId | None = None
184
+ input: Message
185
+ mode: RunMode = RunMode.SYNC
186
+
187
+
188
+ class RunCreateResponse(Run):
189
+ pass
190
+
191
+
192
+ class RunResumeRequest(BaseModel):
193
+ await_: AwaitResume = Field(alias="await")
194
+ mode: RunMode
195
+
196
+
197
+ class RunResumeResponse(Run):
198
+ pass
199
+
200
+
201
+ class RunReadResponse(Run):
202
+ pass
203
+
204
+
205
+ class RunCancelResponse(Run):
206
+ pass
207
+
208
+
209
+ class Agent(BaseModel):
210
+ name: str
211
+ description: str | None = None
212
+
213
+
214
+ class AgentsListResponse(BaseModel):
215
+ agents: list[Agent]
216
+
217
+
218
+ class AgentReadResponse(Agent):
219
+ pass
@@ -0,0 +1,2 @@
1
+ from acp_sdk.server.server import create_app
2
+ from acp_sdk.server.agent import Agent
@@ -0,0 +1,32 @@
1
+ import abc
2
+ from typing import AsyncGenerator
3
+
4
+ from acp_sdk.models import (
5
+ AgentName,
6
+ Message,
7
+ Await,
8
+ AwaitResume,
9
+ SessionId,
10
+ )
11
+ from acp_sdk.server.context import Context
12
+
13
+
14
+ class Agent(abc.ABC):
15
+ @property
16
+ def name(self) -> AgentName:
17
+ return self.__class__.__name__
18
+
19
+ @property
20
+ def description(self) -> str:
21
+ return ""
22
+
23
+ @abc.abstractmethod
24
+ def run(
25
+ self, input: Message, *, context: Context
26
+ ) -> AsyncGenerator[Message | Await, AwaitResume]:
27
+ pass
28
+
29
+ async def session(self, session_id: SessionId | None) -> SessionId | None:
30
+ if session_id:
31
+ raise NotImplementedError()
32
+ return None
@@ -0,0 +1,133 @@
1
+ import asyncio
2
+ import logging
3
+
4
+ from opentelemetry import trace
5
+ from pydantic import ValidationError
6
+
7
+ from acp_sdk.server.agent import Agent
8
+ from acp_sdk.models import (
9
+ ACPError,
10
+ AnyModel,
11
+ Await,
12
+ AwaitEvent,
13
+ CancelledEvent,
14
+ CompletedEvent,
15
+ CreatedEvent,
16
+ FailedEvent,
17
+ GenericEvent,
18
+ InProgressEvent,
19
+ Message,
20
+ MessageEvent,
21
+ Run,
22
+ AwaitResume,
23
+ RunEvent,
24
+ RunStatus,
25
+ )
26
+ from acp_sdk.server.context import Context
27
+
28
+ logger = logging.getLogger("uvicorn.error")
29
+
30
+
31
+ class RunBundle:
32
+ def __init__(self, *, agent: Agent, run: Run, task: asyncio.Task | None = None):
33
+ self.agent = agent
34
+ self.run = run
35
+ self.task = task
36
+
37
+ self.stream_queue: asyncio.Queue[RunEvent] = asyncio.Queue()
38
+ self.composed_message = Message()
39
+
40
+ self.await_queue: asyncio.Queue[AwaitResume] = asyncio.Queue(maxsize=1)
41
+ self.await_or_terminate_event = asyncio.Event()
42
+
43
+ async def stream(self):
44
+ try:
45
+ while True:
46
+ event = await self.stream_queue.get()
47
+ yield event
48
+ self.stream_queue.task_done()
49
+ except asyncio.QueueShutDown:
50
+ pass
51
+
52
+ async def emit(self, event: RunEvent):
53
+ await self.stream_queue.put(event)
54
+
55
+ async def await_(self) -> AwaitResume:
56
+ self.stream_queue.shutdown()
57
+ self.await_queue.empty()
58
+ self.await_or_terminate_event.set()
59
+ self.await_or_terminate_event.clear()
60
+ resume = await self.await_queue.get()
61
+ self.await_queue.task_done()
62
+ return resume
63
+
64
+ async def resume(self, resume: AwaitResume):
65
+ self.stream_queue = asyncio.Queue()
66
+ await self.await_queue.put(resume)
67
+
68
+ async def join(self):
69
+ await self.await_or_terminate_event.wait()
70
+
71
+ async def execute(self, input: Message):
72
+ with trace.get_tracer(__name__).start_as_current_span("execute"):
73
+ run_logger = logging.LoggerAdapter(logger, {"run_id": self.run.run_id})
74
+
75
+ await self.emit(CreatedEvent(run=self.run))
76
+ try:
77
+ self.run.session_id = await self.agent.session(self.run.session_id)
78
+ run_logger.info("Session loaded")
79
+
80
+ generator = self.agent.run(
81
+ input=input, context=Context(session_id=self.run.session_id)
82
+ )
83
+ run_logger.info("Run started")
84
+
85
+ self.run.status = RunStatus.IN_PROGRESS
86
+ await self.emit(InProgressEvent(run=self.run))
87
+
88
+ await_resume = None
89
+ while True:
90
+ next = await generator.asend(await_resume)
91
+ if isinstance(next, Message):
92
+ self.composed_message += next
93
+ await self.emit(MessageEvent(message=next))
94
+ elif isinstance(next, Await):
95
+ self.run.await_ = next
96
+ self.run.status = RunStatus.AWAITING
97
+ await self.emit(
98
+ AwaitEvent.model_validate(
99
+ {
100
+ "run_id": self.run.run_id,
101
+ "type": "await",
102
+ "await": next,
103
+ }
104
+ )
105
+ )
106
+ run_logger.info("Run awaited")
107
+ await_resume = await self.await_()
108
+ self.run.status = RunStatus.IN_PROGRESS
109
+ await self.emit(InProgressEvent(run=self.run))
110
+ run_logger.info("Run resumed")
111
+ else:
112
+ try:
113
+ generic = AnyModel.model_validate(next)
114
+ await self.emit(GenericEvent(generic=generic))
115
+ except ValidationError:
116
+ raise TypeError("Invalid yield")
117
+ except StopAsyncIteration:
118
+ self.run.output = self.composed_message
119
+ self.run.status = RunStatus.COMPLETED
120
+ await self.emit(CompletedEvent(run=self.run))
121
+ run_logger.info("Run completed")
122
+ except asyncio.CancelledError:
123
+ self.run.status = RunStatus.CANCELLED
124
+ await self.emit(CancelledEvent(run=self.run))
125
+ run_logger.info("Run cancelled")
126
+ except Exception as e:
127
+ self.run.error = ACPError(code="unspecified", message=str(e))
128
+ self.run.status = RunStatus.FAILED
129
+ await self.emit(FailedEvent(run=self.run))
130
+ run_logger.exception("Run failed")
131
+ finally:
132
+ self.await_or_terminate_event.set()
133
+ self.stream_queue.shutdown()
@@ -0,0 +1,6 @@
1
+ from acp_sdk.models import SessionId
2
+
3
+
4
+ class Context:
5
+ def __init__(self, *, session_id: SessionId | None = None):
6
+ self.session_id = session_id
@@ -0,0 +1,137 @@
1
+ import asyncio
2
+
3
+ from acp_sdk.server.telemetry import configure_telemetry
4
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
5
+
6
+ from fastapi import FastAPI, HTTPException, status
7
+ from fastapi.responses import JSONResponse, StreamingResponse
8
+
9
+ from acp_sdk.server.agent import Agent
10
+ from acp_sdk.models import (
11
+ AgentName,
12
+ Agent as AgentModel,
13
+ AgentsListResponse,
14
+ AgentReadResponse,
15
+ Run,
16
+ RunCancelResponse,
17
+ RunCreateRequest,
18
+ RunCreateResponse,
19
+ RunId,
20
+ RunMode,
21
+ RunReadResponse,
22
+ RunResumeRequest,
23
+ RunResumeResponse,
24
+ RunStatus,
25
+ )
26
+ from acp_sdk.server.bundle import RunBundle
27
+ from acp_sdk.server.utils import stream_sse
28
+
29
+
30
+ def create_app(*agents: Agent) -> FastAPI:
31
+ app = FastAPI(title="acp-agents")
32
+
33
+ configure_telemetry()
34
+ FastAPIInstrumentor.instrument_app(app)
35
+
36
+ agents: dict[AgentName, Agent] = {agent.name: agent for agent in agents}
37
+ runs: dict[RunId, RunBundle] = dict()
38
+
39
+ def find_run_bundle(run_id: RunId):
40
+ bundle = runs.get(run_id, None)
41
+ if not bundle:
42
+ raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
43
+ return bundle
44
+
45
+ def find_agent(agent_name: AgentName):
46
+ agent = agents.get(agent_name, None)
47
+ if not agent:
48
+ raise HTTPException(status_code=404, detail=f"Agent {agent_name} not found")
49
+ return agent
50
+
51
+ @app.get("/agents")
52
+ async def list() -> AgentsListResponse:
53
+ return AgentsListResponse(
54
+ agents=[
55
+ AgentModel(name=agent.name, description=agent.description)
56
+ for agent in agents.values()
57
+ ]
58
+ )
59
+
60
+ @app.get("/agents/{name}")
61
+ async def read(name: AgentName) -> AgentReadResponse:
62
+ agent = find_agent(name)
63
+ return AgentModel(name=agent.name, description=agent.description)
64
+
65
+ @app.post("/runs")
66
+ async def create(request: RunCreateRequest) -> RunCreateResponse:
67
+ agent = find_agent(request.agent_name)
68
+ bundle = RunBundle(
69
+ agent=agent,
70
+ run=Run(
71
+ agent_name=agent.name,
72
+ session_id=request.session_id,
73
+ ),
74
+ )
75
+
76
+ bundle.task = asyncio.create_task(bundle.execute(request.input))
77
+ runs[bundle.run.run_id] = bundle
78
+
79
+ match request.mode:
80
+ case RunMode.STREAM:
81
+ return StreamingResponse(
82
+ stream_sse(bundle),
83
+ media_type="text/event-stream",
84
+ )
85
+ case RunMode.SYNC:
86
+ await bundle.join()
87
+ return bundle.run
88
+ case RunMode.ASYNC:
89
+ return JSONResponse(
90
+ status_code=status.HTTP_202_ACCEPTED,
91
+ content=bundle.run.model_dump(),
92
+ )
93
+ case _:
94
+ raise NotImplementedError()
95
+
96
+ @app.get("/runs/{run_id}")
97
+ async def read(run_id: RunId) -> RunReadResponse:
98
+ bundle = find_run_bundle(run_id)
99
+ return bundle.run
100
+
101
+ @app.post("/runs/{run_id}")
102
+ async def resume(run_id: RunId, request: RunResumeRequest) -> RunResumeResponse:
103
+ bundle = find_run_bundle(run_id)
104
+ bundle.stream_queue = asyncio.Queue() # TODO improve
105
+ await bundle.await_queue.put(request.await_)
106
+ match request.mode:
107
+ case RunMode.STREAM:
108
+ return StreamingResponse(
109
+ stream_sse(bundle),
110
+ media_type="text/event-stream",
111
+ )
112
+ case RunMode.SYNC:
113
+ await bundle.join()
114
+ return bundle.run
115
+ case RunMode.ASYNC:
116
+ return JSONResponse(
117
+ status_code=status.HTTP_202_ACCEPTED,
118
+ content=bundle.run.model_dump(),
119
+ )
120
+ case _:
121
+ raise NotImplementedError()
122
+
123
+ @app.post("/runs/{run_id}/cancel")
124
+ async def cancel(run_id: RunId) -> RunCancelResponse:
125
+ bundle = find_run_bundle(run_id)
126
+ if bundle.run.status.is_terminal:
127
+ raise HTTPException(
128
+ status_code=403,
129
+ detail=f"Run with terminal status {bundle.run.status} can't be cancelled",
130
+ )
131
+ bundle.task.cancel()
132
+ bundle.run.status = RunStatus.CANCELLING
133
+ return JSONResponse(
134
+ status_code=status.HTTP_202_ACCEPTED, content=bundle.run.model_dump()
135
+ )
136
+
137
+ return app