grasp_agents 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/agent_message.py +28 -0
  2. grasp_agents/agent_message_pool.py +94 -0
  3. grasp_agents/base_agent.py +72 -0
  4. grasp_agents/cloud_llm.py +353 -0
  5. grasp_agents/comm_agent.py +230 -0
  6. grasp_agents/costs_dict.yaml +122 -0
  7. grasp_agents/data_retrieval/__init__.py +7 -0
  8. grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
  9. grasp_agents/data_retrieval/types.py +57 -0
  10. grasp_agents/data_retrieval/utils.py +57 -0
  11. grasp_agents/grasp_logging.py +36 -0
  12. grasp_agents/http_client.py +24 -0
  13. grasp_agents/llm.py +106 -0
  14. grasp_agents/llm_agent.py +361 -0
  15. grasp_agents/llm_agent_state.py +73 -0
  16. grasp_agents/memory.py +150 -0
  17. grasp_agents/openai/__init__.py +83 -0
  18. grasp_agents/openai/completion_converters.py +49 -0
  19. grasp_agents/openai/content_converters.py +80 -0
  20. grasp_agents/openai/converters.py +170 -0
  21. grasp_agents/openai/message_converters.py +155 -0
  22. grasp_agents/openai/openai_llm.py +179 -0
  23. grasp_agents/openai/tool_converters.py +37 -0
  24. grasp_agents/printer.py +156 -0
  25. grasp_agents/prompt_builder.py +204 -0
  26. grasp_agents/run_context.py +90 -0
  27. grasp_agents/tool_orchestrator.py +181 -0
  28. grasp_agents/typing/__init__.py +0 -0
  29. grasp_agents/typing/completion.py +30 -0
  30. grasp_agents/typing/content.py +116 -0
  31. grasp_agents/typing/converters.py +118 -0
  32. grasp_agents/typing/io.py +32 -0
  33. grasp_agents/typing/message.py +130 -0
  34. grasp_agents/typing/tool.py +52 -0
  35. grasp_agents/usage_tracker.py +99 -0
  36. grasp_agents/utils.py +151 -0
  37. grasp_agents/workflow/__init__.py +0 -0
  38. grasp_agents/workflow/looped_agent.py +113 -0
  39. grasp_agents/workflow/sequential_agent.py +57 -0
  40. grasp_agents/workflow/workflow_agent.py +69 -0
  41. grasp_agents-0.1.5.dist-info/METADATA +14 -0
  42. grasp_agents-0.1.5.dist-info/RECORD +44 -0
  43. grasp_agents-0.1.5.dist-info/WHEEL +4 -0
  44. grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,28 @@
1
+ from collections.abc import Sequence
2
+ from typing import Generic, TypeVar
3
+ from uuid import uuid4
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+
7
+ # from .base_agent import StateT
8
+ from .typing.io import AgentID, AgentPayload, AgentState
9
+
10
+ _PayloadT = TypeVar("_PayloadT", bound=AgentPayload, covariant=True) # noqa: PLC0105
11
+ _StateT = TypeVar("_StateT", bound=AgentState, covariant=True) # noqa: PLC0105
12
+
13
+
14
+ class AgentMessage(BaseModel, Generic[_PayloadT, _StateT]):
15
+ payloads: Sequence[_PayloadT]
16
+ sender_id: AgentID
17
+ sender_state: _StateT | None = None
18
+ recipient_ids: Sequence[AgentID] = Field(default_factory=list)
19
+
20
+ message_id: str = Field(default_factory=lambda: str(uuid4())[:8])
21
+
22
+ model_config = ConfigDict(extra="forbid", frozen=True)
23
+
24
+ def __repr__(self) -> str:
25
+ return (
26
+ f"From: {self.sender_id}, To: {', '.join(self.recipient_ids)}, "
27
+ f"Payloads: {len(self.payloads)}"
28
+ )
@@ -0,0 +1,94 @@
1
+ import asyncio
2
+ import logging
3
+ from typing import Any, Generic, Protocol, TypeVar
4
+
5
+ from .agent_message import AgentMessage
6
+ from .run_context import CtxT, RunContextWrapper
7
+ from .typing.io import AgentID, AgentPayload, AgentState
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ _MH_PayloadT = TypeVar("_MH_PayloadT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
13
+ _MH_StateT = TypeVar("_MH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
14
+
15
+
16
+ class MessageHandler(Protocol[_MH_PayloadT, _MH_StateT, CtxT]):
17
+ async def __call__(
18
+ self,
19
+ message: AgentMessage[_MH_PayloadT, _MH_StateT],
20
+ ctx: RunContextWrapper[CtxT] | None,
21
+ **kwargs: Any,
22
+ ) -> None: ...
23
+
24
+
25
+ class AgentMessagePool(Generic[CtxT]):
26
+ def __init__(self) -> None:
27
+ self._queues: dict[
28
+ AgentID, asyncio.Queue[AgentMessage[AgentPayload, AgentState]]
29
+ ] = {}
30
+ self._message_handlers: dict[
31
+ AgentID, MessageHandler[AgentPayload, AgentState, CtxT]
32
+ ] = {}
33
+ self._tasks: dict[AgentID, asyncio.Task[None]] = {}
34
+
35
+ async def post(self, message: AgentMessage[AgentPayload, AgentState]) -> None:
36
+ for recipient_id in message.recipient_ids:
37
+ queue = self._queues.setdefault(recipient_id, asyncio.Queue())
38
+ await queue.put(message)
39
+
40
+ def register_message_handler(
41
+ self,
42
+ agent_id: AgentID,
43
+ handler: MessageHandler[AgentPayload, AgentState, CtxT],
44
+ ctx: RunContextWrapper[CtxT] | None = None,
45
+ **run_kwargs: Any,
46
+ ) -> None:
47
+ self._message_handlers[agent_id] = handler
48
+ self._queues.setdefault(agent_id, asyncio.Queue())
49
+ if agent_id not in self._tasks:
50
+ self._tasks[agent_id] = asyncio.create_task(
51
+ self._process_messages(agent_id, ctx=ctx, **run_kwargs)
52
+ )
53
+
54
+ async def _process_messages(
55
+ self,
56
+ agent_id: AgentID,
57
+ ctx: RunContextWrapper[CtxT] | None = None,
58
+ **run_kwargs: Any,
59
+ ) -> None:
60
+ queue = self._queues[agent_id]
61
+ while True:
62
+ try:
63
+ message = await queue.get()
64
+ handler = self._message_handlers.get(agent_id)
65
+ if handler is None:
66
+ break
67
+
68
+ try:
69
+ await self._message_handlers[agent_id](
70
+ message, ctx=ctx, **run_kwargs
71
+ )
72
+ except Exception:
73
+ logger.exception(f"Error handling message for {agent_id}")
74
+
75
+ queue.task_done()
76
+
77
+ except Exception:
78
+ logger.exception(f"Unexpected error in processing loop for {agent_id}")
79
+
80
+ async def unregister_message_handler(self, agent_id: AgentID) -> None:
81
+ if task := self._tasks.get(agent_id):
82
+ task.cancel()
83
+ try:
84
+ await task
85
+ except asyncio.CancelledError:
86
+ logger.debug(f"{agent_id} exited")
87
+
88
+ self._tasks.pop(agent_id, None)
89
+ self._queues.pop(agent_id, None)
90
+ self._message_handlers.pop(agent_id, None)
91
+
92
+ async def stop_all(self) -> None:
93
+ for agent_id in list(self._tasks):
94
+ await self.unregister_message_handler(agent_id)
@@ -0,0 +1,72 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Generic, Protocol
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from .run_context import CtxT, RunContextWrapper
7
+ from .typing.io import AgentID, AgentPayload, OutT, StateT
8
+ from .typing.tool import BaseTool
9
+
10
+
11
+ class ParseOutputHandler(Protocol[OutT, CtxT]):
12
+ def __call__(
13
+ self, *args: Any, ctx: RunContextWrapper[CtxT] | None, **kwargs: Any
14
+ ) -> OutT: ...
15
+
16
+
17
+ class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
18
+ @abstractmethod
19
+ def __init__(
20
+ self,
21
+ agent_id: AgentID,
22
+ *,
23
+ out_schema: type[OutT] = AgentPayload,
24
+ **kwargs: Any,
25
+ ) -> None:
26
+ self._state: StateT
27
+ self._agent_id = agent_id
28
+ self._out_schema = out_schema
29
+ self._parse_output_impl: ParseOutputHandler[OutT, CtxT] | None = None
30
+
31
+ def parse_output_handler(
32
+ self, func: ParseOutputHandler[OutT, CtxT]
33
+ ) -> ParseOutputHandler[OutT, CtxT]:
34
+ self._parse_output_impl = func
35
+
36
+ return func
37
+
38
+ def _parse_output(
39
+ self, *args: Any, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
40
+ ) -> OutT:
41
+ if self._parse_output_impl:
42
+ return self._parse_output_impl(*args, ctx=ctx, **kwargs)
43
+
44
+ return self._out_schema()
45
+
46
+ @property
47
+ def agent_id(self) -> AgentID:
48
+ return self._agent_id
49
+
50
+ @property
51
+ def state(self) -> StateT:
52
+ return self._state
53
+
54
+ @property
55
+ def out_schema(self) -> type[OutT]:
56
+ return self._out_schema
57
+
58
+ @abstractmethod
59
+ async def run(
60
+ self,
61
+ inp_items: Any,
62
+ *,
63
+ ctx: RunContextWrapper[CtxT] | None = None,
64
+ **kwargs: Any,
65
+ ) -> Any:
66
+ pass
67
+
68
+ @abstractmethod
69
+ def as_tool(
70
+ self, tool_name: str, tool_description: str, tool_strict: bool = True
71
+ ) -> BaseTool[BaseModel, BaseModel, CtxT]:
72
+ pass
@@ -0,0 +1,353 @@
1
+ import fnmatch
2
+ import logging
3
+ import os
4
+ from abc import abstractmethod
5
+ from collections.abc import AsyncIterator, Sequence
6
+ from typing import Any, Generic, Literal
7
+
8
+ import httpx
9
+ from pydantic import BaseModel, TypeAdapter
10
+ from tenacity import (
11
+ RetryCallState,
12
+ retry,
13
+ stop_after_attempt,
14
+ wait_random_exponential,
15
+ )
16
+ from typing_extensions import TypedDict
17
+
18
+ from .data_retrieval.rate_limiter_chunked import ( # type: ignore
19
+ RateLimiterC,
20
+ limit_rate_chunked,
21
+ )
22
+
23
+ from .http_client import AsyncHTTPClientParams, create_async_http_client
24
+ from .llm import LLM, ConvertT, LLMSettings, SettingsT
25
+ from .memory import MessageHistory
26
+ from .typing.completion import Completion, CompletionChunk
27
+ from .typing.message import AssistantMessage, Conversation
28
+ from .typing.tool import BaseTool, ToolChoice
29
+ from .utils import extract_json
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ APIProvider = Literal["openai", "openrouter", "google_ai_studio"]
35
+
36
+
37
+ class APIProviderInfo(TypedDict):
38
+ name: APIProvider
39
+ base_url: str
40
+ api_key: str | None
41
+ struct_output_support: list[str]
42
+
43
+
44
+ PROVIDERS: dict[APIProvider, APIProviderInfo] = {
45
+ "openai": APIProviderInfo(
46
+ name="openai",
47
+ base_url="https://api.openai.com/v1",
48
+ api_key=os.getenv("OPENAI_API_KEY"),
49
+ struct_output_support=["*"],
50
+ ),
51
+ "openrouter": APIProviderInfo(
52
+ name="openrouter",
53
+ base_url="https://openrouter.ai/api/v1",
54
+ api_key=os.getenv("OPENROUTER_API_KEY"),
55
+ struct_output_support=[],
56
+ ),
57
+ "google_ai_studio": APIProviderInfo(
58
+ name="google_ai_studio",
59
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
60
+ api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
61
+ struct_output_support=["*"],
62
+ ),
63
+ }
64
+
65
+
66
+ def retry_error_callback(retry_state: RetryCallState) -> None:
67
+ assert retry_state.outcome is not None
68
+ exception = retry_state.outcome.exception()
69
+ if exception:
70
+ if retry_state.attempt_number == 1:
71
+ logger.error(
72
+ f"CloudLLM completion request failed:\n{exception}",
73
+ exc_info=exception,
74
+ )
75
+ if retry_state.attempt_number > 1:
76
+ logger.warning(
77
+ f"CloudLLM completion request failed after retrying:\n{exception}",
78
+ exc_info=exception,
79
+ )
80
+
81
+
82
+ def retry_before_callback(retry_state: RetryCallState) -> None:
83
+ if retry_state.attempt_number > 1:
84
+ logger.info(
85
+ "Retrying CloudLLM completion request "
86
+ f"(attempt {retry_state.attempt_number - 1}) ..."
87
+ )
88
+
89
+
90
+ class CloudLLMSettings(LLMSettings, total=False):
91
+ max_completion_tokens: int | None
92
+ temperature: float | None
93
+ top_p: float | None
94
+ seed: int | None
95
+
96
+
97
+ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
98
+ def __init__(
99
+ self,
100
+ # Base LLM args
101
+ model_name: str,
102
+ converters: ConvertT,
103
+ llm_settings: SettingsT | None = None,
104
+ model_id: str | None = None,
105
+ tools: list[BaseTool[BaseModel, BaseModel, Any]] | None = None,
106
+ response_format: type | None = None,
107
+ # Connection settings
108
+ api_provider: APIProvider = "openai",
109
+ async_http_client_params: (
110
+ dict[str, Any] | AsyncHTTPClientParams | None
111
+ ) = None,
112
+ # Rate limiting
113
+ rate_limiter: (RateLimiterC[Conversation, AssistantMessage] | None) = None,
114
+ rate_limiter_rpm: float | None = None,
115
+ rate_limiter_chunk_size: int = 1000,
116
+ rate_limiter_max_concurrency: int = 300,
117
+ # Retries
118
+ num_generation_retries: int = 0,
119
+ # Disable tqdm for batch processing
120
+ no_tqdm: bool = True,
121
+ **kwargs: Any,
122
+ ) -> None:
123
+ self.llm_settings: CloudLLMSettings | None
124
+
125
+ super().__init__(
126
+ model_name=model_name,
127
+ llm_settings=llm_settings,
128
+ converters=converters,
129
+ model_id=model_id,
130
+ tools=tools,
131
+ response_format=response_format,
132
+ **kwargs,
133
+ )
134
+
135
+ self._model_name = model_name
136
+ self._api_provider: APIProvider = api_provider
137
+
138
+ patterns = PROVIDERS[api_provider]["struct_output_support"]
139
+ self._struct_output_support: bool = any(
140
+ fnmatch.fnmatch(self._model_name, pat) for pat in patterns
141
+ )
142
+ self._response_format_pyd: TypeAdapter[Any] | None = (
143
+ TypeAdapter(self._response_format) if response_format else None
144
+ )
145
+
146
+ self._rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = (
147
+ self._get_rate_limiter(
148
+ rate_limiter=rate_limiter,
149
+ rate_limiter_rpm=rate_limiter_rpm,
150
+ rate_limiter_chunk_size=rate_limiter_chunk_size,
151
+ rate_limiter_max_concurrency=rate_limiter_max_concurrency,
152
+ )
153
+ )
154
+ self.no_tqdm = no_tqdm
155
+
156
+ self._base_url: str = PROVIDERS[api_provider]["base_url"]
157
+ self._api_key: str | None = PROVIDERS[api_provider]["api_key"]
158
+ self._client: Any
159
+
160
+ self._async_http_client: httpx.AsyncClient | None = None
161
+ if async_http_client_params is not None:
162
+ val_async_http_client_params = AsyncHTTPClientParams.model_validate(
163
+ async_http_client_params
164
+ )
165
+ self._async_http_client = create_async_http_client(
166
+ val_async_http_client_params
167
+ )
168
+
169
+ self.num_generation_retries = num_generation_retries
170
+
171
+ @property
172
+ def api_provider(self) -> APIProvider:
173
+ return self._api_provider
174
+
175
+ @property
176
+ def rate_limiter(
177
+ self,
178
+ ) -> RateLimiterC[Conversation, AssistantMessage] | None:
179
+ return self._rate_limiter
180
+
181
+ def _make_completion_kwargs(
182
+ self, conversation: Conversation, tool_choice: ToolChoice | None = None
183
+ ) -> dict[str, Any]:
184
+ api_llm_settings = self.llm_settings or {}
185
+ api_messages = [self._converters.to_message(m) for m in conversation]
186
+ api_tools = None
187
+ api_tool_choice = None
188
+ if self.tools:
189
+ api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
190
+ if tool_choice is not None:
191
+ api_tool_choice = self._converters.to_tool_choice(tool_choice)
192
+
193
+ return dict(
194
+ api_messages=api_messages,
195
+ api_tools=api_tools,
196
+ api_tool_choice=api_tool_choice,
197
+ api_response_format=self._response_format,
198
+ **api_llm_settings,
199
+ )
200
+
201
+ @abstractmethod
202
+ async def _get_completion(
203
+ self,
204
+ api_messages: list[Any],
205
+ *,
206
+ api_tools: list[Any] | None = None,
207
+ api_tool_choice: Any | None = None,
208
+ **api_llm_settings: Any,
209
+ ) -> Any:
210
+ pass
211
+
212
+ @abstractmethod
213
+ async def _get_parsed_completion(
214
+ self,
215
+ api_messages: list[Any],
216
+ *,
217
+ api_tools: list[Any] | None = None,
218
+ api_tool_choice: Any | None = None,
219
+ **api_llm_settings: Any,
220
+ ) -> Any:
221
+ pass
222
+
223
+ @abstractmethod
224
+ async def _get_completion_stream(
225
+ self,
226
+ api_messages: list[Any],
227
+ *,
228
+ api_tools: list[Any] | None = None,
229
+ api_tool_choice: Any | None = None,
230
+ **api_llm_settings: Any,
231
+ ) -> AsyncIterator[Any]:
232
+ pass
233
+
234
+ async def generate_completion(
235
+ self,
236
+ conversation: Conversation,
237
+ *,
238
+ tool_choice: ToolChoice | None = None,
239
+ **kwargs: Any,
240
+ ) -> Completion:
241
+ completion_kwargs = self._make_completion_kwargs(
242
+ conversation=conversation, tool_choice=tool_choice
243
+ )
244
+
245
+ if self._response_format is None or not self._struct_output_support:
246
+ completion_kwargs.pop("api_response_format", None)
247
+ api_completion = await self._get_completion(**completion_kwargs, **kwargs)
248
+ else:
249
+ api_completion = await self._get_parsed_completion(
250
+ **completion_kwargs, **kwargs
251
+ )
252
+
253
+ return self._converters.from_completion(api_completion, model_id=self.model_id)
254
+
255
+ async def generate_completion_stream(
256
+ self,
257
+ conversation: Conversation,
258
+ *,
259
+ tool_choice: ToolChoice | None = None,
260
+ **kwargs: Any,
261
+ ) -> AsyncIterator[CompletionChunk]:
262
+ completion_kwargs = self._make_completion_kwargs(
263
+ conversation=conversation, tool_choice=tool_choice
264
+ )
265
+ completion_kwargs.pop("api_response_format", None)
266
+ api_completion_chunk_iterator = await self._get_completion_stream(
267
+ **completion_kwargs, **kwargs
268
+ )
269
+
270
+ return self._converters.from_completion_chunk_iterator(
271
+ api_completion_chunk_iterator, model_id=self.model_id
272
+ )
273
+
274
+ async def generate_message(
275
+ self,
276
+ conversation: Conversation,
277
+ *,
278
+ tool_choice: ToolChoice | None = None,
279
+ **kwargs: Any,
280
+ ) -> AssistantMessage:
281
+ completion = await self.generate_completion(
282
+ conversation, tool_choice=tool_choice, **kwargs
283
+ )
284
+ message = completion.choices[0].message
285
+ if self._response_format_pyd is not None and not self._struct_output_support:
286
+ self._response_format_pyd.validate_python(extract_json(message.content))
287
+
288
+ return message
289
+
290
+ async def _generate_message_with_retry(
291
+ self,
292
+ conversation: Conversation,
293
+ *,
294
+ tool_choice: ToolChoice | None = None,
295
+ **kwargs: Any,
296
+ ) -> AssistantMessage:
297
+ wrapped_func = retry(
298
+ wait=wait_random_exponential(min=1, max=8),
299
+ stop=stop_after_attempt(self.num_generation_retries + 1),
300
+ before=retry_before_callback,
301
+ retry_error_callback=retry_error_callback,
302
+ )(self.__class__.generate_message)
303
+
304
+ return await wrapped_func(self, conversation, tool_choice=tool_choice, **kwargs)
305
+
306
+ @limit_rate_chunked # type: ignore
307
+ async def _generate_message_batch_with_retry_and_rate_lim(
308
+ self,
309
+ conversation: Conversation,
310
+ *,
311
+ tool_choice: ToolChoice | None = None,
312
+ **kwargs: Any,
313
+ ) -> AssistantMessage:
314
+ return await self._generate_message_with_retry(
315
+ conversation, tool_choice=tool_choice, **kwargs
316
+ )
317
+
318
+ async def generate_message_batch(
319
+ self,
320
+ message_history: MessageHistory,
321
+ *,
322
+ tool_choice: ToolChoice | None = None,
323
+ **kwargs: Any,
324
+ ) -> Sequence[AssistantMessage]:
325
+ return await self._generate_message_batch_with_retry_and_rate_lim(
326
+ list(message_history.batched_conversations), # type: ignore
327
+ tool_choice=tool_choice,
328
+ **kwargs,
329
+ )
330
+
331
+ def _get_rate_limiter(
332
+ self,
333
+ rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = None,
334
+ rate_limiter_rpm: float | None = None,
335
+ rate_limiter_chunk_size: int = 1000,
336
+ rate_limiter_max_concurrency: int = 300,
337
+ ) -> RateLimiterC[Conversation, AssistantMessage] | None:
338
+ if rate_limiter is not None:
339
+ logger.info(
340
+ f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
341
+ )
342
+ return rate_limiter
343
+ if rate_limiter_rpm is not None:
344
+ logger.info(
345
+ f"[{self.__class__.__name__}] Set rate limit to {rate_limiter_rpm} RPM"
346
+ )
347
+ return RateLimiterC(
348
+ rpm=rate_limiter_rpm,
349
+ chunk_size=rate_limiter_chunk_size,
350
+ max_concurrency=rate_limiter_max_concurrency,
351
+ )
352
+
353
+ return None