grasp_agents 0.3.11__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/cloud_llm.py CHANGED
@@ -2,9 +2,9 @@ import fnmatch
2
2
  import logging
3
3
  import os
4
4
  from abc import abstractmethod
5
- from collections.abc import AsyncIterator, Mapping, Sequence
5
+ from collections.abc import AsyncIterator, Mapping
6
6
  from copy import deepcopy
7
- from typing import Any, Generic, Literal
7
+ from typing import Any, Generic, Literal, NotRequired
8
8
 
9
9
  import httpx
10
10
  from pydantic import BaseModel
@@ -16,10 +16,9 @@ from tenacity import (
16
16
  )
17
17
  from typing_extensions import TypedDict
18
18
 
19
- from .http_client import AsyncHTTPClientParams, create_async_http_client
19
+ from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
20
20
  from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
21
- from .message_history import MessageHistory
22
- from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate_chunked
21
+ from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
23
22
  from .typing.completion import Completion
24
23
  from .typing.completion_chunk import (
25
24
  CompletionChoice,
@@ -33,30 +32,30 @@ from .typing.tool import BaseTool, ToolChoice
33
32
  logger = logging.getLogger(__name__)
34
33
 
35
34
 
36
- APIProvider = Literal["openai", "openrouter", "google_ai_studio"]
35
+ APIProviderName = Literal["openai", "openrouter", "google_ai_studio"]
37
36
 
38
37
 
39
- class APIProviderInfo(TypedDict):
40
- name: APIProvider
38
+ class APIProvider(TypedDict):
39
+ name: APIProviderName
41
40
  base_url: str
42
- api_key: str | None
43
- struct_outputs_support: tuple[str, ...]
41
+ api_key: NotRequired[str | None]
42
+ struct_outputs_support: NotRequired[tuple[str, ...]]
44
43
 
45
44
 
46
- PROVIDERS: dict[APIProvider, APIProviderInfo] = {
47
- "openai": APIProviderInfo(
45
+ API_PROVIDERS: dict[APIProviderName, APIProvider] = {
46
+ "openai": APIProvider(
48
47
  name="openai",
49
48
  base_url="https://api.openai.com/v1",
50
49
  api_key=os.getenv("OPENAI_API_KEY"),
51
50
  struct_outputs_support=("*",),
52
51
  ),
53
- "openrouter": APIProviderInfo(
52
+ "openrouter": APIProvider(
54
53
  name="openrouter",
55
54
  base_url="https://openrouter.ai/api/v1",
56
55
  api_key=os.getenv("OPENROUTER_API_KEY"),
57
56
  struct_outputs_support=(),
58
57
  ),
59
- "google_ai_studio": APIProviderInfo(
58
+ "google_ai_studio": APIProvider(
60
59
  name="google_ai_studio",
61
60
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
62
61
  api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
@@ -66,18 +65,17 @@ PROVIDERS: dict[APIProvider, APIProviderInfo] = {
66
65
 
67
66
 
68
67
  def retry_error_callback(retry_state: RetryCallState) -> Completion:
69
- assert retry_state.outcome is not None
70
- exception = retry_state.outcome.exception()
68
+ exception = retry_state.outcome.exception() if retry_state.outcome else None
71
69
  if exception:
72
70
  if retry_state.attempt_number == 1:
73
71
  logger.warning(
74
- f"CloudLLM completion request failed:\n{exception}",
75
- exc_info=exception,
72
+ f"\nCloudLLM completion request failed:\n{exception}",
73
+ # exc_info=exception,
76
74
  )
77
75
  if retry_state.attempt_number > 1:
78
76
  logger.warning(
79
- f"CloudLLM completion request failed after retrying:\n{exception}",
80
- exc_info=exception,
77
+ f"\nCloudLLM completion request failed after retrying:\n{exception}",
78
+ # exc_info=exception,
81
79
  )
82
80
  failed_message = AssistantMessage(content=None, refusal=str(exception))
83
81
 
@@ -87,11 +85,12 @@ def retry_error_callback(retry_state: RetryCallState) -> Completion:
87
85
  )
88
86
 
89
87
 
90
- def retry_before_callback(retry_state: RetryCallState) -> None:
91
- if retry_state.attempt_number > 1:
88
+ def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
89
+ exception = retry_state.outcome.exception() if retry_state.outcome else None
90
+ if exception:
92
91
  logger.info(
93
- "Retrying CloudLLM completion request "
94
- f"(attempt {retry_state.attempt_number - 1}) ..."
92
+ "\nRetrying CloudLLM completion request "
93
+ f"(attempt {retry_state.attempt_number}):\n{exception}"
95
94
  )
96
95
 
97
96
 
@@ -106,10 +105,13 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
106
105
  model_name: str,
107
106
  converters: ConvertT_co,
108
107
  llm_settings: SettingsT_co | None = None,
109
- model_id: str | None = None,
110
108
  tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
111
109
  response_format: type | Mapping[str, type] | None = None,
110
+ model_id: str | None = None,
111
+ # Custom LLM provider
112
+ api_provider: APIProvider | None = None,
112
113
  # Connection settings
114
+ async_http_client: httpx.AsyncClient | None = None,
113
115
  async_http_client_params: (
114
116
  dict[str, Any] | AsyncHTTPClientParams | None
115
117
  ) = None,
@@ -120,8 +122,6 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
120
122
  rate_limiter_max_concurrency: int = 300,
121
123
  # Retries
122
124
  num_generation_retries: int = 0,
123
- # Disable tqdm for batch processing
124
- no_tqdm: bool = True,
125
125
  **kwargs: Any,
126
126
  ) -> None:
127
127
  self.llm_settings: CloudLLMSettings | None
@@ -139,29 +139,31 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
139
139
  self._model_name = model_name
140
140
  model_name_parts = model_name.split(":", 1)
141
141
 
142
- if len(model_name_parts) == 2 and model_name_parts[0] in PROVIDERS:
143
- api_provider, api_model_name = model_name_parts
144
- if api_provider not in PROVIDERS:
142
+ if len(model_name_parts) == 2:
143
+ api_provider_name, api_model_name = model_name_parts
144
+ self._api_model_name: str = api_model_name
145
+ if api_provider_name not in API_PROVIDERS:
145
146
  raise ValueError(
146
- f"API provider '{api_provider}' is not supported. "
147
- f"Supported providers are: {', '.join(PROVIDERS.keys())}"
147
+ f"API provider '{api_provider_name}' is not supported. "
148
+ f"Supported providers are: {', '.join(API_PROVIDERS.keys())}"
148
149
  )
149
-
150
- self._api_provider: APIProvider | None = api_provider
151
- self._api_model_name: str = api_model_name
152
- self._base_url: str | None = PROVIDERS[api_provider]["base_url"]
153
- self._api_key: str | None = PROVIDERS[api_provider]["api_key"]
154
- self._struct_outputs_support: bool = any(
155
- fnmatch.fnmatch(self._model_name, pat)
156
- for pat in PROVIDERS[api_provider]["struct_outputs_support"]
150
+ _api_provider = API_PROVIDERS[api_provider_name]
151
+ elif api_provider is not None:
152
+ self._api_model_name: str = model_name
153
+ _api_provider = api_provider
154
+ else:
155
+ raise ValueError(
156
+ "API provider must be specified either in the model name "
157
+ "or as a separate argument."
157
158
  )
158
159
 
159
- else:
160
- self._api_provider = None
161
- self._api_model_name = model_name
162
- self._base_url = None
163
- self._api_key = None
164
- self._struct_outputs_support = False
160
+ self._api_provider_name: APIProviderName = _api_provider["name"]
161
+ self._base_url: str | None = _api_provider.get("base_url")
162
+ self._api_key: str | None = _api_provider.get("api_key")
163
+ self._struct_outputs_support: bool = any(
164
+ fnmatch.fnmatch(self._model_name, pat)
165
+ for pat in _api_provider.get("struct_outputs_support", ())
166
+ )
165
167
 
166
168
  if (
167
169
  self._llm_settings.get("use_struct_outputs")
@@ -181,23 +183,20 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
181
183
  max_concurrency=rate_limiter_max_concurrency,
182
184
  )
183
185
  )
184
- self.no_tqdm = no_tqdm
185
- self._client: Any
186
186
 
187
187
  self._async_http_client: httpx.AsyncClient | None = None
188
- if async_http_client_params is not None:
189
- val_async_http_client_params = AsyncHTTPClientParams.model_validate(
188
+ if async_http_client is not None:
189
+ self._async_http_client = async_http_client
190
+ elif async_http_client_params is not None:
191
+ self._async_http_client = create_simple_async_httpx_client(
190
192
  async_http_client_params
191
193
  )
192
- self._async_http_client = create_async_http_client(
193
- val_async_http_client_params
194
- )
195
194
 
196
195
  self.num_generation_retries = num_generation_retries
197
196
 
198
197
  @property
199
- def api_provider(self) -> APIProvider | None:
200
- return self._api_provider
198
+ def api_provider_name(self) -> APIProviderName | None:
199
+ return self._api_provider_name
201
200
 
202
201
  @property
203
202
  def rate_limiter(
@@ -353,7 +352,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
353
352
 
354
353
  return iterate()
355
354
 
356
- async def generate_completion(
355
+ @limit_rate
356
+ async def generate_completion( # type: ignore[override]
357
357
  self,
358
358
  conversation: Messages,
359
359
  *,
@@ -363,7 +363,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
363
363
  wrapped_func = retry(
364
364
  wait=wait_random_exponential(min=1, max=8),
365
365
  stop=stop_after_attempt(self.num_generation_retries + 1),
366
- before=retry_before_callback,
366
+ before_sleep=retry_before_sleep_callback,
367
367
  retry_error_callback=retry_error_callback,
368
368
  )(self.__class__.generate_completion_no_retry)
369
369
 
@@ -371,23 +371,6 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
371
371
  self, conversation, tool_choice=tool_choice, n_choices=n_choices
372
372
  )
373
373
 
374
- @limit_rate_chunked # type: ignore
375
- async def _generate_completion_batch(
376
- self,
377
- conversation: Messages,
378
- *,
379
- tool_choice: ToolChoice | None = None,
380
- ) -> Completion:
381
- return await self.generate_completion(conversation, tool_choice=tool_choice)
382
-
383
- async def generate_completion_batch(
384
- self, message_history: MessageHistory, *, tool_choice: ToolChoice | None = None
385
- ) -> Sequence[Completion]:
386
- return await self._generate_completion_batch(
387
- list(message_history.conversations), # type: ignore
388
- tool_choice=tool_choice,
389
- )
390
-
391
374
  def _get_rate_limiter(
392
375
  self,
393
376
  rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = None,
@@ -5,12 +5,13 @@ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
5
5
  from pydantic import BaseModel
6
6
  from pydantic.json_schema import SkipJsonSchema
7
7
 
8
+ from .memory import MemT
8
9
  from .packet import Packet
9
10
  from .packet_pool import PacketPool
10
11
  from .processor import Processor
11
12
  from .run_context import CtxT, RunContext
12
13
  from .typing.events import Event, PacketEvent
13
- from .typing.io import InT_contra, MemT_co, OutT_co, ProcName
14
+ from .typing.io import InT, OutT_co, ProcName
14
15
 
15
16
  logger = logging.getLogger(__name__)
16
17
 
@@ -31,8 +32,8 @@ class ExitCommunicationHandler(Protocol[_OutT_contra, CtxT]):
31
32
 
32
33
 
33
34
  class CommProcessor(
34
- Processor[InT_contra, OutT_co, MemT_co, CtxT],
35
- Generic[InT_contra, OutT_co, MemT_co, CtxT],
35
+ Processor[InT, OutT_co, MemT, CtxT],
36
+ Generic[InT, OutT_co, MemT, CtxT],
36
37
  ):
37
38
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
38
39
  0: "_in_type",
@@ -45,8 +46,9 @@ class CommProcessor(
45
46
  *,
46
47
  recipients: Sequence[ProcName] | None = None,
47
48
  packet_pool: PacketPool[CtxT] | None = None,
49
+ num_par_run_retries: int = 0,
48
50
  ) -> None:
49
- super().__init__(name=name)
51
+ super().__init__(name=name, num_par_run_retries=num_par_run_retries)
50
52
 
51
53
  self.recipients = recipients or []
52
54
 
@@ -56,6 +58,10 @@ class CommProcessor(
56
58
  ExitCommunicationHandler[OutT_co, CtxT] | None
57
59
  ) = None
58
60
 
61
+ @property
62
+ def packet_pool(self) -> PacketPool[CtxT] | None:
63
+ return self._packet_pool
64
+
59
65
  def _validate_routing(self, payloads: Sequence[OutT_co]) -> Sequence[ProcName]:
60
66
  if all(isinstance(p, DynCommPayload) for p in payloads):
61
67
  payloads_ = cast("Sequence[DynCommPayload]", payloads)
@@ -88,9 +94,10 @@ class CommProcessor(
88
94
  self,
89
95
  chat_inputs: Any | None = None,
90
96
  *,
91
- in_packet: Packet[InT_contra] | None = None,
92
- in_args: InT_contra | Sequence[InT_contra] | None = None,
97
+ in_packet: Packet[InT] | None = None,
98
+ in_args: InT | Sequence[InT] | None = None,
93
99
  forgetful: bool = False,
100
+ run_id: str | None = None,
94
101
  ctx: RunContext[CtxT] | None = None,
95
102
  ) -> Packet[OutT_co]:
96
103
  out_packet = await super().run(
@@ -98,6 +105,7 @@ class CommProcessor(
98
105
  in_packet=in_packet,
99
106
  in_args=in_args,
100
107
  forgetful=forgetful,
108
+ run_id=run_id,
101
109
  ctx=ctx,
102
110
  )
103
111
  recipients = self._validate_routing(out_packet.payloads)
@@ -114,9 +122,10 @@ class CommProcessor(
114
122
  self,
115
123
  chat_inputs: Any | None = None,
116
124
  *,
117
- in_packet: Packet[InT_contra] | None = None,
118
- in_args: InT_contra | Sequence[InT_contra] | None = None,
125
+ in_packet: Packet[InT] | None = None,
126
+ in_args: InT | None = None,
119
127
  forgetful: bool = False,
128
+ run_id: str | None = None,
120
129
  ctx: RunContext[CtxT] | None = None,
121
130
  ) -> AsyncIterator[Event[Any]]:
122
131
  out_packet: Packet[OutT_co] | None = None
@@ -125,6 +134,7 @@ class CommProcessor(
125
134
  in_packet=in_packet,
126
135
  in_args=in_args,
127
136
  forgetful=forgetful,
137
+ run_id=run_id,
128
138
  ctx=ctx,
129
139
  ):
130
140
  if isinstance(event, PacketEvent):
@@ -152,7 +162,7 @@ class CommProcessor(
152
162
 
153
163
  return func
154
164
 
155
- def _exit_communication_fn(
165
+ def _exit_communication(
156
166
  self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT] | None
157
167
  ) -> bool:
158
168
  if self._exit_communication_impl:
@@ -162,7 +172,7 @@ class CommProcessor(
162
172
 
163
173
  async def _packet_handler(
164
174
  self,
165
- packet: Packet[InT_contra],
175
+ packet: Packet[InT],
166
176
  ctx: RunContext[CtxT] | None = None,
167
177
  **run_kwargs: Any,
168
178
  ) -> None:
@@ -170,7 +180,7 @@ class CommProcessor(
170
180
 
171
181
  out_packet = await self.run(ctx=ctx, in_packet=packet, **run_kwargs)
172
182
 
173
- if self._exit_communication_fn(out_packet=out_packet, ctx=ctx):
183
+ if self._exit_communication(out_packet=out_packet, ctx=ctx):
174
184
  await self._packet_pool.stop_all()
175
185
  return
176
186
 
grasp_agents/errors.py ADDED
@@ -0,0 +1,34 @@
1
+ class InputValidationError(Exception):
2
+ pass
3
+
4
+
5
+ class StringParsingError(Exception):
6
+ pass
7
+
8
+
9
+ class CompletionError(Exception):
10
+ pass
11
+
12
+
13
+ class CombineCompletionChunksError(Exception):
14
+ pass
15
+
16
+
17
+ class ToolValidationError(Exception):
18
+ pass
19
+
20
+
21
+ class OutputValidationError(Exception):
22
+ pass
23
+
24
+
25
+ class WorkflowConstructionError(Exception):
26
+ pass
27
+
28
+
29
+ class SystemPromptBuilderError(Exception):
30
+ pass
31
+
32
+
33
+ class InputPromptBuilderError(Exception):
34
+ pass
@@ -1,3 +1,5 @@
1
+ from typing import Any
2
+
1
3
  import httpx
2
4
  from pydantic import BaseModel, NonNegativeFloat, PositiveInt
3
5
 
@@ -9,10 +11,12 @@ class AsyncHTTPClientParams(BaseModel):
9
11
  keepalive_expiry: float | None = 5
10
12
 
11
13
 
12
- def create_async_http_client(
13
- client_params: AsyncHTTPClientParams,
14
+ def create_simple_async_httpx_client(
15
+ client_params: AsyncHTTPClientParams | dict[str, Any],
14
16
  ) -> httpx.AsyncClient:
15
- http_client = httpx.AsyncClient(
17
+ if isinstance(client_params, dict):
18
+ client_params = AsyncHTTPClientParams(**client_params)
19
+ return httpx.AsyncClient(
16
20
  timeout=httpx.Timeout(client_params.timeout),
17
21
  limits=httpx.Limits(
18
22
  max_connections=client_params.max_connections,
@@ -20,5 +24,3 @@ def create_async_http_client(
20
24
  keepalive_expiry=client_params.keepalive_expiry,
21
25
  ),
22
26
  )
23
-
24
- return http_client
grasp_agents/llm.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import AsyncIterator, Mapping, Sequence
3
+ from collections.abc import AsyncIterator, Mapping
4
4
  from typing import Any, Generic, TypeVar, cast
5
5
  from uuid import uuid4
6
6
 
@@ -9,7 +9,7 @@ from typing_extensions import TypedDict
9
9
 
10
10
  from grasp_agents.utils import validate_obj_from_json_or_py_string
11
11
 
12
- from .message_history import MessageHistory
12
+ from .errors import ToolValidationError
13
13
  from .typing.completion import Completion
14
14
  from .typing.converters import Converters
15
15
  from .typing.events import CompletionChunkEvent, CompletionEvent
@@ -118,7 +118,7 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
118
118
 
119
119
  available_tool_names = list(self.tools) if self.tools else []
120
120
  if tool_name not in available_tool_names or not self.tools:
121
- raise ValueError(
121
+ raise ToolValidationError(
122
122
  f"Tool '{tool_name}' is not available in the LLM tools "
123
123
  f"(available: {available_tool_names}"
124
124
  )
@@ -146,9 +146,3 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
146
146
  n_choices: int | None = None,
147
147
  ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
148
148
  pass
149
-
150
- @abstractmethod
151
- async def generate_completion_batch(
152
- self, message_history: MessageHistory, *, tool_choice: ToolChoice | None = None
153
- ) -> Sequence[Completion]:
154
- pass