grasp_agents 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -224
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +23 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +233 -73
- grasp_agents/processor.py +229 -91
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/METADATA +7 -6
- grasp_agents-0.5.1.dist-info/RECORD +57 -0
- grasp_agents-0.4.7.dist-info/RECORD +0 -50
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,138 @@
|
|
1
|
+
from collections.abc import Iterable
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
|
6
|
+
from ..openai.content_converters import from_api_content, to_api_content
|
7
|
+
from ..openai.message_converters import (
|
8
|
+
from_api_system_message,
|
9
|
+
from_api_tool_message,
|
10
|
+
from_api_user_message,
|
11
|
+
to_api_system_message,
|
12
|
+
to_api_tool_message,
|
13
|
+
to_api_user_message,
|
14
|
+
)
|
15
|
+
from ..openai.tool_converters import to_api_tool, to_api_tool_choice
|
16
|
+
from ..typing.completion import Completion, Usage
|
17
|
+
from ..typing.completion_chunk import CompletionChunk
|
18
|
+
from ..typing.content import Content
|
19
|
+
from ..typing.converters import Converters
|
20
|
+
from ..typing.message import AssistantMessage, SystemMessage, ToolMessage, UserMessage
|
21
|
+
from ..typing.tool import BaseTool, ToolChoice
|
22
|
+
from . import (
|
23
|
+
LiteLLMCompletion,
|
24
|
+
LiteLLMCompletionChunk,
|
25
|
+
LiteLLMCompletionMessage,
|
26
|
+
LiteLLMUsage,
|
27
|
+
OpenAIContentPartParam,
|
28
|
+
OpenAISystemMessageParam,
|
29
|
+
OpenAIToolChoiceOptionParam,
|
30
|
+
OpenAIToolMessageParam,
|
31
|
+
OpenAIToolParam,
|
32
|
+
OpenAIUserMessageParam,
|
33
|
+
)
|
34
|
+
from .completion_chunk_converters import from_api_completion_chunk
|
35
|
+
from .completion_converters import (
|
36
|
+
from_api_completion,
|
37
|
+
from_api_completion_usage,
|
38
|
+
to_api_completion,
|
39
|
+
)
|
40
|
+
from .message_converters import from_api_assistant_message, to_api_assistant_message
|
41
|
+
|
42
|
+
|
43
|
+
class LiteLLMConverters(Converters):
|
44
|
+
@staticmethod
|
45
|
+
def to_completion(completion: Completion, **kwargs: Any) -> LiteLLMCompletion:
|
46
|
+
return to_api_completion(completion, **kwargs)
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def from_completion(
|
50
|
+
raw_completion: LiteLLMCompletion, name: str | None = None, **kwargs: Any
|
51
|
+
) -> Completion:
|
52
|
+
return from_api_completion(raw_completion, name=name, **kwargs)
|
53
|
+
|
54
|
+
@staticmethod
|
55
|
+
def to_completion_chunk(
|
56
|
+
chunk: CompletionChunk, **kwargs: Any
|
57
|
+
) -> LiteLLMCompletionChunk:
|
58
|
+
raise NotImplementedError
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def from_completion_chunk(
|
62
|
+
raw_chunk: LiteLLMCompletionChunk, name: str | None = None, **kwargs: Any
|
63
|
+
) -> CompletionChunk:
|
64
|
+
return from_api_completion_chunk(raw_chunk, name=name, **kwargs)
|
65
|
+
|
66
|
+
@staticmethod
|
67
|
+
def from_assistant_message(
|
68
|
+
raw_message: LiteLLMCompletionMessage, name: str | None = None, **kwargs: Any
|
69
|
+
) -> AssistantMessage:
|
70
|
+
return from_api_assistant_message(raw_message, name=name, **kwargs)
|
71
|
+
|
72
|
+
@staticmethod
|
73
|
+
def to_assistant_message(
|
74
|
+
assistant_message: AssistantMessage, **kwargs: Any
|
75
|
+
) -> LiteLLMCompletionMessage:
|
76
|
+
return to_api_assistant_message(assistant_message, **kwargs)
|
77
|
+
|
78
|
+
# The remaining converters are the same as OpenAIConverters
|
79
|
+
|
80
|
+
@staticmethod
|
81
|
+
def to_system_message(
|
82
|
+
system_message: SystemMessage, **kwargs: Any
|
83
|
+
) -> OpenAISystemMessageParam:
|
84
|
+
return to_api_system_message(system_message, **kwargs)
|
85
|
+
|
86
|
+
@staticmethod
|
87
|
+
def from_system_message(
|
88
|
+
raw_message: OpenAISystemMessageParam, name: str | None = None, **kwargs: Any
|
89
|
+
) -> SystemMessage:
|
90
|
+
return from_api_system_message(raw_message, name=name, **kwargs)
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def to_user_message(
|
94
|
+
user_message: UserMessage, **kwargs: Any
|
95
|
+
) -> OpenAIUserMessageParam:
|
96
|
+
return to_api_user_message(user_message, **kwargs)
|
97
|
+
|
98
|
+
@staticmethod
|
99
|
+
def from_user_message(
|
100
|
+
raw_message: OpenAIUserMessageParam, name: str | None = None, **kwargs: Any
|
101
|
+
) -> UserMessage:
|
102
|
+
return from_api_user_message(raw_message, name=name, **kwargs)
|
103
|
+
|
104
|
+
@staticmethod
|
105
|
+
def from_completion_usage(raw_usage: LiteLLMUsage, **kwargs: Any) -> Usage:
|
106
|
+
return from_api_completion_usage(raw_usage, **kwargs)
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def to_tool_message(
|
110
|
+
tool_message: ToolMessage, **kwargs: Any
|
111
|
+
) -> OpenAIToolMessageParam:
|
112
|
+
return to_api_tool_message(tool_message, **kwargs)
|
113
|
+
|
114
|
+
@staticmethod
|
115
|
+
def from_tool_message(
|
116
|
+
raw_message: OpenAIToolMessageParam, name: str | None = None, **kwargs: Any
|
117
|
+
) -> ToolMessage:
|
118
|
+
return from_api_tool_message(raw_message, name=name, **kwargs)
|
119
|
+
|
120
|
+
@staticmethod
|
121
|
+
def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> OpenAIToolParam:
|
122
|
+
return to_api_tool(tool, **kwargs)
|
123
|
+
|
124
|
+
@staticmethod
|
125
|
+
def to_tool_choice(
|
126
|
+
tool_choice: ToolChoice, **kwargs: Any
|
127
|
+
) -> OpenAIToolChoiceOptionParam:
|
128
|
+
return to_api_tool_choice(tool_choice, **kwargs)
|
129
|
+
|
130
|
+
@staticmethod
|
131
|
+
def to_content(content: Content, **kwargs: Any) -> Iterable[OpenAIContentPartParam]:
|
132
|
+
return to_api_content(content, **kwargs)
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def from_content(
|
136
|
+
raw_content: str | Iterable[OpenAIContentPartParam], **kwargs: Any
|
137
|
+
) -> Content:
|
138
|
+
return from_api_content(raw_content, **kwargs)
|
@@ -0,0 +1,210 @@
|
|
1
|
+
import logging
|
2
|
+
from collections.abc import AsyncIterator, Mapping
|
3
|
+
from typing import Any, cast
|
4
|
+
|
5
|
+
import litellm
|
6
|
+
from litellm.litellm_core_utils.get_supported_openai_params import (
|
7
|
+
get_supported_openai_params, # type: ignore[no-redef]
|
8
|
+
)
|
9
|
+
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
10
|
+
from litellm.types.llms.anthropic import AnthropicThinkingParam
|
11
|
+
from litellm.utils import (
|
12
|
+
supports_parallel_function_calling,
|
13
|
+
supports_prompt_caching,
|
14
|
+
supports_reasoning,
|
15
|
+
supports_response_schema,
|
16
|
+
supports_tool_choice,
|
17
|
+
)
|
18
|
+
|
19
|
+
# from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
|
20
|
+
from pydantic import BaseModel
|
21
|
+
|
22
|
+
from ..cloud_llm import APIProvider, CloudLLM, LLMRateLimiter
|
23
|
+
from ..openai.openai_llm import OpenAILLMSettings
|
24
|
+
from ..typing.tool import BaseTool
|
25
|
+
from . import (
|
26
|
+
LiteLLMCompletion,
|
27
|
+
LiteLLMCompletionChunk,
|
28
|
+
OpenAIMessageParam,
|
29
|
+
OpenAIToolChoiceOptionParam,
|
30
|
+
OpenAIToolParam,
|
31
|
+
)
|
32
|
+
from .converters import LiteLLMConverters
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
class LiteLLMSettings(OpenAILLMSettings, total=False):
|
38
|
+
thinking: AnthropicThinkingParam | None
|
39
|
+
|
40
|
+
|
41
|
+
class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
# Base LLM args
|
45
|
+
model_name: str,
|
46
|
+
model_id: str | None = None,
|
47
|
+
llm_settings: LiteLLMSettings | None = None,
|
48
|
+
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
49
|
+
response_schema: Any | None = None,
|
50
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
51
|
+
apply_response_schema_via_provider: bool = False,
|
52
|
+
# LLM provider
|
53
|
+
api_provider: APIProvider | None = None,
|
54
|
+
# deployment_id: str | None = None,
|
55
|
+
# api_version: str | None = None,
|
56
|
+
# Connection settings
|
57
|
+
timeout: float | None = None,
|
58
|
+
max_client_retries: int = 2,
|
59
|
+
# Rate limiting
|
60
|
+
rate_limiter: LLMRateLimiter | None = None,
|
61
|
+
# Drop unsupported LLM settings
|
62
|
+
drop_params: bool = True,
|
63
|
+
additional_drop_params: list[str] | None = None,
|
64
|
+
allowed_openai_params: list[str] | None = None,
|
65
|
+
# Mock LLM response for testing
|
66
|
+
mock_response: str | None = None,
|
67
|
+
# LLM response retries: try to regenerate to pass validation
|
68
|
+
max_response_retries: int = 1,
|
69
|
+
) -> None:
|
70
|
+
self._lite_llm_completion_params: dict[str, Any] = {
|
71
|
+
"max_retries": max_client_retries,
|
72
|
+
"timeout": timeout,
|
73
|
+
"drop_params": drop_params,
|
74
|
+
"additional_drop_params": additional_drop_params,
|
75
|
+
"allowed_openai_params": allowed_openai_params,
|
76
|
+
"mock_response": mock_response,
|
77
|
+
# "deployment_id": deployment_id,
|
78
|
+
# "api_version": api_version,
|
79
|
+
}
|
80
|
+
|
81
|
+
if model_name in litellm.get_valid_models(): # type: ignore[no-untyped-call]
|
82
|
+
_, provider_name, _, _ = litellm.get_llm_provider(model_name) # type: ignore[no-untyped-call]
|
83
|
+
api_provider = APIProvider(name=provider_name)
|
84
|
+
elif api_provider is not None:
|
85
|
+
self._lite_llm_completion_params["api_key"] = api_provider.get("api_key")
|
86
|
+
self._lite_llm_completion_params["api_base"] = api_provider.get("api_base")
|
87
|
+
elif api_provider is None:
|
88
|
+
raise ValueError(
|
89
|
+
f"Model '{model_name}' is not supported by LiteLLM and no API provider "
|
90
|
+
"was specified. Please provide a valid API provider or use a different "
|
91
|
+
"model."
|
92
|
+
)
|
93
|
+
super().__init__(
|
94
|
+
model_name=model_name,
|
95
|
+
model_id=model_id,
|
96
|
+
llm_settings=llm_settings,
|
97
|
+
converters=LiteLLMConverters(),
|
98
|
+
tools=tools,
|
99
|
+
response_schema=response_schema,
|
100
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
101
|
+
apply_response_schema_via_provider=apply_response_schema_via_provider,
|
102
|
+
api_provider=api_provider,
|
103
|
+
rate_limiter=rate_limiter,
|
104
|
+
max_client_retries=max_client_retries,
|
105
|
+
max_response_retries=max_response_retries,
|
106
|
+
)
|
107
|
+
|
108
|
+
if self._apply_response_schema_via_provider:
|
109
|
+
if self._tools:
|
110
|
+
for tool in self._tools.values():
|
111
|
+
tool.strict = True
|
112
|
+
if not self.supports_response_schema:
|
113
|
+
raise ValueError(
|
114
|
+
f"Model '{self._model_name}' does not support response schema "
|
115
|
+
"natively. Please set `apply_response_schema_via_provider=False`"
|
116
|
+
)
|
117
|
+
|
118
|
+
def get_supported_openai_params(self) -> list[Any] | None:
|
119
|
+
return get_supported_openai_params( # type: ignore[no-untyped-call]
|
120
|
+
model=self._model_name, request_type="chat_completion"
|
121
|
+
)
|
122
|
+
|
123
|
+
@property
|
124
|
+
def supports_reasoning(self) -> bool:
|
125
|
+
return supports_reasoning(model=self._model_name)
|
126
|
+
|
127
|
+
@property
|
128
|
+
def supports_parallel_function_calling(self) -> bool:
|
129
|
+
return supports_parallel_function_calling(model=self._model_name)
|
130
|
+
|
131
|
+
@property
|
132
|
+
def supports_prompt_caching(self) -> bool:
|
133
|
+
return supports_prompt_caching(model=self._model_name)
|
134
|
+
|
135
|
+
@property
|
136
|
+
def supports_response_schema(self) -> bool:
|
137
|
+
return supports_response_schema(model=self._model_name)
|
138
|
+
|
139
|
+
@property
|
140
|
+
def supports_tool_choice(self) -> bool:
|
141
|
+
return supports_tool_choice(model=self._model_name)
|
142
|
+
|
143
|
+
# # client
|
144
|
+
# model_list: Optional[list] = (None,) # pass in a list of api_base,keys, etc.
|
145
|
+
|
146
|
+
async def _get_completion(
|
147
|
+
self,
|
148
|
+
api_messages: list[OpenAIMessageParam],
|
149
|
+
api_tools: list[OpenAIToolParam] | None = None,
|
150
|
+
api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
|
151
|
+
api_response_schema: type | None = None,
|
152
|
+
n_choices: int | None = None,
|
153
|
+
**api_llm_settings: Any,
|
154
|
+
) -> LiteLLMCompletion:
|
155
|
+
completion = await litellm.acompletion( # type: ignore[no-untyped-call]
|
156
|
+
model=self._model_name,
|
157
|
+
messages=api_messages,
|
158
|
+
tools=api_tools,
|
159
|
+
tool_choice=api_tool_choice, # type: ignore[arg-type]
|
160
|
+
response_format=api_response_schema,
|
161
|
+
n=n_choices,
|
162
|
+
stream=False,
|
163
|
+
**self._lite_llm_completion_params,
|
164
|
+
**api_llm_settings,
|
165
|
+
)
|
166
|
+
completion = cast("LiteLLMCompletion", completion)
|
167
|
+
|
168
|
+
# Should not be needed in litellm>=1.74
|
169
|
+
completion._hidden_params["response_cost"] = litellm.completion_cost(completion) # type: ignore[no-untyped-call]
|
170
|
+
|
171
|
+
return completion
|
172
|
+
|
173
|
+
async def _get_completion_stream( # type: ignore[no-untyped-def]
|
174
|
+
self,
|
175
|
+
api_messages: list[OpenAIMessageParam],
|
176
|
+
api_tools: list[OpenAIToolParam] | None = None,
|
177
|
+
api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
|
178
|
+
api_response_schema: type | None = None,
|
179
|
+
n_choices: int | None = None,
|
180
|
+
**api_llm_settings: Any,
|
181
|
+
) -> AsyncIterator[LiteLLMCompletionChunk]:
|
182
|
+
stream = await litellm.acompletion( # type: ignore[no-untyped-call]
|
183
|
+
model=self._model_name,
|
184
|
+
messages=api_messages,
|
185
|
+
tools=api_tools,
|
186
|
+
tool_choice=api_tool_choice, # type: ignore[arg-type]
|
187
|
+
response_format=api_response_schema,
|
188
|
+
stream=True,
|
189
|
+
n=n_choices,
|
190
|
+
**self._lite_llm_completion_params,
|
191
|
+
**api_llm_settings,
|
192
|
+
)
|
193
|
+
stream = cast("CustomStreamWrapper", stream)
|
194
|
+
|
195
|
+
async for completion_chunk in stream:
|
196
|
+
yield completion_chunk
|
197
|
+
|
198
|
+
def combine_completion_chunks(
|
199
|
+
self, completion_chunks: list[LiteLLMCompletionChunk]
|
200
|
+
) -> LiteLLMCompletion:
|
201
|
+
combined_chunk = cast(
|
202
|
+
"LiteLLMCompletion",
|
203
|
+
litellm.stream_chunk_builder(completion_chunks), # type: ignore[no-untyped-call]
|
204
|
+
)
|
205
|
+
# Should not be needed in litellm>=1.74
|
206
|
+
combined_chunk._hidden_params["response_cost"] = litellm.completion_cost( # type: ignore[no-untyped-call]
|
207
|
+
combined_chunk
|
208
|
+
)
|
209
|
+
|
210
|
+
return combined_chunk
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from ..typing.message import (
|
2
|
+
AssistantMessage,
|
3
|
+
)
|
4
|
+
from ..typing.tool import ToolCall
|
5
|
+
from . import LiteLLMCompletionMessage, LiteLLMFunction, LiteLLMToolCall
|
6
|
+
|
7
|
+
|
8
|
+
def from_api_assistant_message(
|
9
|
+
api_message: LiteLLMCompletionMessage, name: str | None = None
|
10
|
+
) -> AssistantMessage:
|
11
|
+
tool_calls = None
|
12
|
+
if api_message.tool_calls is not None:
|
13
|
+
tool_calls = [
|
14
|
+
ToolCall(
|
15
|
+
id=tool_call.id,
|
16
|
+
tool_name=tool_call.function.name, # type: ignore
|
17
|
+
tool_arguments=tool_call.function.arguments,
|
18
|
+
)
|
19
|
+
for tool_call in api_message.tool_calls
|
20
|
+
]
|
21
|
+
|
22
|
+
return AssistantMessage(
|
23
|
+
content=api_message.content,
|
24
|
+
tool_calls=tool_calls,
|
25
|
+
name=name,
|
26
|
+
thinking_blocks=getattr(api_message, "thinking_blocks", None),
|
27
|
+
reasoning_content=getattr(api_message, "reasoning_content", None),
|
28
|
+
annotations=getattr(api_message, "annotations", None),
|
29
|
+
provider_specific_fields=api_message.provider_specific_fields,
|
30
|
+
refusal=getattr(api_message, "refusal", None),
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
def to_api_assistant_message(
|
35
|
+
message: AssistantMessage,
|
36
|
+
) -> LiteLLMCompletionMessage:
|
37
|
+
api_tool_calls = None
|
38
|
+
if message.tool_calls is not None:
|
39
|
+
api_tool_calls = [
|
40
|
+
LiteLLMToolCall(
|
41
|
+
type="function",
|
42
|
+
id=tool_call.id,
|
43
|
+
function=LiteLLMFunction(
|
44
|
+
name=tool_call.tool_name,
|
45
|
+
arguments=tool_call.tool_arguments,
|
46
|
+
),
|
47
|
+
)
|
48
|
+
for tool_call in message.tool_calls
|
49
|
+
]
|
50
|
+
|
51
|
+
api_message = LiteLLMCompletionMessage(role="assistant", content=message.content)
|
52
|
+
|
53
|
+
if api_tool_calls:
|
54
|
+
api_message.tool_calls = api_tool_calls
|
55
|
+
|
56
|
+
for key in [
|
57
|
+
"thinking_blocks",
|
58
|
+
"reasoning_content",
|
59
|
+
"annotations",
|
60
|
+
"provider_specific_fields",
|
61
|
+
"refusal",
|
62
|
+
]:
|
63
|
+
if getattr(message, key):
|
64
|
+
api_message[key] = getattr(message, key)
|
65
|
+
|
66
|
+
return api_message
|
grasp_agents/llm.py
CHANGED
@@ -1,18 +1,25 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from collections.abc import AsyncIterator, Mapping
|
3
|
+
from collections.abc import AsyncIterator, Mapping, Sequence
|
4
4
|
from typing import Any, Generic, TypeVar, cast
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
|
-
from pydantic import BaseModel
|
7
|
+
from pydantic import BaseModel
|
8
8
|
from typing_extensions import TypedDict
|
9
9
|
|
10
|
-
from grasp_agents.utils import
|
10
|
+
from grasp_agents.utils import (
|
11
|
+
validate_obj_from_json_or_py_string,
|
12
|
+
validate_tagged_objs_from_json_or_py_string,
|
13
|
+
)
|
11
14
|
|
12
|
-
from .errors import
|
15
|
+
from .errors import (
|
16
|
+
JSONSchemaValidationError,
|
17
|
+
LLMResponseValidationError,
|
18
|
+
LLMToolCallValidationError,
|
19
|
+
)
|
13
20
|
from .typing.completion import Completion
|
14
21
|
from .typing.converters import Converters
|
15
|
-
from .typing.events import CompletionChunkEvent, CompletionEvent
|
22
|
+
from .typing.events import CompletionChunkEvent, CompletionEvent, LLMStreamingErrorEvent
|
16
23
|
from .typing.message import Messages
|
17
24
|
from .typing.tool import BaseTool, ToolChoice
|
18
25
|
|
@@ -38,8 +45,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
38
45
|
model_name: str | None = None,
|
39
46
|
model_id: str | None = None,
|
40
47
|
llm_settings: SettingsT_co | None = None,
|
41
|
-
tools:
|
42
|
-
|
48
|
+
tools: Sequence[BaseTool[BaseModel, Any, Any]] | None = None,
|
49
|
+
response_schema: Any | None = None,
|
50
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
43
51
|
**kwargs: Any,
|
44
52
|
) -> None:
|
45
53
|
super().__init__()
|
@@ -50,20 +58,13 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
50
58
|
self._tools = {t.name: t for t in tools} if tools else None
|
51
59
|
self._llm_settings: SettingsT_co = llm_settings or cast("SettingsT_co", {})
|
52
60
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
response_format: Any | Mapping[str, Any] | None = None,
|
61
|
-
) -> TypeAdapter[Any] | Mapping[str, TypeAdapter[Any]]:
|
62
|
-
if response_format is None:
|
63
|
-
return TypeAdapter(Any)
|
64
|
-
if isinstance(response_format, Mapping):
|
65
|
-
return {k: TypeAdapter(v) for k, v in response_format.items()} # type: ignore[return-value]
|
66
|
-
return TypeAdapter(response_format)
|
61
|
+
if response_schema and response_schema_by_xml_tag:
|
62
|
+
raise ValueError(
|
63
|
+
"Only one of response_schema and response_schema_by_xml_tag can be "
|
64
|
+
"provided, but not both."
|
65
|
+
)
|
66
|
+
self._response_schema = response_schema
|
67
|
+
self._response_schema_by_xml_tag = response_schema_by_xml_tag
|
67
68
|
|
68
69
|
@property
|
69
70
|
def model_id(self) -> str:
|
@@ -78,40 +79,59 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
78
79
|
return self._llm_settings
|
79
80
|
|
80
81
|
@property
|
81
|
-
def
|
82
|
-
return self.
|
82
|
+
def response_schema(self) -> Any | None:
|
83
|
+
return self._response_schema
|
83
84
|
|
84
|
-
@
|
85
|
-
def
|
86
|
-
self.
|
87
|
-
|
88
|
-
|
89
|
-
|
85
|
+
@response_schema.setter
|
86
|
+
def response_schema(self, response_schema: Any | None) -> None:
|
87
|
+
self._response_schema = response_schema
|
88
|
+
|
89
|
+
@property
|
90
|
+
def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
|
91
|
+
return self._response_schema_by_xml_tag
|
90
92
|
|
91
93
|
@property
|
92
94
|
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
93
95
|
return self._tools
|
94
96
|
|
95
97
|
@tools.setter
|
96
|
-
def tools(self, tools:
|
98
|
+
def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
97
99
|
self._tools = {t.name: t for t in tools} if tools else None
|
98
100
|
|
99
101
|
def __repr__(self) -> str:
|
100
|
-
return (
|
101
|
-
|
102
|
-
|
103
|
-
|
102
|
+
return f"{type(self).__name__}[{self.model_id}]; model_name={self._model_name})"
|
103
|
+
|
104
|
+
def _validate_response(self, completion: Completion) -> None:
|
105
|
+
parsing_params = {
|
106
|
+
"from_substring": False,
|
107
|
+
"strip_language_markdown": True,
|
108
|
+
}
|
109
|
+
try:
|
110
|
+
for message in completion.messages:
|
111
|
+
if not message.tool_calls:
|
112
|
+
if self._response_schema:
|
113
|
+
validate_obj_from_json_or_py_string(
|
114
|
+
message.content or "",
|
115
|
+
schema=self._response_schema,
|
116
|
+
**parsing_params,
|
117
|
+
)
|
104
118
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
119
|
+
elif self._response_schema_by_xml_tag:
|
120
|
+
validate_tagged_objs_from_json_or_py_string(
|
121
|
+
message.content or "",
|
122
|
+
schema_by_xml_tag=self._response_schema_by_xml_tag,
|
123
|
+
**parsing_params,
|
124
|
+
)
|
125
|
+
except JSONSchemaValidationError as exc:
|
126
|
+
raise LLMResponseValidationError(
|
127
|
+
exc.s, exc.schema, message=str(exc)
|
128
|
+
) from exc
|
113
129
|
|
114
130
|
def _validate_tool_calls(self, completion: Completion) -> None:
|
131
|
+
parsing_params = {
|
132
|
+
"from_substring": False,
|
133
|
+
"strip_language_markdown": True,
|
134
|
+
}
|
115
135
|
for message in completion.messages:
|
116
136
|
if message.tool_calls:
|
117
137
|
for tool_call in message.tool_calls:
|
@@ -120,14 +140,21 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
120
140
|
|
121
141
|
available_tool_names = list(self.tools) if self.tools else []
|
122
142
|
if tool_name not in available_tool_names or not self.tools:
|
123
|
-
raise
|
124
|
-
|
125
|
-
|
143
|
+
raise LLMToolCallValidationError(
|
144
|
+
tool_name,
|
145
|
+
tool_arguments,
|
146
|
+
message=f"Tool '{tool_name}' is not available in the LLM "
|
147
|
+
f"tools (available: {available_tool_names})",
|
126
148
|
)
|
127
149
|
tool = self.tools[tool_name]
|
128
|
-
|
129
|
-
|
130
|
-
|
150
|
+
try:
|
151
|
+
validate_obj_from_json_or_py_string(
|
152
|
+
tool_arguments, schema=tool.in_type, **parsing_params
|
153
|
+
)
|
154
|
+
except JSONSchemaValidationError as exc:
|
155
|
+
raise LLMToolCallValidationError(
|
156
|
+
tool_name, tool_arguments
|
157
|
+
) from exc
|
131
158
|
|
132
159
|
@abstractmethod
|
133
160
|
async def generate_completion(
|
@@ -136,6 +163,8 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
136
163
|
*,
|
137
164
|
tool_choice: ToolChoice | None = None,
|
138
165
|
n_choices: int | None = None,
|
166
|
+
proc_name: str | None = None,
|
167
|
+
call_id: str | None = None,
|
139
168
|
) -> Completion:
|
140
169
|
pass
|
141
170
|
|
@@ -146,5 +175,11 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
146
175
|
*,
|
147
176
|
tool_choice: ToolChoice | None = None,
|
148
177
|
n_choices: int | None = None,
|
149
|
-
|
178
|
+
proc_name: str | None = None,
|
179
|
+
call_id: str | None = None,
|
180
|
+
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
|
150
181
|
pass
|
182
|
+
|
183
|
+
@abstractmethod
|
184
|
+
def combine_completion_chunks(self, completion_chunks: list[Any]) -> Any:
|
185
|
+
raise NotImplementedError
|