grasp_agents 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,6 @@ from uuid import uuid4
4
4
 
5
5
  from pydantic import BaseModel, ConfigDict, Field
6
6
 
7
- # from .base_agent import StateT
8
7
  from .typing.io import AgentID, AgentPayload, AgentState
9
8
 
10
9
  _PayloadT = TypeVar("_PayloadT", bound=AgentPayload, covariant=True) # noqa: PLC0105
@@ -68,5 +68,5 @@ class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
68
68
  @abstractmethod
69
69
  def as_tool(
70
70
  self, tool_name: str, tool_description: str, tool_strict: bool = True
71
- ) -> BaseTool[BaseModel, BaseModel, CtxT]:
71
+ ) -> BaseTool[BaseModel, Any, CtxT]:
72
72
  pass
grasp_agents/cloud_llm.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
3
  import os
4
4
  from abc import abstractmethod
5
5
  from collections.abc import AsyncIterator, Sequence
6
+ from copy import deepcopy
6
7
  from typing import Any, Generic, Literal
7
8
 
8
9
  import httpx
@@ -15,14 +16,13 @@ from tenacity import (
15
16
  )
16
17
  from typing_extensions import TypedDict
17
18
 
18
- from .data_retrieval.rate_limiter_chunked import ( # type: ignore
19
- RateLimiterC,
20
- limit_rate_chunked,
21
- )
22
-
23
19
  from .http_client import AsyncHTTPClientParams, create_async_http_client
24
20
  from .llm import LLM, ConvertT, LLMSettings, SettingsT
25
21
  from .memory import MessageHistory
22
+ from .rate_limiting.rate_limiter_chunked import ( # type: ignore
23
+ RateLimiterC,
24
+ limit_rate_chunked,
25
+ )
26
26
  from .typing.completion import Completion, CompletionChunk
27
27
  from .typing.message import AssistantMessage, Conversation
28
28
  from .typing.tool import BaseTool, ToolChoice
@@ -38,7 +38,7 @@ class APIProviderInfo(TypedDict):
38
38
  name: APIProvider
39
39
  base_url: str
40
40
  api_key: str | None
41
- struct_output_support: list[str]
41
+ struct_output_support: tuple[str, ...]
42
42
 
43
43
 
44
44
  PROVIDERS: dict[APIProvider, APIProviderInfo] = {
@@ -46,19 +46,19 @@ PROVIDERS: dict[APIProvider, APIProviderInfo] = {
46
46
  name="openai",
47
47
  base_url="https://api.openai.com/v1",
48
48
  api_key=os.getenv("OPENAI_API_KEY"),
49
- struct_output_support=["*"],
49
+ struct_output_support=("*",),
50
50
  ),
51
51
  "openrouter": APIProviderInfo(
52
52
  name="openrouter",
53
53
  base_url="https://openrouter.ai/api/v1",
54
54
  api_key=os.getenv("OPENROUTER_API_KEY"),
55
- struct_output_support=[],
55
+ struct_output_support=(),
56
56
  ),
57
57
  "google_ai_studio": APIProviderInfo(
58
58
  name="google_ai_studio",
59
59
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
60
60
  api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
61
- struct_output_support=["*"],
61
+ struct_output_support=("*",),
62
62
  ),
63
63
  }
64
64
 
@@ -92,6 +92,7 @@ class CloudLLMSettings(LLMSettings, total=False):
92
92
  temperature: float | None
93
93
  top_p: float | None
94
94
  seed: int | None
95
+ use_structured_outputs: bool
95
96
 
96
97
 
97
98
  class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
@@ -102,7 +103,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
102
103
  converters: ConvertT,
103
104
  llm_settings: SettingsT | None = None,
104
105
  model_id: str | None = None,
105
- tools: list[BaseTool[BaseModel, BaseModel, Any]] | None = None,
106
+ tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
106
107
  response_format: type | None = None,
107
108
  # Connection settings
108
109
  api_provider: APIProvider = "openai",
@@ -135,13 +136,21 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
135
136
  self._model_name = model_name
136
137
  self._api_provider: APIProvider = api_provider
137
138
 
138
- patterns = PROVIDERS[api_provider]["struct_output_support"]
139
139
  self._struct_output_support: bool = any(
140
- fnmatch.fnmatch(self._model_name, pat) for pat in patterns
140
+ fnmatch.fnmatch(self._model_name, pat)
141
+ for pat in PROVIDERS[api_provider]["struct_output_support"]
141
142
  )
142
143
  self._response_format_pyd: TypeAdapter[Any] | None = (
143
144
  TypeAdapter(self._response_format) if response_format else None
144
145
  )
146
+ if (
147
+ self._llm_settings.get("use_structured_outputs")
148
+ and not self._struct_output_support
149
+ ):
150
+ raise ValueError(
151
+ f"Model {api_provider}:{self._model_name} does "
152
+ "not support structured outputs."
153
+ )
145
154
 
146
155
  self._rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = (
147
156
  self._get_rate_limiter(
@@ -181,8 +190,8 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
181
190
  def _make_completion_kwargs(
182
191
  self, conversation: Conversation, tool_choice: ToolChoice | None = None
183
192
  ) -> dict[str, Any]:
184
- api_llm_settings = self.llm_settings or {}
185
193
  api_messages = [self._converters.to_message(m) for m in conversation]
194
+
186
195
  api_tools = None
187
196
  api_tool_choice = None
188
197
  if self.tools:
@@ -190,6 +199,9 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
190
199
  if tool_choice is not None:
191
200
  api_tool_choice = self._converters.to_tool_choice(tool_choice)
192
201
 
202
+ api_llm_settings = deepcopy(self.llm_settings or {})
203
+ api_llm_settings.pop("use_structured_outputs", None)
204
+
193
205
  return dict(
194
206
  api_messages=api_messages,
195
207
  api_tools=api_tools,
@@ -216,6 +228,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
216
228
  *,
217
229
  api_tools: list[Any] | None = None,
218
230
  api_tool_choice: Any | None = None,
231
+ api_response_format: type | None = None,
219
232
  **api_llm_settings: Any,
220
233
  ) -> Any:
221
234
  pass
@@ -242,7 +255,11 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
242
255
  conversation=conversation, tool_choice=tool_choice
243
256
  )
244
257
 
245
- if self._response_format is None or not self._struct_output_support:
258
+ if (
259
+ self._response_format is None
260
+ or (not self._struct_output_support)
261
+ or (not self._llm_settings.get("use_structured_outputs"))
262
+ ):
246
263
  completion_kwargs.pop("api_response_format", None)
247
264
  api_completion = await self._get_completion(**completion_kwargs, **kwargs)
248
265
  else:
@@ -250,7 +267,23 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
250
267
  **completion_kwargs, **kwargs
251
268
  )
252
269
 
253
- return self._converters.from_completion(api_completion, model_id=self.model_id)
270
+ completion = self._converters.from_completion(
271
+ api_completion, model_id=self.model_id
272
+ )
273
+
274
+ for choice in completion.choices:
275
+ message = choice.message
276
+ if (
277
+ self._response_format_pyd is not None
278
+ and not self._llm_settings.get("use_structured_outputs")
279
+ and not message.tool_calls
280
+ ):
281
+ message_json = extract_json(
282
+ message.content, return_none_on_failure=True
283
+ )
284
+ self._response_format_pyd.validate_python(message_json)
285
+
286
+ return completion
254
287
 
255
288
  async def generate_completion_stream(
256
289
  self,
@@ -271,63 +304,73 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
271
304
  api_completion_chunk_iterator, model_id=self.model_id
272
305
  )
273
306
 
274
- async def generate_message(
307
+ async def _generate_completion_with_retry(
275
308
  self,
276
309
  conversation: Conversation,
277
310
  *,
278
311
  tool_choice: ToolChoice | None = None,
279
312
  **kwargs: Any,
280
- ) -> AssistantMessage:
281
- completion = await self.generate_completion(
282
- conversation, tool_choice=tool_choice, **kwargs
283
- )
284
- message = completion.choices[0].message
285
- if self._response_format_pyd is not None and not self._struct_output_support:
286
- self._response_format_pyd.validate_python(extract_json(message.content))
287
-
288
- return message
289
-
290
- async def _generate_message_with_retry(
291
- self,
292
- conversation: Conversation,
293
- *,
294
- tool_choice: ToolChoice | None = None,
295
- **kwargs: Any,
296
- ) -> AssistantMessage:
313
+ ) -> Completion:
297
314
  wrapped_func = retry(
298
315
  wait=wait_random_exponential(min=1, max=8),
299
316
  stop=stop_after_attempt(self.num_generation_retries + 1),
300
317
  before=retry_before_callback,
301
318
  retry_error_callback=retry_error_callback,
302
- )(self.__class__.generate_message)
319
+ )(self.__class__.generate_completion)
303
320
 
304
321
  return await wrapped_func(self, conversation, tool_choice=tool_choice, **kwargs)
305
322
 
306
323
  @limit_rate_chunked # type: ignore
307
- async def _generate_message_batch_with_retry_and_rate_lim(
324
+ async def _generate_completion_batch_with_retry_and_rate_lim(
308
325
  self,
309
326
  conversation: Conversation,
310
327
  *,
311
328
  tool_choice: ToolChoice | None = None,
312
329
  **kwargs: Any,
313
- ) -> AssistantMessage:
314
- return await self._generate_message_with_retry(
330
+ ) -> Completion:
331
+ return await self._generate_completion_with_retry(
315
332
  conversation, tool_choice=tool_choice, **kwargs
316
333
  )
317
334
 
318
- async def generate_message_batch(
335
+ async def generate_completion_batch(
319
336
  self,
320
337
  message_history: MessageHistory,
321
338
  *,
322
339
  tool_choice: ToolChoice | None = None,
323
340
  **kwargs: Any,
324
- ) -> Sequence[AssistantMessage]:
325
- return await self._generate_message_batch_with_retry_and_rate_lim(
341
+ ) -> Sequence[Completion]:
342
+ return await self._generate_completion_batch_with_retry_and_rate_lim(
326
343
  list(message_history.batched_conversations), # type: ignore
327
344
  tool_choice=tool_choice,
328
345
  **kwargs,
329
346
  )
330
347
 
348
+ async def generate_message(
349
+ self,
350
+ conversation: Conversation,
351
+ *,
352
+ tool_choice: ToolChoice | None = None,
353
+ **kwargs: Any,
354
+ ) -> AssistantMessage:
355
+ completion = await self.generate_completion(
356
+ conversation, tool_choice=tool_choice, **kwargs
357
+ )
358
+
359
+ return completion.choices[0].message
360
+
361
+ async def generate_message_batch(
362
+ self,
363
+ message_history: MessageHistory,
364
+ *,
365
+ tool_choice: ToolChoice | None = None,
366
+ **kwargs: Any,
367
+ ) -> Sequence[AssistantMessage]:
368
+ completion_batch = await self.generate_completion_batch(
369
+ message_history, tool_choice=tool_choice, **kwargs
370
+ )
371
+
372
+ return [completion.choices[0].message for completion in completion_batch]
373
+
331
374
  def _get_rate_limiter(
332
375
  self,
333
376
  rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = None,
@@ -4,6 +4,7 @@ from collections.abc import Sequence
4
4
  from typing import Any, Generic, Protocol, TypeVar, cast, final
5
5
 
6
6
  from pydantic import BaseModel
7
+ from pydantic.json_schema import SkipJsonSchema
7
8
 
8
9
  from .agent_message import AgentMessage
9
10
  from .agent_message_pool import AgentMessagePool
@@ -14,6 +15,11 @@ from .typing.tool import BaseTool
14
15
 
15
16
  logger = logging.getLogger(__name__)
16
17
 
18
+
19
+ class DCommAgentPayload(AgentPayload):
20
+ selected_recipient_ids: SkipJsonSchema[Sequence[AgentID]]
21
+
22
+
17
23
  _EH_OutT = TypeVar("_EH_OutT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
18
24
  _EH_StateT = TypeVar("_EH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
19
25
 
@@ -22,7 +28,6 @@ class ExitHandler(Protocol[_EH_OutT, _EH_StateT, CtxT]):
22
28
  def __call__(
23
29
  self,
24
30
  output_message: AgentMessage[_EH_OutT, _EH_StateT],
25
- agent_state: _EH_StateT,
26
31
  ctx: RunContextWrapper[CtxT] | None,
27
32
  ) -> bool: ...
28
33
 
@@ -38,14 +43,11 @@ class CommunicatingAgent(
38
43
  rcv_args_schema: type[InT] = AgentPayload,
39
44
  recipient_ids: Sequence[AgentID] | None = None,
40
45
  message_pool: AgentMessagePool[CtxT] | None = None,
41
- dynamic_routing: bool = False,
42
46
  **kwargs: Any,
43
47
  ) -> None:
44
48
  super().__init__(agent_id=agent_id, out_schema=out_schema, **kwargs)
45
49
  self._message_pool = message_pool or AgentMessagePool()
46
50
 
47
- self._dynamic_routing = dynamic_routing
48
-
49
51
  self._is_listening = False
50
52
  self._exit_impl: ExitHandler[OutT, StateT, CtxT] | None = None
51
53
 
@@ -56,10 +58,6 @@ class CommunicatingAgent(
56
58
  def rcv_args_schema(self) -> type[InT]: # type: ignore[reportInvalidTypeVarUse]
57
59
  return self._rcv_args_schema
58
60
 
59
- @property
60
- def dynamic_routing(self) -> bool:
61
- return self._dynamic_routing
62
-
63
61
  def _parse_output(
64
62
  self,
65
63
  *args: Any,
@@ -72,41 +70,36 @@ class CommunicatingAgent(
72
70
 
73
71
  return self._out_schema()
74
72
 
75
- def _validate_dynamic_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
76
- assert all((p.selected_recipient_ids is not None) for p in payloads), (
77
- "Dynamic routing is enabled, but some payloads have no recipient IDs"
78
- )
79
-
80
- selected_recipient_ids_per_payload = [
81
- set(p.selected_recipient_ids or []) for p in payloads
82
- ]
83
- assert all(
84
- x == selected_recipient_ids_per_payload[0]
85
- for x in selected_recipient_ids_per_payload
86
- ), "All payloads must have the same recipient IDs for dynamic routing"
87
-
88
- assert payloads[0].selected_recipient_ids is not None
89
- selected_recipient_ids = payloads[0].selected_recipient_ids
73
+ def _validate_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
74
+ if all(isinstance(p, DCommAgentPayload) for p in payloads):
75
+ payloads_ = cast("Sequence[DCommAgentPayload]", payloads)
76
+ selected_recipient_ids_per_payload = [
77
+ set(p.selected_recipient_ids or []) for p in payloads_
78
+ ]
79
+ assert all(
80
+ x == selected_recipient_ids_per_payload[0]
81
+ for x in selected_recipient_ids_per_payload
82
+ ), "All payloads must have the same recipient IDs for dynamic routing"
83
+
84
+ assert payloads_[0].selected_recipient_ids is not None
85
+ selected_recipient_ids = payloads_[0].selected_recipient_ids
86
+
87
+ assert all(rid in self.recipient_ids for rid in selected_recipient_ids), (
88
+ "Dynamic routing is enabled, but recipient IDs are not in "
89
+ "the allowed agent's recipient IDs"
90
+ )
90
91
 
91
- assert all(rid in self.recipient_ids for rid in selected_recipient_ids), (
92
- "Dynamic routing is enabled, but recipient IDs are not in "
93
- "the allowed agent's recipient IDs"
94
- )
92
+ return selected_recipient_ids
95
93
 
96
- return selected_recipient_ids
94
+ if all((not isinstance(p, DCommAgentPayload)) for p in payloads):
95
+ return self.recipient_ids
97
96
 
98
- def _validate_static_routing(self, payloads: Sequence[OutT]) -> Sequence[AgentID]:
99
- assert all((p.selected_recipient_ids is None) for p in payloads), (
100
- "Dynamic routing is not enabled, but some payloads have recipient IDs"
97
+ raise ValueError(
98
+ "All payloads must be either DCommAgentPayload or not DCommAgentPayload"
101
99
  )
102
100
 
103
- return self.recipient_ids
104
-
105
101
  async def post_message(self, message: AgentMessage[OutT, StateT]) -> None:
106
- if self._dynamic_routing:
107
- self._validate_dynamic_routing(message.payloads)
108
- else:
109
- self._validate_static_routing(message.payloads)
102
+ self._validate_routing(message.payloads)
110
103
 
111
104
  await self._message_pool.post(message)
112
105
 
@@ -144,9 +137,7 @@ class CommunicatingAgent(
144
137
  ctx: RunContextWrapper[CtxT] | None,
145
138
  ) -> bool:
146
139
  if self._exit_impl:
147
- return self._exit_impl(
148
- output_message=output_message, agent_state=self.state, ctx=ctx
149
- )
140
+ return self._exit_impl(output_message=output_message, ctx=ctx)
150
141
 
151
142
  return False
152
143
 
@@ -190,28 +181,28 @@ class CommunicatingAgent(
190
181
 
191
182
  @final
192
183
  def as_tool(
193
- self, tool_name: str, tool_description: str, tool_strict: bool = True
194
- ) -> BaseTool[BaseModel, BaseModel, CtxT]:
195
- # assert self.state.batch_size == 1, (
196
- # "Using agents as tools is only supported for batch size 1"
197
- # )
198
-
184
+ self,
185
+ tool_name: str,
186
+ tool_description: str,
187
+ tool_strict: bool = True,
188
+ ) -> BaseTool[Any, Any, Any]:
199
189
  agent_instance = self
200
190
 
201
- class AgentTool(BaseTool[BaseModel, BaseModel, Any]):
191
+ class AgentTool(BaseTool[Any, Any, Any]):
202
192
  name: str = tool_name
203
193
  description: str = tool_description
204
194
  in_schema: type[BaseModel] = agent_instance.rcv_args_schema
205
- out_schema: type[BaseModel] = agent_instance.out_schema
195
+ out_schema: Any = agent_instance.out_schema
206
196
 
207
197
  strict: bool | None = tool_strict
208
198
 
209
199
  async def run(
210
200
  self,
211
- inp: BaseModel,
201
+ inp: InT,
212
202
  ctx: RunContextWrapper[CtxT] | None = None,
213
203
  ) -> OutT:
214
204
  rcv_args = agent_instance.rcv_args_schema.model_validate(inp)
205
+
215
206
  rcv_message = AgentMessage( # type: ignore[arg-type]
216
207
  payloads=[rcv_args],
217
208
  sender_id="<tool_user>",
grasp_agents/llm.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import AsyncIterator, Sequence
4
- from typing import Any, Generic, TypeVar
4
+ from typing import Any, Generic, TypeVar, cast
5
5
  from uuid import uuid4
6
6
 
7
7
  from pydantic import BaseModel
@@ -32,7 +32,7 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
32
32
  model_name: str | None = None,
33
33
  model_id: str | None = None,
34
34
  llm_settings: SettingsT | None = None,
35
- tools: list[BaseTool[BaseModel, BaseModel, Any]] | None = None,
35
+ tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
36
36
  response_format: type | None = None,
37
37
  **kwargs: Any,
38
38
  ) -> None:
@@ -41,9 +41,9 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
41
41
  self._converters = converters
42
42
  self._model_id = model_id or str(uuid4())[:8]
43
43
  self._model_name = model_name
44
- self._llm_settings = llm_settings
45
44
  self._tools = {t.name: t for t in tools} if tools else None
46
45
  self._response_format = response_format
46
+ self._llm_settings: SettingsT = llm_settings or cast("SettingsT", {})
47
47
 
48
48
  @property
49
49
  def model_id(self) -> str:
@@ -54,11 +54,11 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
54
54
  return self._model_name
55
55
 
56
56
  @property
57
- def llm_settings(self) -> SettingsT | None:
57
+ def llm_settings(self) -> SettingsT:
58
58
  return self._llm_settings
59
59
 
60
60
  @property
61
- def tools(self) -> dict[str, BaseTool[BaseModel, BaseModel, Any]] | None:
61
+ def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
62
62
  return self._tools
63
63
 
64
64
  @property
@@ -66,7 +66,7 @@ class LLM(ABC, Generic[SettingsT, ConvertT]):
66
66
  return self._response_format
67
67
 
68
68
  @tools.setter
69
- def tools(self, tools: list[BaseTool[BaseModel, BaseModel, Any]] | None) -> None:
69
+ def tools(self, tools: list[BaseTool[BaseModel, Any, Any]] | None) -> None:
70
70
  self._tools = {t.name: t for t in tools} if tools else None
71
71
 
72
72
  def __repr__(self) -> str: