grasp_agents 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +28 -0
- grasp_agents/agent_message_pool.py +94 -0
- grasp_agents/base_agent.py +72 -0
- grasp_agents/cloud_llm.py +353 -0
- grasp_agents/comm_agent.py +230 -0
- grasp_agents/costs_dict.yaml +122 -0
- grasp_agents/data_retrieval/__init__.py +7 -0
- grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
- grasp_agents/data_retrieval/types.py +57 -0
- grasp_agents/data_retrieval/utils.py +57 -0
- grasp_agents/grasp_logging.py +36 -0
- grasp_agents/http_client.py +24 -0
- grasp_agents/llm.py +106 -0
- grasp_agents/llm_agent.py +361 -0
- grasp_agents/llm_agent_state.py +73 -0
- grasp_agents/memory.py +150 -0
- grasp_agents/openai/__init__.py +83 -0
- grasp_agents/openai/completion_converters.py +49 -0
- grasp_agents/openai/content_converters.py +80 -0
- grasp_agents/openai/converters.py +170 -0
- grasp_agents/openai/message_converters.py +155 -0
- grasp_agents/openai/openai_llm.py +179 -0
- grasp_agents/openai/tool_converters.py +37 -0
- grasp_agents/printer.py +156 -0
- grasp_agents/prompt_builder.py +204 -0
- grasp_agents/run_context.py +90 -0
- grasp_agents/tool_orchestrator.py +181 -0
- grasp_agents/typing/__init__.py +0 -0
- grasp_agents/typing/completion.py +30 -0
- grasp_agents/typing/content.py +116 -0
- grasp_agents/typing/converters.py +118 -0
- grasp_agents/typing/io.py +32 -0
- grasp_agents/typing/message.py +130 -0
- grasp_agents/typing/tool.py +52 -0
- grasp_agents/usage_tracker.py +99 -0
- grasp_agents/utils.py +151 -0
- grasp_agents/workflow/__init__.py +0 -0
- grasp_agents/workflow/looped_agent.py +113 -0
- grasp_agents/workflow/sequential_agent.py +57 -0
- grasp_agents/workflow/workflow_agent.py +69 -0
- grasp_agents-0.1.5.dist-info/METADATA +14 -0
- grasp_agents-0.1.5.dist-info/RECORD +44 -0
- grasp_agents-0.1.5.dist-info/WHEEL +4 -0
- grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,28 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
from typing import Generic, TypeVar
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field
|
6
|
+
|
7
|
+
# from .base_agent import StateT
|
8
|
+
from .typing.io import AgentID, AgentPayload, AgentState
|
9
|
+
|
10
|
+
_PayloadT = TypeVar("_PayloadT", bound=AgentPayload, covariant=True) # noqa: PLC0105
|
11
|
+
_StateT = TypeVar("_StateT", bound=AgentState, covariant=True) # noqa: PLC0105
|
12
|
+
|
13
|
+
|
14
|
+
class AgentMessage(BaseModel, Generic[_PayloadT, _StateT]):
|
15
|
+
payloads: Sequence[_PayloadT]
|
16
|
+
sender_id: AgentID
|
17
|
+
sender_state: _StateT | None = None
|
18
|
+
recipient_ids: Sequence[AgentID] = Field(default_factory=list)
|
19
|
+
|
20
|
+
message_id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
21
|
+
|
22
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
23
|
+
|
24
|
+
def __repr__(self) -> str:
|
25
|
+
return (
|
26
|
+
f"From: {self.sender_id}, To: {', '.join(self.recipient_ids)}, "
|
27
|
+
f"Payloads: {len(self.payloads)}"
|
28
|
+
)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from typing import Any, Generic, Protocol, TypeVar
|
4
|
+
|
5
|
+
from .agent_message import AgentMessage
|
6
|
+
from .run_context import CtxT, RunContextWrapper
|
7
|
+
from .typing.io import AgentID, AgentPayload, AgentState
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
_MH_PayloadT = TypeVar("_MH_PayloadT", bound=AgentPayload, contravariant=True) # noqa: PLC0105
|
13
|
+
_MH_StateT = TypeVar("_MH_StateT", bound=AgentState, contravariant=True) # noqa: PLC0105
|
14
|
+
|
15
|
+
|
16
|
+
class MessageHandler(Protocol[_MH_PayloadT, _MH_StateT, CtxT]):
|
17
|
+
async def __call__(
|
18
|
+
self,
|
19
|
+
message: AgentMessage[_MH_PayloadT, _MH_StateT],
|
20
|
+
ctx: RunContextWrapper[CtxT] | None,
|
21
|
+
**kwargs: Any,
|
22
|
+
) -> None: ...
|
23
|
+
|
24
|
+
|
25
|
+
class AgentMessagePool(Generic[CtxT]):
|
26
|
+
def __init__(self) -> None:
|
27
|
+
self._queues: dict[
|
28
|
+
AgentID, asyncio.Queue[AgentMessage[AgentPayload, AgentState]]
|
29
|
+
] = {}
|
30
|
+
self._message_handlers: dict[
|
31
|
+
AgentID, MessageHandler[AgentPayload, AgentState, CtxT]
|
32
|
+
] = {}
|
33
|
+
self._tasks: dict[AgentID, asyncio.Task[None]] = {}
|
34
|
+
|
35
|
+
async def post(self, message: AgentMessage[AgentPayload, AgentState]) -> None:
|
36
|
+
for recipient_id in message.recipient_ids:
|
37
|
+
queue = self._queues.setdefault(recipient_id, asyncio.Queue())
|
38
|
+
await queue.put(message)
|
39
|
+
|
40
|
+
def register_message_handler(
|
41
|
+
self,
|
42
|
+
agent_id: AgentID,
|
43
|
+
handler: MessageHandler[AgentPayload, AgentState, CtxT],
|
44
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
45
|
+
**run_kwargs: Any,
|
46
|
+
) -> None:
|
47
|
+
self._message_handlers[agent_id] = handler
|
48
|
+
self._queues.setdefault(agent_id, asyncio.Queue())
|
49
|
+
if agent_id not in self._tasks:
|
50
|
+
self._tasks[agent_id] = asyncio.create_task(
|
51
|
+
self._process_messages(agent_id, ctx=ctx, **run_kwargs)
|
52
|
+
)
|
53
|
+
|
54
|
+
async def _process_messages(
|
55
|
+
self,
|
56
|
+
agent_id: AgentID,
|
57
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
58
|
+
**run_kwargs: Any,
|
59
|
+
) -> None:
|
60
|
+
queue = self._queues[agent_id]
|
61
|
+
while True:
|
62
|
+
try:
|
63
|
+
message = await queue.get()
|
64
|
+
handler = self._message_handlers.get(agent_id)
|
65
|
+
if handler is None:
|
66
|
+
break
|
67
|
+
|
68
|
+
try:
|
69
|
+
await self._message_handlers[agent_id](
|
70
|
+
message, ctx=ctx, **run_kwargs
|
71
|
+
)
|
72
|
+
except Exception:
|
73
|
+
logger.exception(f"Error handling message for {agent_id}")
|
74
|
+
|
75
|
+
queue.task_done()
|
76
|
+
|
77
|
+
except Exception:
|
78
|
+
logger.exception(f"Unexpected error in processing loop for {agent_id}")
|
79
|
+
|
80
|
+
async def unregister_message_handler(self, agent_id: AgentID) -> None:
|
81
|
+
if task := self._tasks.get(agent_id):
|
82
|
+
task.cancel()
|
83
|
+
try:
|
84
|
+
await task
|
85
|
+
except asyncio.CancelledError:
|
86
|
+
logger.debug(f"{agent_id} exited")
|
87
|
+
|
88
|
+
self._tasks.pop(agent_id, None)
|
89
|
+
self._queues.pop(agent_id, None)
|
90
|
+
self._message_handlers.pop(agent_id, None)
|
91
|
+
|
92
|
+
async def stop_all(self) -> None:
|
93
|
+
for agent_id in list(self._tasks):
|
94
|
+
await self.unregister_message_handler(agent_id)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Any, Generic, Protocol
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
|
6
|
+
from .run_context import CtxT, RunContextWrapper
|
7
|
+
from .typing.io import AgentID, AgentPayload, OutT, StateT
|
8
|
+
from .typing.tool import BaseTool
|
9
|
+
|
10
|
+
|
11
|
+
class ParseOutputHandler(Protocol[OutT, CtxT]):
|
12
|
+
def __call__(
|
13
|
+
self, *args: Any, ctx: RunContextWrapper[CtxT] | None, **kwargs: Any
|
14
|
+
) -> OutT: ...
|
15
|
+
|
16
|
+
|
17
|
+
class BaseAgent(ABC, Generic[OutT, StateT, CtxT]):
|
18
|
+
@abstractmethod
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
agent_id: AgentID,
|
22
|
+
*,
|
23
|
+
out_schema: type[OutT] = AgentPayload,
|
24
|
+
**kwargs: Any,
|
25
|
+
) -> None:
|
26
|
+
self._state: StateT
|
27
|
+
self._agent_id = agent_id
|
28
|
+
self._out_schema = out_schema
|
29
|
+
self._parse_output_impl: ParseOutputHandler[OutT, CtxT] | None = None
|
30
|
+
|
31
|
+
def parse_output_handler(
|
32
|
+
self, func: ParseOutputHandler[OutT, CtxT]
|
33
|
+
) -> ParseOutputHandler[OutT, CtxT]:
|
34
|
+
self._parse_output_impl = func
|
35
|
+
|
36
|
+
return func
|
37
|
+
|
38
|
+
def _parse_output(
|
39
|
+
self, *args: Any, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
|
40
|
+
) -> OutT:
|
41
|
+
if self._parse_output_impl:
|
42
|
+
return self._parse_output_impl(*args, ctx=ctx, **kwargs)
|
43
|
+
|
44
|
+
return self._out_schema()
|
45
|
+
|
46
|
+
@property
|
47
|
+
def agent_id(self) -> AgentID:
|
48
|
+
return self._agent_id
|
49
|
+
|
50
|
+
@property
|
51
|
+
def state(self) -> StateT:
|
52
|
+
return self._state
|
53
|
+
|
54
|
+
@property
|
55
|
+
def out_schema(self) -> type[OutT]:
|
56
|
+
return self._out_schema
|
57
|
+
|
58
|
+
@abstractmethod
|
59
|
+
async def run(
|
60
|
+
self,
|
61
|
+
inp_items: Any,
|
62
|
+
*,
|
63
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
64
|
+
**kwargs: Any,
|
65
|
+
) -> Any:
|
66
|
+
pass
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
def as_tool(
|
70
|
+
self, tool_name: str, tool_description: str, tool_strict: bool = True
|
71
|
+
) -> BaseTool[BaseModel, BaseModel, CtxT]:
|
72
|
+
pass
|
@@ -0,0 +1,353 @@
|
|
1
|
+
import fnmatch
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from abc import abstractmethod
|
5
|
+
from collections.abc import AsyncIterator, Sequence
|
6
|
+
from typing import Any, Generic, Literal
|
7
|
+
|
8
|
+
import httpx
|
9
|
+
from pydantic import BaseModel, TypeAdapter
|
10
|
+
from tenacity import (
|
11
|
+
RetryCallState,
|
12
|
+
retry,
|
13
|
+
stop_after_attempt,
|
14
|
+
wait_random_exponential,
|
15
|
+
)
|
16
|
+
from typing_extensions import TypedDict
|
17
|
+
|
18
|
+
from .data_retrieval.rate_limiter_chunked import ( # type: ignore
|
19
|
+
RateLimiterC,
|
20
|
+
limit_rate_chunked,
|
21
|
+
)
|
22
|
+
|
23
|
+
from .http_client import AsyncHTTPClientParams, create_async_http_client
|
24
|
+
from .llm import LLM, ConvertT, LLMSettings, SettingsT
|
25
|
+
from .memory import MessageHistory
|
26
|
+
from .typing.completion import Completion, CompletionChunk
|
27
|
+
from .typing.message import AssistantMessage, Conversation
|
28
|
+
from .typing.tool import BaseTool, ToolChoice
|
29
|
+
from .utils import extract_json
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
|
34
|
+
APIProvider = Literal["openai", "openrouter", "google_ai_studio"]
|
35
|
+
|
36
|
+
|
37
|
+
class APIProviderInfo(TypedDict):
|
38
|
+
name: APIProvider
|
39
|
+
base_url: str
|
40
|
+
api_key: str | None
|
41
|
+
struct_output_support: list[str]
|
42
|
+
|
43
|
+
|
44
|
+
PROVIDERS: dict[APIProvider, APIProviderInfo] = {
|
45
|
+
"openai": APIProviderInfo(
|
46
|
+
name="openai",
|
47
|
+
base_url="https://api.openai.com/v1",
|
48
|
+
api_key=os.getenv("OPENAI_API_KEY"),
|
49
|
+
struct_output_support=["*"],
|
50
|
+
),
|
51
|
+
"openrouter": APIProviderInfo(
|
52
|
+
name="openrouter",
|
53
|
+
base_url="https://openrouter.ai/api/v1",
|
54
|
+
api_key=os.getenv("OPENROUTER_API_KEY"),
|
55
|
+
struct_output_support=[],
|
56
|
+
),
|
57
|
+
"google_ai_studio": APIProviderInfo(
|
58
|
+
name="google_ai_studio",
|
59
|
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
60
|
+
api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
|
61
|
+
struct_output_support=["*"],
|
62
|
+
),
|
63
|
+
}
|
64
|
+
|
65
|
+
|
66
|
+
def retry_error_callback(retry_state: RetryCallState) -> None:
|
67
|
+
assert retry_state.outcome is not None
|
68
|
+
exception = retry_state.outcome.exception()
|
69
|
+
if exception:
|
70
|
+
if retry_state.attempt_number == 1:
|
71
|
+
logger.error(
|
72
|
+
f"CloudLLM completion request failed:\n{exception}",
|
73
|
+
exc_info=exception,
|
74
|
+
)
|
75
|
+
if retry_state.attempt_number > 1:
|
76
|
+
logger.warning(
|
77
|
+
f"CloudLLM completion request failed after retrying:\n{exception}",
|
78
|
+
exc_info=exception,
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
def retry_before_callback(retry_state: RetryCallState) -> None:
|
83
|
+
if retry_state.attempt_number > 1:
|
84
|
+
logger.info(
|
85
|
+
"Retrying CloudLLM completion request "
|
86
|
+
f"(attempt {retry_state.attempt_number - 1}) ..."
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
class CloudLLMSettings(LLMSettings, total=False):
|
91
|
+
max_completion_tokens: int | None
|
92
|
+
temperature: float | None
|
93
|
+
top_p: float | None
|
94
|
+
seed: int | None
|
95
|
+
|
96
|
+
|
97
|
+
class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
# Base LLM args
|
101
|
+
model_name: str,
|
102
|
+
converters: ConvertT,
|
103
|
+
llm_settings: SettingsT | None = None,
|
104
|
+
model_id: str | None = None,
|
105
|
+
tools: list[BaseTool[BaseModel, BaseModel, Any]] | None = None,
|
106
|
+
response_format: type | None = None,
|
107
|
+
# Connection settings
|
108
|
+
api_provider: APIProvider = "openai",
|
109
|
+
async_http_client_params: (
|
110
|
+
dict[str, Any] | AsyncHTTPClientParams | None
|
111
|
+
) = None,
|
112
|
+
# Rate limiting
|
113
|
+
rate_limiter: (RateLimiterC[Conversation, AssistantMessage] | None) = None,
|
114
|
+
rate_limiter_rpm: float | None = None,
|
115
|
+
rate_limiter_chunk_size: int = 1000,
|
116
|
+
rate_limiter_max_concurrency: int = 300,
|
117
|
+
# Retries
|
118
|
+
num_generation_retries: int = 0,
|
119
|
+
# Disable tqdm for batch processing
|
120
|
+
no_tqdm: bool = True,
|
121
|
+
**kwargs: Any,
|
122
|
+
) -> None:
|
123
|
+
self.llm_settings: CloudLLMSettings | None
|
124
|
+
|
125
|
+
super().__init__(
|
126
|
+
model_name=model_name,
|
127
|
+
llm_settings=llm_settings,
|
128
|
+
converters=converters,
|
129
|
+
model_id=model_id,
|
130
|
+
tools=tools,
|
131
|
+
response_format=response_format,
|
132
|
+
**kwargs,
|
133
|
+
)
|
134
|
+
|
135
|
+
self._model_name = model_name
|
136
|
+
self._api_provider: APIProvider = api_provider
|
137
|
+
|
138
|
+
patterns = PROVIDERS[api_provider]["struct_output_support"]
|
139
|
+
self._struct_output_support: bool = any(
|
140
|
+
fnmatch.fnmatch(self._model_name, pat) for pat in patterns
|
141
|
+
)
|
142
|
+
self._response_format_pyd: TypeAdapter[Any] | None = (
|
143
|
+
TypeAdapter(self._response_format) if response_format else None
|
144
|
+
)
|
145
|
+
|
146
|
+
self._rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = (
|
147
|
+
self._get_rate_limiter(
|
148
|
+
rate_limiter=rate_limiter,
|
149
|
+
rate_limiter_rpm=rate_limiter_rpm,
|
150
|
+
rate_limiter_chunk_size=rate_limiter_chunk_size,
|
151
|
+
rate_limiter_max_concurrency=rate_limiter_max_concurrency,
|
152
|
+
)
|
153
|
+
)
|
154
|
+
self.no_tqdm = no_tqdm
|
155
|
+
|
156
|
+
self._base_url: str = PROVIDERS[api_provider]["base_url"]
|
157
|
+
self._api_key: str | None = PROVIDERS[api_provider]["api_key"]
|
158
|
+
self._client: Any
|
159
|
+
|
160
|
+
self._async_http_client: httpx.AsyncClient | None = None
|
161
|
+
if async_http_client_params is not None:
|
162
|
+
val_async_http_client_params = AsyncHTTPClientParams.model_validate(
|
163
|
+
async_http_client_params
|
164
|
+
)
|
165
|
+
self._async_http_client = create_async_http_client(
|
166
|
+
val_async_http_client_params
|
167
|
+
)
|
168
|
+
|
169
|
+
self.num_generation_retries = num_generation_retries
|
170
|
+
|
171
|
+
@property
|
172
|
+
def api_provider(self) -> APIProvider:
|
173
|
+
return self._api_provider
|
174
|
+
|
175
|
+
@property
|
176
|
+
def rate_limiter(
|
177
|
+
self,
|
178
|
+
) -> RateLimiterC[Conversation, AssistantMessage] | None:
|
179
|
+
return self._rate_limiter
|
180
|
+
|
181
|
+
def _make_completion_kwargs(
|
182
|
+
self, conversation: Conversation, tool_choice: ToolChoice | None = None
|
183
|
+
) -> dict[str, Any]:
|
184
|
+
api_llm_settings = self.llm_settings or {}
|
185
|
+
api_messages = [self._converters.to_message(m) for m in conversation]
|
186
|
+
api_tools = None
|
187
|
+
api_tool_choice = None
|
188
|
+
if self.tools:
|
189
|
+
api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
|
190
|
+
if tool_choice is not None:
|
191
|
+
api_tool_choice = self._converters.to_tool_choice(tool_choice)
|
192
|
+
|
193
|
+
return dict(
|
194
|
+
api_messages=api_messages,
|
195
|
+
api_tools=api_tools,
|
196
|
+
api_tool_choice=api_tool_choice,
|
197
|
+
api_response_format=self._response_format,
|
198
|
+
**api_llm_settings,
|
199
|
+
)
|
200
|
+
|
201
|
+
@abstractmethod
|
202
|
+
async def _get_completion(
|
203
|
+
self,
|
204
|
+
api_messages: list[Any],
|
205
|
+
*,
|
206
|
+
api_tools: list[Any] | None = None,
|
207
|
+
api_tool_choice: Any | None = None,
|
208
|
+
**api_llm_settings: Any,
|
209
|
+
) -> Any:
|
210
|
+
pass
|
211
|
+
|
212
|
+
@abstractmethod
|
213
|
+
async def _get_parsed_completion(
|
214
|
+
self,
|
215
|
+
api_messages: list[Any],
|
216
|
+
*,
|
217
|
+
api_tools: list[Any] | None = None,
|
218
|
+
api_tool_choice: Any | None = None,
|
219
|
+
**api_llm_settings: Any,
|
220
|
+
) -> Any:
|
221
|
+
pass
|
222
|
+
|
223
|
+
@abstractmethod
|
224
|
+
async def _get_completion_stream(
|
225
|
+
self,
|
226
|
+
api_messages: list[Any],
|
227
|
+
*,
|
228
|
+
api_tools: list[Any] | None = None,
|
229
|
+
api_tool_choice: Any | None = None,
|
230
|
+
**api_llm_settings: Any,
|
231
|
+
) -> AsyncIterator[Any]:
|
232
|
+
pass
|
233
|
+
|
234
|
+
async def generate_completion(
|
235
|
+
self,
|
236
|
+
conversation: Conversation,
|
237
|
+
*,
|
238
|
+
tool_choice: ToolChoice | None = None,
|
239
|
+
**kwargs: Any,
|
240
|
+
) -> Completion:
|
241
|
+
completion_kwargs = self._make_completion_kwargs(
|
242
|
+
conversation=conversation, tool_choice=tool_choice
|
243
|
+
)
|
244
|
+
|
245
|
+
if self._response_format is None or not self._struct_output_support:
|
246
|
+
completion_kwargs.pop("api_response_format", None)
|
247
|
+
api_completion = await self._get_completion(**completion_kwargs, **kwargs)
|
248
|
+
else:
|
249
|
+
api_completion = await self._get_parsed_completion(
|
250
|
+
**completion_kwargs, **kwargs
|
251
|
+
)
|
252
|
+
|
253
|
+
return self._converters.from_completion(api_completion, model_id=self.model_id)
|
254
|
+
|
255
|
+
async def generate_completion_stream(
|
256
|
+
self,
|
257
|
+
conversation: Conversation,
|
258
|
+
*,
|
259
|
+
tool_choice: ToolChoice | None = None,
|
260
|
+
**kwargs: Any,
|
261
|
+
) -> AsyncIterator[CompletionChunk]:
|
262
|
+
completion_kwargs = self._make_completion_kwargs(
|
263
|
+
conversation=conversation, tool_choice=tool_choice
|
264
|
+
)
|
265
|
+
completion_kwargs.pop("api_response_format", None)
|
266
|
+
api_completion_chunk_iterator = await self._get_completion_stream(
|
267
|
+
**completion_kwargs, **kwargs
|
268
|
+
)
|
269
|
+
|
270
|
+
return self._converters.from_completion_chunk_iterator(
|
271
|
+
api_completion_chunk_iterator, model_id=self.model_id
|
272
|
+
)
|
273
|
+
|
274
|
+
async def generate_message(
|
275
|
+
self,
|
276
|
+
conversation: Conversation,
|
277
|
+
*,
|
278
|
+
tool_choice: ToolChoice | None = None,
|
279
|
+
**kwargs: Any,
|
280
|
+
) -> AssistantMessage:
|
281
|
+
completion = await self.generate_completion(
|
282
|
+
conversation, tool_choice=tool_choice, **kwargs
|
283
|
+
)
|
284
|
+
message = completion.choices[0].message
|
285
|
+
if self._response_format_pyd is not None and not self._struct_output_support:
|
286
|
+
self._response_format_pyd.validate_python(extract_json(message.content))
|
287
|
+
|
288
|
+
return message
|
289
|
+
|
290
|
+
async def _generate_message_with_retry(
|
291
|
+
self,
|
292
|
+
conversation: Conversation,
|
293
|
+
*,
|
294
|
+
tool_choice: ToolChoice | None = None,
|
295
|
+
**kwargs: Any,
|
296
|
+
) -> AssistantMessage:
|
297
|
+
wrapped_func = retry(
|
298
|
+
wait=wait_random_exponential(min=1, max=8),
|
299
|
+
stop=stop_after_attempt(self.num_generation_retries + 1),
|
300
|
+
before=retry_before_callback,
|
301
|
+
retry_error_callback=retry_error_callback,
|
302
|
+
)(self.__class__.generate_message)
|
303
|
+
|
304
|
+
return await wrapped_func(self, conversation, tool_choice=tool_choice, **kwargs)
|
305
|
+
|
306
|
+
@limit_rate_chunked # type: ignore
|
307
|
+
async def _generate_message_batch_with_retry_and_rate_lim(
|
308
|
+
self,
|
309
|
+
conversation: Conversation,
|
310
|
+
*,
|
311
|
+
tool_choice: ToolChoice | None = None,
|
312
|
+
**kwargs: Any,
|
313
|
+
) -> AssistantMessage:
|
314
|
+
return await self._generate_message_with_retry(
|
315
|
+
conversation, tool_choice=tool_choice, **kwargs
|
316
|
+
)
|
317
|
+
|
318
|
+
async def generate_message_batch(
|
319
|
+
self,
|
320
|
+
message_history: MessageHistory,
|
321
|
+
*,
|
322
|
+
tool_choice: ToolChoice | None = None,
|
323
|
+
**kwargs: Any,
|
324
|
+
) -> Sequence[AssistantMessage]:
|
325
|
+
return await self._generate_message_batch_with_retry_and_rate_lim(
|
326
|
+
list(message_history.batched_conversations), # type: ignore
|
327
|
+
tool_choice=tool_choice,
|
328
|
+
**kwargs,
|
329
|
+
)
|
330
|
+
|
331
|
+
def _get_rate_limiter(
|
332
|
+
self,
|
333
|
+
rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = None,
|
334
|
+
rate_limiter_rpm: float | None = None,
|
335
|
+
rate_limiter_chunk_size: int = 1000,
|
336
|
+
rate_limiter_max_concurrency: int = 300,
|
337
|
+
) -> RateLimiterC[Conversation, AssistantMessage] | None:
|
338
|
+
if rate_limiter is not None:
|
339
|
+
logger.info(
|
340
|
+
f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
|
341
|
+
)
|
342
|
+
return rate_limiter
|
343
|
+
if rate_limiter_rpm is not None:
|
344
|
+
logger.info(
|
345
|
+
f"[{self.__class__.__name__}] Set rate limit to {rate_limiter_rpm} RPM"
|
346
|
+
)
|
347
|
+
return RateLimiterC(
|
348
|
+
rpm=rate_limiter_rpm,
|
349
|
+
chunk_size=rate_limiter_chunk_size,
|
350
|
+
max_concurrency=rate_limiter_max_concurrency,
|
351
|
+
)
|
352
|
+
|
353
|
+
return None
|