grasp_agents 0.1.17__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,9 +4,9 @@ from uuid import uuid4
4
4
 
5
5
  from pydantic import BaseModel, ConfigDict, Field
6
6
 
7
- from .typing.io import AgentID, AgentPayload, AgentState
7
+ from .typing.io import AgentID, AgentState
8
8
 
9
- _PayloadT = TypeVar("_PayloadT", bound=AgentPayload, covariant=True) # noqa: PLC0105
9
+ _PayloadT = TypeVar("_PayloadT", covariant=True) # noqa: PLC0105
10
10
  _StateT = TypeVar("_StateT", bound=AgentState, covariant=True) # noqa: PLC0105
11
11
 
12
12
 
@@ -4,12 +4,12 @@ from typing import Any, Generic, Protocol, TypeVar
4
4
 
5
5
  from .agent_message import AgentMessage
6
6
  from .run_context import CtxT, RunContextWrapper
7
- from .typing.io import AgentID, AgentPayload, AgentState
7
+ from .typing.io import AgentID, AgentState
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
 
11
11
 
12
- _MH_PayloadT = TypeVar("_MH_PayloadT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
12
+ _MH_PayloadT = TypeVar("_MH_PayloadT", contravariant=True) # noqa: PLC0105
13
13
  _MH_StateT = TypeVar("_MH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
14
14
 
15
15
 
@@ -24,15 +24,13 @@ class MessageHandler(Protocol[_MH_PayloadT, _MH_StateT, CtxT]):
24
24
 
25
25
  class AgentMessagePool(Generic[CtxT]):
26
26
  def __init__(self) -> None:
27
- self._queues: dict[
28
- AgentID, asyncio.Queue[AgentMessage[AgentPayload, AgentState]]
29
- ] = {}
27
+ self._queues: dict[AgentID, asyncio.Queue[AgentMessage[Any, AgentState]]] = {}
30
28
  self._message_handlers: dict[
31
- AgentID, MessageHandler[AgentPayload, AgentState, CtxT]
29
+ AgentID, MessageHandler[Any, AgentState, CtxT]
32
30
  ] = {}
33
31
  self._tasks: dict[AgentID, asyncio.Task[None]] = {}
34
32
 
35
- async def post(self, message: AgentMessage[AgentPayload, AgentState]) -> None:
33
+ async def post(self, message: AgentMessage[Any, AgentState]) -> None:
36
34
  for recipient_id in message.recipient_ids:
37
35
  queue = self._queues.setdefault(recipient_id, asyncio.Queue())
38
36
  await queue.put(message)
@@ -40,7 +38,7 @@ class AgentMessagePool(Generic[CtxT]):
40
38
  def register_message_handler(
41
39
  self,
42
40
  agent_id: AgentID,
43
- handler: MessageHandler[AgentPayload, AgentState, CtxT],
41
+ handler: MessageHandler[Any, AgentState, CtxT],
44
42
  ctx: RunContextWrapper[CtxT] | None = None,
45
43
  **run_kwargs: Any,
46
44
  ) -> None:
@@ -1,47 +1,30 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any, Generic, Protocol
2
+ from typing import Any, ClassVar, Generic
3
3
 
4
- from pydantic import BaseModel
4
+ from pydantic import TypeAdapter
5
5
 
6
+ from .generics_utils import AutoInstanceAttributesMixin
6
7
  from .run_context import CtxT, RunContextWrapper
7
- from .typing.io import AgentID, AgentPayload, OutT, StateT
8
+ from .typing.io import AgentID, OutT, StateT
8
9
  from .typing.tool import BaseTool
9
10
 
10
11
 
11
- class ParseOutputHandler(Protocol[OutT, CtxT]):
12
- def __call__(
13
- self, *args: Any, ctx: RunContextWrapper[CtxT] | None, **kwargs: Any
14
- ) -> OutT: ...
12
+ class BaseAgent(AutoInstanceAttributesMixin, ABC, Generic[OutT, StateT, CtxT]):
13
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_out_type"}
15
14
 
16
-
17
- class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
18
15
  @abstractmethod
19
- def __init__(
20
- self,
21
- agent_id: AgentID,
22
- *,
23
- out_schema: type[OutT] = AgentPayload,
24
- **kwargs: Any,
25
- ) -> None:
16
+ def __init__(self, agent_id: AgentID, **kwargs: Any) -> None:
17
+ self._out_type: type[OutT]
26
18
  self._state: StateT
27
- self._agent_id = agent_id
28
- self._out_schema = out_schema
29
- self._parse_output_impl: ParseOutputHandler[OutT, CtxT] | None = None
30
19
 
31
- def parse_output_handler(
32
- self, func: ParseOutputHandler[OutT, CtxT]
33
- ) -> ParseOutputHandler[OutT, CtxT]:
34
- self._parse_output_impl = func
20
+ super().__init__()
35
21
 
36
- return func
37
-
38
- def _parse_output(
39
- self, *args: Any, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
40
- ) -> OutT:
41
- if self._parse_output_impl:
42
- return self._parse_output_impl(*args, ctx=ctx, **kwargs)
22
+ self._agent_id = agent_id
23
+ self._out_type_adapter: TypeAdapter[OutT] = TypeAdapter(self._out_type)
43
24
 
44
- return self._out_schema()
25
+ @property
26
+ def out_type(self) -> type[OutT]:
27
+ return self._out_type
45
28
 
46
29
  @property
47
30
  def agent_id(self) -> AgentID:
@@ -51,10 +34,6 @@ class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
51
34
  def state(self) -> StateT:
52
35
  return self._state
53
36
 
54
- @property
55
- def out_schema(self) -> type[OutT]:
56
- return self._out_schema
57
-
58
37
  @abstractmethod
59
38
  async def run(
60
39
  self,
@@ -68,5 +47,5 @@ class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
68
47
  @abstractmethod
69
48
  def as_tool(
70
49
  self, tool_name: str, tool_description: str, tool_strict: bool = True
71
- ) -> BaseTool[BaseModel, Any, CtxT]:
50
+ ) -> BaseTool[Any, OutT, CtxT]:
72
51
  pass
grasp_agents/cloud_llm.py CHANGED
@@ -26,7 +26,7 @@ from .rate_limiting.rate_limiter_chunked import ( # type: ignore
26
26
  from .typing.completion import Completion, CompletionChunk
27
27
  from .typing.message import AssistantMessage, Conversation
28
28
  from .typing.tool import BaseTool, ToolChoice
29
- from .utils import extract_json
29
+ from .utils import validate_obj_from_json_or_py_string
30
30
 
31
31
  logger = logging.getLogger(__name__)
32
32
 
@@ -271,6 +271,11 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
271
271
  api_completion, model_id=self.model_id
272
272
  )
273
273
 
274
+ self._validate_completion(completion)
275
+
276
+ return completion
277
+
278
+ def _validate_completion(self, completion: Completion) -> None:
274
279
  for choice in completion.choices:
275
280
  message = choice.message
276
281
  if (
@@ -278,12 +283,10 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
278
283
  and not self._llm_settings.get("use_structured_outputs")
279
284
  and not message.tool_calls
280
285
  ):
281
- message_json = extract_json(
282
- message.content, return_none_on_failure=True
286
+ validate_obj_from_json_or_py_string(
287
+ message.content,
288
+ adapter=self._response_format_pyd,
283
289
  )
284
- self._response_format_pyd.validate_python(message_json)
285
-
286
- return completion
287
290
 
288
291
  async def generate_completion_stream(
289
292
  self,
@@ -1,26 +1,26 @@
1
1
  import logging
2
2
  from abc import abstractmethod
3
3
  from collections.abc import Sequence
4
- from typing import Any, Generic, Protocol, TypeVar, cast, final
4
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, final
5
5
 
6
- from pydantic import BaseModel
6
+ from pydantic import BaseModel, TypeAdapter
7
7
  from pydantic.json_schema import SkipJsonSchema
8
8
 
9
9
  from .agent_message import AgentMessage
10
10
  from .agent_message_pool import AgentMessagePool
11
11
  from .base_agent import BaseAgent
12
12
  from .run_context import CtxT, RunContextWrapper
13
- from .typing.io import AgentID, AgentPayload, AgentState, InT, OutT, StateT
13
+ from .typing.io import AgentID, AgentState, InT, OutT, StateT
14
14
  from .typing.tool import BaseTool
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
18
18
 
19
- class DCommAgentPayload(AgentPayload):
19
+ class DynCommPayload(BaseModel):
20
20
  selected_recipient_ids: SkipJsonSchema[Sequence[AgentID]]
21
21
 
22
22
 
23
- _EH_OutT = TypeVar("_EH_OutT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
23
+ _EH_OutT = TypeVar("_EH_OutT", contravariant=True) # noqa: PLC0105
24
24
  _EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
25
25
 
26
26
 
@@ -35,44 +35,37 @@ class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
35
35
  class CommunicatingAgent(
36
36
  BaseAgent[OutT, StateT, CtxT], Generic[InT, OutT, StateT, CtxT]
37
37
  ):
38
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
39
+ 0: "_in_type",
40
+ 1: "_out_type",
41
+ }
42
+
38
43
  def __init__(
39
44
  self,
40
45
  agent_id: AgentID,
41
46
  *,
42
- out_schema: type[OutT] = AgentPayload,
43
- rcv_args_schema: type[InT] = AgentPayload,
44
47
  recipient_ids: Sequence[AgentID] | None = None,
45
48
  message_pool: AgentMessagePool[CtxT] | None = None,
46
49
  **kwargs: Any,
47
50
  ) -> None:
48
- super().__init__(agent_id=agent_id, out_schema=out_schema, **kwargs)
49
- self._message_pool = message_pool or AgentMessagePool()
51
+ self._in_type: type[InT]
52
+ super().__init__(agent_id=agent_id, **kwargs)
53
+
54
+ self._rcv_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
55
+ self.recipient_ids = recipient_ids or []
50
56
 
57
+ self._message_pool = message_pool or AgentMessagePool()
51
58
  self._is_listening = False
52
59
  self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
53
60
 
54
- self._rcv_args_schema = rcv_args_schema
55
- self.recipient_ids = recipient_ids or []
56
-
57
61
  @property
58
- def rcv_args_schema(self) -> type[InT]: # type: ignore[reportInvalidTypeVarUse]
59
- return self._rcv_args_schema
60
-
61
- def _parse_output(
62
- self,
63
- *args: Any,
64
- rcv_args: InT | None = None,
65
- ctx: RunContextWrapper[CtxT] | None = None,
66
- **kwargs: Any,
67
- ) -> OutT:
68
- if self._parse_output_impl:
69
- return self._parse_output_impl(*args, rcv_args=rcv_args, ctx=ctx, **kwargs)
70
-
71
- return self._out_schema()
62
+ def in_type(self) -> type[InT]: # type: ignore
63
+ # Exposing the type of a contravariant variable only, should be safe
64
+ return self._in_type
72
65
 
73
66
  def _validate_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
74
- if all(isinstance(p, DCommAgentPayload) for p in payloads):
75
- payloads_ = cast("Sequence[DCommAgentPayload]", payloads)
67
+ if all(isinstance(p, DynCommPayload) for p in payloads):
68
+ payloads_ = cast("Sequence[DynCommPayload]", payloads)
76
69
  selected_recipient_ids_per_payload = [
77
70
  set(p.selected_recipient_ids or []) for p in payloads_
78
71
  ]
@@ -91,7 +84,7 @@ class CommunicatingAgent(
91
84
 
92
85
  return selected_recipient_ids
93
86
 
94
- if all((not isinstance(p, DCommAgentPayload)) for p in payloads):
87
+ if all((not isinstance(p, DynCommPayload)) for p in payloads):
95
88
  return self.recipient_ids
96
89
 
97
90
  raise ValueError(
@@ -109,7 +102,7 @@ class CommunicatingAgent(
109
102
  inp_items: Any | None = None,
110
103
  *,
111
104
  ctx: RunContextWrapper[CtxT] | None = None,
112
- rcv_message: AgentMessage[InT, StateT] | None = None,
105
+ rcv_message: AgentMessage[InT, AgentState] | None = None,
113
106
  entry_point: bool = False,
114
107
  forbid_state_change: bool = False,
115
108
  **kwargs: Any,
@@ -143,11 +136,11 @@ class CommunicatingAgent(
143
136
 
144
137
  async def _message_handler(
145
138
  self,
146
- message: AgentMessage[AgentPayload, AgentState],
139
+ message: AgentMessage[Any, AgentState],
147
140
  ctx: RunContextWrapper[CtxT] | None = None,
148
141
  **run_kwargs: Any,
149
142
  ) -> None:
150
- rcv_message = cast("AgentMessage[InT, StateT]", message)
143
+ rcv_message = cast("AgentMessage[InT, AgentState]", message)
151
144
  out_message = await self.run(ctx=ctx, rcv_message=rcv_message, **run_kwargs)
152
145
 
153
146
  if self._exit_condition(output_message=out_message, ctx=ctx):
@@ -185,15 +178,20 @@ class CommunicatingAgent(
185
178
  tool_name: str,
186
179
  tool_description: str,
187
180
  tool_strict: bool = True,
188
- ) -> BaseTool[Any, Any, Any]:
181
+ ) -> BaseTool[InT, OutT, Any]: # type: ignore[override]
182
+ # Will check if InT is a BaseModel at runtime
189
183
  agent_instance = self
184
+ in_type = agent_instance.in_type
185
+ out_type = agent_instance.out_type
186
+ if not issubclass(in_type, BaseModel):
187
+ raise TypeError(
188
+ "Cannot create a tool from an agent with "
189
+ f"non-BaseModel input type: {in_type}"
190
+ )
190
191
 
191
- class AgentTool(BaseTool[Any, Any, Any]):
192
+ class AgentTool(BaseTool[in_type, out_type, Any]):
192
193
  name: str = tool_name
193
194
  description: str = tool_description
194
- in_schema: type[BaseModel] = agent_instance.rcv_args_schema
195
- out_schema: Any = agent_instance.out_schema
196
-
197
195
  strict: bool | None = tool_strict
198
196
 
199
197
  async def run(
@@ -201,16 +199,14 @@ class CommunicatingAgent(
201
199
  inp: InT,
202
200
  ctx: RunContextWrapper[CtxT] | None = None,
203
201
  ) -> OutT:
204
- rcv_args = agent_instance.rcv_args_schema.model_validate(inp)
205
-
206
- rcv_message = AgentMessage( # type: ignore[arg-type]
202
+ rcv_args = in_type.model_validate(inp)
203
+ rcv_message = AgentMessage[in_type, AgentState](
207
204
  payloads=[rcv_args],
208
205
  sender_id="<tool_user>",
209
206
  recipient_ids=[agent_instance.agent_id],
210
207
  )
211
-
212
208
  agent_result = await agent_instance.run(
213
- rcv_message=rcv_message, # type: ignore[arg-type]
209
+ rcv_message=rcv_message,
214
210
  entry_point=False,
215
211
  forbid_state_change=True,
216
212
  ctx=ctx,
@@ -218,4 +214,4 @@ class CommunicatingAgent(
218
214
 
219
215
  return agent_result.payloads[0]
220
216
 
221
- return AgentTool()
217
+ return AgentTool() # type: ignore[return-value]
@@ -0,0 +1,159 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, ClassVar, Self, TypeVar, cast, get_args
4
+
5
+
6
+ class AutoInstanceAttributesMixin:
7
+ """
8
+ A **runtime convenience mix-in** that automatically exposes the concrete
9
+ types supplied to a *generic* base class as **instance attributes**.
10
+
11
+ Example:
12
+ -------
13
+ from typing import Generic, TypeVar
14
+ from grasp_agents.generics_utils import AutoInstanceAttributesMixin
15
+
16
+ T = TypeVar("T")
17
+ U = TypeVar("U")
18
+
19
+ class MyBase(AutoInstanceAttributesMixin, Generic[T, U]):
20
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
21
+ 0: "elem_type",
22
+ 1: "meta_type",
23
+ }
24
+
25
+ class Packet(MyBase[int, str]):
26
+ ...
27
+
28
+ Alias = MyBase[bytes, float] # "late" specialization
29
+
30
+ print(Packet().elem_type) # <class 'int'>
31
+ print(Alias().meta_type) # <class 'float'>
32
+
33
+ """
34
+
35
+ # Configure this on your *generic* base class
36
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {}
37
+
38
+ # Filled automatically for every concrete specialization
39
+ _resolved_instance_attr_types: ClassVar[dict[str, type]]
40
+
41
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
42
+ super().__init__(*args, **kwargs)
43
+ self._set_resolved_generic_instance_attributes()
44
+
45
+ def __init_subclass__(cls, **kwargs: Any) -> None:
46
+ """Runs when a *class statement* defines a new subclass."""
47
+ super().__init_subclass__(**kwargs)
48
+ cls._resolved_instance_attr_types = cls._compute_resolved_attrs(cls)
49
+
50
+ @classmethod
51
+ def __class_getitem__(cls, params: Any) -> type[Self]:
52
+ """
53
+ Run when someone writes ``SomeGeneric[ConcreteTypes]``.
54
+ Ensures aliases receive the same resolved-type mapping.
55
+ """
56
+ specialized: type[Self] = cast("type[Self]", super().__class_getitem__(params)) # type: ignore[assignment]
57
+ specialized._resolved_instance_attr_types = cls._compute_resolved_attrs( # noqa: SLF001
58
+ specialized
59
+ )
60
+
61
+ return specialized
62
+
63
+ @staticmethod
64
+ def _compute_resolved_attrs(_cls: type) -> dict[str, type]:
65
+ """
66
+ Walks the MRO, finds the first generic base that defines
67
+ `_generic_arg_to_instance_attr_map`, and resolves concrete types.
68
+ """
69
+ target_generic_base: Any | None = None
70
+ attr_mapping: dict[int, str] | None = None
71
+
72
+ # Locate the mapping
73
+ for mro_cls in _cls.mro():
74
+ if mro_cls in {AutoInstanceAttributesMixin, object}:
75
+ continue
76
+ if (
77
+ hasattr(mro_cls, "__parameters__")
78
+ and mro_cls.__parameters__ # type: ignore[has-type]
79
+ and "_generic_arg_to_instance_attr_map" in mro_cls.__dict__
80
+ ):
81
+ target_generic_base = mro_cls
82
+ attr_mapping = cast(
83
+ "dict[int, str]",
84
+ mro_cls._generic_arg_to_instance_attr_map, # type: ignore[attr-defined] # noqa: SLF001
85
+ )
86
+ break
87
+
88
+ if target_generic_base is None or attr_mapping is None:
89
+ return {}
90
+
91
+ resolved: dict[str, type] = {}
92
+
93
+ def _add_to_resolved(generic_args: tuple[type | TypeVar, ...]) -> None:
94
+ for index, attr_name in attr_mapping.items():
95
+ if attr_name in resolved:
96
+ continue
97
+ if index < len(generic_args):
98
+ arg = generic_args[index]
99
+ if not isinstance(arg, TypeVar):
100
+ resolved[attr_name] = arg
101
+
102
+ def _all_resolved() -> bool:
103
+ return all(name in resolved for name in attr_mapping.values())
104
+
105
+ # Scenario 1: _cls itself is the direct parameterization (handles aliases).
106
+ # e.g., _cls is MyBase[bytes, float]. Its __origin__ is MyBase.
107
+ if getattr(_cls, "__origin__", None) is target_generic_base:
108
+ _add_to_resolved(get_args(_cls))
109
+
110
+ # Scenario 2: Check MRO for subclasses or more complex structures.
111
+ # This also acts as a fallback if Scenario 1 didn't resolve all attributes.
112
+ if not _all_resolved():
113
+ for mro_candidate in _cls.mro():
114
+ # Pydantic-specific check first
115
+ pydantic_generic_metadata = getattr(
116
+ mro_candidate, "__pydantic_generic_metadata__", None
117
+ )
118
+ if (
119
+ pydantic_generic_metadata
120
+ and pydantic_generic_metadata.get("origin") is target_generic_base
121
+ ):
122
+ _add_to_resolved(pydantic_generic_metadata.get("args", ()))
123
+ if _all_resolved():
124
+ break
125
+
126
+ # Fallback to standard generic introspection if not fully resolved by Pydantic check
127
+ if not _all_resolved():
128
+ mro_candidate_origin = getattr(mro_candidate, "__origin__", None)
129
+ if mro_candidate_origin is target_generic_base:
130
+ _add_to_resolved(get_args(mro_candidate))
131
+ if _all_resolved():
132
+ break
133
+
134
+ if not _all_resolved():
135
+ mro_candidate_orig_bases = getattr(
136
+ mro_candidate, "__orig_bases__", []
137
+ )
138
+ for param_base in mro_candidate_orig_bases:
139
+ param_base_origin = getattr(param_base, "__origin__", None)
140
+ if param_base_origin is target_generic_base:
141
+ _add_to_resolved(get_args(param_base))
142
+ if _all_resolved():
143
+ break
144
+
145
+ if _all_resolved():
146
+ break
147
+
148
+ return resolved
149
+
150
+ def _set_resolved_generic_instance_attributes(self) -> None:
151
+ for name, typ in getattr(
152
+ self.__class__, "_resolved_instance_attr_types", {}
153
+ ).items():
154
+ _typ = None if typ is type(None) else typ
155
+ pyd_private = getattr(self, "__pydantic_private__", {})
156
+ if name in pyd_private:
157
+ pyd_private[name] = _typ
158
+ else:
159
+ setattr(self, name, _typ)
grasp_agents/llm.py CHANGED
@@ -69,6 +69,10 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
69
69
  def tools(self, tools: list[BaseTool[BaseModel, Any, Any]] | None) -> None:
70
70
  self._tools = {t.name: t for t in tools} if tools else None
71
71
 
72
+ @response_format.setter
73
+ def response_format(self, response_format: type | None) -> None:
74
+ self._response_format = response_format
75
+
72
76
  def __repr__(self) -> str:
73
77
  return (
74
78
  f"{type(self).__name__}(model_id={self.model_id}; "