acp-sdk 0.0.6__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. acp_sdk/__init__.py +1 -0
  2. acp_sdk/client/__init__.py +1 -0
  3. acp_sdk/client/client.py +139 -0
  4. acp_sdk/models/__init__.py +3 -0
  5. acp_sdk/models/errors.py +23 -0
  6. acp_sdk/models/models.py +181 -0
  7. acp_sdk/models/schemas.py +39 -0
  8. acp_sdk/server/__init__.py +6 -0
  9. acp_sdk/server/agent.py +105 -0
  10. acp_sdk/server/app.py +161 -0
  11. acp_sdk/server/bundle.py +131 -0
  12. acp_sdk/server/context.py +33 -0
  13. acp_sdk/server/errors.py +54 -0
  14. acp_sdk/server/logging.py +16 -0
  15. acp_sdk/server/server.py +120 -0
  16. acp_sdk/server/telemetry.py +52 -0
  17. acp_sdk/server/types.py +6 -0
  18. acp_sdk/server/utils.py +14 -0
  19. acp_sdk-0.1.0rc5.dist-info/METADATA +78 -0
  20. acp_sdk-0.1.0rc5.dist-info/RECORD +22 -0
  21. acp/__init__.py +0 -138
  22. acp/cli/__init__.py +0 -6
  23. acp/cli/claude.py +0 -139
  24. acp/cli/cli.py +0 -471
  25. acp/client/__init__.py +0 -0
  26. acp/client/__main__.py +0 -79
  27. acp/client/session.py +0 -372
  28. acp/client/sse.py +0 -145
  29. acp/client/stdio.py +0 -153
  30. acp/server/__init__.py +0 -3
  31. acp/server/__main__.py +0 -50
  32. acp/server/highlevel/__init__.py +0 -9
  33. acp/server/highlevel/agents/__init__.py +0 -5
  34. acp/server/highlevel/agents/agent_manager.py +0 -110
  35. acp/server/highlevel/agents/base.py +0 -20
  36. acp/server/highlevel/agents/templates.py +0 -21
  37. acp/server/highlevel/context.py +0 -185
  38. acp/server/highlevel/exceptions.py +0 -25
  39. acp/server/highlevel/prompts/__init__.py +0 -4
  40. acp/server/highlevel/prompts/base.py +0 -167
  41. acp/server/highlevel/prompts/manager.py +0 -50
  42. acp/server/highlevel/prompts/prompt_manager.py +0 -33
  43. acp/server/highlevel/resources/__init__.py +0 -23
  44. acp/server/highlevel/resources/base.py +0 -48
  45. acp/server/highlevel/resources/resource_manager.py +0 -94
  46. acp/server/highlevel/resources/templates.py +0 -80
  47. acp/server/highlevel/resources/types.py +0 -185
  48. acp/server/highlevel/server.py +0 -705
  49. acp/server/highlevel/tools/__init__.py +0 -4
  50. acp/server/highlevel/tools/base.py +0 -83
  51. acp/server/highlevel/tools/tool_manager.py +0 -53
  52. acp/server/highlevel/utilities/__init__.py +0 -1
  53. acp/server/highlevel/utilities/func_metadata.py +0 -210
  54. acp/server/highlevel/utilities/logging.py +0 -43
  55. acp/server/highlevel/utilities/types.py +0 -54
  56. acp/server/lowlevel/__init__.py +0 -3
  57. acp/server/lowlevel/helper_types.py +0 -9
  58. acp/server/lowlevel/server.py +0 -643
  59. acp/server/models.py +0 -17
  60. acp/server/session.py +0 -315
  61. acp/server/sse.py +0 -175
  62. acp/server/stdio.py +0 -83
  63. acp/server/websocket.py +0 -61
  64. acp/shared/__init__.py +0 -0
  65. acp/shared/context.py +0 -14
  66. acp/shared/exceptions.py +0 -14
  67. acp/shared/memory.py +0 -87
  68. acp/shared/progress.py +0 -40
  69. acp/shared/session.py +0 -413
  70. acp/shared/version.py +0 -3
  71. acp/types.py +0 -1258
  72. acp_sdk-0.0.6.dist-info/METADATA +0 -46
  73. acp_sdk-0.0.6.dist-info/RECORD +0 -57
  74. acp_sdk-0.0.6.dist-info/entry_points.txt +0 -2
  75. acp_sdk-0.0.6.dist-info/licenses/LICENSE +0 -22
  76. {acp → acp_sdk}/py.typed +0 -0
  77. {acp_sdk-0.0.6.dist-info → acp_sdk-0.1.0rc5.dist-info}/WHEEL +0 -0
acp_sdk/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from acp_sdk.models import * # noqa: F403
@@ -0,0 +1 @@
1
+ from acp_sdk.client.client import Client as Client
@@ -0,0 +1,139 @@
1
+ from collections.abc import AsyncIterator
2
+ from types import TracebackType
3
+ from typing import Self
4
+
5
+ import httpx
6
+ from httpx_sse import EventSource, aconnect_sse
7
+ from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
8
+ from pydantic import TypeAdapter
9
+
10
+ from acp_sdk.models import (
11
+ ACPError,
12
+ Agent,
13
+ AgentName,
14
+ AgentReadResponse,
15
+ AgentsListResponse,
16
+ AwaitResume,
17
+ Error,
18
+ Message,
19
+ Run,
20
+ RunCancelResponse,
21
+ RunCreateRequest,
22
+ RunCreateResponse,
23
+ RunEvent,
24
+ RunId,
25
+ RunMode,
26
+ RunResumeRequest,
27
+ RunResumeResponse,
28
+ )
29
+
30
+
31
+ class Client:
32
+ def __init__(self, *, base_url: httpx.URL | str = "", client: httpx.AsyncClient | None = None) -> None:
33
+ self.base_url = base_url
34
+
35
+ self._client = self._init_client(client)
36
+
37
+ def _init_client(self, client: httpx.AsyncClient | None = None) -> httpx.AsyncClient:
38
+ client = client or httpx.AsyncClient(base_url=self.base_url)
39
+ HTTPXClientInstrumentor.instrument_client(client)
40
+ return client
41
+
42
+ async def __aenter__(self) -> Self:
43
+ await self._client.__aenter__()
44
+ return self
45
+
46
+ async def __aexit__(
47
+ self,
48
+ exc_type: type[BaseException] | None = None,
49
+ exc_value: BaseException | None = None,
50
+ traceback: TracebackType | None = None,
51
+ ) -> None:
52
+ await self._client.__aexit__(exc_type, exc_value, traceback)
53
+
54
+ async def agents(self) -> AsyncIterator[Agent]:
55
+ response = await self._client.get("/agents")
56
+ self._raise_error(response)
57
+ for agent in AgentsListResponse.model_validate(response.json()).agents:
58
+ yield agent
59
+
60
+ async def agent(self, *, name: AgentName) -> Agent:
61
+ response = await self._client.get(f"/agents/{name}")
62
+ self._raise_error(response)
63
+ return AgentReadResponse.model_validate(response.json())
64
+
65
+ async def run_sync(self, *, agent: AgentName, input: Message) -> Run:
66
+ response = await self._client.post(
67
+ "/runs",
68
+ json=RunCreateRequest(agent_name=agent, input=input, mode=RunMode.SYNC).model_dump(),
69
+ )
70
+ self._raise_error(response)
71
+ return RunCreateResponse.model_validate(response.json())
72
+
73
+ async def run_async(self, *, agent: AgentName, input: Message) -> Run:
74
+ response = await self._client.post(
75
+ "/runs",
76
+ json=RunCreateRequest(agent_name=agent, input=input, mode=RunMode.ASYNC).model_dump(),
77
+ )
78
+ self._raise_error(response)
79
+ return RunCreateResponse.model_validate(response.json())
80
+
81
+ async def run_stream(self, *, agent: AgentName, input: Message) -> AsyncIterator[RunEvent]:
82
+ async with aconnect_sse(
83
+ self._client,
84
+ "POST",
85
+ "/runs",
86
+ json=RunCreateRequest(agent_name=agent, input=input, mode=RunMode.STREAM).model_dump(),
87
+ ) as event_source:
88
+ async for event in self._validate_stream(event_source):
89
+ yield event
90
+
91
+ async def run_status(self, *, run_id: RunId) -> Run:
92
+ response = await self._client.get(f"/runs/{run_id}")
93
+ self._raise_error(response)
94
+ return Run.model_validate(response.json())
95
+
96
+ async def run_cancel(self, *, run_id: RunId) -> Run:
97
+ response = await self._client.post(f"/runs/{run_id}/cancel")
98
+ self._raise_error(response)
99
+ return RunCancelResponse.model_validate(response.json())
100
+
101
+ async def run_resume_sync(self, *, run_id: RunId, await_: AwaitResume) -> Run:
102
+ response = await self._client.post(
103
+ f"/runs/{run_id}",
104
+ json=RunResumeRequest(await_=await_, mode=RunMode.SYNC).model_dump(),
105
+ )
106
+ self._raise_error(response)
107
+ return RunResumeResponse.model_validate(response.json())
108
+
109
+ async def run_resume_async(self, *, run_id: RunId, await_: AwaitResume) -> Run:
110
+ response = await self._client.post(
111
+ f"/runs/{run_id}",
112
+ json=RunResumeRequest(await_=await_, mode=RunMode.ASYNC).model_dump(),
113
+ )
114
+ self._raise_error(response)
115
+ return RunResumeResponse.model_validate(response.json())
116
+
117
+ async def run_resume_stream(self, *, run_id: RunId, await_: AwaitResume) -> AsyncIterator[RunEvent]:
118
+ async with aconnect_sse(
119
+ self._client,
120
+ "POST",
121
+ f"/runs/{run_id}",
122
+ json=RunResumeRequest(await_=await_, mode=RunMode.STREAM).model_dump(),
123
+ ) as event_source:
124
+ async for event in self._validate_stream(event_source):
125
+ yield event
126
+
127
+ async def _validate_stream(
128
+ self,
129
+ event_source: EventSource,
130
+ ) -> AsyncIterator[RunEvent]:
131
+ async for event in event_source.aiter_sse():
132
+ event = TypeAdapter(RunEvent).validate_json(event.data)
133
+ yield event
134
+
135
+ def _raise_error(self, response: httpx.Response) -> None:
136
+ try:
137
+ response.raise_for_status()
138
+ except httpx.HTTPError:
139
+ raise ACPError(Error.model_validate(response.json()))
@@ -0,0 +1,3 @@
1
+ from acp_sdk.models.errors import * # noqa: F403
2
+ from acp_sdk.models.models import * # noqa: F403
3
+ from acp_sdk.models.schemas import * # noqa: F403
@@ -0,0 +1,23 @@
1
+ from enum import Enum
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ErrorCode(str, Enum):
7
+ SERVER_ERROR = "server_error"
8
+ INVALID_INPUT = "invalid_input"
9
+ NOT_FOUND = "not_found"
10
+
11
+
12
+ class Error(BaseModel):
13
+ code: ErrorCode
14
+ message: str
15
+
16
+
17
+ class ACPError(Exception):
18
+ def __init__(self, error: Error) -> None:
19
+ super().__init__()
20
+ self.error = error
21
+
22
+ def __str__(self) -> str:
23
+ return str(self.error.message)
@@ -0,0 +1,181 @@
1
+ import uuid
2
+ from collections.abc import Iterator
3
+ from enum import Enum
4
+ from typing import Any, Literal, Union
5
+
6
+ from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel
7
+
8
+ from acp_sdk.models.errors import Error
9
+
10
+
11
+ class Metadata(BaseModel):
12
+ model_config = ConfigDict(extra="allow")
13
+
14
+
15
+ class AnyModel(BaseModel):
16
+ model_config = ConfigDict(extra="allow")
17
+
18
+
19
+ class MessagePartBase(BaseModel):
20
+ type: Literal["text", "image", "artifact"]
21
+
22
+
23
+ class TextMessagePart(MessagePartBase):
24
+ type: Literal["text"] = "text"
25
+ content: str
26
+
27
+
28
+ class ImageMessagePart(MessagePartBase):
29
+ type: Literal["image"] = "image"
30
+ content_url: AnyUrl
31
+
32
+
33
+ class ArtifactMessagePart(MessagePartBase):
34
+ type: Literal["artifact"] = "artifact"
35
+ name: str
36
+ content_url: AnyUrl
37
+
38
+
39
+ MessagePart = Union[TextMessagePart, ImageMessagePart, ArtifactMessagePart]
40
+
41
+
42
+ class Message(RootModel):
43
+ root: list[MessagePart]
44
+
45
+ def __init__(self, *items: MessagePart) -> None:
46
+ super().__init__(root=list(items))
47
+
48
+ def __iter__(self) -> Iterator[MessagePart]:
49
+ return iter(self.root)
50
+
51
+ def __add__(self, other: "Message") -> "Message":
52
+ if not isinstance(other, Message):
53
+ raise TypeError(f"Cannot concatenate Message with {type(other).__name__}")
54
+ return Message(*(self.root + other.root))
55
+
56
+ def __str__(self) -> str:
57
+ return "".join(str(part) for part in self.root if isinstance(part, TextMessagePart))
58
+
59
+
60
+ AgentName = str
61
+ SessionId = str
62
+ RunId = str
63
+
64
+
65
+ class RunMode(str, Enum):
66
+ SYNC = "sync"
67
+ ASYNC = "async"
68
+ STREAM = "stream"
69
+
70
+
71
+ class RunStatus(str, Enum):
72
+ CREATED = "created"
73
+ IN_PROGRESS = "in-progress"
74
+ AWAITING = "awaiting"
75
+ CANCELLING = "cancelling"
76
+ CANCELLED = "cancelled"
77
+ COMPLETED = "completed"
78
+ FAILED = "failed"
79
+
80
+ @property
81
+ def is_terminal(self) -> bool:
82
+ terminal_states = {RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.CANCELLED}
83
+ return self in terminal_states
84
+
85
+
86
+ class Await(BaseModel):
87
+ type: Literal["placeholder"] = "placeholder"
88
+
89
+
90
+ class AwaitResume(BaseModel):
91
+ pass
92
+
93
+
94
+ class Run(BaseModel):
95
+ run_id: RunId = str(uuid.uuid4())
96
+ agent_name: AgentName
97
+ session_id: SessionId | None = None
98
+ status: RunStatus = RunStatus.CREATED
99
+ await_: Await | None = Field(None, alias="await")
100
+ output: Message | None = None
101
+ error: Error | None = None
102
+
103
+ model_config = ConfigDict(populate_by_name=True)
104
+
105
+ def model_dump_json(
106
+ self,
107
+ **kwargs: dict[str, Any],
108
+ ) -> str:
109
+ return super().model_dump_json(
110
+ by_alias=True,
111
+ **kwargs,
112
+ )
113
+
114
+
115
+ class MessageEvent(BaseModel):
116
+ type: Literal["message"] = "message"
117
+ message: Message
118
+
119
+
120
+ class AwaitEvent(BaseModel):
121
+ type: Literal["await"] = "await"
122
+ await_: Await | None = Field(alias="await")
123
+
124
+ model_config = ConfigDict(populate_by_name=True)
125
+
126
+ def model_dump_json(
127
+ self,
128
+ **kwargs: dict[str, Any],
129
+ ) -> str:
130
+ return super().model_dump_json(
131
+ by_alias=True,
132
+ **kwargs,
133
+ )
134
+
135
+
136
+ class GenericEvent(BaseModel):
137
+ type: Literal["generic"] = "generic"
138
+ generic: AnyModel
139
+
140
+
141
+ class CreatedEvent(BaseModel):
142
+ type: Literal["created"] = "created"
143
+ run: Run
144
+
145
+
146
+ class InProgressEvent(BaseModel):
147
+ type: Literal["in-progress"] = "in-progress"
148
+ run: Run
149
+
150
+
151
+ class FailedEvent(BaseModel):
152
+ type: Literal["failed"] = "failed"
153
+ run: Run
154
+
155
+
156
+ class CancelledEvent(BaseModel):
157
+ type: Literal["cancelled"] = "cancelled"
158
+ run: Run
159
+
160
+
161
+ class CompletedEvent(BaseModel):
162
+ type: Literal["completed"] = "completed"
163
+ run: Run
164
+
165
+
166
+ RunEvent = Union[
167
+ CreatedEvent,
168
+ InProgressEvent,
169
+ MessageEvent,
170
+ AwaitEvent,
171
+ GenericEvent,
172
+ CancelledEvent,
173
+ FailedEvent,
174
+ CompletedEvent,
175
+ ]
176
+
177
+
178
+ class Agent(BaseModel):
179
+ name: str
180
+ description: str | None = None
181
+ metadata: Metadata = Metadata()
@@ -0,0 +1,39 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ from acp_sdk.models.models import Agent, AgentName, AwaitResume, Message, Run, RunMode, SessionId
4
+
5
+
6
+ class AgentsListResponse(BaseModel):
7
+ agents: list[Agent]
8
+
9
+
10
+ class AgentReadResponse(Agent):
11
+ pass
12
+
13
+
14
+ class RunCreateRequest(BaseModel):
15
+ agent_name: AgentName
16
+ session_id: SessionId | None = None
17
+ input: Message
18
+ mode: RunMode = RunMode.SYNC
19
+
20
+
21
+ class RunCreateResponse(Run):
22
+ pass
23
+
24
+
25
+ class RunResumeRequest(BaseModel):
26
+ await_: AwaitResume = Field(alias="await")
27
+ mode: RunMode
28
+
29
+
30
+ class RunResumeResponse(Run):
31
+ pass
32
+
33
+
34
+ class RunReadResponse(Run):
35
+ pass
36
+
37
+
38
+ class RunCancelResponse(Run):
39
+ pass
@@ -0,0 +1,6 @@
1
+ from acp_sdk.server.agent import Agent as Agent
2
+ from acp_sdk.server.app import create_app as create_app
3
+ from acp_sdk.server.context import Context as Context
4
+ from acp_sdk.server.server import Server as Server
5
+ from acp_sdk.server.types import RunYield as RunYield
6
+ from acp_sdk.server.types import RunYieldResume as RunYieldResume
@@ -0,0 +1,105 @@
1
+ import abc
2
+ import asyncio
3
+ import inspect
4
+ from collections.abc import AsyncGenerator, Coroutine, Generator
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ import janus
8
+
9
+ from acp_sdk.models import (
10
+ AgentName,
11
+ Message,
12
+ SessionId,
13
+ )
14
+ from acp_sdk.models.models import Metadata
15
+ from acp_sdk.server.context import Context
16
+ from acp_sdk.server.types import RunYield, RunYieldResume
17
+
18
+
19
+ class Agent(abc.ABC):
20
+ @property
21
+ def name(self) -> AgentName:
22
+ return self.__class__.__name__
23
+
24
+ @property
25
+ def description(self) -> str:
26
+ return ""
27
+
28
+ @property
29
+ def metadata(self) -> Metadata:
30
+ return Metadata()
31
+
32
+ @abc.abstractmethod
33
+ def run(
34
+ self, input: Message, context: Context
35
+ ) -> (
36
+ AsyncGenerator[RunYield, RunYieldResume] | Generator[RunYield, RunYieldResume] | Coroutine[RunYield] | RunYield
37
+ ):
38
+ pass
39
+
40
+ async def session(self, session_id: SessionId | None) -> SessionId | None:
41
+ if session_id:
42
+ raise NotImplementedError()
43
+ return None
44
+
45
+ async def execute(
46
+ self, input: Message, session_id: SessionId | None, executor: ThreadPoolExecutor
47
+ ) -> AsyncGenerator[RunYield, RunYieldResume]:
48
+ yield_queue: janus.Queue[RunYield] = janus.Queue()
49
+ yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue()
50
+
51
+ context = Context(
52
+ session_id=session_id, executor=executor, yield_queue=yield_queue, yield_resume_queue=yield_resume_queue
53
+ )
54
+
55
+ if inspect.isasyncgenfunction(self.run):
56
+ run = asyncio.create_task(self._run_async_gen(input, context))
57
+ elif inspect.iscoroutinefunction(self.run):
58
+ run = asyncio.create_task(self._run_coro(input, context))
59
+ elif inspect.isgeneratorfunction(self.run):
60
+ run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, input, context)
61
+ else:
62
+ run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context)
63
+
64
+ try:
65
+ while True:
66
+ value = yield await yield_queue.async_q.get()
67
+ await yield_resume_queue.async_q.put(value)
68
+ except janus.AsyncQueueShutDown:
69
+ pass
70
+ finally:
71
+ await run # Raise exceptions
72
+
73
+ async def _run_async_gen(self, input: Message, context: Context) -> None:
74
+ try:
75
+ gen: AsyncGenerator[RunYield, RunYieldResume] = self.run(input, context)
76
+ value = None
77
+ while True:
78
+ value = await context.yield_async(await gen.asend(value))
79
+ except StopAsyncIteration:
80
+ pass
81
+ finally:
82
+ context.shutdown()
83
+
84
+ async def _run_coro(self, input: Message, context: Context) -> None:
85
+ try:
86
+ await context.yield_async(await self.run(input, context))
87
+ finally:
88
+ context.shutdown()
89
+
90
+ def _run_gen(self, input: Message, context: Context) -> None:
91
+ try:
92
+ gen: Generator[RunYield, RunYieldResume] = self.run(input, context)
93
+ value = None
94
+ while True:
95
+ value = context.yield_sync(gen.send(value))
96
+ except StopIteration:
97
+ pass
98
+ finally:
99
+ context.shutdown()
100
+
101
+ def _run_func(self, input: Message, context: Context) -> None:
102
+ try:
103
+ context.yield_sync(self.run(input, context))
104
+ finally:
105
+ context.shutdown()
acp_sdk/server/app.py ADDED
@@ -0,0 +1,161 @@
1
+ import asyncio
2
+ from collections.abc import AsyncGenerator
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from contextlib import asynccontextmanager
5
+
6
+ from fastapi import FastAPI, HTTPException, status
7
+ from fastapi.responses import JSONResponse, StreamingResponse
8
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
9
+
10
+ from acp_sdk.models import (
11
+ Agent as AgentModel,
12
+ )
13
+ from acp_sdk.models import (
14
+ AgentName,
15
+ AgentReadResponse,
16
+ AgentsListResponse,
17
+ Run,
18
+ RunCancelResponse,
19
+ RunCreateRequest,
20
+ RunCreateResponse,
21
+ RunId,
22
+ RunMode,
23
+ RunReadResponse,
24
+ RunResumeRequest,
25
+ RunResumeResponse,
26
+ RunStatus,
27
+ )
28
+ from acp_sdk.models.errors import ACPError
29
+ from acp_sdk.server.agent import Agent
30
+ from acp_sdk.server.bundle import RunBundle
31
+ from acp_sdk.server.errors import (
32
+ RequestValidationError,
33
+ StarletteHTTPException,
34
+ acp_error_handler,
35
+ catch_all_exception_handler,
36
+ http_exception_handler,
37
+ validation_exception_handler,
38
+ )
39
+ from acp_sdk.server.utils import stream_sse
40
+
41
+
42
+ def create_app(*agents: Agent) -> FastAPI:
43
+ executor: ThreadPoolExecutor
44
+
45
+ @asynccontextmanager
46
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
47
+ nonlocal executor
48
+ with ThreadPoolExecutor(max_workers=5) as exec:
49
+ executor = exec
50
+ yield
51
+
52
+ app = FastAPI(lifespan=lifespan)
53
+
54
+ FastAPIInstrumentor.instrument_app(app)
55
+
56
+ agents: dict[AgentName, Agent] = {agent.name: agent for agent in agents}
57
+ runs: dict[RunId, RunBundle] = {}
58
+
59
+ app.exception_handler(ACPError)(acp_error_handler)
60
+ app.exception_handler(StarletteHTTPException)(http_exception_handler)
61
+ app.exception_handler(RequestValidationError)(validation_exception_handler)
62
+ app.exception_handler(Exception)(catch_all_exception_handler)
63
+
64
+ def find_run_bundle(run_id: RunId) -> RunBundle:
65
+ bundle = runs.get(run_id)
66
+ if not bundle:
67
+ raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
68
+ return bundle
69
+
70
+ def find_agent(agent_name: AgentName) -> Agent:
71
+ agent = agents.get(agent_name, None)
72
+ if not agent:
73
+ raise HTTPException(status_code=404, detail=f"Agent {agent_name} not found")
74
+ return agent
75
+
76
+ @app.get("/agents")
77
+ async def list_agents() -> AgentsListResponse:
78
+ return AgentsListResponse(
79
+ agents=[
80
+ AgentModel(name=agent.name, description=agent.description, metadata=agent.metadata)
81
+ for agent in agents.values()
82
+ ]
83
+ )
84
+
85
+ @app.get("/agents/{name}")
86
+ async def read_agent(name: AgentName) -> AgentReadResponse:
87
+ agent = find_agent(name)
88
+ return AgentModel(name=agent.name, description=agent.description, metadata=agent.metadata)
89
+
90
+ @app.post("/runs")
91
+ async def create_run(request: RunCreateRequest) -> RunCreateResponse:
92
+ agent = find_agent(request.agent_name)
93
+ bundle = RunBundle(
94
+ agent=agent,
95
+ run=Run(
96
+ agent_name=agent.name,
97
+ session_id=request.session_id,
98
+ ),
99
+ )
100
+
101
+ nonlocal executor
102
+ bundle.task = asyncio.create_task(bundle.execute(request.input, executor=executor))
103
+ runs[bundle.run.run_id] = bundle
104
+
105
+ match request.mode:
106
+ case RunMode.STREAM:
107
+ return StreamingResponse(
108
+ stream_sse(bundle),
109
+ media_type="text/event-stream",
110
+ )
111
+ case RunMode.SYNC:
112
+ await bundle.join()
113
+ return bundle.run
114
+ case RunMode.ASYNC:
115
+ return JSONResponse(
116
+ status_code=status.HTTP_202_ACCEPTED,
117
+ content=bundle.run.model_dump(),
118
+ )
119
+ case _:
120
+ raise NotImplementedError()
121
+
122
+ @app.get("/runs/{run_id}")
123
+ async def read_run(run_id: RunId) -> RunReadResponse:
124
+ bundle = find_run_bundle(run_id)
125
+ return bundle.run
126
+
127
+ @app.post("/runs/{run_id}")
128
+ async def resume_run(run_id: RunId, request: RunResumeRequest) -> RunResumeResponse:
129
+ bundle = find_run_bundle(run_id)
130
+ bundle.stream_queue = asyncio.Queue() # TODO improve
131
+ await bundle.await_queue.put(request.await_)
132
+ match request.mode:
133
+ case RunMode.STREAM:
134
+ return StreamingResponse(
135
+ stream_sse(bundle),
136
+ media_type="text/event-stream",
137
+ )
138
+ case RunMode.SYNC:
139
+ await bundle.join()
140
+ return bundle.run
141
+ case RunMode.ASYNC:
142
+ return JSONResponse(
143
+ status_code=status.HTTP_202_ACCEPTED,
144
+ content=bundle.run.model_dump(),
145
+ )
146
+ case _:
147
+ raise NotImplementedError()
148
+
149
+ @app.post("/runs/{run_id}/cancel")
150
+ async def cancel_run(run_id: RunId) -> RunCancelResponse:
151
+ bundle = find_run_bundle(run_id)
152
+ if bundle.run.status.is_terminal:
153
+ raise HTTPException(
154
+ status_code=403,
155
+ detail=f"Run with terminal status {bundle.run.status} can't be cancelled",
156
+ )
157
+ bundle.task.cancel()
158
+ bundle.run.status = RunStatus.CANCELLING
159
+ return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content=bundle.run.model_dump())
160
+
161
+ return app