grasp_agents 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/agent_message.py +28 -0
  2. grasp_agents/agent_message_pool.py +94 -0
  3. grasp_agents/base_agent.py +72 -0
  4. grasp_agents/cloud_llm.py +353 -0
  5. grasp_agents/comm_agent.py +230 -0
  6. grasp_agents/costs_dict.yaml +122 -0
  7. grasp_agents/data_retrieval/__init__.py +7 -0
  8. grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
  9. grasp_agents/data_retrieval/types.py +57 -0
  10. grasp_agents/data_retrieval/utils.py +57 -0
  11. grasp_agents/grasp_logging.py +36 -0
  12. grasp_agents/http_client.py +24 -0
  13. grasp_agents/llm.py +106 -0
  14. grasp_agents/llm_agent.py +361 -0
  15. grasp_agents/llm_agent_state.py +73 -0
  16. grasp_agents/memory.py +150 -0
  17. grasp_agents/openai/__init__.py +83 -0
  18. grasp_agents/openai/completion_converters.py +49 -0
  19. grasp_agents/openai/content_converters.py +80 -0
  20. grasp_agents/openai/converters.py +170 -0
  21. grasp_agents/openai/message_converters.py +155 -0
  22. grasp_agents/openai/openai_llm.py +179 -0
  23. grasp_agents/openai/tool_converters.py +37 -0
  24. grasp_agents/printer.py +156 -0
  25. grasp_agents/prompt_builder.py +204 -0
  26. grasp_agents/run_context.py +90 -0
  27. grasp_agents/tool_orchestrator.py +181 -0
  28. grasp_agents/typing/__init__.py +0 -0
  29. grasp_agents/typing/completion.py +30 -0
  30. grasp_agents/typing/content.py +116 -0
  31. grasp_agents/typing/converters.py +118 -0
  32. grasp_agents/typing/io.py +32 -0
  33. grasp_agents/typing/message.py +130 -0
  34. grasp_agents/typing/tool.py +52 -0
  35. grasp_agents/usage_tracker.py +99 -0
  36. grasp_agents/utils.py +151 -0
  37. grasp_agents/workflow/__init__.py +0 -0
  38. grasp_agents/workflow/looped_agent.py +113 -0
  39. grasp_agents/workflow/sequential_agent.py +57 -0
  40. grasp_agents/workflow/workflow_agent.py +69 -0
  41. grasp_agents-0.1.5.dist-info/METADATA +14 -0
  42. grasp_agents-0.1.5.dist-info/RECORD +44 -0
  43. grasp_agents-0.1.5.dist-info/WHEEL +4 -0
  44. grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,181 @@
1
+ import asyncio
2
+ import json
3
+ from collections.abc import Coroutine, Sequence
4
+ from logging import getLogger
5
+ from typing import Any, Generic, Protocol
6
+
7
+ from pydantic import BaseModel
8
+
9
+ from .llm import LLM, LLMSettings
10
+ from .llm_agent_state import LLMAgentState
11
+ from .run_context import CtxT, RunContextWrapper
12
+ from .typing.converters import Converters
13
+ from .typing.message import AssistantMessage, Conversation, Message, ToolMessage
14
+ from .typing.tool import BaseTool, ToolCall, ToolChoice
15
+
16
+ logger = getLogger(__name__)
17
+
18
+
19
+ class ToolCallLoopExitHandler(Protocol[CtxT]):
20
+ def __call__(
21
+ self,
22
+ conversation: Conversation,
23
+ *,
24
+ ctx: RunContextWrapper[CtxT] | None,
25
+ **kwargs: Any,
26
+ ) -> bool: ...
27
+
28
+
29
+ class ToolOrchestrator(Generic[CtxT]):
30
+ def __init__(
31
+ self,
32
+ agent_id: str,
33
+ llm: LLM[LLMSettings, Converters],
34
+ tools: list[BaseTool[BaseModel, BaseModel, CtxT]] | None,
35
+ max_turns: int,
36
+ react_mode: bool = False,
37
+ ) -> None:
38
+ self._agent_id = agent_id
39
+
40
+ self._llm = llm
41
+ self._tools = tools
42
+ self.llm.tools = tools
43
+
44
+ self._max_turns = max_turns
45
+ self._react_mode = react_mode
46
+
47
+ self.tool_call_loop_exit_impl: ToolCallLoopExitHandler[CtxT] | None = None
48
+
49
+ @property
50
+ def agent_id(self) -> str:
51
+ return self._agent_id
52
+
53
+ @property
54
+ def llm(self) -> LLM[LLMSettings, Converters]:
55
+ return self._llm
56
+
57
+ @property
58
+ def tools(self) -> dict[str, BaseTool[BaseModel, BaseModel, CtxT]]:
59
+ return self._llm.tools or {}
60
+
61
+ @property
62
+ def max_turns(self) -> int:
63
+ return self._max_turns
64
+
65
+ def _tool_call_loop_exit(
66
+ self,
67
+ conversation: Conversation,
68
+ *,
69
+ ctx: RunContextWrapper[CtxT] | None = None,
70
+ **kwargs: Any,
71
+ ) -> bool:
72
+ if self.tool_call_loop_exit_impl:
73
+ return self.tool_call_loop_exit_impl(
74
+ conversation=conversation, ctx=ctx, **kwargs
75
+ )
76
+
77
+ assert conversation, "Conversation must not be empty"
78
+ assert isinstance(conversation[-1], AssistantMessage), (
79
+ "Last message in conversation must be an AssistantMessage"
80
+ )
81
+
82
+ return not bool(conversation[-1].tool_calls)
83
+
84
+ async def generate_once(
85
+ self,
86
+ agent_state: LLMAgentState,
87
+ tool_choice: ToolChoice | None = None,
88
+ ctx: RunContextWrapper[CtxT] | None = None,
89
+ ) -> Sequence[AssistantMessage]:
90
+ message_history = agent_state.message_history
91
+ message_batch = await self.llm.generate_message_batch(
92
+ message_history, tool_choice=tool_choice
93
+ )
94
+ message_history.add_message_batch(message_batch)
95
+
96
+ self._print_messages_and_track_usage(message_batch, ctx=ctx)
97
+
98
+ return message_batch
99
+
100
+ async def run_loop(
101
+ self,
102
+ agent_state: LLMAgentState,
103
+ ctx: RunContextWrapper[CtxT] | None = None,
104
+ ) -> None:
105
+ message_history = agent_state.message_history
106
+ assert message_history.batch_size == 1, (
107
+ "Batch size must be 1 for tool call loop"
108
+ )
109
+
110
+ tool_choice: ToolChoice
111
+
112
+ tool_choice = "none" if self._react_mode else "auto"
113
+ gen_message_batch = await self.generate_once(
114
+ agent_state, tool_choice=tool_choice, ctx=ctx
115
+ )
116
+
117
+ turns = 0
118
+
119
+ while True:
120
+ if self._tool_call_loop_exit(
121
+ message_history.batched_conversations[0], ctx=ctx, num_turns=turns
122
+ ):
123
+ return
124
+ if turns >= self.max_turns:
125
+ logger.info(
126
+ f"Max turns reached: {self.max_turns}. Stopping tool call loop."
127
+ )
128
+ return
129
+
130
+ msg = gen_message_batch[0]
131
+ if msg.tool_calls:
132
+ tool_messages = await self.call_tools(msg.tool_calls, ctx=ctx)
133
+ message_history.add_messages(tool_messages)
134
+
135
+ tool_choice = "none" if (self._react_mode and msg.tool_calls) else "auto"
136
+ gen_message_batch = await self.generate_once(
137
+ agent_state, tool_choice=tool_choice, ctx=ctx
138
+ )
139
+
140
+ turns += 1
141
+
142
+ async def call_tools(
143
+ self,
144
+ calls: Sequence[ToolCall],
145
+ ctx: RunContextWrapper[CtxT] | None = None,
146
+ ) -> Sequence[ToolMessage]:
147
+ corouts: list[Coroutine[Any, Any, BaseModel]] = []
148
+ for call in calls:
149
+ tool = self.tools[call.tool_name]
150
+ args = json.loads(call.tool_arguments)
151
+ corouts.append(tool(ctx=ctx, **args))
152
+
153
+ outs = await asyncio.gather(*corouts)
154
+
155
+ tool_messages = [
156
+ ToolMessage.from_tool_output(out, call, model_id=self.agent_id)
157
+ for out, call in zip(outs, calls, strict=False)
158
+ ]
159
+
160
+ self._print_messages(tool_messages, ctx=ctx)
161
+
162
+ return tool_messages
163
+
164
+ def _print_messages(
165
+ self,
166
+ message_batch: Sequence[Message],
167
+ ctx: RunContextWrapper[CtxT] | None = None,
168
+ ) -> None:
169
+ if ctx:
170
+ ctx.printer.print_llm_messages(message_batch, agent_id=self.agent_id)
171
+
172
+ def _print_messages_and_track_usage(
173
+ self,
174
+ message_batch: Sequence[AssistantMessage],
175
+ ctx: RunContextWrapper[CtxT] | None = None,
176
+ ) -> None:
177
+ if ctx:
178
+ self._print_messages(message_batch, ctx=ctx)
179
+ ctx.usage_tracker.update(
180
+ messages=message_batch, model_name=self.llm.model_name
181
+ )
File without changes
@@ -0,0 +1,30 @@
1
+ from abc import ABC
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from .message import AssistantMessage
6
+
7
+
8
+ class CompletionChoice(BaseModel):
9
+ # TODO: add fields
10
+ message: AssistantMessage
11
+ finish_reason: str | None
12
+
13
+
14
+ class CompletionError(BaseModel):
15
+ message: str
16
+ metadata: dict[str, str | None] | None = None
17
+ code: int
18
+
19
+
20
+ class Completion(BaseModel, ABC):
21
+ # TODO: add fields
22
+ choices: list[CompletionChoice]
23
+ model_id: str | None = None
24
+ error: CompletionError | None = None
25
+
26
+
27
+ class CompletionChunk(BaseModel):
28
+ # TODO: add more fields and tool use support (and choices?)
29
+ delta: str | None = None
30
+ model_id: str | None = None
@@ -0,0 +1,116 @@
1
+ import base64
2
+ import re
3
+ from collections.abc import Iterable
4
+ from enum import StrEnum
5
+ from pathlib import Path
6
+ from typing import Annotated, Any, Literal, TypeAlias
7
+
8
+ from pydantic import AnyUrl, BaseModel, Field
9
+
10
+
11
+ class ContentType(StrEnum):
12
+ TEXT = "text"
13
+ IMAGE = "image"
14
+
15
+
16
+ ImageDetail: TypeAlias = Literal["low", "high", "auto"]
17
+
18
+
19
+ class ImageData(BaseModel):
20
+ type: Literal["url", "base64"]
21
+ url: AnyUrl | None = None
22
+ base64: str | None = None
23
+
24
+ # Supported by OpenAI API
25
+ detail: ImageDetail = "high"
26
+
27
+ @classmethod
28
+ def from_base64(cls, base64_encoding: str, **kwargs: Any) -> "ImageData":
29
+ return cls(type="base64", base64=base64_encoding, **kwargs)
30
+
31
+ @classmethod
32
+ def from_path(cls, img_path: str | Path, **kwargs: Any) -> "ImageData":
33
+ image_bytes = Path(img_path).read_bytes()
34
+ base64_encoding = base64.b64encode(image_bytes).decode("utf-8")
35
+ return cls(type="base64", base64=base64_encoding, **kwargs)
36
+
37
+ @classmethod
38
+ def from_url(cls, img_url: str, **kwargs: Any) -> "ImageData":
39
+ return cls(type="url", url=img_url, **kwargs) # type: ignore
40
+
41
+ def to_str(self) -> str:
42
+ if self.type == "url":
43
+ return str(self.url)
44
+ if self.type == "base64":
45
+ return str(self.base64)
46
+ raise ValueError(f"Unsupported image data type: {self.type}")
47
+
48
+
49
+ class ContentPartText(BaseModel):
50
+ type: Literal[ContentType.TEXT] = ContentType.TEXT
51
+ data: str
52
+
53
+
54
+ class ContentPartImage(BaseModel):
55
+ type: Literal[ContentType.IMAGE] = ContentType.IMAGE
56
+ data: ImageData
57
+
58
+
59
+ ContentPart = Annotated[ContentPartText | ContentPartImage, Field(discriminator="type")]
60
+
61
+
62
+ class Content(BaseModel):
63
+ parts: list[ContentPart]
64
+
65
+ @classmethod
66
+ def from_formatted_prompt(
67
+ cls,
68
+ prompt_template: str,
69
+ prompt_args: dict[str, str | ImageData] | None = None,
70
+ ) -> "Content":
71
+ prompt_args = prompt_args or {}
72
+ image_args = {
73
+ arg_name: arg_val
74
+ for arg_name, arg_val in prompt_args.items()
75
+ if isinstance(arg_val, ImageData)
76
+ }
77
+ text_args = {
78
+ arg_name: arg_val
79
+ for arg_name, arg_val in prompt_args.items()
80
+ if isinstance(arg_val, (str, int, float))
81
+ }
82
+
83
+ if not image_args:
84
+ prompt_with_args = prompt_template.format(**text_args)
85
+ return cls(parts=[ContentPartText(data=prompt_with_args)])
86
+
87
+ pattern = r"({})".format("|".join([r"\{" + s + r"\}" for s in image_args]))
88
+ input_prompt_chunks = re.split(pattern, prompt_template)
89
+
90
+ content_parts: list[ContentPart] = []
91
+ for chunk in input_prompt_chunks:
92
+ stripped_chunk = chunk.strip(" \n")
93
+ if re.match(pattern, stripped_chunk):
94
+ image_data = image_args[stripped_chunk[1:-1]]
95
+ content_part = ContentPartImage(data=image_data)
96
+ else:
97
+ text_data = stripped_chunk.format(**text_args)
98
+ content_part = ContentPartText(data=text_data)
99
+ content_parts.append(content_part)
100
+
101
+ return cls(parts=content_parts)
102
+
103
+ @classmethod
104
+ def from_text(cls, text: str) -> "Content":
105
+ return cls(parts=[ContentPartText(data=text)])
106
+
107
+ @classmethod
108
+ def from_content_parts(cls, content_parts: Iterable[str | ImageData]) -> "Content":
109
+ parts: list[ContentPart] = []
110
+ for part in content_parts:
111
+ if isinstance(part, str):
112
+ parts.append(ContentPartText(data=part))
113
+ else:
114
+ parts.append(ContentPartImage(data=part))
115
+
116
+ return cls(parts=parts)
@@ -0,0 +1,118 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import AsyncIterator
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from .completion import Completion, CompletionChunk
8
+ from .content import Content
9
+ from .message import (
10
+ AssistantMessage,
11
+ Message,
12
+ SystemMessage,
13
+ ToolMessage,
14
+ UserMessage,
15
+ )
16
+ from .tool import BaseTool, ToolChoice
17
+
18
+
19
+ class Converters(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def to_system_message(system_message: SystemMessage, **kwargs: Any) -> Any:
23
+ pass
24
+
25
+ @staticmethod
26
+ @abstractmethod
27
+ def from_system_message(raw_message: Any, **kwargs: Any) -> SystemMessage:
28
+ pass
29
+
30
+ @staticmethod
31
+ @abstractmethod
32
+ def to_user_message(user_message: UserMessage, **kwargs: Any) -> Any:
33
+ pass
34
+
35
+ @staticmethod
36
+ @abstractmethod
37
+ def from_user_message(raw_message: Any, **kwargs: Any) -> UserMessage:
38
+ pass
39
+
40
+ @staticmethod
41
+ @abstractmethod
42
+ def to_assistant_message(assistant_message: AssistantMessage, **kwargs: Any) -> Any:
43
+ pass
44
+
45
+ @staticmethod
46
+ @abstractmethod
47
+ def from_assistant_message(
48
+ raw_message: Any, raw_usage: Any, **kwargs: Any
49
+ ) -> AssistantMessage:
50
+ pass
51
+
52
+ @staticmethod
53
+ @abstractmethod
54
+ def to_tool_message(tool_message: ToolMessage, **kwargs: Any) -> Any:
55
+ pass
56
+
57
+ @staticmethod
58
+ @abstractmethod
59
+ def from_tool_message(raw_message: Any, **kwargs: Any) -> ToolMessage:
60
+ pass
61
+
62
+ @classmethod
63
+ def to_message(cls, message: Message, **kwargs: Any) -> Any:
64
+ if isinstance(message, UserMessage):
65
+ return cls.to_user_message(message, **kwargs)
66
+ if isinstance(message, AssistantMessage):
67
+ return cls.to_assistant_message(message, **kwargs)
68
+ if isinstance(message, ToolMessage):
69
+ return cls.to_tool_message(message, **kwargs)
70
+
71
+ return cls.to_system_message(message, **kwargs)
72
+
73
+ @staticmethod
74
+ @abstractmethod
75
+ def to_tool(tool: BaseTool[BaseModel, BaseModel, Any], **kwargs: Any) -> Any:
76
+ pass
77
+
78
+ @staticmethod
79
+ @abstractmethod
80
+ def to_tool_choice(tool_choice: ToolChoice, **kwargs: Any) -> Any:
81
+ pass
82
+
83
+ @staticmethod
84
+ @abstractmethod
85
+ def to_content(content: Content, **kwargs: Any) -> Any:
86
+ pass
87
+
88
+ @staticmethod
89
+ @abstractmethod
90
+ def from_content(raw_content: Any, **kwargs: Any) -> Content:
91
+ pass
92
+
93
+ @staticmethod
94
+ @abstractmethod
95
+ def to_completion(completion: Completion, **kwargs: Any) -> Any:
96
+ pass
97
+
98
+ @staticmethod
99
+ @abstractmethod
100
+ def from_completion(raw_completion: Any, **kwargs: Any) -> Completion:
101
+ pass
102
+
103
+ @staticmethod
104
+ @abstractmethod
105
+ def to_completion_chunk(chunk: CompletionChunk, **kwargs: Any) -> Any:
106
+ pass
107
+
108
+ @staticmethod
109
+ @abstractmethod
110
+ def from_completion_chunk(raw_chunk: Any, **kwargs: Any) -> CompletionChunk:
111
+ pass
112
+
113
+ @staticmethod
114
+ @abstractmethod
115
+ def from_completion_chunk_iterator(
116
+ raw_chunk_iterator: AsyncIterator[Any], **kwargs: Any
117
+ ) -> AsyncIterator[CompletionChunk]:
118
+ pass
@@ -0,0 +1,32 @@
1
+ from collections.abc import Sequence
2
+ from typing import TypeAlias, TypeVar
3
+
4
+ from pydantic import BaseModel
5
+ from pydantic.json_schema import SkipJsonSchema
6
+
7
+ from .content import ImageData
8
+
9
+ AgentID: TypeAlias = str
10
+
11
+
12
+ class AgentPayload(BaseModel):
13
+ # TODO: do we need conversation?
14
+ selected_recipient_ids: SkipJsonSchema[Sequence[AgentID] | None] = None
15
+
16
+
17
+ class AgentState(BaseModel):
18
+ pass
19
+
20
+
21
+ InT = TypeVar("InT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
22
+ OutT = TypeVar("OutT", bound=AgentPayload, covariant=True) # noqa: PLC0105
23
+ StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
24
+
25
+
26
+ class LLMPromptArgs(BaseModel):
27
+ pass
28
+
29
+
30
+ LLMPrompt: TypeAlias = str
31
+ LLMFormattedSystemArgs: TypeAlias = dict[str, str]
32
+ LLMFormattedArgs: TypeAlias = dict[str, str | ImageData]
@@ -0,0 +1,130 @@
1
+ from collections.abc import Hashable, Sequence
2
+ from enum import StrEnum
3
+ from typing import Annotated, Literal, TypeAlias
4
+ from uuid import uuid4
5
+
6
+ from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
7
+
8
+ from .content import Content, ImageData
9
+ from .tool import ToolCall
10
+
11
+
12
+ class Role(StrEnum):
13
+ USER = "user"
14
+ SYSTEM = "system"
15
+ ASSISTANT = "assistant"
16
+ TOOL = "tool"
17
+
18
+
19
+ class Usage(BaseModel):
20
+ input_tokens: NonNegativeInt = 0
21
+ output_tokens: NonNegativeInt = 0
22
+ reasoning_tokens: NonNegativeInt | None = None
23
+ cached_tokens: NonNegativeInt | None = None
24
+ cost: NonNegativeFloat | None = None
25
+
26
+ def __add__(self, add_usage: "Usage") -> "Usage":
27
+ input_tokens = self.input_tokens + add_usage.input_tokens
28
+ output_tokens = self.output_tokens + add_usage.output_tokens
29
+ if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
30
+ reasoning_tokens = (self.reasoning_tokens or 0) + (
31
+ add_usage.reasoning_tokens or 0
32
+ )
33
+ else:
34
+ reasoning_tokens = None
35
+
36
+ if self.cached_tokens is not None or add_usage.cached_tokens is not None:
37
+ cached_tokens = (self.cached_tokens or 0) + (add_usage.cached_tokens or 0)
38
+ else:
39
+ cached_tokens = None
40
+
41
+ cost = (
42
+ (self.cost or 0.0) + add_usage.cost
43
+ if (add_usage.cost is not None)
44
+ else None
45
+ )
46
+ return Usage(
47
+ input_tokens=input_tokens,
48
+ output_tokens=output_tokens,
49
+ reasoning_tokens=reasoning_tokens,
50
+ cached_tokens=cached_tokens,
51
+ cost=cost,
52
+ )
53
+
54
+
55
+ class MessageBase(BaseModel):
56
+ message_id: Hashable = Field(default_factory=lambda: str(uuid4())[:8])
57
+ model_id: str | None = None
58
+
59
+
60
+ class AssistantMessage(MessageBase):
61
+ role: Literal[Role.ASSISTANT] = Role.ASSISTANT
62
+ content: str
63
+ usage: Usage | None = None
64
+ tool_calls: Sequence[ToolCall] | None = None
65
+ refusal: str | None = None
66
+
67
+
68
+ class UserMessage(MessageBase):
69
+ role: Literal[Role.USER] = Role.USER
70
+ content: Content
71
+
72
+ @classmethod
73
+ def from_text(cls, text: str, model_id: str | None = None) -> "UserMessage":
74
+ return cls(content=Content.from_text(text), model_id=model_id)
75
+
76
+ @classmethod
77
+ def from_formatted_prompt(
78
+ cls,
79
+ prompt_template: str,
80
+ prompt_args: dict[str, str | ImageData] | None = None,
81
+ model_id: str | None = None,
82
+ ) -> "UserMessage":
83
+ content = Content.from_formatted_prompt(
84
+ prompt_template=prompt_template, prompt_args=prompt_args
85
+ )
86
+
87
+ return cls(content=content, model_id=model_id)
88
+
89
+ @classmethod
90
+ def from_content_parts(
91
+ cls,
92
+ content_parts: Sequence[str | ImageData],
93
+ model_id: str | None = None,
94
+ ) -> "UserMessage":
95
+ content = Content.from_content_parts(content_parts)
96
+
97
+ return cls(content=content, model_id=model_id)
98
+
99
+
100
+ class SystemMessage(MessageBase):
101
+ role: Literal[Role.SYSTEM] = Role.SYSTEM
102
+ content: str
103
+
104
+
105
+ class ToolMessage(MessageBase):
106
+ role: Literal[Role.TOOL] = Role.TOOL
107
+ content: str
108
+ tool_call_id: str
109
+
110
+ @classmethod
111
+ def from_tool_output(
112
+ cls,
113
+ tool_output: BaseModel,
114
+ tool_call: ToolCall,
115
+ model_id: str | None = None,
116
+ indent: int = 2,
117
+ ) -> "ToolMessage":
118
+ return cls(
119
+ content=tool_output.model_dump_json(indent=indent),
120
+ tool_call_id=tool_call.id,
121
+ model_id=model_id,
122
+ )
123
+
124
+
125
+ Message = Annotated[
126
+ AssistantMessage | UserMessage | SystemMessage | ToolMessage,
127
+ Field(discriminator="role"),
128
+ ]
129
+
130
+ Conversation: TypeAlias = list[Message]
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
5
+
6
+ from pydantic import BaseModel
7
+
8
+ if TYPE_CHECKING:
9
+ from ..run_context import CtxT, RunContextWrapper
10
+ else:
11
+ CtxT = TypeVar("CtxT")
12
+
13
+ class RunContextWrapper(Generic[CtxT]):
14
+ """Runtime placeholder so RunContextWrapper[CtxT] works"""
15
+
16
+
17
+ ToolInT = TypeVar("ToolInT", bound=BaseModel, contravariant=True) # noqa: PLC0105
18
+ ToolOutT = TypeVar("ToolOutT", bound=BaseModel, covariant=True) # noqa: PLC0105
19
+
20
+
21
+ class ToolCall(BaseModel):
22
+ id: str
23
+ tool_name: str
24
+ tool_arguments: str
25
+
26
+
27
+ class BaseTool(BaseModel, ABC, Generic[ToolInT, ToolOutT, CtxT]):
28
+ name: str
29
+ description: str
30
+ in_schema: type[ToolInT]
31
+ out_schema: type[ToolOutT]
32
+
33
+ # Supported by OpenAI API
34
+ strict: bool | None = None
35
+
36
+ @abstractmethod
37
+ async def run(
38
+ self, inp: ToolInT, ctx: RunContextWrapper[CtxT] | None = None
39
+ ) -> ToolOutT:
40
+ pass
41
+
42
+ async def __call__(
43
+ self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
44
+ ) -> ToolOutT:
45
+ result = await self.run(self.in_schema(**kwargs), ctx=ctx)
46
+
47
+ return self.out_schema.model_validate(result)
48
+
49
+
50
+ ToolChoice: TypeAlias = (
51
+ Literal["none", "auto", "required"] | BaseTool[BaseModel, BaseModel, Any]
52
+ )