grasp_agents 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/agent_message.py +28 -0
  2. grasp_agents/agent_message_pool.py +94 -0
  3. grasp_agents/base_agent.py +72 -0
  4. grasp_agents/cloud_llm.py +353 -0
  5. grasp_agents/comm_agent.py +230 -0
  6. grasp_agents/costs_dict.yaml +122 -0
  7. grasp_agents/data_retrieval/__init__.py +7 -0
  8. grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
  9. grasp_agents/data_retrieval/types.py +57 -0
  10. grasp_agents/data_retrieval/utils.py +57 -0
  11. grasp_agents/grasp_logging.py +36 -0
  12. grasp_agents/http_client.py +24 -0
  13. grasp_agents/llm.py +106 -0
  14. grasp_agents/llm_agent.py +361 -0
  15. grasp_agents/llm_agent_state.py +73 -0
  16. grasp_agents/memory.py +150 -0
  17. grasp_agents/openai/__init__.py +83 -0
  18. grasp_agents/openai/completion_converters.py +49 -0
  19. grasp_agents/openai/content_converters.py +80 -0
  20. grasp_agents/openai/converters.py +170 -0
  21. grasp_agents/openai/message_converters.py +155 -0
  22. grasp_agents/openai/openai_llm.py +179 -0
  23. grasp_agents/openai/tool_converters.py +37 -0
  24. grasp_agents/printer.py +156 -0
  25. grasp_agents/prompt_builder.py +204 -0
  26. grasp_agents/run_context.py +90 -0
  27. grasp_agents/tool_orchestrator.py +181 -0
  28. grasp_agents/typing/__init__.py +0 -0
  29. grasp_agents/typing/completion.py +30 -0
  30. grasp_agents/typing/content.py +116 -0
  31. grasp_agents/typing/converters.py +118 -0
  32. grasp_agents/typing/io.py +32 -0
  33. grasp_agents/typing/message.py +130 -0
  34. grasp_agents/typing/tool.py +52 -0
  35. grasp_agents/usage_tracker.py +99 -0
  36. grasp_agents/utils.py +151 -0
  37. grasp_agents/workflow/__init__.py +0 -0
  38. grasp_agents/workflow/looped_agent.py +113 -0
  39. grasp_agents/workflow/sequential_agent.py +57 -0
  40. grasp_agents/workflow/workflow_agent.py +69 -0
  41. grasp_agents-0.1.5.dist-info/METADATA +14 -0
  42. grasp_agents-0.1.5.dist-info/RECORD +44 -0
  43. grasp_agents-0.1.5.dist-info/WHEEL +4 -0
  44. grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
grasp_agents/memory.py ADDED
@@ -0,0 +1,150 @@
1
+ import logging
2
+ from collections.abc import Iterator, Sequence
3
+ from copy import deepcopy
4
+
5
+ from .typing.io import LLMPrompt
6
+ from .typing.message import Conversation, Message, SystemMessage
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class MessageHistory:
12
+ def __init__(self, sys_prompt: LLMPrompt | None = None) -> None:
13
+ self._sys_prompt = sys_prompt
14
+ self._batched_conversations: list[Conversation]
15
+ self.reset()
16
+
17
+ @property
18
+ def sys_prompt(self) -> LLMPrompt | None:
19
+ return self._sys_prompt
20
+
21
+ def add_message_batch(self, message_batch: Sequence[Message]) -> None:
22
+ """
23
+ Adds a batch of messages to the current batched conversations.
24
+ This method verifies that the size of the input message batch matches
25
+ the expected batch size (self.batch_size).
26
+ If there is a mismatch, the method adjusts by duplicating either
27
+ the message or the conversation as necessary:
28
+
29
+ - If the message batch contains exactly one message and
30
+ self.batch_size > 1, the single message is duplicated to match
31
+ the batch size.
32
+ - If the message batch contains multiple messages but
33
+ self.batch_size == 1, the entire conversation is duplicated to
34
+ accommodate each message in the batch.
35
+ - If the message batch size does not match self.batch_size and none of
36
+ the above adjustments apply, a ValueError is raised.
37
+
38
+ Afterwards, each message in the batch is appended to its corresponding
39
+ conversation in the batched conversations.
40
+
41
+ Args:
42
+ message_batch: A sequence of Message objects
43
+ representing the batch of messages to be added. Must align with
44
+ or be adjusted to match the current batch size.
45
+
46
+ Raises:
47
+ ValueError: If the message batch size does not match the current
48
+ batch size and cannot be automatically adjusted.
49
+
50
+ """
51
+ message_batch_size = len(message_batch)
52
+
53
+ if message_batch_size == 1 and self.batch_size > 1:
54
+ logger.info(
55
+ "Message batch size is 1, current batch size is "
56
+ f"{self.batch_size}: duplicating the message to match the "
57
+ "current batch size"
58
+ )
59
+ message_batch = self._duplicate_message_to_current_batch_size(message_batch)
60
+ message_batch_size = self.batch_size
61
+ elif message_batch_size > 1 and self.batch_size == 1:
62
+ logger.info(
63
+ f"Message batch size is {len(message_batch)}, current batch "
64
+ "size is 1: duplicating the conversation to match the message "
65
+ "batch size"
66
+ )
67
+ self._duplicate_conversation_to_message_batch_size(message_batch_size)
68
+ elif message_batch_size != self.batch_size:
69
+ raise ValueError(
70
+ f"Message batch size {message_batch_size} does not match "
71
+ f"current batch size {self.batch_size}"
72
+ )
73
+
74
+ for batch_id in range(message_batch_size):
75
+ self._batched_conversations[batch_id].append(message_batch[batch_id])
76
+
77
+ def add_message_batches(self, message_batches: Sequence[Sequence[Message]]) -> None:
78
+ for message_batch in message_batches:
79
+ self.add_message_batch(message_batch)
80
+
81
+ def add_message(self, message: Message) -> None:
82
+ for conversation in self._batched_conversations:
83
+ conversation.append(message)
84
+
85
+ def add_messages(self, messages: Sequence[Message]) -> None:
86
+ for message in messages:
87
+ self.add_message(message)
88
+
89
+ def __len__(self) -> int:
90
+ return len(self._batched_conversations[0])
91
+
92
+ def __repr__(self) -> str:
93
+ return f"{self.__class__.__name__}(len={len(self)}; bs={self.batch_size})"
94
+
95
+ def __getitem__(self, idx: int) -> tuple[Message, ...]:
96
+ return tuple(conversation[idx] for conversation in self._batched_conversations)
97
+
98
+ def __iter__(self) -> Iterator[tuple[Message, ...]]:
99
+ for idx in range(len(self)):
100
+ yield tuple(
101
+ conversation[idx] for conversation in self._batched_conversations
102
+ )
103
+
104
+ def _duplicate_message_to_current_batch_size(
105
+ self, message_batch: Sequence[Message]
106
+ ) -> Sequence[Message]:
107
+ assert len(message_batch) == 1, (
108
+ "Message batch size must be 1 to duplicate to current batch size"
109
+ )
110
+
111
+ return [deepcopy(message_batch[0]) for _ in range(self.batch_size)]
112
+
113
+ def _duplicate_conversation_to_message_batch_size(
114
+ self, target_batch_size: int
115
+ ) -> None:
116
+ assert self.batch_size == 1, "Batch size must be 1 to duplicate conversation"
117
+ self._batched_conversations = [
118
+ deepcopy(self._batched_conversations[0]) for _ in range(target_batch_size)
119
+ ]
120
+
121
+ @property
122
+ def batched_conversations(self) -> list[Conversation]:
123
+ return self._batched_conversations
124
+
125
+ @property
126
+ def batch_size(self) -> int:
127
+ return len(self._batched_conversations)
128
+
129
+ def reset(
130
+ self, sys_prompt: LLMPrompt | None = None, *, batch_size: int = 1
131
+ ) -> None:
132
+ if sys_prompt is not None:
133
+ self._sys_prompt = sys_prompt
134
+
135
+ conv: Conversation
136
+ if self._sys_prompt is not None:
137
+ conv = [SystemMessage(content=self._sys_prompt)]
138
+ else:
139
+ conv = []
140
+
141
+ self._batched_conversations = [deepcopy(conv) for _ in range(batch_size)]
142
+
143
+ def erase(self) -> None:
144
+ self._batched_conversations = [[]]
145
+
146
+ # def get_batch(self, batch_id: int) -> list[Message]:
147
+ # return self._batched_conversations[batch_id]
148
+
149
+ # def iterate_conversations(self) -> Iterator[list[Message]]:
150
+ # return iter(self._batched_conversations)
@@ -0,0 +1,83 @@
1
+ # pyright: reportUnusedImport=false
2
+
3
+ from openai._streaming import (
4
+ AsyncStream as ChatCompletionAsyncStream, # type: ignore[import] # noqa: PLC2701
5
+ )
6
+ from openai.types import CompletionUsage as ChatCompletionUsage
7
+ from openai.types.chat.chat_completion import ChatCompletion
8
+ from openai.types.chat.chat_completion_assistant_message_param import (
9
+ ChatCompletionAssistantMessageParam,
10
+ )
11
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
12
+ from openai.types.chat.chat_completion_content_part_image_param import (
13
+ ChatCompletionContentPartImageParam,
14
+ )
15
+ from openai.types.chat.chat_completion_content_part_image_param import (
16
+ ImageURL as ChatCompletionImageURL,
17
+ )
18
+ from openai.types.chat.chat_completion_content_part_param import (
19
+ ChatCompletionContentPartParam,
20
+ )
21
+ from openai.types.chat.chat_completion_content_part_text_param import (
22
+ ChatCompletionContentPartTextParam,
23
+ )
24
+ from openai.types.chat.chat_completion_developer_message_param import (
25
+ ChatCompletionDeveloperMessageParam,
26
+ )
27
+ from openai.types.chat.chat_completion_function_message_param import (
28
+ ChatCompletionFunctionMessageParam,
29
+ )
30
+ from openai.types.chat.chat_completion_message import ChatCompletionMessage
31
+ from openai.types.chat.chat_completion_message_param import (
32
+ ChatCompletionMessageParam,
33
+ )
34
+ from openai.types.chat.chat_completion_message_tool_call_param import (
35
+ ChatCompletionMessageToolCallParam,
36
+ )
37
+ from openai.types.chat.chat_completion_message_tool_call_param import (
38
+ Function as ChatCompletionToolCallFunction,
39
+ )
40
+ from openai.types.chat.chat_completion_named_tool_choice_param import (
41
+ ChatCompletionNamedToolChoiceParam,
42
+ )
43
+ from openai.types.chat.chat_completion_named_tool_choice_param import (
44
+ Function as ChatCompletionNamedToolChoiceFunction,
45
+ )
46
+ from openai.types.chat.chat_completion_prediction_content_param import (
47
+ ChatCompletionPredictionContentParam,
48
+ )
49
+ from openai.types.chat.chat_completion_stream_options_param import (
50
+ ChatCompletionStreamOptionsParam,
51
+ )
52
+ from openai.types.chat.chat_completion_system_message_param import (
53
+ ChatCompletionSystemMessageParam,
54
+ )
55
+ from openai.types.chat.chat_completion_tool_choice_option_param import (
56
+ ChatCompletionToolChoiceOptionParam,
57
+ )
58
+ from openai.types.chat.chat_completion_tool_message_param import (
59
+ ChatCompletionToolMessageParam,
60
+ )
61
+ from openai.types.chat.chat_completion_tool_param import (
62
+ ChatCompletionToolParam,
63
+ )
64
+ from openai.types.chat.chat_completion_user_message_param import (
65
+ ChatCompletionUserMessageParam,
66
+ )
67
+ from openai.types.chat.parsed_chat_completion import (
68
+ ParsedChatCompletion,
69
+ ParsedChatCompletionMessage,
70
+ ParsedChoice,
71
+ )
72
+ from openai.types.shared_params.function_definition import (
73
+ FunctionDefinition as ChatCompletionFunctionDefinition,
74
+ )
75
+ from openai.types.shared_params.response_format_json_object import (
76
+ ResponseFormatJSONObject,
77
+ )
78
+ from openai.types.shared_params.response_format_json_schema import (
79
+ ResponseFormatJSONSchema,
80
+ )
81
+ from openai.types.shared_params.response_format_text import (
82
+ ResponseFormatText,
83
+ )
@@ -0,0 +1,49 @@
1
+ from collections.abc import AsyncIterator
2
+
3
+ from ..typing.completion import Completion, CompletionChoice, CompletionChunk
4
+ from . import (
5
+ ChatCompletion,
6
+ ChatCompletionAsyncStream, # type: ignore[import]
7
+ ChatCompletionChunk,
8
+ )
9
+ from .message_converters import from_api_assistant_message
10
+
11
+
12
+ def from_api_completion(
13
+ api_completion: ChatCompletion, model_id: str | None = None
14
+ ) -> Completion:
15
+ choices: list[CompletionChoice] = []
16
+ # TODO: add custom error type
17
+ if api_completion.choices is None: # type: ignore
18
+ raise ValueError(
19
+ f"Completion API error: {getattr(api_completion, 'error', None)}"
20
+ )
21
+ for api_choice in api_completion.choices:
22
+ # TODO: no way to assign individual message usages when len(choices) > 1
23
+ message = from_api_assistant_message(
24
+ api_choice.message, api_completion.usage, model_id=model_id
25
+ )
26
+ finish_reason = api_choice.finish_reason
27
+ choices.append(CompletionChoice(message=message, finish_reason=finish_reason))
28
+
29
+ return Completion(choices=choices, model_id=model_id)
30
+
31
+
32
+ def to_api_completion(completion: Completion) -> ChatCompletion:
33
+ raise NotImplementedError
34
+
35
+
36
+ def from_api_completion_chunk(
37
+ api_completion_chunk: ChatCompletionChunk, model_id: str | None = None
38
+ ) -> CompletionChunk:
39
+ delta = api_completion_chunk.choices[0].delta.content
40
+
41
+ return CompletionChunk(delta=delta, model_id=model_id)
42
+
43
+
44
+ async def from_api_completion_chunk_iterator(
45
+ api_completion_chunk_iterator: ChatCompletionAsyncStream[ChatCompletionChunk],
46
+ model_id: str | None = None,
47
+ ) -> AsyncIterator[CompletionChunk]:
48
+ async for api_chunk in api_completion_chunk_iterator:
49
+ yield from_api_completion_chunk(api_chunk, model_id=model_id)
@@ -0,0 +1,80 @@
1
+ from collections.abc import Iterable
2
+
3
+ from ..typing.content import (
4
+ Content,
5
+ ContentPart,
6
+ ContentPartImage,
7
+ ContentPartText,
8
+ ImageData,
9
+ )
10
+ from . import (
11
+ ChatCompletionContentPartImageParam,
12
+ ChatCompletionContentPartParam,
13
+ ChatCompletionContentPartTextParam,
14
+ ChatCompletionImageURL,
15
+ )
16
+
17
+ BASE64_PREFIX = "data:image/jpeg;base64,"
18
+
19
+
20
+ def image_data_to_str(image_data: ImageData) -> str:
21
+ if image_data.type == "url":
22
+ return str(image_data.url)
23
+ if image_data.type == "base64":
24
+ return f"{BASE64_PREFIX}{image_data.base64}"
25
+ raise ValueError(f"Unsupported image data type: {image_data.type}")
26
+
27
+
28
+ def from_api_content(
29
+ api_content: str | Iterable[ChatCompletionContentPartParam],
30
+ ) -> "Content":
31
+ if isinstance(api_content, str):
32
+ return Content(parts=[ContentPartText(data=api_content)])
33
+
34
+ content_parts: list[ContentPart] = []
35
+ for api_content_part in api_content:
36
+ content_part: ContentPart
37
+
38
+ if api_content_part["type"] == "text":
39
+ text_data = api_content_part["text"]
40
+ content_part = ContentPartText(data=text_data)
41
+
42
+ elif api_content_part["type"] == "image_url":
43
+ url = api_content_part["image_url"]["url"]
44
+ detail = api_content_part["image_url"].get("detail")
45
+ if url.startswith(BASE64_PREFIX):
46
+ image_data = ImageData.from_base64(
47
+ base64_encoding=url.removeprefix(BASE64_PREFIX),
48
+ detail=detail,
49
+ )
50
+ else:
51
+ image_data = ImageData.from_url(img_url=url, detail=detail) # type: ignore
52
+ content_part = ContentPartImage(data=image_data)
53
+
54
+ content_parts.append(content_part) # type: ignore
55
+
56
+ return Content(parts=content_parts)
57
+
58
+
59
+ def to_api_content(
60
+ content: Content,
61
+ ) -> Iterable[ChatCompletionContentPartParam]:
62
+ api_content: list[ChatCompletionContentPartParam] = []
63
+ for content_part in content.parts:
64
+ api_content_part: ChatCompletionContentPartParam
65
+ if isinstance(content_part, ContentPartText):
66
+ api_content_part = ChatCompletionContentPartTextParam(
67
+ type="text",
68
+ text=content_part.data,
69
+ )
70
+ else:
71
+ api_content_part = ChatCompletionContentPartImageParam(
72
+ type="image_url",
73
+ image_url=ChatCompletionImageURL(
74
+ url=image_data_to_str(content_part.data),
75
+ detail=content_part.data.detail,
76
+ ),
77
+ )
78
+ api_content.append(api_content_part)
79
+
80
+ return api_content
@@ -0,0 +1,170 @@
1
+ from collections.abc import AsyncIterator, Iterable
2
+ from typing import Any
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from ..typing.completion import Completion, CompletionChunk
7
+ from ..typing.content import Content
8
+ from ..typing.converters import Converters
9
+ from ..typing.message import (
10
+ AssistantMessage,
11
+ SystemMessage,
12
+ ToolMessage,
13
+ UserMessage,
14
+ )
15
+ from ..typing.tool import BaseTool, ToolChoice
16
+ from . import (
17
+ ChatCompletion,
18
+ ChatCompletionAssistantMessageParam,
19
+ ChatCompletionAsyncStream, # type: ignore[import]
20
+ ChatCompletionChunk,
21
+ ChatCompletionContentPartParam,
22
+ # ChatCompletionDeveloperMessageParam,
23
+ ChatCompletionMessage,
24
+ ChatCompletionSystemMessageParam,
25
+ ChatCompletionToolChoiceOptionParam,
26
+ ChatCompletionToolMessageParam,
27
+ ChatCompletionToolParam,
28
+ ChatCompletionUsage,
29
+ ChatCompletionUserMessageParam,
30
+ )
31
+ from .completion_converters import (
32
+ from_api_completion,
33
+ from_api_completion_chunk,
34
+ from_api_completion_chunk_iterator,
35
+ to_api_completion,
36
+ )
37
+ from .content_converters import from_api_content, to_api_content
38
+ from .message_converters import (
39
+ from_api_assistant_message,
40
+ from_api_system_message,
41
+ from_api_tool_message,
42
+ from_api_user_message,
43
+ to_api_assistant_message,
44
+ to_api_system_message,
45
+ to_api_tool_message,
46
+ to_api_user_message,
47
+ )
48
+ from .tool_converters import to_api_tool, to_api_tool_choice
49
+
50
+
51
+ class OpenAIConverters(Converters):
52
+ @staticmethod
53
+ def to_system_message(
54
+ system_message: SystemMessage, **kwargs: Any
55
+ ) -> ChatCompletionSystemMessageParam:
56
+ return to_api_system_message(system_message, **kwargs)
57
+
58
+ @staticmethod
59
+ def from_system_message(
60
+ raw_message: ChatCompletionSystemMessageParam,
61
+ model_id: str | None = None,
62
+ **kwargs: Any,
63
+ ) -> SystemMessage:
64
+ return from_api_system_message(raw_message, model_id=model_id, **kwargs)
65
+
66
+ @staticmethod
67
+ def to_user_message(
68
+ user_message: UserMessage, **kwargs: Any
69
+ ) -> ChatCompletionUserMessageParam:
70
+ return to_api_user_message(user_message, **kwargs)
71
+
72
+ @staticmethod
73
+ def from_user_message(
74
+ raw_message: ChatCompletionUserMessageParam,
75
+ model_id: str | None = None,
76
+ **kwargs: Any,
77
+ ) -> UserMessage:
78
+ return from_api_user_message(raw_message, model_id=model_id, **kwargs)
79
+
80
+ @staticmethod
81
+ def to_assistant_message(
82
+ assistant_message: AssistantMessage, **kwargs: Any
83
+ ) -> ChatCompletionAssistantMessageParam:
84
+ return to_api_assistant_message(assistant_message, **kwargs)
85
+
86
+ @staticmethod
87
+ def from_assistant_message(
88
+ raw_message: ChatCompletionMessage,
89
+ raw_usage: ChatCompletionUsage,
90
+ model_id: str | None = None,
91
+ **kwargs: Any,
92
+ ) -> AssistantMessage:
93
+ return from_api_assistant_message(
94
+ raw_message, raw_usage, model_id=model_id, **kwargs
95
+ )
96
+
97
+ @staticmethod
98
+ def to_tool_message(
99
+ tool_message: ToolMessage, **kwargs: Any
100
+ ) -> ChatCompletionToolMessageParam:
101
+ return to_api_tool_message(tool_message, **kwargs)
102
+
103
+ @staticmethod
104
+ def from_tool_message(
105
+ raw_message: ChatCompletionToolMessageParam,
106
+ model_id: str | None = None,
107
+ **kwargs: Any,
108
+ ) -> ToolMessage:
109
+ return from_api_tool_message(raw_message, model_id=model_id, **kwargs)
110
+
111
+ @staticmethod
112
+ def to_tool(
113
+ tool: BaseTool[BaseModel, BaseModel, Any], **kwargs: Any
114
+ ) -> ChatCompletionToolParam:
115
+ return to_api_tool(tool, **kwargs)
116
+
117
+ @staticmethod
118
+ def to_tool_choice(
119
+ tool_choice: ToolChoice, **kwargs: Any
120
+ ) -> ChatCompletionToolChoiceOptionParam:
121
+ return to_api_tool_choice(tool_choice, **kwargs)
122
+
123
+ @staticmethod
124
+ def to_content(
125
+ content: Content, **kwargs: Any
126
+ ) -> Iterable[ChatCompletionContentPartParam]:
127
+ return to_api_content(content, **kwargs)
128
+
129
+ @staticmethod
130
+ def from_content(
131
+ raw_content: str | Iterable[ChatCompletionContentPartParam],
132
+ **kwargs: Any,
133
+ ) -> Content:
134
+ return from_api_content(raw_content, **kwargs)
135
+
136
+ @staticmethod
137
+ def to_completion(completion: Completion, **kwargs: Any) -> ChatCompletion:
138
+ return to_api_completion(completion, **kwargs)
139
+
140
+ @staticmethod
141
+ def from_completion(
142
+ raw_completion: ChatCompletion,
143
+ model_id: str | None = None,
144
+ **kwargs: Any,
145
+ ) -> Completion:
146
+ return from_api_completion(raw_completion, model_id=model_id, **kwargs)
147
+
148
+ @staticmethod
149
+ def to_completion_chunk(
150
+ chunk: CompletionChunk, **kwargs: Any
151
+ ) -> ChatCompletionChunk:
152
+ raise NotImplementedError
153
+
154
+ @staticmethod
155
+ def from_completion_chunk(
156
+ raw_chunk: ChatCompletionChunk,
157
+ model_id: str | None = None,
158
+ **kwargs: Any,
159
+ ) -> CompletionChunk:
160
+ return from_api_completion_chunk(raw_chunk, model_id=model_id, **kwargs)
161
+
162
+ @staticmethod
163
+ def from_completion_chunk_iterator( # type: ignore[override]
164
+ raw_chunk_iterator: ChatCompletionAsyncStream[ChatCompletionChunk],
165
+ model_id: str | None = None,
166
+ **kwargs: Any,
167
+ ) -> AsyncIterator[CompletionChunk]:
168
+ return from_api_completion_chunk_iterator(
169
+ raw_chunk_iterator, model_id=model_id, **kwargs
170
+ )
@@ -0,0 +1,155 @@
1
+ from typing import TypeAlias
2
+
3
+ from ..typing.message import (
4
+ AssistantMessage,
5
+ SystemMessage,
6
+ ToolMessage,
7
+ Usage,
8
+ UserMessage,
9
+ )
10
+ from ..typing.tool import ToolCall
11
+ from . import (
12
+ ChatCompletionAssistantMessageParam,
13
+ ChatCompletionDeveloperMessageParam,
14
+ ChatCompletionFunctionMessageParam,
15
+ ChatCompletionMessage,
16
+ ChatCompletionMessageToolCallParam,
17
+ ChatCompletionSystemMessageParam,
18
+ ChatCompletionToolCallFunction,
19
+ ChatCompletionToolMessageParam,
20
+ ChatCompletionUsage,
21
+ ChatCompletionUserMessageParam,
22
+ )
23
+ from .content_converters import from_api_content, to_api_content
24
+
25
+ OpenAIMessage: TypeAlias = (
26
+ ChatCompletionAssistantMessageParam
27
+ | ChatCompletionToolMessageParam
28
+ | ChatCompletionUserMessageParam
29
+ | ChatCompletionDeveloperMessageParam
30
+ | ChatCompletionSystemMessageParam
31
+ | ChatCompletionFunctionMessageParam
32
+ )
33
+
34
+
35
+ def from_api_user_message(
36
+ api_message: ChatCompletionUserMessageParam, model_id: str | None = None
37
+ ) -> UserMessage:
38
+ content = from_api_content(api_message["content"])
39
+
40
+ return UserMessage(content=content, model_id=model_id)
41
+
42
+
43
+ def to_api_user_message(message: UserMessage) -> ChatCompletionUserMessageParam:
44
+ api_content = to_api_content(message.content)
45
+
46
+ return ChatCompletionUserMessageParam(role="user", content=api_content)
47
+
48
+
49
+ def from_api_assistant_message(
50
+ api_message: ChatCompletionMessage,
51
+ api_usage: ChatCompletionUsage | None = None,
52
+ model_id: str | None = None,
53
+ ) -> AssistantMessage:
54
+ content = api_message.content or ""
55
+ assert isinstance(content, str), (
56
+ "Only string content is currently supported in assistant messages"
57
+ )
58
+
59
+ usage = None
60
+ if api_usage is not None:
61
+ reasoning_tokens = None
62
+ cached_tokens = None
63
+
64
+ if api_usage.completion_tokens_details is not None:
65
+ reasoning_tokens = api_usage.completion_tokens_details.reasoning_tokens
66
+ if api_usage.prompt_tokens_details is not None:
67
+ cached_tokens = api_usage.prompt_tokens_details.cached_tokens
68
+
69
+ input_tokens = api_usage.prompt_tokens - (cached_tokens or 0)
70
+ output_tokens = api_usage.completion_tokens - (reasoning_tokens or 0)
71
+
72
+ usage = Usage(
73
+ input_tokens=input_tokens,
74
+ output_tokens=output_tokens,
75
+ reasoning_tokens=reasoning_tokens,
76
+ cached_tokens=cached_tokens,
77
+ )
78
+
79
+ tool_calls = None
80
+ if api_message.tool_calls is not None:
81
+ tool_calls = [
82
+ ToolCall(
83
+ id=tool_call.id,
84
+ tool_name=tool_call.function.name,
85
+ tool_arguments=tool_call.function.arguments,
86
+ )
87
+ for tool_call in api_message.tool_calls
88
+ ]
89
+
90
+ return AssistantMessage(
91
+ content=content,
92
+ usage=usage,
93
+ tool_calls=tool_calls,
94
+ refusal=api_message.refusal,
95
+ model_id=model_id,
96
+ )
97
+
98
+
99
+ def to_api_assistant_message(
100
+ message: AssistantMessage,
101
+ ) -> ChatCompletionAssistantMessageParam:
102
+ api_tool_calls = None
103
+ if message.tool_calls is not None:
104
+ api_tool_calls = [
105
+ ChatCompletionMessageToolCallParam(
106
+ type="function",
107
+ id=tool_call.id,
108
+ function=ChatCompletionToolCallFunction(
109
+ name=tool_call.tool_name,
110
+ arguments=tool_call.tool_arguments,
111
+ ),
112
+ )
113
+ for tool_call in message.tool_calls
114
+ ]
115
+
116
+ return ChatCompletionAssistantMessageParam(
117
+ role="assistant",
118
+ content=message.content,
119
+ tool_calls=api_tool_calls, # type: ignore
120
+ refusal=message.refusal,
121
+ )
122
+
123
+
124
+ def from_api_system_message(
125
+ api_message: ChatCompletionSystemMessageParam,
126
+ model_id: str | None = None,
127
+ ) -> SystemMessage:
128
+ return SystemMessage(content=api_message["content"], model_id=model_id) # type: ignore
129
+
130
+
131
+ def to_api_system_message(
132
+ message: SystemMessage,
133
+ ) -> ChatCompletionSystemMessageParam:
134
+ return ChatCompletionSystemMessageParam(role="system", content=message.content)
135
+ # return ChatCompletionSystemMessageParam(
136
+ # role="system", content=message.content
137
+ # )
138
+
139
+
140
+ def from_api_tool_message(
141
+ api_message: ChatCompletionToolMessageParam, model_id: str | None = None
142
+ ) -> ToolMessage:
143
+ return ToolMessage(
144
+ content=api_message["content"], # type: ignore
145
+ tool_call_id=api_message["tool_call_id"],
146
+ model_id=model_id,
147
+ )
148
+
149
+
150
+ def to_api_tool_message(message: ToolMessage) -> ChatCompletionToolMessageParam:
151
+ return ChatCompletionToolMessageParam(
152
+ role="tool",
153
+ content=message.content,
154
+ tool_call_id=message.tool_call_id,
155
+ )