grasp_agents 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -218
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +4 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +234 -72
  26. grasp_agents/processor.py +228 -88
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.0.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.6.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,21 +1,22 @@
1
+ import fnmatch
1
2
  import logging
3
+ import os
2
4
  from collections.abc import AsyncIterator, Iterable, Mapping
3
5
  from copy import deepcopy
4
- from typing import Any, Literal, NamedTuple
6
+ from typing import Any, Literal
5
7
 
6
8
  import httpx
7
- from openai import AsyncOpenAI
9
+ from openai import AsyncOpenAI, AsyncStream
8
10
  from openai._types import NOT_GIVEN # type: ignore[import]
9
11
  from openai.lib.streaming.chat import (
10
12
  AsyncChatCompletionStreamManager as OpenAIAsyncChatCompletionStreamManager,
11
13
  )
14
+ from openai.lib.streaming.chat import ChatCompletionStreamState
12
15
  from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
13
16
  from pydantic import BaseModel
14
17
 
15
- from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
18
+ from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings, LLMRateLimiter
16
19
  from ..http_client import AsyncHTTPClientParams
17
- from ..rate_limiting.rate_limiter_chunked import RateLimiterC
18
- from ..typing.message import AssistantMessage, Messages
19
20
  from ..typing.tool import BaseTool
20
21
  from . import (
21
22
  OpenAICompletion,
@@ -23,17 +24,40 @@ from . import (
23
24
  OpenAIMessageParam,
24
25
  OpenAIParsedCompletion,
25
26
  OpenAIPredictionContentParam,
27
+ OpenAIResponseFormatJSONObject,
28
+ OpenAIResponseFormatText,
26
29
  OpenAIStreamOptionsParam,
27
30
  OpenAIToolChoiceOptionParam,
28
31
  OpenAIToolParam,
32
+ OpenAIWebSearchOptions,
29
33
  )
30
34
  from .converters import OpenAIConverters
31
35
 
32
36
  logger = logging.getLogger(__name__)
33
37
 
34
38
 
35
- class ToolCallSettings(NamedTuple):
36
- strict: bool | None = None
39
+ def get_openai_compatible_providers() -> list[APIProvider]:
40
+ """Returns a dictionary of available OpenAI-compatible API providers."""
41
+ return [
42
+ APIProvider(
43
+ name="openai",
44
+ base_url="https://api.openai.com/v1",
45
+ api_key=os.getenv("OPENAI_API_KEY"),
46
+ response_schema_support=("*",),
47
+ ),
48
+ APIProvider(
49
+ name="gemini_openai",
50
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
51
+ api_key=os.getenv("GEMINI_API_KEY"),
52
+ response_schema_support=("*",),
53
+ ),
54
+ APIProvider(
55
+ name="openrouter",
56
+ base_url="https://openrouter.ai/api/v1",
57
+ api_key=os.getenv("OPENROUTER_API_KEY"),
58
+ response_schema_support=(),
59
+ ),
60
+ ]
37
61
 
38
62
 
39
63
  class OpenAILLMSettings(CloudLLMSettings, total=False):
@@ -41,7 +65,7 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
41
65
 
42
66
  parallel_tool_calls: bool
43
67
 
44
- modalities: list[Literal["text", "audio"]] | None
68
+ modalities: list[Literal["text"]] | None
45
69
 
46
70
  frequency_penalty: float | None
47
71
  presence_penalty: float | None
@@ -50,19 +74,20 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
50
74
  logprobs: bool | None
51
75
  top_logprobs: int | None
52
76
 
77
+ stream_options: OpenAIStreamOptionsParam | None
78
+
53
79
  prediction: OpenAIPredictionContentParam | None
54
80
 
55
- stream_options: OpenAIStreamOptionsParam | None
81
+ web_search_options: OpenAIWebSearchOptions | None
56
82
 
57
83
  metadata: dict[str, str] | None
58
84
  store: bool | None
59
85
  user: str
60
86
 
61
- # response_format: (
62
- # OpenAIResponseFormatText
63
- # | OpenAIResponseFormatJSONSchema
64
- # | OpenAIResponseFormatJSONObject
65
- # )
87
+ # To support the old JSON mode without respose schemas
88
+ response_format: OpenAIResponseFormatJSONObject | OpenAIResponseFormatText
89
+
90
+ # TODO: support audio
66
91
 
67
92
 
68
93
  class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
@@ -72,52 +97,87 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
72
97
  model_name: str,
73
98
  llm_settings: OpenAILLMSettings | None = None,
74
99
  tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
75
- response_format: type | Mapping[str, type] | None = None,
100
+ response_schema: Any | None = None,
101
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
102
+ apply_response_schema_via_provider: bool = False,
76
103
  model_id: str | None = None,
77
104
  # Custom LLM provider
78
105
  api_provider: APIProvider | None = None,
79
106
  # Connection settings
107
+ max_client_retries: int = 2,
80
108
  async_http_client: httpx.AsyncClient | None = None,
81
109
  async_http_client_params: (
82
110
  dict[str, Any] | AsyncHTTPClientParams | None
83
111
  ) = None,
84
112
  async_openai_client_params: dict[str, Any] | None = None,
85
113
  # Rate limiting
86
- rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
87
- rate_limiter_rpm: float | None = None,
88
- rate_limiter_chunk_size: int = 1000,
89
- rate_limiter_max_concurrency: int = 300,
90
- # Retries
91
- num_generation_retries: int = 0,
114
+ rate_limiter: LLMRateLimiter | None = None,
115
+ # LLM response retries: try to regenerate to pass validation
116
+ max_response_retries: int = 1,
92
117
  ) -> None:
118
+ openai_compatible_providers = get_openai_compatible_providers()
119
+
120
+ model_name_parts = model_name.split("/", 1)
121
+ if api_provider is not None:
122
+ provider_model_name = model_name
123
+ elif len(model_name_parts) == 2:
124
+ compat_providers_map = {
125
+ provider["name"]: provider for provider in openai_compatible_providers
126
+ }
127
+ provider_name, provider_model_name = model_name_parts
128
+ if provider_name not in compat_providers_map:
129
+ raise ValueError(
130
+ f"OpenAI compatible API provider '{provider_name}' "
131
+ "is not supported. Supported providers are: "
132
+ f"{', '.join(compat_providers_map.keys())}"
133
+ )
134
+ api_provider = compat_providers_map[provider_name]
135
+ else:
136
+ raise ValueError(
137
+ "Model name must be in the format 'provider/model_name' or "
138
+ "you must provide an 'api_provider' argument."
139
+ )
140
+
93
141
  super().__init__(
94
- model_name=model_name,
142
+ model_name=provider_model_name,
95
143
  model_id=model_id,
96
144
  llm_settings=llm_settings,
97
145
  converters=OpenAIConverters(),
98
146
  tools=tools,
99
- response_format=response_format,
147
+ response_schema=response_schema,
148
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
149
+ apply_response_schema_via_provider=apply_response_schema_via_provider,
100
150
  api_provider=api_provider,
101
151
  async_http_client=async_http_client,
102
152
  async_http_client_params=async_http_client_params,
103
153
  rate_limiter=rate_limiter,
104
- rate_limiter_rpm=rate_limiter_rpm,
105
- rate_limiter_chunk_size=rate_limiter_chunk_size,
106
- rate_limiter_max_concurrency=rate_limiter_max_concurrency,
107
- num_generation_retries=num_generation_retries,
154
+ max_client_retries=max_client_retries,
155
+ max_response_retries=max_response_retries,
108
156
  )
109
157
 
110
- self._tool_call_settings = {
111
- "strict": self._llm_settings.get("use_struct_outputs", False),
112
- }
158
+ response_schema_support: bool = any(
159
+ fnmatch.fnmatch(self._model_name, pat)
160
+ for pat in api_provider.get("response_schema_support") or []
161
+ )
162
+ if apply_response_schema_via_provider:
163
+ if self._tools:
164
+ for tool in self._tools.values():
165
+ tool.strict = True
166
+ if not response_schema_support:
167
+ raise ValueError(
168
+ "Native response schema validation is not supported for model "
169
+ f"'{self._model_name}' by the API provider. Please set "
170
+ "apply_response_schema_via_provider=False."
171
+ )
113
172
 
114
173
  _async_openai_client_params = deepcopy(async_openai_client_params or {})
115
174
  if self._async_http_client is not None:
116
175
  _async_openai_client_params["http_client"] = self._async_http_client
117
176
 
118
177
  self._client: AsyncOpenAI = AsyncOpenAI(
119
- base_url=self._base_url,
120
- api_key=self._api_key,
178
+ base_url=self.api_provider.get("base_url"),
179
+ api_key=self.api_provider.get("api_key"),
180
+ max_retries=max_client_retries,
121
181
  **_async_openai_client_params,
122
182
  )
123
183
 
@@ -126,15 +186,28 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
126
186
  api_messages: Iterable[OpenAIMessageParam],
127
187
  api_tools: list[OpenAIToolParam] | None = None,
128
188
  api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
189
+ api_response_schema: type[Any] | None = None,
129
190
  n_choices: int | None = None,
130
191
  **api_llm_settings: Any,
131
- ) -> OpenAICompletion:
192
+ ) -> OpenAICompletion | OpenAIParsedCompletion[Any]:
132
193
  tools = api_tools or NOT_GIVEN
133
194
  tool_choice = api_tool_choice or NOT_GIVEN
195
+ response_format = api_response_schema or NOT_GIVEN
134
196
  n = n_choices or NOT_GIVEN
135
197
 
198
+ if self._apply_response_schema_via_provider:
199
+ return await self._client.beta.chat.completions.parse(
200
+ model=self._model_name,
201
+ messages=api_messages,
202
+ tools=tools,
203
+ tool_choice=tool_choice,
204
+ response_format=response_format,
205
+ n=n,
206
+ **api_llm_settings,
207
+ )
208
+
136
209
  return await self._client.chat.completions.create(
137
- model=self._api_model_name,
210
+ model=self._model_name,
138
211
  messages=api_messages,
139
212
  tools=tools,
140
213
  tool_choice=tool_choice,
@@ -143,89 +216,68 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
143
216
  **api_llm_settings,
144
217
  )
145
218
 
146
- async def _get_parsed_completion(
147
- self,
148
- api_messages: Iterable[OpenAIMessageParam],
149
- api_tools: list[OpenAIToolParam] | None = None,
150
- api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
151
- api_response_format: type | None = None,
152
- n_choices: int | None = None,
153
- **api_llm_settings: Any,
154
- ) -> OpenAIParsedCompletion[Any]:
155
- tools = api_tools or NOT_GIVEN
156
- tool_choice = api_tool_choice or NOT_GIVEN
157
- n = n_choices or NOT_GIVEN
158
- response_format = api_response_format or NOT_GIVEN
159
-
160
- return await self._client.beta.chat.completions.parse(
161
- model=self._api_model_name,
162
- messages=api_messages,
163
- tools=tools,
164
- tool_choice=tool_choice,
165
- response_format=response_format,
166
- n=n,
167
- **api_llm_settings,
168
- )
169
-
170
- async def _get_completion_stream(
219
+ async def _get_completion_stream( # type: ignore[override]
171
220
  self,
172
221
  api_messages: Iterable[OpenAIMessageParam],
173
222
  api_tools: list[OpenAIToolParam] | None = None,
174
223
  api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
224
+ api_response_schema: type[Any] | None = None,
175
225
  n_choices: int | None = None,
176
226
  **api_llm_settings: Any,
177
227
  ) -> AsyncIterator[OpenAICompletionChunk]:
178
228
  tools = api_tools or NOT_GIVEN
179
229
  tool_choice = api_tool_choice or NOT_GIVEN
230
+ response_format = api_response_schema or NOT_GIVEN
180
231
  n = n_choices or NOT_GIVEN
181
232
 
182
- stream_generator = await self._client.chat.completions.create(
183
- model=self._api_model_name,
184
- messages=api_messages,
185
- tools=tools,
186
- tool_choice=tool_choice,
187
- stream=True,
188
- n=n,
189
- **api_llm_settings,
190
- )
191
-
192
- async def iterate() -> AsyncIterator[OpenAICompletionChunk]:
233
+ if self._apply_response_schema_via_provider:
234
+ stream_manager: OpenAIAsyncChatCompletionStreamManager[Any] = (
235
+ self._client.beta.chat.completions.stream(
236
+ model=self._model_name,
237
+ messages=api_messages,
238
+ tools=tools,
239
+ tool_choice=tool_choice,
240
+ response_format=response_format,
241
+ n=n,
242
+ **api_llm_settings,
243
+ )
244
+ )
245
+ async with stream_manager as stream:
246
+ async for chunk_event in stream:
247
+ if isinstance(chunk_event, OpenAIChunkEvent):
248
+ yield chunk_event.chunk
249
+ else:
250
+ stream_generator: AsyncStream[
251
+ OpenAICompletionChunk
252
+ ] = await self._client.chat.completions.create(
253
+ model=self._model_name,
254
+ messages=api_messages,
255
+ tools=tools,
256
+ tool_choice=tool_choice,
257
+ stream=True,
258
+ n=n,
259
+ **api_llm_settings,
260
+ )
193
261
  async with stream_generator as stream:
194
262
  async for completion_chunk in stream:
195
263
  yield completion_chunk
196
264
 
197
- return iterate()
198
-
199
- async def _get_parsed_completion_stream(
200
- self,
201
- api_messages: Iterable[OpenAIMessageParam],
202
- api_tools: list[OpenAIToolParam] | None = None,
203
- api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
204
- api_response_format: type | None = None,
205
- n_choices: int | None = None,
206
- **api_llm_settings: Any,
207
- ) -> AsyncIterator[OpenAICompletionChunk]:
208
- tools = api_tools or NOT_GIVEN
209
- tool_choice = api_tool_choice or NOT_GIVEN
210
- response_format = api_response_format or NOT_GIVEN
211
- n = n_choices or NOT_GIVEN
212
-
213
- stream_manager: OpenAIAsyncChatCompletionStreamManager[
214
- OpenAICompletionChunk
215
- ] = self._client.beta.chat.completions.stream(
216
- model=self._api_model_name,
217
- messages=api_messages,
218
- tools=tools,
219
- tool_choice=tool_choice,
220
- response_format=response_format,
221
- n=n,
222
- **api_llm_settings,
265
+ def combine_completion_chunks(
266
+ self, completion_chunks: list[OpenAICompletionChunk]
267
+ ) -> OpenAICompletion:
268
+ response_format = NOT_GIVEN
269
+ input_tools = NOT_GIVEN
270
+ if self._apply_response_schema_via_provider:
271
+ if self._response_schema:
272
+ response_format = self._response_schema
273
+ if self._tools:
274
+ input_tools = [
275
+ self._converters.to_tool(tool) for tool in self._tools.values()
276
+ ]
277
+ state = ChatCompletionStreamState[Any](
278
+ input_tools=input_tools, response_format=response_format
223
279
  )
280
+ for chunk in completion_chunks:
281
+ state.handle_chunk(chunk)
224
282
 
225
- async def iterate() -> AsyncIterator[OpenAICompletionChunk]:
226
- async with stream_manager as stream:
227
- async for chunk_event in stream:
228
- if isinstance(chunk_event, OpenAIChunkEvent):
229
- yield chunk_event.chunk
230
-
231
- return iterate()
283
+ return state.get_final_completion()
@@ -13,10 +13,8 @@ from . import (
13
13
  )
14
14
 
15
15
 
16
- def to_api_tool(
17
- tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
18
- ) -> OpenAIToolParam:
19
- if strict:
16
+ def to_api_tool(tool: BaseTool[BaseModel, Any, Any]) -> OpenAIToolParam:
17
+ if tool.strict:
20
18
  return pydantic_function_tool(
21
19
  model=tool.in_type, name=tool.name, description=tool.description
22
20
  )
@@ -25,9 +23,9 @@ def to_api_tool(
25
23
  name=tool.name,
26
24
  description=tool.description,
27
25
  parameters=tool.in_type.model_json_schema(),
28
- strict=strict,
26
+ strict=tool.strict,
29
27
  )
30
- if strict is None:
28
+ if tool.strict is None:
31
29
  function.pop("strict")
32
30
 
33
31
  return OpenAIToolParam(type="function", function=function)
grasp_agents/packet.py CHANGED
@@ -15,10 +15,13 @@ class Packet(BaseModel, Generic[_PayloadT_co]):
15
15
  sender: ProcName
16
16
  recipients: Sequence[ProcName] = Field(default_factory=list)
17
17
 
18
- model_config = ConfigDict(extra="forbid", frozen=True)
18
+ model_config = ConfigDict(extra="forbid")
19
19
 
20
20
  def __repr__(self) -> str:
21
21
  return (
22
- f"From: {self.sender}, To: {', '.join(self.recipients)}, "
22
+ f"{self.__class__.__name__}:\n"
23
+ f"ID: {self.id}\n"
24
+ f"From: {self.sender}\n"
25
+ f"To: {', '.join(self.recipients)}\n"
23
26
  f"Payloads: {len(self.payloads)}"
24
27
  )
@@ -1,9 +1,11 @@
1
1
  import asyncio
2
2
  import logging
3
+ from collections.abc import AsyncIterator
3
4
  from typing import Any, Generic, Protocol, TypeVar
4
5
 
5
6
  from .packet import Packet
6
7
  from .run_context import CtxT, RunContext
8
+ from .typing.events import Event
7
9
  from .typing.io import ProcName
8
10
 
9
11
  logger = logging.getLogger(__name__)
@@ -16,16 +18,16 @@ class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
16
18
  async def __call__(
17
19
  self,
18
20
  packet: Packet[_PayloadT_contra],
19
- ctx: RunContext[CtxT] | None,
21
+ ctx: RunContext[CtxT],
20
22
  **kwargs: Any,
21
- ) -> None: ...
23
+ ) -> AsyncIterator[Event[Any]] | None: ...
22
24
 
23
25
 
24
26
  class PacketPool(Generic[CtxT]):
25
27
  def __init__(self) -> None:
26
28
  self._queues: dict[ProcName, asyncio.Queue[Packet[Any]]] = {}
27
29
  self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
28
- self._tasks: dict[ProcName, asyncio.Task[None]] = {}
30
+ self._tasks: dict[ProcName, asyncio.Task[AsyncIterator[Event[Any]] | None]] = {}
29
31
 
30
32
  async def post(self, packet: Packet[Any]) -> None:
31
33
  for recipient_id in packet.recipients:
@@ -36,7 +38,7 @@ class PacketPool(Generic[CtxT]):
36
38
  self,
37
39
  processor_name: ProcName,
38
40
  handler: PacketHandler[Any, CtxT],
39
- ctx: RunContext[CtxT] | None = None,
41
+ ctx: RunContext[CtxT],
40
42
  **run_kwargs: Any,
41
43
  ) -> None:
42
44
  self._packet_handlers[processor_name] = handler
@@ -47,11 +49,8 @@ class PacketPool(Generic[CtxT]):
47
49
  )
48
50
 
49
51
  async def _handle_packets(
50
- self,
51
- processor_name: ProcName,
52
- ctx: RunContext[CtxT] | None = None,
53
- **run_kwargs: Any,
54
- ) -> None:
52
+ self, processor_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
53
+ ) -> AsyncIterator[Event[Any]] | None:
55
54
  queue = self._queues[processor_name]
56
55
  while True:
57
56
  try:
@@ -59,11 +58,13 @@ class PacketPool(Generic[CtxT]):
59
58
  handler = self._packet_handlers.get(processor_name)
60
59
  if handler is None:
61
60
  break
62
-
63
61
  try:
64
- await self._packet_handlers[processor_name](
65
- packet, ctx=ctx, **run_kwargs
66
- )
62
+ if ctx.is_streaming:
63
+ async for event in handler(packet, ctx=ctx, **run_kwargs): # type: ignore[return-value]
64
+ yield event
65
+ else:
66
+ await handler(packet, ctx=ctx, **run_kwargs)
67
+
67
68
  except Exception:
68
69
  logger.exception(f"Error handling packet for {processor_name}")
69
70