grasp_agents 0.4.7__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -224
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +4 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +234 -72
  26. grasp_agents/processor.py +228 -88
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.0.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.7.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,138 @@
1
+ from collections.abc import Iterable
2
+ from typing import Any
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from ..openai.content_converters import from_api_content, to_api_content
7
+ from ..openai.message_converters import (
8
+ from_api_system_message,
9
+ from_api_tool_message,
10
+ from_api_user_message,
11
+ to_api_system_message,
12
+ to_api_tool_message,
13
+ to_api_user_message,
14
+ )
15
+ from ..openai.tool_converters import to_api_tool, to_api_tool_choice
16
+ from ..typing.completion import Completion, Usage
17
+ from ..typing.completion_chunk import CompletionChunk
18
+ from ..typing.content import Content
19
+ from ..typing.converters import Converters
20
+ from ..typing.message import AssistantMessage, SystemMessage, ToolMessage, UserMessage
21
+ from ..typing.tool import BaseTool, ToolChoice
22
+ from . import (
23
+ LiteLLMCompletion,
24
+ LiteLLMCompletionChunk,
25
+ LiteLLMCompletionMessage,
26
+ LiteLLMUsage,
27
+ OpenAIContentPartParam,
28
+ OpenAISystemMessageParam,
29
+ OpenAIToolChoiceOptionParam,
30
+ OpenAIToolMessageParam,
31
+ OpenAIToolParam,
32
+ OpenAIUserMessageParam,
33
+ )
34
+ from .completion_chunk_converters import from_api_completion_chunk
35
+ from .completion_converters import (
36
+ from_api_completion,
37
+ from_api_completion_usage,
38
+ to_api_completion,
39
+ )
40
+ from .message_converters import from_api_assistant_message, to_api_assistant_message
41
+
42
+
43
+ class LiteLLMConverters(Converters):
44
+ @staticmethod
45
+ def to_completion(completion: Completion, **kwargs: Any) -> LiteLLMCompletion:
46
+ return to_api_completion(completion, **kwargs)
47
+
48
+ @staticmethod
49
+ def from_completion(
50
+ raw_completion: LiteLLMCompletion, name: str | None = None, **kwargs: Any
51
+ ) -> Completion:
52
+ return from_api_completion(raw_completion, name=name, **kwargs)
53
+
54
+ @staticmethod
55
+ def to_completion_chunk(
56
+ chunk: CompletionChunk, **kwargs: Any
57
+ ) -> LiteLLMCompletionChunk:
58
+ raise NotImplementedError
59
+
60
+ @staticmethod
61
+ def from_completion_chunk(
62
+ raw_chunk: LiteLLMCompletionChunk, name: str | None = None, **kwargs: Any
63
+ ) -> CompletionChunk:
64
+ return from_api_completion_chunk(raw_chunk, name=name, **kwargs)
65
+
66
+ @staticmethod
67
+ def from_assistant_message(
68
+ raw_message: LiteLLMCompletionMessage, name: str | None = None, **kwargs: Any
69
+ ) -> AssistantMessage:
70
+ return from_api_assistant_message(raw_message, name=name, **kwargs)
71
+
72
+ @staticmethod
73
+ def to_assistant_message(
74
+ assistant_message: AssistantMessage, **kwargs: Any
75
+ ) -> LiteLLMCompletionMessage:
76
+ return to_api_assistant_message(assistant_message, **kwargs)
77
+
78
+ # The remaining converters are the same as OpenAIConverters
79
+
80
+ @staticmethod
81
+ def to_system_message(
82
+ system_message: SystemMessage, **kwargs: Any
83
+ ) -> OpenAISystemMessageParam:
84
+ return to_api_system_message(system_message, **kwargs)
85
+
86
+ @staticmethod
87
+ def from_system_message(
88
+ raw_message: OpenAISystemMessageParam, name: str | None = None, **kwargs: Any
89
+ ) -> SystemMessage:
90
+ return from_api_system_message(raw_message, name=name, **kwargs)
91
+
92
+ @staticmethod
93
+ def to_user_message(
94
+ user_message: UserMessage, **kwargs: Any
95
+ ) -> OpenAIUserMessageParam:
96
+ return to_api_user_message(user_message, **kwargs)
97
+
98
+ @staticmethod
99
+ def from_user_message(
100
+ raw_message: OpenAIUserMessageParam, name: str | None = None, **kwargs: Any
101
+ ) -> UserMessage:
102
+ return from_api_user_message(raw_message, name=name, **kwargs)
103
+
104
+ @staticmethod
105
+ def from_completion_usage(raw_usage: LiteLLMUsage, **kwargs: Any) -> Usage:
106
+ return from_api_completion_usage(raw_usage, **kwargs)
107
+
108
+ @staticmethod
109
+ def to_tool_message(
110
+ tool_message: ToolMessage, **kwargs: Any
111
+ ) -> OpenAIToolMessageParam:
112
+ return to_api_tool_message(tool_message, **kwargs)
113
+
114
+ @staticmethod
115
+ def from_tool_message(
116
+ raw_message: OpenAIToolMessageParam, name: str | None = None, **kwargs: Any
117
+ ) -> ToolMessage:
118
+ return from_api_tool_message(raw_message, name=name, **kwargs)
119
+
120
+ @staticmethod
121
+ def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> OpenAIToolParam:
122
+ return to_api_tool(tool, **kwargs)
123
+
124
+ @staticmethod
125
+ def to_tool_choice(
126
+ tool_choice: ToolChoice, **kwargs: Any
127
+ ) -> OpenAIToolChoiceOptionParam:
128
+ return to_api_tool_choice(tool_choice, **kwargs)
129
+
130
+ @staticmethod
131
+ def to_content(content: Content, **kwargs: Any) -> Iterable[OpenAIContentPartParam]:
132
+ return to_api_content(content, **kwargs)
133
+
134
+ @staticmethod
135
+ def from_content(
136
+ raw_content: str | Iterable[OpenAIContentPartParam], **kwargs: Any
137
+ ) -> Content:
138
+ return from_api_content(raw_content, **kwargs)
@@ -0,0 +1,210 @@
1
+ import logging
2
+ from collections.abc import AsyncIterator, Mapping
3
+ from typing import Any, cast
4
+
5
+ import litellm
6
+ from litellm.litellm_core_utils.get_supported_openai_params import (
7
+ get_supported_openai_params, # type: ignore[no-redef]
8
+ )
9
+ from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
10
+ from litellm.types.llms.anthropic import AnthropicThinkingParam
11
+ from litellm.utils import (
12
+ supports_parallel_function_calling,
13
+ supports_prompt_caching,
14
+ supports_reasoning,
15
+ supports_response_schema,
16
+ supports_tool_choice,
17
+ )
18
+
19
+ # from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
20
+ from pydantic import BaseModel
21
+
22
+ from ..cloud_llm import APIProvider, CloudLLM, LLMRateLimiter
23
+ from ..openai.openai_llm import OpenAILLMSettings
24
+ from ..typing.tool import BaseTool
25
+ from . import (
26
+ LiteLLMCompletion,
27
+ LiteLLMCompletionChunk,
28
+ OpenAIMessageParam,
29
+ OpenAIToolChoiceOptionParam,
30
+ OpenAIToolParam,
31
+ )
32
+ from .converters import LiteLLMConverters
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class LiteLLMSettings(OpenAILLMSettings, total=False):
38
+ thinking: AnthropicThinkingParam | None
39
+
40
+
41
+ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
42
+ def __init__(
43
+ self,
44
+ # Base LLM args
45
+ model_name: str,
46
+ model_id: str | None = None,
47
+ llm_settings: LiteLLMSettings | None = None,
48
+ tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
49
+ response_schema: Any | None = None,
50
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
51
+ apply_response_schema_via_provider: bool = False,
52
+ # LLM provider
53
+ api_provider: APIProvider | None = None,
54
+ # deployment_id: str | None = None,
55
+ # api_version: str | None = None,
56
+ # Connection settings
57
+ timeout: float | None = None,
58
+ max_client_retries: int = 2,
59
+ # Rate limiting
60
+ rate_limiter: LLMRateLimiter | None = None,
61
+ # Drop unsupported LLM settings
62
+ drop_params: bool = True,
63
+ additional_drop_params: list[str] | None = None,
64
+ allowed_openai_params: list[str] | None = None,
65
+ # Mock LLM response for testing
66
+ mock_response: str | None = None,
67
+ # LLM response retries: try to regenerate to pass validation
68
+ max_response_retries: int = 1,
69
+ ) -> None:
70
+ self._lite_llm_completion_params: dict[str, Any] = {
71
+ "max_retries": max_client_retries,
72
+ "timeout": timeout,
73
+ "drop_params": drop_params,
74
+ "additional_drop_params": additional_drop_params,
75
+ "allowed_openai_params": allowed_openai_params,
76
+ "mock_response": mock_response,
77
+ # "deployment_id": deployment_id,
78
+ # "api_version": api_version,
79
+ }
80
+
81
+ if model_name in litellm.get_valid_models(): # type: ignore[no-untyped-call]
82
+ _, provider_name, _, _ = litellm.get_llm_provider(model_name) # type: ignore[no-untyped-call]
83
+ api_provider = APIProvider(name=provider_name)
84
+ elif api_provider is not None:
85
+ self._lite_llm_completion_params["api_key"] = api_provider.get("api_key")
86
+ self._lite_llm_completion_params["api_base"] = api_provider.get("api_base")
87
+ elif api_provider is None:
88
+ raise ValueError(
89
+ f"Model '{model_name}' is not supported by LiteLLM and no API provider "
90
+ "was specified. Please provide a valid API provider or use a different "
91
+ "model."
92
+ )
93
+ super().__init__(
94
+ model_name=model_name,
95
+ model_id=model_id,
96
+ llm_settings=llm_settings,
97
+ converters=LiteLLMConverters(),
98
+ tools=tools,
99
+ response_schema=response_schema,
100
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
101
+ apply_response_schema_via_provider=apply_response_schema_via_provider,
102
+ api_provider=api_provider,
103
+ rate_limiter=rate_limiter,
104
+ max_client_retries=max_client_retries,
105
+ max_response_retries=max_response_retries,
106
+ )
107
+
108
+ if self._apply_response_schema_via_provider:
109
+ if self._tools:
110
+ for tool in self._tools.values():
111
+ tool.strict = True
112
+ if not self.supports_response_schema:
113
+ raise ValueError(
114
+ f"Model '{self._model_name}' does not support response schema "
115
+ "natively. Please set `apply_response_schema_via_provider=False`"
116
+ )
117
+
118
+ def get_supported_openai_params(self) -> list[Any] | None:
119
+ return get_supported_openai_params( # type: ignore[no-untyped-call]
120
+ model=self._model_name, request_type="chat_completion"
121
+ )
122
+
123
+ @property
124
+ def supports_reasoning(self) -> bool:
125
+ return supports_reasoning(model=self._model_name)
126
+
127
+ @property
128
+ def supports_parallel_function_calling(self) -> bool:
129
+ return supports_parallel_function_calling(model=self._model_name)
130
+
131
+ @property
132
+ def supports_prompt_caching(self) -> bool:
133
+ return supports_prompt_caching(model=self._model_name)
134
+
135
+ @property
136
+ def supports_response_schema(self) -> bool:
137
+ return supports_response_schema(model=self._model_name)
138
+
139
+ @property
140
+ def supports_tool_choice(self) -> bool:
141
+ return supports_tool_choice(model=self._model_name)
142
+
143
+ # # client
144
+ # model_list: Optional[list] = (None,) # pass in a list of api_base,keys, etc.
145
+
146
+ async def _get_completion(
147
+ self,
148
+ api_messages: list[OpenAIMessageParam],
149
+ api_tools: list[OpenAIToolParam] | None = None,
150
+ api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
151
+ api_response_schema: type | None = None,
152
+ n_choices: int | None = None,
153
+ **api_llm_settings: Any,
154
+ ) -> LiteLLMCompletion:
155
+ completion = await litellm.acompletion( # type: ignore[no-untyped-call]
156
+ model=self._model_name,
157
+ messages=api_messages,
158
+ tools=api_tools,
159
+ tool_choice=api_tool_choice, # type: ignore[arg-type]
160
+ response_format=api_response_schema,
161
+ n=n_choices,
162
+ stream=False,
163
+ **self._lite_llm_completion_params,
164
+ **api_llm_settings,
165
+ )
166
+ completion = cast("LiteLLMCompletion", completion)
167
+
168
+ # Should not be needed in litellm>=1.74
169
+ completion._hidden_params["response_cost"] = litellm.completion_cost(completion) # type: ignore[no-untyped-call]
170
+
171
+ return completion
172
+
173
+ async def _get_completion_stream( # type: ignore[no-untyped-def]
174
+ self,
175
+ api_messages: list[OpenAIMessageParam],
176
+ api_tools: list[OpenAIToolParam] | None = None,
177
+ api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
178
+ api_response_schema: type | None = None,
179
+ n_choices: int | None = None,
180
+ **api_llm_settings: Any,
181
+ ) -> AsyncIterator[LiteLLMCompletionChunk]:
182
+ stream = await litellm.acompletion( # type: ignore[no-untyped-call]
183
+ model=self._model_name,
184
+ messages=api_messages,
185
+ tools=api_tools,
186
+ tool_choice=api_tool_choice, # type: ignore[arg-type]
187
+ response_format=api_response_schema,
188
+ stream=True,
189
+ n=n_choices,
190
+ **self._lite_llm_completion_params,
191
+ **api_llm_settings,
192
+ )
193
+ stream = cast("CustomStreamWrapper", stream)
194
+
195
+ async for completion_chunk in stream:
196
+ yield completion_chunk
197
+
198
+ def combine_completion_chunks(
199
+ self, completion_chunks: list[LiteLLMCompletionChunk]
200
+ ) -> LiteLLMCompletion:
201
+ combined_chunk = cast(
202
+ "LiteLLMCompletion",
203
+ litellm.stream_chunk_builder(completion_chunks), # type: ignore[no-untyped-call]
204
+ )
205
+ # Should not be needed in litellm>=1.74
206
+ combined_chunk._hidden_params["response_cost"] = litellm.completion_cost( # type: ignore[no-untyped-call]
207
+ combined_chunk
208
+ )
209
+
210
+ return combined_chunk
@@ -0,0 +1,66 @@
1
+ from ..typing.message import (
2
+ AssistantMessage,
3
+ )
4
+ from ..typing.tool import ToolCall
5
+ from . import LiteLLMCompletionMessage, LiteLLMFunction, LiteLLMToolCall
6
+
7
+
8
+ def from_api_assistant_message(
9
+ api_message: LiteLLMCompletionMessage, name: str | None = None
10
+ ) -> AssistantMessage:
11
+ tool_calls = None
12
+ if api_message.tool_calls is not None:
13
+ tool_calls = [
14
+ ToolCall(
15
+ id=tool_call.id,
16
+ tool_name=tool_call.function.name, # type: ignore
17
+ tool_arguments=tool_call.function.arguments,
18
+ )
19
+ for tool_call in api_message.tool_calls
20
+ ]
21
+
22
+ return AssistantMessage(
23
+ content=api_message.content,
24
+ tool_calls=tool_calls,
25
+ name=name,
26
+ thinking_blocks=getattr(api_message, "thinking_blocks", None),
27
+ reasoning_content=getattr(api_message, "reasoning_content", None),
28
+ annotations=getattr(api_message, "annotations", None),
29
+ provider_specific_fields=api_message.provider_specific_fields,
30
+ refusal=getattr(api_message, "refusal", None),
31
+ )
32
+
33
+
34
+ def to_api_assistant_message(
35
+ message: AssistantMessage,
36
+ ) -> LiteLLMCompletionMessage:
37
+ api_tool_calls = None
38
+ if message.tool_calls is not None:
39
+ api_tool_calls = [
40
+ LiteLLMToolCall(
41
+ type="function",
42
+ id=tool_call.id,
43
+ function=LiteLLMFunction(
44
+ name=tool_call.tool_name,
45
+ arguments=tool_call.tool_arguments,
46
+ ),
47
+ )
48
+ for tool_call in message.tool_calls
49
+ ]
50
+
51
+ api_message = LiteLLMCompletionMessage(role="assistant", content=message.content)
52
+
53
+ if api_tool_calls:
54
+ api_message.tool_calls = api_tool_calls
55
+
56
+ for key in [
57
+ "thinking_blocks",
58
+ "reasoning_content",
59
+ "annotations",
60
+ "provider_specific_fields",
61
+ "refusal",
62
+ ]:
63
+ if getattr(message, key):
64
+ api_message[key] = getattr(message, key)
65
+
66
+ return api_message
grasp_agents/llm.py CHANGED
@@ -1,18 +1,25 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import AsyncIterator, Mapping
3
+ from collections.abc import AsyncIterator, Mapping, Sequence
4
4
  from typing import Any, Generic, TypeVar, cast
5
5
  from uuid import uuid4
6
6
 
7
- from pydantic import BaseModel, TypeAdapter
7
+ from pydantic import BaseModel
8
8
  from typing_extensions import TypedDict
9
9
 
10
- from grasp_agents.utils import validate_obj_from_json_or_py_string
10
+ from grasp_agents.utils import (
11
+ validate_obj_from_json_or_py_string,
12
+ validate_tagged_objs_from_json_or_py_string,
13
+ )
11
14
 
12
- from .errors import ToolValidationError
15
+ from .errors import (
16
+ JSONSchemaValidationError,
17
+ LLMResponseValidationError,
18
+ LLMToolCallValidationError,
19
+ )
13
20
  from .typing.completion import Completion
14
21
  from .typing.converters import Converters
15
- from .typing.events import CompletionChunkEvent, CompletionEvent
22
+ from .typing.events import CompletionChunkEvent, CompletionEvent, LLMStreamingErrorEvent
16
23
  from .typing.message import Messages
17
24
  from .typing.tool import BaseTool, ToolChoice
18
25
 
@@ -38,8 +45,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
38
45
  model_name: str | None = None,
39
46
  model_id: str | None = None,
40
47
  llm_settings: SettingsT_co | None = None,
41
- tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
42
- response_format: Any | Mapping[str, Any] | None = None,
48
+ tools: Sequence[BaseTool[BaseModel, Any, Any]] | None = None,
49
+ response_schema: Any | None = None,
50
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
43
51
  **kwargs: Any,
44
52
  ) -> None:
45
53
  super().__init__()
@@ -50,20 +58,13 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
50
58
  self._tools = {t.name: t for t in tools} if tools else None
51
59
  self._llm_settings: SettingsT_co = llm_settings or cast("SettingsT_co", {})
52
60
 
53
- self._response_format = response_format
54
- self._response_format_adapter: (
55
- TypeAdapter[Any] | Mapping[str, TypeAdapter[Any]]
56
- ) = self._get_response_format_adapter(response_format=response_format)
57
-
58
- @staticmethod
59
- def _get_response_format_adapter(
60
- response_format: Any | Mapping[str, Any] | None = None,
61
- ) -> TypeAdapter[Any] | Mapping[str, TypeAdapter[Any]]:
62
- if response_format is None:
63
- return TypeAdapter(Any)
64
- if isinstance(response_format, Mapping):
65
- return {k: TypeAdapter(v) for k, v in response_format.items()} # type: ignore[return-value]
66
- return TypeAdapter(response_format)
61
+ if response_schema and response_schema_by_xml_tag:
62
+ raise ValueError(
63
+ "Only one of response_schema and response_schema_by_xml_tag can be "
64
+ "provided, but not both."
65
+ )
66
+ self._response_schema = response_schema
67
+ self._response_schema_by_xml_tag = response_schema_by_xml_tag
67
68
 
68
69
  @property
69
70
  def model_id(self) -> str:
@@ -78,40 +79,59 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
78
79
  return self._llm_settings
79
80
 
80
81
  @property
81
- def response_format(self) -> Any | Mapping[str, Any] | None:
82
- return self._response_format
82
+ def response_schema(self) -> Any | None:
83
+ return self._response_schema
83
84
 
84
- @response_format.setter
85
- def response_format(self, response_format: Any | Mapping[str, Any] | None) -> None:
86
- self._response_format = response_format
87
- self._response_format_adapter = self._get_response_format_adapter(
88
- response_format
89
- )
85
+ @response_schema.setter
86
+ def response_schema(self, response_schema: Any | None) -> None:
87
+ self._response_schema = response_schema
88
+
89
+ @property
90
+ def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
91
+ return self._response_schema_by_xml_tag
90
92
 
91
93
  @property
92
94
  def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
93
95
  return self._tools
94
96
 
95
97
  @tools.setter
96
- def tools(self, tools: list[BaseTool[BaseModel, Any, Any]] | None) -> None:
98
+ def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
97
99
  self._tools = {t.name: t for t in tools} if tools else None
98
100
 
99
101
  def __repr__(self) -> str:
100
- return (
101
- f"{type(self).__name__}(model_id={self.model_id}; "
102
- f"model_name={self._model_name})"
103
- )
102
+ return f"{type(self).__name__}[{self.model_id}]; model_name={self._model_name})"
103
+
104
+ def _validate_response(self, completion: Completion) -> None:
105
+ parsing_params = {
106
+ "from_substring": False,
107
+ "strip_language_markdown": True,
108
+ }
109
+ try:
110
+ for message in completion.messages:
111
+ if not message.tool_calls:
112
+ if self._response_schema:
113
+ validate_obj_from_json_or_py_string(
114
+ message.content or "",
115
+ schema=self._response_schema,
116
+ **parsing_params,
117
+ )
104
118
 
105
- def _validate_completion(self, completion: Completion) -> None:
106
- for message in completion.messages:
107
- if not message.tool_calls:
108
- validate_obj_from_json_or_py_string(
109
- message.content or "",
110
- adapter=self._response_format_adapter,
111
- from_substring=True,
112
- )
119
+ elif self._response_schema_by_xml_tag:
120
+ validate_tagged_objs_from_json_or_py_string(
121
+ message.content or "",
122
+ schema_by_xml_tag=self._response_schema_by_xml_tag,
123
+ **parsing_params,
124
+ )
125
+ except JSONSchemaValidationError as exc:
126
+ raise LLMResponseValidationError(
127
+ exc.s, exc.schema, message=str(exc)
128
+ ) from exc
113
129
 
114
130
  def _validate_tool_calls(self, completion: Completion) -> None:
131
+ parsing_params = {
132
+ "from_substring": False,
133
+ "strip_language_markdown": True,
134
+ }
115
135
  for message in completion.messages:
116
136
  if message.tool_calls:
117
137
  for tool_call in message.tool_calls:
@@ -120,14 +140,21 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
120
140
 
121
141
  available_tool_names = list(self.tools) if self.tools else []
122
142
  if tool_name not in available_tool_names or not self.tools:
123
- raise ToolValidationError(
124
- f"Tool '{tool_name}' is not available in the LLM tools "
125
- f"(available: {available_tool_names}"
143
+ raise LLMToolCallValidationError(
144
+ tool_name,
145
+ tool_arguments,
146
+ message=f"Tool '{tool_name}' is not available in the LLM "
147
+ f"tools (available: {available_tool_names})",
126
148
  )
127
149
  tool = self.tools[tool_name]
128
- validate_obj_from_json_or_py_string(
129
- tool_arguments, adapter=TypeAdapter(tool.in_type)
130
- )
150
+ try:
151
+ validate_obj_from_json_or_py_string(
152
+ tool_arguments, schema=tool.in_type, **parsing_params
153
+ )
154
+ except JSONSchemaValidationError as exc:
155
+ raise LLMToolCallValidationError(
156
+ tool_name, tool_arguments
157
+ ) from exc
131
158
 
132
159
  @abstractmethod
133
160
  async def generate_completion(
@@ -136,6 +163,8 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
136
163
  *,
137
164
  tool_choice: ToolChoice | None = None,
138
165
  n_choices: int | None = None,
166
+ proc_name: str | None = None,
167
+ call_id: str | None = None,
139
168
  ) -> Completion:
140
169
  pass
141
170
 
@@ -146,5 +175,11 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
146
175
  *,
147
176
  tool_choice: ToolChoice | None = None,
148
177
  n_choices: int | None = None,
149
- ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
178
+ proc_name: str | None = None,
179
+ call_id: str | None = None,
180
+ ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
150
181
  pass
182
+
183
+ @abstractmethod
184
+ def combine_completion_chunks(self, completion_chunks: list[Any]) -> Any:
185
+ raise NotImplementedError