grasp_agents 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +28 -0
- grasp_agents/agent_message_pool.py +94 -0
- grasp_agents/base_agent.py +72 -0
- grasp_agents/cloud_llm.py +353 -0
- grasp_agents/comm_agent.py +230 -0
- grasp_agents/costs_dict.yaml +122 -0
- grasp_agents/data_retrieval/__init__.py +7 -0
- grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
- grasp_agents/data_retrieval/types.py +57 -0
- grasp_agents/data_retrieval/utils.py +57 -0
- grasp_agents/grasp_logging.py +36 -0
- grasp_agents/http_client.py +24 -0
- grasp_agents/llm.py +106 -0
- grasp_agents/llm_agent.py +361 -0
- grasp_agents/llm_agent_state.py +73 -0
- grasp_agents/memory.py +150 -0
- grasp_agents/openai/__init__.py +83 -0
- grasp_agents/openai/completion_converters.py +49 -0
- grasp_agents/openai/content_converters.py +80 -0
- grasp_agents/openai/converters.py +170 -0
- grasp_agents/openai/message_converters.py +155 -0
- grasp_agents/openai/openai_llm.py +179 -0
- grasp_agents/openai/tool_converters.py +37 -0
- grasp_agents/printer.py +156 -0
- grasp_agents/prompt_builder.py +204 -0
- grasp_agents/run_context.py +90 -0
- grasp_agents/tool_orchestrator.py +181 -0
- grasp_agents/typing/__init__.py +0 -0
- grasp_agents/typing/completion.py +30 -0
- grasp_agents/typing/content.py +116 -0
- grasp_agents/typing/converters.py +118 -0
- grasp_agents/typing/io.py +32 -0
- grasp_agents/typing/message.py +130 -0
- grasp_agents/typing/tool.py +52 -0
- grasp_agents/usage_tracker.py +99 -0
- grasp_agents/utils.py +151 -0
- grasp_agents/workflow/__init__.py +0 -0
- grasp_agents/workflow/looped_agent.py +113 -0
- grasp_agents/workflow/sequential_agent.py +57 -0
- grasp_agents/workflow/workflow_agent.py +69 -0
- grasp_agents-0.1.5.dist-info/METADATA +14 -0
- grasp_agents-0.1.5.dist-info/RECORD +44 -0
- grasp_agents-0.1.5.dist-info/WHEEL +4 -0
- grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,181 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
from collections.abc import Coroutine, Sequence
|
4
|
+
from logging import getLogger
|
5
|
+
from typing import Any, Generic, Protocol
|
6
|
+
|
7
|
+
from pydantic import BaseModel
|
8
|
+
|
9
|
+
from .llm import LLM, LLMSettings
|
10
|
+
from .llm_agent_state import LLMAgentState
|
11
|
+
from .run_context import CtxT, RunContextWrapper
|
12
|
+
from .typing.converters import Converters
|
13
|
+
from .typing.message import AssistantMessage, Conversation, Message, ToolMessage
|
14
|
+
from .typing.tool import BaseTool, ToolCall, ToolChoice
|
15
|
+
|
16
|
+
logger = getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class ToolCallLoopExitHandler(Protocol[CtxT]):
|
20
|
+
def __call__(
|
21
|
+
self,
|
22
|
+
conversation: Conversation,
|
23
|
+
*,
|
24
|
+
ctx: RunContextWrapper[CtxT] | None,
|
25
|
+
**kwargs: Any,
|
26
|
+
) -> bool: ...
|
27
|
+
|
28
|
+
|
29
|
+
class ToolOrchestrator(Generic[CtxT]):
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
agent_id: str,
|
33
|
+
llm: LLM[LLMSettings, Converters],
|
34
|
+
tools: list[BaseTool[BaseModel, BaseModel, CtxT]] | None,
|
35
|
+
max_turns: int,
|
36
|
+
react_mode: bool = False,
|
37
|
+
) -> None:
|
38
|
+
self._agent_id = agent_id
|
39
|
+
|
40
|
+
self._llm = llm
|
41
|
+
self._tools = tools
|
42
|
+
self.llm.tools = tools
|
43
|
+
|
44
|
+
self._max_turns = max_turns
|
45
|
+
self._react_mode = react_mode
|
46
|
+
|
47
|
+
self.tool_call_loop_exit_impl: ToolCallLoopExitHandler[CtxT] | None = None
|
48
|
+
|
49
|
+
@property
|
50
|
+
def agent_id(self) -> str:
|
51
|
+
return self._agent_id
|
52
|
+
|
53
|
+
@property
|
54
|
+
def llm(self) -> LLM[LLMSettings, Converters]:
|
55
|
+
return self._llm
|
56
|
+
|
57
|
+
@property
|
58
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, BaseModel, CtxT]]:
|
59
|
+
return self._llm.tools or {}
|
60
|
+
|
61
|
+
@property
|
62
|
+
def max_turns(self) -> int:
|
63
|
+
return self._max_turns
|
64
|
+
|
65
|
+
def _tool_call_loop_exit(
|
66
|
+
self,
|
67
|
+
conversation: Conversation,
|
68
|
+
*,
|
69
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
70
|
+
**kwargs: Any,
|
71
|
+
) -> bool:
|
72
|
+
if self.tool_call_loop_exit_impl:
|
73
|
+
return self.tool_call_loop_exit_impl(
|
74
|
+
conversation=conversation, ctx=ctx, **kwargs
|
75
|
+
)
|
76
|
+
|
77
|
+
assert conversation, "Conversation must not be empty"
|
78
|
+
assert isinstance(conversation[-1], AssistantMessage), (
|
79
|
+
"Last message in conversation must be an AssistantMessage"
|
80
|
+
)
|
81
|
+
|
82
|
+
return not bool(conversation[-1].tool_calls)
|
83
|
+
|
84
|
+
async def generate_once(
|
85
|
+
self,
|
86
|
+
agent_state: LLMAgentState,
|
87
|
+
tool_choice: ToolChoice | None = None,
|
88
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
89
|
+
) -> Sequence[AssistantMessage]:
|
90
|
+
message_history = agent_state.message_history
|
91
|
+
message_batch = await self.llm.generate_message_batch(
|
92
|
+
message_history, tool_choice=tool_choice
|
93
|
+
)
|
94
|
+
message_history.add_message_batch(message_batch)
|
95
|
+
|
96
|
+
self._print_messages_and_track_usage(message_batch, ctx=ctx)
|
97
|
+
|
98
|
+
return message_batch
|
99
|
+
|
100
|
+
async def run_loop(
|
101
|
+
self,
|
102
|
+
agent_state: LLMAgentState,
|
103
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
104
|
+
) -> None:
|
105
|
+
message_history = agent_state.message_history
|
106
|
+
assert message_history.batch_size == 1, (
|
107
|
+
"Batch size must be 1 for tool call loop"
|
108
|
+
)
|
109
|
+
|
110
|
+
tool_choice: ToolChoice
|
111
|
+
|
112
|
+
tool_choice = "none" if self._react_mode else "auto"
|
113
|
+
gen_message_batch = await self.generate_once(
|
114
|
+
agent_state, tool_choice=tool_choice, ctx=ctx
|
115
|
+
)
|
116
|
+
|
117
|
+
turns = 0
|
118
|
+
|
119
|
+
while True:
|
120
|
+
if self._tool_call_loop_exit(
|
121
|
+
message_history.batched_conversations[0], ctx=ctx, num_turns=turns
|
122
|
+
):
|
123
|
+
return
|
124
|
+
if turns >= self.max_turns:
|
125
|
+
logger.info(
|
126
|
+
f"Max turns reached: {self.max_turns}. Stopping tool call loop."
|
127
|
+
)
|
128
|
+
return
|
129
|
+
|
130
|
+
msg = gen_message_batch[0]
|
131
|
+
if msg.tool_calls:
|
132
|
+
tool_messages = await self.call_tools(msg.tool_calls, ctx=ctx)
|
133
|
+
message_history.add_messages(tool_messages)
|
134
|
+
|
135
|
+
tool_choice = "none" if (self._react_mode and msg.tool_calls) else "auto"
|
136
|
+
gen_message_batch = await self.generate_once(
|
137
|
+
agent_state, tool_choice=tool_choice, ctx=ctx
|
138
|
+
)
|
139
|
+
|
140
|
+
turns += 1
|
141
|
+
|
142
|
+
async def call_tools(
|
143
|
+
self,
|
144
|
+
calls: Sequence[ToolCall],
|
145
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
146
|
+
) -> Sequence[ToolMessage]:
|
147
|
+
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
148
|
+
for call in calls:
|
149
|
+
tool = self.tools[call.tool_name]
|
150
|
+
args = json.loads(call.tool_arguments)
|
151
|
+
corouts.append(tool(ctx=ctx, **args))
|
152
|
+
|
153
|
+
outs = await asyncio.gather(*corouts)
|
154
|
+
|
155
|
+
tool_messages = [
|
156
|
+
ToolMessage.from_tool_output(out, call, model_id=self.agent_id)
|
157
|
+
for out, call in zip(outs, calls, strict=False)
|
158
|
+
]
|
159
|
+
|
160
|
+
self._print_messages(tool_messages, ctx=ctx)
|
161
|
+
|
162
|
+
return tool_messages
|
163
|
+
|
164
|
+
def _print_messages(
|
165
|
+
self,
|
166
|
+
message_batch: Sequence[Message],
|
167
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
168
|
+
) -> None:
|
169
|
+
if ctx:
|
170
|
+
ctx.printer.print_llm_messages(message_batch, agent_id=self.agent_id)
|
171
|
+
|
172
|
+
def _print_messages_and_track_usage(
|
173
|
+
self,
|
174
|
+
message_batch: Sequence[AssistantMessage],
|
175
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
176
|
+
) -> None:
|
177
|
+
if ctx:
|
178
|
+
self._print_messages(message_batch, ctx=ctx)
|
179
|
+
ctx.usage_tracker.update(
|
180
|
+
messages=message_batch, model_name=self.llm.model_name
|
181
|
+
)
|
File without changes
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from .message import AssistantMessage
|
6
|
+
|
7
|
+
|
8
|
+
class CompletionChoice(BaseModel):
|
9
|
+
# TODO: add fields
|
10
|
+
message: AssistantMessage
|
11
|
+
finish_reason: str | None
|
12
|
+
|
13
|
+
|
14
|
+
class CompletionError(BaseModel):
|
15
|
+
message: str
|
16
|
+
metadata: dict[str, str | None] | None = None
|
17
|
+
code: int
|
18
|
+
|
19
|
+
|
20
|
+
class Completion(BaseModel, ABC):
|
21
|
+
# TODO: add fields
|
22
|
+
choices: list[CompletionChoice]
|
23
|
+
model_id: str | None = None
|
24
|
+
error: CompletionError | None = None
|
25
|
+
|
26
|
+
|
27
|
+
class CompletionChunk(BaseModel):
|
28
|
+
# TODO: add more fields and tool use support (and choices?)
|
29
|
+
delta: str | None = None
|
30
|
+
model_id: str | None = None
|
@@ -0,0 +1,116 @@
|
|
1
|
+
import base64
|
2
|
+
import re
|
3
|
+
from collections.abc import Iterable
|
4
|
+
from enum import StrEnum
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Annotated, Any, Literal, TypeAlias
|
7
|
+
|
8
|
+
from pydantic import AnyUrl, BaseModel, Field
|
9
|
+
|
10
|
+
|
11
|
+
class ContentType(StrEnum):
|
12
|
+
TEXT = "text"
|
13
|
+
IMAGE = "image"
|
14
|
+
|
15
|
+
|
16
|
+
ImageDetail: TypeAlias = Literal["low", "high", "auto"]
|
17
|
+
|
18
|
+
|
19
|
+
class ImageData(BaseModel):
|
20
|
+
type: Literal["url", "base64"]
|
21
|
+
url: AnyUrl | None = None
|
22
|
+
base64: str | None = None
|
23
|
+
|
24
|
+
# Supported by OpenAI API
|
25
|
+
detail: ImageDetail = "high"
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def from_base64(cls, base64_encoding: str, **kwargs: Any) -> "ImageData":
|
29
|
+
return cls(type="base64", base64=base64_encoding, **kwargs)
|
30
|
+
|
31
|
+
@classmethod
|
32
|
+
def from_path(cls, img_path: str | Path, **kwargs: Any) -> "ImageData":
|
33
|
+
image_bytes = Path(img_path).read_bytes()
|
34
|
+
base64_encoding = base64.b64encode(image_bytes).decode("utf-8")
|
35
|
+
return cls(type="base64", base64=base64_encoding, **kwargs)
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def from_url(cls, img_url: str, **kwargs: Any) -> "ImageData":
|
39
|
+
return cls(type="url", url=img_url, **kwargs) # type: ignore
|
40
|
+
|
41
|
+
def to_str(self) -> str:
|
42
|
+
if self.type == "url":
|
43
|
+
return str(self.url)
|
44
|
+
if self.type == "base64":
|
45
|
+
return str(self.base64)
|
46
|
+
raise ValueError(f"Unsupported image data type: {self.type}")
|
47
|
+
|
48
|
+
|
49
|
+
class ContentPartText(BaseModel):
|
50
|
+
type: Literal[ContentType.TEXT] = ContentType.TEXT
|
51
|
+
data: str
|
52
|
+
|
53
|
+
|
54
|
+
class ContentPartImage(BaseModel):
|
55
|
+
type: Literal[ContentType.IMAGE] = ContentType.IMAGE
|
56
|
+
data: ImageData
|
57
|
+
|
58
|
+
|
59
|
+
ContentPart = Annotated[ContentPartText | ContentPartImage, Field(discriminator="type")]
|
60
|
+
|
61
|
+
|
62
|
+
class Content(BaseModel):
|
63
|
+
parts: list[ContentPart]
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def from_formatted_prompt(
|
67
|
+
cls,
|
68
|
+
prompt_template: str,
|
69
|
+
prompt_args: dict[str, str | ImageData] | None = None,
|
70
|
+
) -> "Content":
|
71
|
+
prompt_args = prompt_args or {}
|
72
|
+
image_args = {
|
73
|
+
arg_name: arg_val
|
74
|
+
for arg_name, arg_val in prompt_args.items()
|
75
|
+
if isinstance(arg_val, ImageData)
|
76
|
+
}
|
77
|
+
text_args = {
|
78
|
+
arg_name: arg_val
|
79
|
+
for arg_name, arg_val in prompt_args.items()
|
80
|
+
if isinstance(arg_val, (str, int, float))
|
81
|
+
}
|
82
|
+
|
83
|
+
if not image_args:
|
84
|
+
prompt_with_args = prompt_template.format(**text_args)
|
85
|
+
return cls(parts=[ContentPartText(data=prompt_with_args)])
|
86
|
+
|
87
|
+
pattern = r"({})".format("|".join([r"\{" + s + r"\}" for s in image_args]))
|
88
|
+
input_prompt_chunks = re.split(pattern, prompt_template)
|
89
|
+
|
90
|
+
content_parts: list[ContentPart] = []
|
91
|
+
for chunk in input_prompt_chunks:
|
92
|
+
stripped_chunk = chunk.strip(" \n")
|
93
|
+
if re.match(pattern, stripped_chunk):
|
94
|
+
image_data = image_args[stripped_chunk[1:-1]]
|
95
|
+
content_part = ContentPartImage(data=image_data)
|
96
|
+
else:
|
97
|
+
text_data = stripped_chunk.format(**text_args)
|
98
|
+
content_part = ContentPartText(data=text_data)
|
99
|
+
content_parts.append(content_part)
|
100
|
+
|
101
|
+
return cls(parts=content_parts)
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
def from_text(cls, text: str) -> "Content":
|
105
|
+
return cls(parts=[ContentPartText(data=text)])
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def from_content_parts(cls, content_parts: Iterable[str | ImageData]) -> "Content":
|
109
|
+
parts: list[ContentPart] = []
|
110
|
+
for part in content_parts:
|
111
|
+
if isinstance(part, str):
|
112
|
+
parts.append(ContentPartText(data=part))
|
113
|
+
else:
|
114
|
+
parts.append(ContentPartImage(data=part))
|
115
|
+
|
116
|
+
return cls(parts=parts)
|
@@ -0,0 +1,118 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from collections.abc import AsyncIterator
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from .completion import Completion, CompletionChunk
|
8
|
+
from .content import Content
|
9
|
+
from .message import (
|
10
|
+
AssistantMessage,
|
11
|
+
Message,
|
12
|
+
SystemMessage,
|
13
|
+
ToolMessage,
|
14
|
+
UserMessage,
|
15
|
+
)
|
16
|
+
from .tool import BaseTool, ToolChoice
|
17
|
+
|
18
|
+
|
19
|
+
class Converters(ABC):
|
20
|
+
@staticmethod
|
21
|
+
@abstractmethod
|
22
|
+
def to_system_message(system_message: SystemMessage, **kwargs: Any) -> Any:
|
23
|
+
pass
|
24
|
+
|
25
|
+
@staticmethod
|
26
|
+
@abstractmethod
|
27
|
+
def from_system_message(raw_message: Any, **kwargs: Any) -> SystemMessage:
|
28
|
+
pass
|
29
|
+
|
30
|
+
@staticmethod
|
31
|
+
@abstractmethod
|
32
|
+
def to_user_message(user_message: UserMessage, **kwargs: Any) -> Any:
|
33
|
+
pass
|
34
|
+
|
35
|
+
@staticmethod
|
36
|
+
@abstractmethod
|
37
|
+
def from_user_message(raw_message: Any, **kwargs: Any) -> UserMessage:
|
38
|
+
pass
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
@abstractmethod
|
42
|
+
def to_assistant_message(assistant_message: AssistantMessage, **kwargs: Any) -> Any:
|
43
|
+
pass
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
@abstractmethod
|
47
|
+
def from_assistant_message(
|
48
|
+
raw_message: Any, raw_usage: Any, **kwargs: Any
|
49
|
+
) -> AssistantMessage:
|
50
|
+
pass
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
@abstractmethod
|
54
|
+
def to_tool_message(tool_message: ToolMessage, **kwargs: Any) -> Any:
|
55
|
+
pass
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
@abstractmethod
|
59
|
+
def from_tool_message(raw_message: Any, **kwargs: Any) -> ToolMessage:
|
60
|
+
pass
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def to_message(cls, message: Message, **kwargs: Any) -> Any:
|
64
|
+
if isinstance(message, UserMessage):
|
65
|
+
return cls.to_user_message(message, **kwargs)
|
66
|
+
if isinstance(message, AssistantMessage):
|
67
|
+
return cls.to_assistant_message(message, **kwargs)
|
68
|
+
if isinstance(message, ToolMessage):
|
69
|
+
return cls.to_tool_message(message, **kwargs)
|
70
|
+
|
71
|
+
return cls.to_system_message(message, **kwargs)
|
72
|
+
|
73
|
+
@staticmethod
|
74
|
+
@abstractmethod
|
75
|
+
def to_tool(tool: BaseTool[BaseModel, BaseModel, Any], **kwargs: Any) -> Any:
|
76
|
+
pass
|
77
|
+
|
78
|
+
@staticmethod
|
79
|
+
@abstractmethod
|
80
|
+
def to_tool_choice(tool_choice: ToolChoice, **kwargs: Any) -> Any:
|
81
|
+
pass
|
82
|
+
|
83
|
+
@staticmethod
|
84
|
+
@abstractmethod
|
85
|
+
def to_content(content: Content, **kwargs: Any) -> Any:
|
86
|
+
pass
|
87
|
+
|
88
|
+
@staticmethod
|
89
|
+
@abstractmethod
|
90
|
+
def from_content(raw_content: Any, **kwargs: Any) -> Content:
|
91
|
+
pass
|
92
|
+
|
93
|
+
@staticmethod
|
94
|
+
@abstractmethod
|
95
|
+
def to_completion(completion: Completion, **kwargs: Any) -> Any:
|
96
|
+
pass
|
97
|
+
|
98
|
+
@staticmethod
|
99
|
+
@abstractmethod
|
100
|
+
def from_completion(raw_completion: Any, **kwargs: Any) -> Completion:
|
101
|
+
pass
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
@abstractmethod
|
105
|
+
def to_completion_chunk(chunk: CompletionChunk, **kwargs: Any) -> Any:
|
106
|
+
pass
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
@abstractmethod
|
110
|
+
def from_completion_chunk(raw_chunk: Any, **kwargs: Any) -> CompletionChunk:
|
111
|
+
pass
|
112
|
+
|
113
|
+
@staticmethod
|
114
|
+
@abstractmethod
|
115
|
+
def from_completion_chunk_iterator(
|
116
|
+
raw_chunk_iterator: AsyncIterator[Any], **kwargs: Any
|
117
|
+
) -> AsyncIterator[CompletionChunk]:
|
118
|
+
pass
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
from typing import TypeAlias, TypeVar
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
from pydantic.json_schema import SkipJsonSchema
|
6
|
+
|
7
|
+
from .content import ImageData
|
8
|
+
|
9
|
+
AgentID: TypeAlias = str
|
10
|
+
|
11
|
+
|
12
|
+
class AgentPayload(BaseModel):
|
13
|
+
# TODO: do we need conversation?
|
14
|
+
selected_recipient_ids: SkipJsonSchema[Sequence[AgentID] | None] = None
|
15
|
+
|
16
|
+
|
17
|
+
class AgentState(BaseModel):
|
18
|
+
pass
|
19
|
+
|
20
|
+
|
21
|
+
InT = TypeVar("InT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
|
22
|
+
OutT = TypeVar("OutT", bound=AgentPayload, covariant=True) # noqa: PLC0105
|
23
|
+
StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
|
24
|
+
|
25
|
+
|
26
|
+
class LLMPromptArgs(BaseModel):
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
LLMPrompt: TypeAlias = str
|
31
|
+
LLMFormattedSystemArgs: TypeAlias = dict[str, str]
|
32
|
+
LLMFormattedArgs: TypeAlias = dict[str, str | ImageData]
|
@@ -0,0 +1,130 @@
|
|
1
|
+
from collections.abc import Hashable, Sequence
|
2
|
+
from enum import StrEnum
|
3
|
+
from typing import Annotated, Literal, TypeAlias
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
|
7
|
+
|
8
|
+
from .content import Content, ImageData
|
9
|
+
from .tool import ToolCall
|
10
|
+
|
11
|
+
|
12
|
+
class Role(StrEnum):
|
13
|
+
USER = "user"
|
14
|
+
SYSTEM = "system"
|
15
|
+
ASSISTANT = "assistant"
|
16
|
+
TOOL = "tool"
|
17
|
+
|
18
|
+
|
19
|
+
class Usage(BaseModel):
|
20
|
+
input_tokens: NonNegativeInt = 0
|
21
|
+
output_tokens: NonNegativeInt = 0
|
22
|
+
reasoning_tokens: NonNegativeInt | None = None
|
23
|
+
cached_tokens: NonNegativeInt | None = None
|
24
|
+
cost: NonNegativeFloat | None = None
|
25
|
+
|
26
|
+
def __add__(self, add_usage: "Usage") -> "Usage":
|
27
|
+
input_tokens = self.input_tokens + add_usage.input_tokens
|
28
|
+
output_tokens = self.output_tokens + add_usage.output_tokens
|
29
|
+
if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
|
30
|
+
reasoning_tokens = (self.reasoning_tokens or 0) + (
|
31
|
+
add_usage.reasoning_tokens or 0
|
32
|
+
)
|
33
|
+
else:
|
34
|
+
reasoning_tokens = None
|
35
|
+
|
36
|
+
if self.cached_tokens is not None or add_usage.cached_tokens is not None:
|
37
|
+
cached_tokens = (self.cached_tokens or 0) + (add_usage.cached_tokens or 0)
|
38
|
+
else:
|
39
|
+
cached_tokens = None
|
40
|
+
|
41
|
+
cost = (
|
42
|
+
(self.cost or 0.0) + add_usage.cost
|
43
|
+
if (add_usage.cost is not None)
|
44
|
+
else None
|
45
|
+
)
|
46
|
+
return Usage(
|
47
|
+
input_tokens=input_tokens,
|
48
|
+
output_tokens=output_tokens,
|
49
|
+
reasoning_tokens=reasoning_tokens,
|
50
|
+
cached_tokens=cached_tokens,
|
51
|
+
cost=cost,
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
class MessageBase(BaseModel):
|
56
|
+
message_id: Hashable = Field(default_factory=lambda: str(uuid4())[:8])
|
57
|
+
model_id: str | None = None
|
58
|
+
|
59
|
+
|
60
|
+
class AssistantMessage(MessageBase):
|
61
|
+
role: Literal[Role.ASSISTANT] = Role.ASSISTANT
|
62
|
+
content: str
|
63
|
+
usage: Usage | None = None
|
64
|
+
tool_calls: Sequence[ToolCall] | None = None
|
65
|
+
refusal: str | None = None
|
66
|
+
|
67
|
+
|
68
|
+
class UserMessage(MessageBase):
|
69
|
+
role: Literal[Role.USER] = Role.USER
|
70
|
+
content: Content
|
71
|
+
|
72
|
+
@classmethod
|
73
|
+
def from_text(cls, text: str, model_id: str | None = None) -> "UserMessage":
|
74
|
+
return cls(content=Content.from_text(text), model_id=model_id)
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def from_formatted_prompt(
|
78
|
+
cls,
|
79
|
+
prompt_template: str,
|
80
|
+
prompt_args: dict[str, str | ImageData] | None = None,
|
81
|
+
model_id: str | None = None,
|
82
|
+
) -> "UserMessage":
|
83
|
+
content = Content.from_formatted_prompt(
|
84
|
+
prompt_template=prompt_template, prompt_args=prompt_args
|
85
|
+
)
|
86
|
+
|
87
|
+
return cls(content=content, model_id=model_id)
|
88
|
+
|
89
|
+
@classmethod
|
90
|
+
def from_content_parts(
|
91
|
+
cls,
|
92
|
+
content_parts: Sequence[str | ImageData],
|
93
|
+
model_id: str | None = None,
|
94
|
+
) -> "UserMessage":
|
95
|
+
content = Content.from_content_parts(content_parts)
|
96
|
+
|
97
|
+
return cls(content=content, model_id=model_id)
|
98
|
+
|
99
|
+
|
100
|
+
class SystemMessage(MessageBase):
|
101
|
+
role: Literal[Role.SYSTEM] = Role.SYSTEM
|
102
|
+
content: str
|
103
|
+
|
104
|
+
|
105
|
+
class ToolMessage(MessageBase):
|
106
|
+
role: Literal[Role.TOOL] = Role.TOOL
|
107
|
+
content: str
|
108
|
+
tool_call_id: str
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def from_tool_output(
|
112
|
+
cls,
|
113
|
+
tool_output: BaseModel,
|
114
|
+
tool_call: ToolCall,
|
115
|
+
model_id: str | None = None,
|
116
|
+
indent: int = 2,
|
117
|
+
) -> "ToolMessage":
|
118
|
+
return cls(
|
119
|
+
content=tool_output.model_dump_json(indent=indent),
|
120
|
+
tool_call_id=tool_call.id,
|
121
|
+
model_id=model_id,
|
122
|
+
)
|
123
|
+
|
124
|
+
|
125
|
+
Message = Annotated[
|
126
|
+
AssistantMessage | UserMessage | SystemMessage | ToolMessage,
|
127
|
+
Field(discriminator="role"),
|
128
|
+
]
|
129
|
+
|
130
|
+
Conversation: TypeAlias = list[Message]
|
@@ -0,0 +1,52 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar
|
5
|
+
|
6
|
+
from pydantic import BaseModel
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from ..run_context import CtxT, RunContextWrapper
|
10
|
+
else:
|
11
|
+
CtxT = TypeVar("CtxT")
|
12
|
+
|
13
|
+
class RunContextWrapper(Generic[CtxT]):
|
14
|
+
"""Runtime placeholder so RunContextWrapper[CtxT] works"""
|
15
|
+
|
16
|
+
|
17
|
+
ToolInT = TypeVar("ToolInT", bound=BaseModel, contravariant=True) # noqa: PLC0105
|
18
|
+
ToolOutT = TypeVar("ToolOutT", bound=BaseModel, covariant=True) # noqa: PLC0105
|
19
|
+
|
20
|
+
|
21
|
+
class ToolCall(BaseModel):
|
22
|
+
id: str
|
23
|
+
tool_name: str
|
24
|
+
tool_arguments: str
|
25
|
+
|
26
|
+
|
27
|
+
class BaseTool(BaseModel, ABC, Generic[ToolInT, ToolOutT, CtxT]):
|
28
|
+
name: str
|
29
|
+
description: str
|
30
|
+
in_schema: type[ToolInT]
|
31
|
+
out_schema: type[ToolOutT]
|
32
|
+
|
33
|
+
# Supported by OpenAI API
|
34
|
+
strict: bool | None = None
|
35
|
+
|
36
|
+
@abstractmethod
|
37
|
+
async def run(
|
38
|
+
self, inp: ToolInT, ctx: RunContextWrapper[CtxT] | None = None
|
39
|
+
) -> ToolOutT:
|
40
|
+
pass
|
41
|
+
|
42
|
+
async def __call__(
|
43
|
+
self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
|
44
|
+
) -> ToolOutT:
|
45
|
+
result = await self.run(self.in_schema(**kwargs), ctx=ctx)
|
46
|
+
|
47
|
+
return self.out_schema.model_validate(result)
|
48
|
+
|
49
|
+
|
50
|
+
ToolChoice: TypeAlias = (
|
51
|
+
Literal["none", "auto", "required"] | BaseTool[BaseModel, BaseModel, Any]
|
52
|
+
)
|