grasp_agents 0.1.18__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +2 -2
- grasp_agents/agent_message_pool.py +6 -8
- grasp_agents/base_agent.py +15 -36
- grasp_agents/cloud_llm.py +9 -6
- grasp_agents/comm_agent.py +39 -43
- grasp_agents/generics_utils.py +159 -0
- grasp_agents/llm.py +4 -0
- grasp_agents/llm_agent.py +68 -37
- grasp_agents/llm_agent_state.py +9 -5
- grasp_agents/prompt_builder.py +45 -25
- grasp_agents/rate_limiting/rate_limiter_chunked.py +49 -48
- grasp_agents/rate_limiting/types.py +19 -40
- grasp_agents/rate_limiting/utils.py +24 -27
- grasp_agents/run_context.py +2 -15
- grasp_agents/tool_orchestrator.py +30 -8
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/io.py +4 -9
- grasp_agents/typing/tool.py +26 -7
- grasp_agents/utils.py +26 -39
- grasp_agents/workflow/looped_agent.py +12 -9
- grasp_agents/workflow/sequential_agent.py +9 -6
- grasp_agents/workflow/workflow_agent.py +16 -11
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/METADATA +37 -33
- grasp_agents-0.2.0.dist-info/RECORD +45 -0
- grasp_agents-0.1.18.dist-info/RECORD +0 -44
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.1.18.dist-info → grasp_agents-0.2.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/agent_message.py
CHANGED
@@ -4,9 +4,9 @@ from uuid import uuid4
|
|
4
4
|
|
5
5
|
from pydantic import BaseModel, ConfigDict, Field
|
6
6
|
|
7
|
-
from .typing.io import AgentID,
|
7
|
+
from .typing.io import AgentID, AgentState
|
8
8
|
|
9
|
-
_PayloadT = TypeVar("_PayloadT",
|
9
|
+
_PayloadT = TypeVar("_PayloadT", covariant=True) # noqa: PLC0105
|
10
10
|
_StateT = TypeVar("_StateT", bound=AgentState, covariant=True) # noqa: PLC0105
|
11
11
|
|
12
12
|
|
@@ -4,12 +4,12 @@ from typing import Any, Generic, Protocol, TypeVar
|
|
4
4
|
|
5
5
|
from .agent_message import AgentMessage
|
6
6
|
from .run_context import CtxT, RunContextWrapper
|
7
|
-
from .typing.io import AgentID,
|
7
|
+
from .typing.io import AgentID, AgentState
|
8
8
|
|
9
9
|
logger = logging.getLogger(__name__)
|
10
10
|
|
11
11
|
|
12
|
-
_MH_PayloadT = TypeVar("_MH_PayloadT",
|
12
|
+
_MH_PayloadT = TypeVar("_MH_PayloadT", contravariant=True) # noqa: PLC0105
|
13
13
|
_MH_StateT = TypeVar("_MH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
|
14
14
|
|
15
15
|
|
@@ -24,15 +24,13 @@ class MessageHandler(Protocol[_MH_PayloadT, _MH_StateT, CtxT]):
|
|
24
24
|
|
25
25
|
class AgentMessagePool(Generic[CtxT]):
|
26
26
|
def __init__(self) -> None:
|
27
|
-
self._queues: dict[
|
28
|
-
AgentID, asyncio.Queue[AgentMessage[AgentPayload, AgentState]]
|
29
|
-
] = {}
|
27
|
+
self._queues: dict[AgentID, asyncio.Queue[AgentMessage[Any, AgentState]]] = {}
|
30
28
|
self._message_handlers: dict[
|
31
|
-
AgentID, MessageHandler[
|
29
|
+
AgentID, MessageHandler[Any, AgentState, CtxT]
|
32
30
|
] = {}
|
33
31
|
self._tasks: dict[AgentID, asyncio.Task[None]] = {}
|
34
32
|
|
35
|
-
async def post(self, message: AgentMessage[
|
33
|
+
async def post(self, message: AgentMessage[Any, AgentState]) -> None:
|
36
34
|
for recipient_id in message.recipient_ids:
|
37
35
|
queue = self._queues.setdefault(recipient_id, asyncio.Queue())
|
38
36
|
await queue.put(message)
|
@@ -40,7 +38,7 @@ class AgentMessagePool(Generic[CtxT]):
|
|
40
38
|
def register_message_handler(
|
41
39
|
self,
|
42
40
|
agent_id: AgentID,
|
43
|
-
handler: MessageHandler[
|
41
|
+
handler: MessageHandler[Any, AgentState, CtxT],
|
44
42
|
ctx: RunContextWrapper[CtxT] | None = None,
|
45
43
|
**run_kwargs: Any,
|
46
44
|
) -> None:
|
grasp_agents/base_agent.py
CHANGED
@@ -1,47 +1,30 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, ClassVar, Generic
|
3
3
|
|
4
|
-
from pydantic import
|
4
|
+
from pydantic import TypeAdapter
|
5
5
|
|
6
|
+
from .generics_utils import AutoInstanceAttributesMixin
|
6
7
|
from .run_context import CtxT, RunContextWrapper
|
7
|
-
from .typing.io import AgentID,
|
8
|
+
from .typing.io import AgentID, OutT, StateT
|
8
9
|
from .typing.tool import BaseTool
|
9
10
|
|
10
11
|
|
11
|
-
class
|
12
|
-
|
13
|
-
self, *args: Any, ctx: RunContextWrapper[CtxT] | None, **kwargs: Any
|
14
|
-
) -> OutT: ...
|
12
|
+
class BaseAgent(AutoInstanceAttributesMixin, ABC, Generic[OutT, StateT, CtxT]):
|
13
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_out_type"}
|
15
14
|
|
16
|
-
|
17
|
-
class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
|
18
15
|
@abstractmethod
|
19
|
-
def __init__(
|
20
|
-
self
|
21
|
-
agent_id: AgentID,
|
22
|
-
*,
|
23
|
-
out_schema: type[OutT] = AgentPayload,
|
24
|
-
**kwargs: Any,
|
25
|
-
) -> None:
|
16
|
+
def __init__(self, agent_id: AgentID, **kwargs: Any) -> None:
|
17
|
+
self._out_type: type[OutT]
|
26
18
|
self._state: StateT
|
27
|
-
self._agent_id = agent_id
|
28
|
-
self._out_schema = out_schema
|
29
|
-
self._parse_output_impl: ParseOutputHandler[OutT, CtxT] | None = None
|
30
19
|
|
31
|
-
|
32
|
-
self, func: ParseOutputHandler[OutT, CtxT]
|
33
|
-
) -> ParseOutputHandler[OutT, CtxT]:
|
34
|
-
self._parse_output_impl = func
|
20
|
+
super().__init__()
|
35
21
|
|
36
|
-
|
37
|
-
|
38
|
-
def _parse_output(
|
39
|
-
self, *args: Any, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
|
40
|
-
) -> OutT:
|
41
|
-
if self._parse_output_impl:
|
42
|
-
return self._parse_output_impl(*args, ctx=ctx, **kwargs)
|
22
|
+
self._agent_id = agent_id
|
23
|
+
self._out_type_adapter: TypeAdapter[OutT] = TypeAdapter(self._out_type)
|
43
24
|
|
44
|
-
|
25
|
+
@property
|
26
|
+
def out_type(self) -> type[OutT]:
|
27
|
+
return self._out_type
|
45
28
|
|
46
29
|
@property
|
47
30
|
def agent_id(self) -> AgentID:
|
@@ -51,10 +34,6 @@ class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
|
|
51
34
|
def state(self) -> StateT:
|
52
35
|
return self._state
|
53
36
|
|
54
|
-
@property
|
55
|
-
def out_schema(self) -> type[OutT]:
|
56
|
-
return self._out_schema
|
57
|
-
|
58
37
|
@abstractmethod
|
59
38
|
async def run(
|
60
39
|
self,
|
@@ -68,5 +47,5 @@ class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
|
|
68
47
|
@abstractmethod
|
69
48
|
def as_tool(
|
70
49
|
self, tool_name: str, tool_description: str, tool_strict: bool = True
|
71
|
-
) -> BaseTool[
|
50
|
+
) -> BaseTool[Any, OutT, CtxT]:
|
72
51
|
pass
|
grasp_agents/cloud_llm.py
CHANGED
@@ -26,7 +26,7 @@ from .rate_limiting.rate_limiter_chunked import ( # type: ignore
|
|
26
26
|
from .typing.completion import Completion, CompletionChunk
|
27
27
|
from .typing.message import AssistantMessage, Conversation
|
28
28
|
from .typing.tool import BaseTool, ToolChoice
|
29
|
-
from .utils import
|
29
|
+
from .utils import validate_obj_from_json_or_py_string
|
30
30
|
|
31
31
|
logger = logging.getLogger(__name__)
|
32
32
|
|
@@ -271,6 +271,11 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
271
271
|
api_completion, model_id=self.model_id
|
272
272
|
)
|
273
273
|
|
274
|
+
self._validate_completion(completion)
|
275
|
+
|
276
|
+
return completion
|
277
|
+
|
278
|
+
def _validate_completion(self, completion: Completion) -> None:
|
274
279
|
for choice in completion.choices:
|
275
280
|
message = choice.message
|
276
281
|
if (
|
@@ -278,12 +283,10 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
278
283
|
and not self._llm_settings.get("use_structured_outputs")
|
279
284
|
and not message.tool_calls
|
280
285
|
):
|
281
|
-
|
282
|
-
message.content,
|
286
|
+
validate_obj_from_json_or_py_string(
|
287
|
+
message.content,
|
288
|
+
adapter=self._response_format_pyd,
|
283
289
|
)
|
284
|
-
self._response_format_pyd.validate_python(message_json)
|
285
|
-
|
286
|
-
return completion
|
287
290
|
|
288
291
|
async def generate_completion_stream(
|
289
292
|
self,
|
grasp_agents/comm_agent.py
CHANGED
@@ -1,26 +1,26 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import abstractmethod
|
3
3
|
from collections.abc import Sequence
|
4
|
-
from typing import Any, Generic, Protocol, TypeVar, cast, final
|
4
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
|
5
5
|
|
6
|
-
from pydantic import BaseModel
|
6
|
+
from pydantic import BaseModel, TypeAdapter
|
7
7
|
from pydantic.json_schema import SkipJsonSchema
|
8
8
|
|
9
9
|
from .agent_message import AgentMessage
|
10
10
|
from .agent_message_pool import AgentMessagePool
|
11
11
|
from .base_agent import BaseAgent
|
12
12
|
from .run_context import CtxT, RunContextWrapper
|
13
|
-
from .typing.io import AgentID,
|
13
|
+
from .typing.io import AgentID, AgentState, InT, OutT, StateT
|
14
14
|
from .typing.tool import BaseTool
|
15
15
|
|
16
16
|
logger = logging.getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
|
-
class
|
19
|
+
class DynCommPayload(BaseModel):
|
20
20
|
selected_recipient_ids: SkipJsonSchema[Sequence[AgentID]]
|
21
21
|
|
22
22
|
|
23
|
-
_EH_OutT = TypeVar("_EH_OutT",
|
23
|
+
_EH_OutT = TypeVar("_EH_OutT", contravariant=True) # noqa: PLC0105
|
24
24
|
_EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
|
25
25
|
|
26
26
|
|
@@ -35,44 +35,37 @@ class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
|
|
35
35
|
class CommunicatingAgent(
|
36
36
|
BaseAgent[OutT, StateT, CtxT], Generic[InT, OutT, StateT, CtxT]
|
37
37
|
):
|
38
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
39
|
+
0: "_in_type",
|
40
|
+
1: "_out_type",
|
41
|
+
}
|
42
|
+
|
38
43
|
def __init__(
|
39
44
|
self,
|
40
45
|
agent_id: AgentID,
|
41
46
|
*,
|
42
|
-
out_schema: type[OutT] = AgentPayload,
|
43
|
-
rcv_args_schema: type[InT] = AgentPayload,
|
44
47
|
recipient_ids: Sequence[AgentID] | None = None,
|
45
48
|
message_pool: AgentMessagePool[CtxT] | None = None,
|
46
49
|
**kwargs: Any,
|
47
50
|
) -> None:
|
48
|
-
|
49
|
-
|
51
|
+
self._in_type: type[InT]
|
52
|
+
super().__init__(agent_id=agent_id, **kwargs)
|
53
|
+
|
54
|
+
self._rcv_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
55
|
+
self.recipient_ids = recipient_ids or []
|
50
56
|
|
57
|
+
self._message_pool = message_pool or AgentMessagePool()
|
51
58
|
self._is_listening = False
|
52
59
|
self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
|
53
60
|
|
54
|
-
self._rcv_args_schema = rcv_args_schema
|
55
|
-
self.recipient_ids = recipient_ids or []
|
56
|
-
|
57
61
|
@property
|
58
|
-
def
|
59
|
-
|
60
|
-
|
61
|
-
def _parse_output(
|
62
|
-
self,
|
63
|
-
*args: Any,
|
64
|
-
rcv_args: InT | None = None,
|
65
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
66
|
-
**kwargs: Any,
|
67
|
-
) -> OutT:
|
68
|
-
if self._parse_output_impl:
|
69
|
-
return self._parse_output_impl(*args, rcv_args=rcv_args, ctx=ctx, **kwargs)
|
70
|
-
|
71
|
-
return self._out_schema()
|
62
|
+
def in_type(self) -> type[InT]: # type: ignore
|
63
|
+
# Exposing the type of a contravariant variable only, should be safe
|
64
|
+
return self._in_type
|
72
65
|
|
73
66
|
def _validate_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
|
74
|
-
if all(isinstance(p,
|
75
|
-
payloads_ = cast("Sequence[
|
67
|
+
if all(isinstance(p, DynCommPayload) for p in payloads):
|
68
|
+
payloads_ = cast("Sequence[DynCommPayload]", payloads)
|
76
69
|
selected_recipient_ids_per_payload = [
|
77
70
|
set(p.selected_recipient_ids or []) for p in payloads_
|
78
71
|
]
|
@@ -91,7 +84,7 @@ class CommunicatingAgent(
|
|
91
84
|
|
92
85
|
return selected_recipient_ids
|
93
86
|
|
94
|
-
if all((not isinstance(p,
|
87
|
+
if all((not isinstance(p, DynCommPayload)) for p in payloads):
|
95
88
|
return self.recipient_ids
|
96
89
|
|
97
90
|
raise ValueError(
|
@@ -109,7 +102,7 @@ class CommunicatingAgent(
|
|
109
102
|
inp_items: Any | None = None,
|
110
103
|
*,
|
111
104
|
ctx: RunContextWrapper[CtxT] | None = None,
|
112
|
-
rcv_message: AgentMessage[InT,
|
105
|
+
rcv_message: AgentMessage[InT, AgentState] | None = None,
|
113
106
|
entry_point: bool = False,
|
114
107
|
forbid_state_change: bool = False,
|
115
108
|
**kwargs: Any,
|
@@ -143,11 +136,11 @@ class CommunicatingAgent(
|
|
143
136
|
|
144
137
|
async def _message_handler(
|
145
138
|
self,
|
146
|
-
message: AgentMessage[
|
139
|
+
message: AgentMessage[Any, AgentState],
|
147
140
|
ctx: RunContextWrapper[CtxT] | None = None,
|
148
141
|
**run_kwargs: Any,
|
149
142
|
) -> None:
|
150
|
-
rcv_message = cast("AgentMessage[InT,
|
143
|
+
rcv_message = cast("AgentMessage[InT, AgentState]", message)
|
151
144
|
out_message = await self.run(ctx=ctx, rcv_message=rcv_message, **run_kwargs)
|
152
145
|
|
153
146
|
if self._exit_condition(output_message=out_message, ctx=ctx):
|
@@ -185,15 +178,20 @@ class CommunicatingAgent(
|
|
185
178
|
tool_name: str,
|
186
179
|
tool_description: str,
|
187
180
|
tool_strict: bool = True,
|
188
|
-
) -> BaseTool[
|
181
|
+
) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
|
182
|
+
# Will check if InT is a BaseModel at runtime
|
189
183
|
agent_instance = self
|
184
|
+
in_type = agent_instance.in_type
|
185
|
+
out_type = agent_instance.out_type
|
186
|
+
if not issubclass(in_type, BaseModel):
|
187
|
+
raise TypeError(
|
188
|
+
"Cannot create a tool from an agent with "
|
189
|
+
f"non-BaseModel input type: {in_type}"
|
190
|
+
)
|
190
191
|
|
191
|
-
class AgentTool(BaseTool[
|
192
|
+
class AgentTool(BaseTool[in_type, out_type, Any]):
|
192
193
|
name: str = tool_name
|
193
194
|
description: str = tool_description
|
194
|
-
in_schema: type[BaseModel] = agent_instance.rcv_args_schema
|
195
|
-
out_schema: Any = agent_instance.out_schema
|
196
|
-
|
197
195
|
strict: bool | None = tool_strict
|
198
196
|
|
199
197
|
async def run(
|
@@ -201,16 +199,14 @@ class CommunicatingAgent(
|
|
201
199
|
inp: InT,
|
202
200
|
ctx: RunContextWrapper[CtxT] | None = None,
|
203
201
|
) -> OutT:
|
204
|
-
rcv_args =
|
205
|
-
|
206
|
-
rcv_message = AgentMessage( # type: ignore[arg-type]
|
202
|
+
rcv_args = in_type.model_validate(inp)
|
203
|
+
rcv_message = AgentMessage[in_type, AgentState](
|
207
204
|
payloads=[rcv_args],
|
208
205
|
sender_id="<tool_user>",
|
209
206
|
recipient_ids=[agent_instance.agent_id],
|
210
207
|
)
|
211
|
-
|
212
208
|
agent_result = await agent_instance.run(
|
213
|
-
rcv_message=rcv_message,
|
209
|
+
rcv_message=rcv_message,
|
214
210
|
entry_point=False,
|
215
211
|
forbid_state_change=True,
|
216
212
|
ctx=ctx,
|
@@ -218,4 +214,4 @@ class CommunicatingAgent(
|
|
218
214
|
|
219
215
|
return agent_result.payloads[0]
|
220
216
|
|
221
|
-
return AgentTool()
|
217
|
+
return AgentTool() # type: ignore[return-value]
|
@@ -0,0 +1,159 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, ClassVar, Self, TypeVar, cast, get_args
|
4
|
+
|
5
|
+
|
6
|
+
class AutoInstanceAttributesMixin:
|
7
|
+
"""
|
8
|
+
A **runtime convenience mix-in** that automatically exposes the concrete
|
9
|
+
types supplied to a *generic* base class as **instance attributes**.
|
10
|
+
|
11
|
+
Example:
|
12
|
+
-------
|
13
|
+
from typing import Generic, TypeVar
|
14
|
+
from grasp_agents.generics_utils import AutoInstanceAttributesMixin
|
15
|
+
|
16
|
+
T = TypeVar("T")
|
17
|
+
U = TypeVar("U")
|
18
|
+
|
19
|
+
class MyBase(AutoInstanceAttributesMixin, Generic[T, U]):
|
20
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
21
|
+
0: "elem_type",
|
22
|
+
1: "meta_type",
|
23
|
+
}
|
24
|
+
|
25
|
+
class Packet(MyBase[int, str]):
|
26
|
+
...
|
27
|
+
|
28
|
+
Alias = MyBase[bytes, float] # "late" specialization
|
29
|
+
|
30
|
+
print(Packet().elem_type) # <class 'int'>
|
31
|
+
print(Alias().meta_type) # <class 'float'>
|
32
|
+
|
33
|
+
"""
|
34
|
+
|
35
|
+
# Configure this on your *generic* base class
|
36
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {}
|
37
|
+
|
38
|
+
# Filled automatically for every concrete specialization
|
39
|
+
_resolved_instance_attr_types: ClassVar[dict[str, type]]
|
40
|
+
|
41
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
42
|
+
super().__init__(*args, **kwargs)
|
43
|
+
self._set_resolved_generic_instance_attributes()
|
44
|
+
|
45
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
46
|
+
"""Runs when a *class statement* defines a new subclass."""
|
47
|
+
super().__init_subclass__(**kwargs)
|
48
|
+
cls._resolved_instance_attr_types = cls._compute_resolved_attrs(cls)
|
49
|
+
|
50
|
+
@classmethod
|
51
|
+
def __class_getitem__(cls, params: Any) -> type[Self]:
|
52
|
+
"""
|
53
|
+
Run when someone writes ``SomeGeneric[ConcreteTypes]``.
|
54
|
+
Ensures aliases receive the same resolved-type mapping.
|
55
|
+
"""
|
56
|
+
specialized: type[Self] = cast("type[Self]", super().__class_getitem__(params)) # type: ignore[assignment]
|
57
|
+
specialized._resolved_instance_attr_types = cls._compute_resolved_attrs( # noqa: SLF001
|
58
|
+
specialized
|
59
|
+
)
|
60
|
+
|
61
|
+
return specialized
|
62
|
+
|
63
|
+
@staticmethod
|
64
|
+
def _compute_resolved_attrs(_cls: type) -> dict[str, type]:
|
65
|
+
"""
|
66
|
+
Walks the MRO, finds the first generic base that defines
|
67
|
+
`_generic_arg_to_instance_attr_map`, and resolves concrete types.
|
68
|
+
"""
|
69
|
+
target_generic_base: Any | None = None
|
70
|
+
attr_mapping: dict[int, str] | None = None
|
71
|
+
|
72
|
+
# Locate the mapping
|
73
|
+
for mro_cls in _cls.mro():
|
74
|
+
if mro_cls in {AutoInstanceAttributesMixin, object}:
|
75
|
+
continue
|
76
|
+
if (
|
77
|
+
hasattr(mro_cls, "__parameters__")
|
78
|
+
and mro_cls.__parameters__ # type: ignore[has-type]
|
79
|
+
and "_generic_arg_to_instance_attr_map" in mro_cls.__dict__
|
80
|
+
):
|
81
|
+
target_generic_base = mro_cls
|
82
|
+
attr_mapping = cast(
|
83
|
+
"dict[int, str]",
|
84
|
+
mro_cls._generic_arg_to_instance_attr_map, # type: ignore[attr-defined] # noqa: SLF001
|
85
|
+
)
|
86
|
+
break
|
87
|
+
|
88
|
+
if target_generic_base is None or attr_mapping is None:
|
89
|
+
return {}
|
90
|
+
|
91
|
+
resolved: dict[str, type] = {}
|
92
|
+
|
93
|
+
def _add_to_resolved(generic_args: tuple[type | TypeVar, ...]) -> None:
|
94
|
+
for index, attr_name in attr_mapping.items():
|
95
|
+
if attr_name in resolved:
|
96
|
+
continue
|
97
|
+
if index < len(generic_args):
|
98
|
+
arg = generic_args[index]
|
99
|
+
if not isinstance(arg, TypeVar):
|
100
|
+
resolved[attr_name] = arg
|
101
|
+
|
102
|
+
def _all_resolved() -> bool:
|
103
|
+
return all(name in resolved for name in attr_mapping.values())
|
104
|
+
|
105
|
+
# Scenario 1: _cls itself is the direct parameterization (handles aliases).
|
106
|
+
# e.g., _cls is MyBase[bytes, float]. Its __origin__ is MyBase.
|
107
|
+
if getattr(_cls, "__origin__", None) is target_generic_base:
|
108
|
+
_add_to_resolved(get_args(_cls))
|
109
|
+
|
110
|
+
# Scenario 2: Check MRO for subclasses or more complex structures.
|
111
|
+
# This also acts as a fallback if Scenario 1 didn't resolve all attributes.
|
112
|
+
if not _all_resolved():
|
113
|
+
for mro_candidate in _cls.mro():
|
114
|
+
# Pydantic-specific check first
|
115
|
+
pydantic_generic_metadata = getattr(
|
116
|
+
mro_candidate, "__pydantic_generic_metadata__", None
|
117
|
+
)
|
118
|
+
if (
|
119
|
+
pydantic_generic_metadata
|
120
|
+
and pydantic_generic_metadata.get("origin") is target_generic_base
|
121
|
+
):
|
122
|
+
_add_to_resolved(pydantic_generic_metadata.get("args", ()))
|
123
|
+
if _all_resolved():
|
124
|
+
break
|
125
|
+
|
126
|
+
# Fallback to standard generic introspection if not fully resolved by Pydantic check
|
127
|
+
if not _all_resolved():
|
128
|
+
mro_candidate_origin = getattr(mro_candidate, "__origin__", None)
|
129
|
+
if mro_candidate_origin is target_generic_base:
|
130
|
+
_add_to_resolved(get_args(mro_candidate))
|
131
|
+
if _all_resolved():
|
132
|
+
break
|
133
|
+
|
134
|
+
if not _all_resolved():
|
135
|
+
mro_candidate_orig_bases = getattr(
|
136
|
+
mro_candidate, "__orig_bases__", []
|
137
|
+
)
|
138
|
+
for param_base in mro_candidate_orig_bases:
|
139
|
+
param_base_origin = getattr(param_base, "__origin__", None)
|
140
|
+
if param_base_origin is target_generic_base:
|
141
|
+
_add_to_resolved(get_args(param_base))
|
142
|
+
if _all_resolved():
|
143
|
+
break
|
144
|
+
|
145
|
+
if _all_resolved():
|
146
|
+
break
|
147
|
+
|
148
|
+
return resolved
|
149
|
+
|
150
|
+
def _set_resolved_generic_instance_attributes(self) -> None:
|
151
|
+
for name, typ in getattr(
|
152
|
+
self.__class__, "_resolved_instance_attr_types", {}
|
153
|
+
).items():
|
154
|
+
_typ = None if typ is type(None) else typ
|
155
|
+
pyd_private = getattr(self, "__pydantic_private__", {})
|
156
|
+
if name in pyd_private:
|
157
|
+
pyd_private[name] = _typ
|
158
|
+
else:
|
159
|
+
setattr(self, name, _typ)
|
grasp_agents/llm.py
CHANGED
@@ -69,6 +69,10 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
|
|
69
69
|
def tools(self, tools: list[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
70
70
|
self._tools = {t.name: t for t in tools} if tools else None
|
71
71
|
|
72
|
+
@response_format.setter
|
73
|
+
def response_format(self, response_format: type | None) -> None:
|
74
|
+
self._response_format = response_format
|
75
|
+
|
72
76
|
def __repr__(self) -> str:
|
73
77
|
return (
|
74
78
|
f"{type(self).__name__}(model_id={self.model_id}; "
|