acp-sdk 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. acp_sdk/__init__.py +2 -0
  2. acp_sdk/client/__init__.py +1 -0
  3. acp_sdk/client/client.py +178 -0
  4. acp_sdk/models/__init__.py +3 -0
  5. acp_sdk/models/errors.py +23 -0
  6. acp_sdk/models/models.py +192 -0
  7. acp_sdk/models/schemas.py +41 -0
  8. acp_sdk/server/__init__.py +7 -0
  9. acp_sdk/server/agent.py +178 -0
  10. acp_sdk/server/app.py +176 -0
  11. acp_sdk/server/bundle.py +146 -0
  12. acp_sdk/server/context.py +33 -0
  13. acp_sdk/server/errors.py +54 -0
  14. acp_sdk/server/logging.py +16 -0
  15. acp_sdk/server/server.py +166 -0
  16. acp_sdk/server/session.py +21 -0
  17. acp_sdk/server/telemetry.py +57 -0
  18. acp_sdk/server/types.py +6 -0
  19. acp_sdk/server/utils.py +14 -0
  20. acp_sdk/version.py +3 -0
  21. acp_sdk-0.1.0.dist-info/METADATA +113 -0
  22. acp_sdk-0.1.0.dist-info/RECORD +24 -0
  23. acp/__init__.py +0 -138
  24. acp/cli/__init__.py +0 -6
  25. acp/cli/claude.py +0 -139
  26. acp/cli/cli.py +0 -471
  27. acp/client/__init__.py +0 -0
  28. acp/client/__main__.py +0 -79
  29. acp/client/session.py +0 -372
  30. acp/client/sse.py +0 -145
  31. acp/client/stdio.py +0 -153
  32. acp/server/__init__.py +0 -3
  33. acp/server/__main__.py +0 -50
  34. acp/server/highlevel/__init__.py +0 -9
  35. acp/server/highlevel/agents/__init__.py +0 -5
  36. acp/server/highlevel/agents/agent_manager.py +0 -110
  37. acp/server/highlevel/agents/base.py +0 -20
  38. acp/server/highlevel/agents/templates.py +0 -21
  39. acp/server/highlevel/context.py +0 -185
  40. acp/server/highlevel/exceptions.py +0 -25
  41. acp/server/highlevel/prompts/__init__.py +0 -4
  42. acp/server/highlevel/prompts/base.py +0 -167
  43. acp/server/highlevel/prompts/manager.py +0 -50
  44. acp/server/highlevel/prompts/prompt_manager.py +0 -33
  45. acp/server/highlevel/resources/__init__.py +0 -23
  46. acp/server/highlevel/resources/base.py +0 -48
  47. acp/server/highlevel/resources/resource_manager.py +0 -94
  48. acp/server/highlevel/resources/templates.py +0 -80
  49. acp/server/highlevel/resources/types.py +0 -185
  50. acp/server/highlevel/server.py +0 -705
  51. acp/server/highlevel/tools/__init__.py +0 -4
  52. acp/server/highlevel/tools/base.py +0 -83
  53. acp/server/highlevel/tools/tool_manager.py +0 -53
  54. acp/server/highlevel/utilities/__init__.py +0 -1
  55. acp/server/highlevel/utilities/func_metadata.py +0 -210
  56. acp/server/highlevel/utilities/logging.py +0 -43
  57. acp/server/highlevel/utilities/types.py +0 -54
  58. acp/server/lowlevel/__init__.py +0 -3
  59. acp/server/lowlevel/helper_types.py +0 -9
  60. acp/server/lowlevel/server.py +0 -643
  61. acp/server/models.py +0 -17
  62. acp/server/session.py +0 -315
  63. acp/server/sse.py +0 -175
  64. acp/server/stdio.py +0 -83
  65. acp/server/websocket.py +0 -61
  66. acp/shared/__init__.py +0 -0
  67. acp/shared/context.py +0 -14
  68. acp/shared/exceptions.py +0 -14
  69. acp/shared/memory.py +0 -87
  70. acp/shared/progress.py +0 -40
  71. acp/shared/session.py +0 -413
  72. acp/shared/version.py +0 -3
  73. acp/types.py +0 -1258
  74. acp_sdk-0.0.6.dist-info/METADATA +0 -46
  75. acp_sdk-0.0.6.dist-info/RECORD +0 -57
  76. acp_sdk-0.0.6.dist-info/entry_points.txt +0 -2
  77. acp_sdk-0.0.6.dist-info/licenses/LICENSE +0 -22
  78. {acp → acp_sdk}/py.typed +0 -0
  79. {acp_sdk-0.0.6.dist-info → acp_sdk-0.1.0.dist-info}/WHEEL +0 -0
acp_sdk/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from acp_sdk.models import * # noqa: F403
2
+ from acp_sdk.version import __version__ as __version__
@@ -0,0 +1 @@
1
+ from acp_sdk.client.client import Client as Client
@@ -0,0 +1,178 @@
1
+ import uuid
2
+ from collections.abc import AsyncGenerator, AsyncIterator
3
+ from contextlib import asynccontextmanager
4
+ from types import TracebackType
5
+ from typing import Self
6
+
7
+ import httpx
8
+ from httpx_sse import EventSource, aconnect_sse
9
+ from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
10
+ from pydantic import TypeAdapter
11
+
12
+ from acp_sdk.models import (
13
+ ACPError,
14
+ Agent,
15
+ AgentName,
16
+ AgentReadResponse,
17
+ AgentsListResponse,
18
+ AwaitResume,
19
+ CreatedEvent,
20
+ Error,
21
+ Message,
22
+ Run,
23
+ RunCancelResponse,
24
+ RunCreateRequest,
25
+ RunCreateResponse,
26
+ RunEvent,
27
+ RunId,
28
+ RunMode,
29
+ RunResumeRequest,
30
+ RunResumeResponse,
31
+ SessionId,
32
+ )
33
+
34
+
35
+ class Client:
36
+ def __init__(
37
+ self,
38
+ *,
39
+ base_url: httpx.URL | str = "",
40
+ session_id: SessionId | None = None,
41
+ client: httpx.AsyncClient | None = None,
42
+ ) -> None:
43
+ self.base_url = base_url
44
+ self.session_id = session_id
45
+
46
+ self._client = self._init_client(client)
47
+
48
+ def _init_client(self, client: httpx.AsyncClient | None = None) -> httpx.AsyncClient:
49
+ client = client or httpx.AsyncClient(base_url=self.base_url)
50
+ HTTPXClientInstrumentor.instrument_client(client)
51
+ return client
52
+
53
+ async def __aenter__(self) -> Self:
54
+ await self._client.__aenter__()
55
+ return self
56
+
57
+ async def __aexit__(
58
+ self,
59
+ exc_type: type[BaseException] | None = None,
60
+ exc_value: BaseException | None = None,
61
+ traceback: TracebackType | None = None,
62
+ ) -> None:
63
+ await self._client.__aexit__(exc_type, exc_value, traceback)
64
+
65
+ @asynccontextmanager
66
+ async def session(self, session_id: SessionId | None = None) -> AsyncGenerator[Self]:
67
+ yield Client(client=self._client, session_id=session_id or uuid.uuid4())
68
+
69
+ async def agents(self) -> AsyncIterator[Agent]:
70
+ response = await self._client.get("/agents")
71
+ self._raise_error(response)
72
+ for agent in AgentsListResponse.model_validate(response.json()).agents:
73
+ yield agent
74
+
75
+ async def agent(self, *, name: AgentName) -> Agent:
76
+ response = await self._client.get(f"/agents/{name}")
77
+ self._raise_error(response)
78
+ return AgentReadResponse.model_validate(response.json())
79
+
80
+ async def run_sync(self, *, agent: AgentName, inputs: list[Message]) -> Run:
81
+ response = await self._client.post(
82
+ "/runs",
83
+ content=RunCreateRequest(
84
+ agent_name=agent,
85
+ inputs=inputs,
86
+ mode=RunMode.SYNC,
87
+ session_id=self.session_id,
88
+ ).model_dump_json(),
89
+ )
90
+ self._raise_error(response)
91
+ response = RunCreateResponse.model_validate(response.json())
92
+ self._set_session(response)
93
+ return response
94
+
95
+ async def run_async(self, *, agent: AgentName, inputs: list[Message]) -> Run:
96
+ response = await self._client.post(
97
+ "/runs",
98
+ content=RunCreateRequest(
99
+ agent_name=agent,
100
+ inputs=inputs,
101
+ mode=RunMode.ASYNC,
102
+ session_id=self.session_id,
103
+ ).model_dump_json(),
104
+ )
105
+ self._raise_error(response)
106
+ response = RunCreateResponse.model_validate(response.json())
107
+ self._set_session(response)
108
+ return response
109
+
110
+ async def run_stream(self, *, agent: AgentName, inputs: list[Message]) -> AsyncIterator[RunEvent]:
111
+ async with aconnect_sse(
112
+ self._client,
113
+ "POST",
114
+ "/runs",
115
+ content=RunCreateRequest(
116
+ agent_name=agent,
117
+ inputs=inputs,
118
+ mode=RunMode.STREAM,
119
+ session_id=self.session_id,
120
+ ).model_dump_json(),
121
+ ) as event_source:
122
+ async for event in self._validate_stream(event_source):
123
+ if isinstance(event, CreatedEvent):
124
+ self._set_session(event.run)
125
+ yield event
126
+
127
+ async def run_status(self, *, run_id: RunId) -> Run:
128
+ response = await self._client.get(f"/runs/{run_id}")
129
+ self._raise_error(response)
130
+ return Run.model_validate(response.json())
131
+
132
+ async def run_cancel(self, *, run_id: RunId) -> Run:
133
+ response = await self._client.post(f"/runs/{run_id}/cancel")
134
+ self._raise_error(response)
135
+ return RunCancelResponse.model_validate(response.json())
136
+
137
+ async def run_resume_sync(self, *, run_id: RunId, await_: AwaitResume) -> Run:
138
+ response = await self._client.post(
139
+ f"/runs/{run_id}",
140
+ json=RunResumeRequest(await_=await_, mode=RunMode.SYNC).model_dump(),
141
+ )
142
+ self._raise_error(response)
143
+ return RunResumeResponse.model_validate(response.json())
144
+
145
+ async def run_resume_async(self, *, run_id: RunId, await_: AwaitResume) -> Run:
146
+ response = await self._client.post(
147
+ f"/runs/{run_id}",
148
+ json=RunResumeRequest(await_=await_, mode=RunMode.ASYNC).model_dump(),
149
+ )
150
+ self._raise_error(response)
151
+ return RunResumeResponse.model_validate(response.json())
152
+
153
+ async def run_resume_stream(self, *, run_id: RunId, await_: AwaitResume) -> AsyncIterator[RunEvent]:
154
+ async with aconnect_sse(
155
+ self._client,
156
+ "POST",
157
+ f"/runs/{run_id}",
158
+ json=RunResumeRequest(await_=await_, mode=RunMode.STREAM).model_dump(),
159
+ ) as event_source:
160
+ async for event in self._validate_stream(event_source):
161
+ yield event
162
+
163
+ async def _validate_stream(
164
+ self,
165
+ event_source: EventSource,
166
+ ) -> AsyncIterator[RunEvent]:
167
+ async for event in event_source.aiter_sse():
168
+ event = TypeAdapter(RunEvent).validate_json(event.data)
169
+ yield event
170
+
171
+ def _raise_error(self, response: httpx.Response) -> None:
172
+ try:
173
+ response.raise_for_status()
174
+ except httpx.HTTPError:
175
+ raise ACPError(Error.model_validate(response.json()))
176
+
177
+ def _set_session(self, run: Run) -> None:
178
+ self.session_id = run.session_id
@@ -0,0 +1,3 @@
1
+ from acp_sdk.models.errors import * # noqa: F403
2
+ from acp_sdk.models.models import * # noqa: F403
3
+ from acp_sdk.models.schemas import * # noqa: F403
@@ -0,0 +1,23 @@
1
+ from enum import Enum
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ErrorCode(str, Enum):
7
+ SERVER_ERROR = "server_error"
8
+ INVALID_INPUT = "invalid_input"
9
+ NOT_FOUND = "not_found"
10
+
11
+
12
+ class Error(BaseModel):
13
+ code: ErrorCode
14
+ message: str
15
+
16
+
17
+ class ACPError(Exception):
18
+ def __init__(self, error: Error) -> None:
19
+ super().__init__()
20
+ self.error = error
21
+
22
+ def __str__(self) -> str:
23
+ return str(self.error.message)
@@ -0,0 +1,192 @@
1
+ import uuid
2
+ from enum import Enum
3
+ from typing import Any, Literal, Optional, Union
4
+
5
+ from pydantic import AnyUrl, BaseModel, ConfigDict, Field
6
+
7
+ from acp_sdk.models.errors import Error
8
+
9
+
10
+ class Metadata(BaseModel):
11
+ model_config = ConfigDict(extra="allow")
12
+
13
+
14
+ class AnyModel(BaseModel):
15
+ model_config = ConfigDict(extra="allow")
16
+
17
+
18
+ class MessagePart(BaseModel):
19
+ name: Optional[str] = None
20
+ content_type: str
21
+ content: Optional[str] = None
22
+ content_encoding: Optional[Literal["plain", "base64"]] = "plain"
23
+ content_url: Optional[AnyUrl] = None
24
+
25
+ model_config = ConfigDict(extra="forbid")
26
+
27
+ def model_post_init(self, __context: Any) -> None:
28
+ if self.content is None and self.content_url is None:
29
+ raise ValueError("Either content or content_url must be provided")
30
+ if self.content is not None and self.content_url is not None:
31
+ raise ValueError("Only one of content or content_url can be provided")
32
+
33
+
34
+ class Message(BaseModel):
35
+ parts: list[MessagePart]
36
+
37
+ def __add__(self, other: "Message") -> "Message":
38
+ if not isinstance(other, Message):
39
+ raise TypeError(f"Cannot concatenate Message with {type(other).__name__}")
40
+ return Message(*(self.parts + other.parts))
41
+
42
+ def __str__(self) -> str:
43
+ return "".join(
44
+ part.content for part in self.parts if part.content is not None and part.content_type == "text/plain"
45
+ )
46
+
47
+
48
+ AgentName = str
49
+ SessionId = uuid.UUID
50
+ RunId = uuid.UUID
51
+
52
+
53
+ class RunMode(str, Enum):
54
+ SYNC = "sync"
55
+ ASYNC = "async"
56
+ STREAM = "stream"
57
+
58
+
59
+ class RunStatus(str, Enum):
60
+ CREATED = "created"
61
+ IN_PROGRESS = "in-progress"
62
+ AWAITING = "awaiting"
63
+ CANCELLING = "cancelling"
64
+ CANCELLED = "cancelled"
65
+ COMPLETED = "completed"
66
+ FAILED = "failed"
67
+
68
+ @property
69
+ def is_terminal(self) -> bool:
70
+ terminal_states = {RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.CANCELLED}
71
+ return self in terminal_states
72
+
73
+
74
+ class Await(BaseModel):
75
+ type: Literal["placeholder"] = "placeholder"
76
+
77
+
78
+ class AwaitResume(BaseModel):
79
+ pass
80
+
81
+
82
+ class Artifact(BaseModel):
83
+ name: str
84
+ content_type: str
85
+ content: Optional[str] = None
86
+ content_encoding: Optional[Literal["plain", "base64"]] = "plain"
87
+ content_url: Optional[AnyUrl] = None
88
+
89
+ model_config = ConfigDict(extra="forbid")
90
+
91
+ def model_post_init(self, __context: Any) -> None:
92
+ if self.content is None and self.content_url is None:
93
+ raise ValueError("Either content or content_url must be provided")
94
+ if self.content is not None and self.content_url is not None:
95
+ raise ValueError("Only one of content or content_url can be provided")
96
+
97
+
98
+ class Run(BaseModel):
99
+ run_id: RunId = Field(default_factory=uuid.uuid4)
100
+ agent_name: AgentName
101
+ session_id: SessionId | None = None
102
+ status: RunStatus = RunStatus.CREATED
103
+ await_: Await | None = Field(None, alias="await")
104
+ outputs: list[Message] = []
105
+ artifacts: list[Artifact] = []
106
+ error: Error | None = None
107
+
108
+ model_config = ConfigDict(populate_by_name=True)
109
+
110
+ def model_dump_json(
111
+ self,
112
+ **kwargs: dict[str, Any],
113
+ ) -> str:
114
+ return super().model_dump_json(
115
+ by_alias=True,
116
+ **kwargs,
117
+ )
118
+
119
+
120
+ class MessageEvent(BaseModel):
121
+ type: Literal["message"] = "message"
122
+ message: Message
123
+
124
+
125
+ class ArtifactEvent(BaseModel):
126
+ type: Literal["artifact"] = "artifact"
127
+ artifact: Artifact
128
+
129
+
130
+ class AwaitEvent(BaseModel):
131
+ type: Literal["await"] = "await"
132
+ await_: Await | None = Field(alias="await")
133
+
134
+ model_config = ConfigDict(populate_by_name=True)
135
+
136
+ def model_dump_json(
137
+ self,
138
+ **kwargs: dict[str, Any],
139
+ ) -> str:
140
+ return super().model_dump_json(
141
+ by_alias=True,
142
+ **kwargs,
143
+ )
144
+
145
+
146
+ class GenericEvent(BaseModel):
147
+ type: Literal["generic"] = "generic"
148
+ generic: AnyModel
149
+
150
+
151
+ class CreatedEvent(BaseModel):
152
+ type: Literal["created"] = "created"
153
+ run: Run
154
+
155
+
156
+ class InProgressEvent(BaseModel):
157
+ type: Literal["in-progress"] = "in-progress"
158
+ run: Run
159
+
160
+
161
+ class FailedEvent(BaseModel):
162
+ type: Literal["failed"] = "failed"
163
+ run: Run
164
+
165
+
166
+ class CancelledEvent(BaseModel):
167
+ type: Literal["cancelled"] = "cancelled"
168
+ run: Run
169
+
170
+
171
+ class CompletedEvent(BaseModel):
172
+ type: Literal["completed"] = "completed"
173
+ run: Run
174
+
175
+
176
+ RunEvent = Union[
177
+ CreatedEvent,
178
+ InProgressEvent,
179
+ MessageEvent,
180
+ AwaitEvent,
181
+ GenericEvent,
182
+ CancelledEvent,
183
+ FailedEvent,
184
+ CompletedEvent,
185
+ ArtifactEvent,
186
+ ]
187
+
188
+
189
+ class Agent(BaseModel):
190
+ name: str
191
+ description: str | None = None
192
+ metadata: Metadata = Metadata()
@@ -0,0 +1,41 @@
1
+ from pydantic import BaseModel, ConfigDict, Field
2
+
3
+ from acp_sdk.models.models import Agent, AgentName, AwaitResume, Message, Run, RunMode, SessionId
4
+
5
+
6
+ class AgentsListResponse(BaseModel):
7
+ agents: list[Agent]
8
+
9
+
10
+ class AgentReadResponse(Agent):
11
+ pass
12
+
13
+
14
+ class RunCreateRequest(BaseModel):
15
+ agent_name: AgentName
16
+ session_id: SessionId | None = None
17
+ inputs: list[Message]
18
+ mode: RunMode = RunMode.SYNC
19
+
20
+
21
+ class RunCreateResponse(Run):
22
+ pass
23
+
24
+
25
+ class RunResumeRequest(BaseModel):
26
+ await_: AwaitResume = Field(alias="await")
27
+ mode: RunMode
28
+
29
+ model_config = ConfigDict(populate_by_name=True)
30
+
31
+
32
+ class RunResumeResponse(Run):
33
+ pass
34
+
35
+
36
+ class RunReadResponse(Run):
37
+ pass
38
+
39
+
40
+ class RunCancelResponse(Run):
41
+ pass
@@ -0,0 +1,7 @@
1
+ from acp_sdk.server.agent import Agent as Agent
2
+ from acp_sdk.server.agent import agent as agent
3
+ from acp_sdk.server.app import create_app as create_app
4
+ from acp_sdk.server.context import Context as Context
5
+ from acp_sdk.server.server import Server as Server
6
+ from acp_sdk.server.types import RunYield as RunYield
7
+ from acp_sdk.server.types import RunYieldResume as RunYieldResume
@@ -0,0 +1,178 @@
1
+ import abc
2
+ import asyncio
3
+ import inspect
4
+ from collections.abc import AsyncGenerator, Coroutine, Generator
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from typing import Callable
7
+
8
+ import janus
9
+
10
+ from acp_sdk.models import (
11
+ AgentName,
12
+ Message,
13
+ SessionId,
14
+ )
15
+ from acp_sdk.models.models import Metadata
16
+ from acp_sdk.server.context import Context
17
+ from acp_sdk.server.types import RunYield, RunYieldResume
18
+
19
+
20
+ class Agent(abc.ABC):
21
+ @property
22
+ def name(self) -> AgentName:
23
+ return self.__class__.__name__
24
+
25
+ @property
26
+ def description(self) -> str:
27
+ return ""
28
+
29
+ @property
30
+ def metadata(self) -> Metadata:
31
+ return Metadata()
32
+
33
+ @abc.abstractmethod
34
+ def run(
35
+ self, inputs: list[Message], context: Context
36
+ ) -> (
37
+ AsyncGenerator[RunYield, RunYieldResume] | Generator[RunYield, RunYieldResume] | Coroutine[RunYield] | RunYield
38
+ ):
39
+ pass
40
+
41
+ async def execute(
42
+ self, inputs: list[Message], session_id: SessionId | None, executor: ThreadPoolExecutor
43
+ ) -> AsyncGenerator[RunYield, RunYieldResume]:
44
+ yield_queue: janus.Queue[RunYield] = janus.Queue()
45
+ yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue()
46
+
47
+ context = Context(
48
+ session_id=session_id, executor=executor, yield_queue=yield_queue, yield_resume_queue=yield_resume_queue
49
+ )
50
+
51
+ if inspect.isasyncgenfunction(self.run):
52
+ run = asyncio.create_task(self._run_async_gen(inputs, context))
53
+ elif inspect.iscoroutinefunction(self.run):
54
+ run = asyncio.create_task(self._run_coro(inputs, context))
55
+ elif inspect.isgeneratorfunction(self.run):
56
+ run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, inputs, context)
57
+ else:
58
+ run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, inputs, context)
59
+
60
+ try:
61
+ while True:
62
+ value = yield await yield_queue.async_q.get()
63
+ await yield_resume_queue.async_q.put(value)
64
+ except janus.AsyncQueueShutDown:
65
+ pass
66
+ finally:
67
+ await run # Raise exceptions
68
+
69
+ async def _run_async_gen(self, input: Message, context: Context) -> None:
70
+ try:
71
+ gen: AsyncGenerator[RunYield, RunYieldResume] = self.run(input, context)
72
+ value = None
73
+ while True:
74
+ value = await context.yield_async(await gen.asend(value))
75
+ except StopAsyncIteration:
76
+ pass
77
+ finally:
78
+ context.shutdown()
79
+
80
+ async def _run_coro(self, input: Message, context: Context) -> None:
81
+ try:
82
+ await context.yield_async(await self.run(input, context))
83
+ finally:
84
+ context.shutdown()
85
+
86
+ def _run_gen(self, input: Message, context: Context) -> None:
87
+ try:
88
+ gen: Generator[RunYield, RunYieldResume] = self.run(input, context)
89
+ value = None
90
+ while True:
91
+ value = context.yield_sync(gen.send(value))
92
+ except StopIteration:
93
+ pass
94
+ finally:
95
+ context.shutdown()
96
+
97
+ def _run_func(self, input: Message, context: Context) -> None:
98
+ try:
99
+ context.yield_sync(self.run(input, context))
100
+ finally:
101
+ context.shutdown()
102
+
103
+
104
+ def agent(
105
+ name: str | None = None,
106
+ description: str | None = None,
107
+ *,
108
+ metadata: Metadata | None = None,
109
+ ) -> Callable[[Callable], Agent]:
110
+ """Decorator to create an agent."""
111
+
112
+ def decorator(fn: Callable) -> Agent:
113
+ signature = inspect.signature(fn)
114
+ parameters = list(signature.parameters.values())
115
+
116
+ if len(parameters) == 0:
117
+ raise TypeError("The agent function must have at least 'input' argument")
118
+ if len(parameters) > 2:
119
+ raise TypeError("The agent function must have only 'input' and 'context' arguments")
120
+ if len(parameters) == 2 and parameters[1].name != "context":
121
+ raise TypeError("The second argument of the agent function must be 'context'")
122
+
123
+ has_context_param = len(parameters) == 2
124
+
125
+ class DecoratorAgentBase(Agent):
126
+ @property
127
+ def name(self) -> str:
128
+ return name or fn.__name__
129
+
130
+ @property
131
+ def description(self) -> str:
132
+ return description or fn.__doc__ or ""
133
+
134
+ @property
135
+ def metadata(self) -> Metadata:
136
+ return metadata or Metadata()
137
+
138
+ agent: Agent
139
+ if inspect.isasyncgenfunction(fn):
140
+
141
+ class AsyncGenDecoratorAgent(DecoratorAgentBase):
142
+ async def run(self, input: Message, context: Context) -> AsyncGenerator[RunYield, RunYieldResume]:
143
+ try:
144
+ gen: AsyncGenerator[RunYield, RunYieldResume] = (
145
+ fn(input, context) if has_context_param else fn(input)
146
+ )
147
+ value = None
148
+ while True:
149
+ value = yield await gen.asend(value)
150
+ except StopAsyncIteration:
151
+ pass
152
+
153
+ agent = AsyncGenDecoratorAgent()
154
+ elif inspect.iscoroutinefunction(fn):
155
+
156
+ class CoroDecoratorAgent(DecoratorAgentBase):
157
+ async def run(self, input: Message, context: Context) -> Coroutine[RunYield]:
158
+ return await (fn(input, context) if has_context_param else fn(input))
159
+
160
+ agent = CoroDecoratorAgent()
161
+ elif inspect.isgeneratorfunction(fn):
162
+
163
+ class GenDecoratorAgent(DecoratorAgentBase):
164
+ def run(self, input: Message, context: Context) -> Generator[RunYield, RunYieldResume]:
165
+ yield from (fn(input, context) if has_context_param else fn(input))
166
+
167
+ agent = GenDecoratorAgent()
168
+ else:
169
+
170
+ class FuncDecoratorAgent(DecoratorAgentBase):
171
+ def run(self, input: Message, context: Context) -> RunYield:
172
+ return fn(input, context) if has_context_param else fn(input)
173
+
174
+ agent = FuncDecoratorAgent()
175
+
176
+ return agent
177
+
178
+ return decorator