pydantic-ai-slim 0.0.53__tar.gz → 0.0.55__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/PKG-INFO +3 -3
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/_agent_graph.py +8 -6
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/agent.py +22 -10
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/__init__.py +1 -1
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/anthropic.py +10 -1
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/bedrock.py +16 -14
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/cohere.py +2 -1
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/gemini.py +1 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/groq.py +3 -1
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/mistral.py +7 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/openai.py +5 -1
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/wrapper.py +3 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/__init__.py +4 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/azure.py +2 -2
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/settings.py +13 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/tools.py +1 -1
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/README.md +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/mcp.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/messages.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pydantic_ai/usage.py +0 -0
- {pydantic_ai_slim-0.0.53 → pydantic_ai_slim-0.0.55}/pyproject.toml +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.55
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.55
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
47
|
Provides-Extra: evals
|
|
48
|
-
Requires-Dist: pydantic-evals==0.0.
|
|
48
|
+
Requires-Dist: pydantic-evals==0.0.55; extra == 'evals'
|
|
49
49
|
Provides-Extra: groq
|
|
50
50
|
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
51
51
|
Provides-Extra: logfire
|
|
@@ -79,7 +79,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
79
79
|
|
|
80
80
|
user_deps: DepsT
|
|
81
81
|
|
|
82
|
-
prompt: str | Sequence[_messages.UserContent]
|
|
82
|
+
prompt: str | Sequence[_messages.UserContent] | None
|
|
83
83
|
new_message_index: int
|
|
84
84
|
|
|
85
85
|
model: models.Model
|
|
@@ -124,7 +124,7 @@ def is_agent_node(
|
|
|
124
124
|
|
|
125
125
|
@dataclasses.dataclass
|
|
126
126
|
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
127
|
-
user_prompt: str | Sequence[_messages.UserContent]
|
|
127
|
+
user_prompt: str | Sequence[_messages.UserContent] | None
|
|
128
128
|
|
|
129
129
|
system_prompts: tuple[str, ...]
|
|
130
130
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
@@ -151,7 +151,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
151
151
|
|
|
152
152
|
async def _prepare_messages(
|
|
153
153
|
self,
|
|
154
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
154
|
+
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
155
155
|
message_history: list[_messages.ModelMessage] | None,
|
|
156
156
|
run_context: RunContext[DepsT],
|
|
157
157
|
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
|
|
@@ -166,16 +166,18 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
166
166
|
messages = ctx_messages.messages
|
|
167
167
|
ctx_messages.used = True
|
|
168
168
|
|
|
169
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
169
170
|
if message_history:
|
|
170
171
|
# Shallow copy messages
|
|
171
172
|
messages.extend(message_history)
|
|
172
173
|
# Reevaluate any dynamic system prompt parts
|
|
173
174
|
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
174
|
-
return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
|
|
175
175
|
else:
|
|
176
|
-
parts
|
|
176
|
+
parts.extend(await self._sys_parts(run_context))
|
|
177
|
+
|
|
178
|
+
if user_prompt is not None:
|
|
177
179
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
178
|
-
|
|
180
|
+
return messages, _messages.ModelRequest(parts)
|
|
179
181
|
|
|
180
182
|
async def _reevaluate_dynamic_prompts(
|
|
181
183
|
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
@@ -242,7 +242,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
242
242
|
@overload
|
|
243
243
|
async def run(
|
|
244
244
|
self,
|
|
245
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
245
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
246
246
|
*,
|
|
247
247
|
result_type: None = None,
|
|
248
248
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -257,7 +257,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
257
257
|
@overload
|
|
258
258
|
async def run(
|
|
259
259
|
self,
|
|
260
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
260
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
261
261
|
*,
|
|
262
262
|
result_type: type[RunResultDataT],
|
|
263
263
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -271,7 +271,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
271
271
|
|
|
272
272
|
async def run(
|
|
273
273
|
self,
|
|
274
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
274
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
275
275
|
*,
|
|
276
276
|
result_type: type[RunResultDataT] | None = None,
|
|
277
277
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -335,7 +335,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
335
335
|
@asynccontextmanager
|
|
336
336
|
async def iter(
|
|
337
337
|
self,
|
|
338
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
338
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
339
339
|
*,
|
|
340
340
|
result_type: type[RunResultDataT] | None = None,
|
|
341
341
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -372,6 +372,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
372
372
|
print(nodes)
|
|
373
373
|
'''
|
|
374
374
|
[
|
|
375
|
+
UserPromptNode(
|
|
376
|
+
user_prompt='What is the capital of France?',
|
|
377
|
+
system_prompts=(),
|
|
378
|
+
system_prompt_functions=[],
|
|
379
|
+
system_prompt_dynamic_functions={},
|
|
380
|
+
),
|
|
375
381
|
ModelRequestNode(
|
|
376
382
|
request=ModelRequest(
|
|
377
383
|
parts=[
|
|
@@ -497,7 +503,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
497
503
|
@overload
|
|
498
504
|
def run_sync(
|
|
499
505
|
self,
|
|
500
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
506
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
501
507
|
*,
|
|
502
508
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
503
509
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
@@ -511,7 +517,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
511
517
|
@overload
|
|
512
518
|
def run_sync(
|
|
513
519
|
self,
|
|
514
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
520
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
515
521
|
*,
|
|
516
522
|
result_type: type[RunResultDataT] | None,
|
|
517
523
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -525,7 +531,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
525
531
|
|
|
526
532
|
def run_sync(
|
|
527
533
|
self,
|
|
528
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
534
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
529
535
|
*,
|
|
530
536
|
result_type: type[RunResultDataT] | None = None,
|
|
531
537
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -586,7 +592,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
586
592
|
@overload
|
|
587
593
|
def run_stream(
|
|
588
594
|
self,
|
|
589
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
595
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
590
596
|
*,
|
|
591
597
|
result_type: None = None,
|
|
592
598
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -601,7 +607,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
601
607
|
@overload
|
|
602
608
|
def run_stream(
|
|
603
609
|
self,
|
|
604
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
610
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
605
611
|
*,
|
|
606
612
|
result_type: type[RunResultDataT],
|
|
607
613
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -616,7 +622,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
616
622
|
@asynccontextmanager
|
|
617
623
|
async def run_stream( # noqa C901
|
|
618
624
|
self,
|
|
619
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
625
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
620
626
|
*,
|
|
621
627
|
result_type: type[RunResultDataT] | None = None,
|
|
622
628
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -1355,6 +1361,12 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1355
1361
|
print(nodes)
|
|
1356
1362
|
'''
|
|
1357
1363
|
[
|
|
1364
|
+
UserPromptNode(
|
|
1365
|
+
user_prompt='What is the capital of France?',
|
|
1366
|
+
system_prompts=(),
|
|
1367
|
+
system_prompt_functions=[],
|
|
1368
|
+
system_prompt_dynamic_functions={},
|
|
1369
|
+
),
|
|
1358
1370
|
ModelRequestNode(
|
|
1359
1371
|
request=ModelRequest(
|
|
1360
1372
|
parts=[
|
|
@@ -427,7 +427,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
|
|
|
427
427
|
from .cohere import CohereModel
|
|
428
428
|
|
|
429
429
|
return CohereModel(model_name, provider=provider)
|
|
430
|
-
elif provider in ('deepseek', 'openai'):
|
|
430
|
+
elif provider in ('deepseek', 'openai', 'azure'):
|
|
431
431
|
from .openai import OpenAIModel
|
|
432
432
|
|
|
433
433
|
return OpenAIModel(model_name, provider=provider)
|
|
@@ -31,7 +31,14 @@ from ..messages import (
|
|
|
31
31
|
from ..providers import Provider, infer_provider
|
|
32
32
|
from ..settings import ModelSettings
|
|
33
33
|
from ..tools import ToolDefinition
|
|
34
|
-
from . import
|
|
34
|
+
from . import (
|
|
35
|
+
Model,
|
|
36
|
+
ModelRequestParameters,
|
|
37
|
+
StreamedResponse,
|
|
38
|
+
cached_async_http_client,
|
|
39
|
+
check_allow_model_requests,
|
|
40
|
+
get_user_agent,
|
|
41
|
+
)
|
|
35
42
|
|
|
36
43
|
try:
|
|
37
44
|
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
|
|
@@ -226,10 +233,12 @@ class AnthropicModel(Model):
|
|
|
226
233
|
tools=tools or NOT_GIVEN,
|
|
227
234
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
228
235
|
stream=stream,
|
|
236
|
+
stop_sequences=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
229
237
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
230
238
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
231
239
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
232
240
|
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
|
|
241
|
+
extra_headers={'User-Agent': get_user_agent()},
|
|
233
242
|
)
|
|
234
243
|
except APIStatusError as e:
|
|
235
244
|
if (status_code := e.status_code) >= 400:
|
|
@@ -42,12 +42,14 @@ if TYPE_CHECKING:
|
|
|
42
42
|
from mypy_boto3_bedrock_runtime.type_defs import (
|
|
43
43
|
ContentBlockOutputTypeDef,
|
|
44
44
|
ContentBlockUnionTypeDef,
|
|
45
|
+
ConverseRequestTypeDef,
|
|
45
46
|
ConverseResponseTypeDef,
|
|
46
47
|
ConverseStreamMetadataEventTypeDef,
|
|
47
48
|
ConverseStreamOutputTypeDef,
|
|
48
49
|
ImageBlockTypeDef,
|
|
49
50
|
InferenceConfigurationTypeDef,
|
|
50
51
|
MessageUnionTypeDef,
|
|
52
|
+
SystemContentBlockTypeDef,
|
|
51
53
|
ToolChoiceTypeDef,
|
|
52
54
|
ToolTypeDef,
|
|
53
55
|
)
|
|
@@ -258,20 +260,19 @@ class BedrockConverseModel(Model):
|
|
|
258
260
|
else:
|
|
259
261
|
tool_choice = {'auto': {}}
|
|
260
262
|
|
|
261
|
-
system_prompt, bedrock_messages = await self.
|
|
263
|
+
system_prompt, bedrock_messages = await self._map_messages(messages)
|
|
262
264
|
inference_config = self._map_inference_config(model_settings)
|
|
263
265
|
|
|
264
|
-
params = {
|
|
266
|
+
params: ConverseRequestTypeDef = {
|
|
265
267
|
'modelId': self.model_name,
|
|
266
268
|
'messages': bedrock_messages,
|
|
267
|
-
'system':
|
|
269
|
+
'system': system_prompt,
|
|
268
270
|
'inferenceConfig': inference_config,
|
|
269
|
-
**(
|
|
270
|
-
{'toolConfig': {'tools': tools, **({'toolChoice': tool_choice} if tool_choice else {})}}
|
|
271
|
-
if tools
|
|
272
|
-
else {}
|
|
273
|
-
),
|
|
274
271
|
}
|
|
272
|
+
if tools:
|
|
273
|
+
params['toolConfig'] = {'tools': tools}
|
|
274
|
+
if tool_choice:
|
|
275
|
+
params['toolConfig']['toolChoice'] = tool_choice
|
|
275
276
|
|
|
276
277
|
if stream:
|
|
277
278
|
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
|
|
@@ -293,21 +294,22 @@ class BedrockConverseModel(Model):
|
|
|
293
294
|
inference_config['temperature'] = temperature
|
|
294
295
|
if top_p := model_settings.get('top_p'):
|
|
295
296
|
inference_config['topP'] = top_p
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
# inference_config['stopSequences'] = stop_sequences
|
|
297
|
+
if stop_sequences := model_settings.get('stop_sequences'):
|
|
298
|
+
inference_config['stopSequences'] = stop_sequences
|
|
299
299
|
|
|
300
300
|
return inference_config
|
|
301
301
|
|
|
302
|
-
async def
|
|
302
|
+
async def _map_messages(
|
|
303
|
+
self, messages: list[ModelMessage]
|
|
304
|
+
) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
|
|
303
305
|
"""Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
|
|
304
|
-
system_prompt:
|
|
306
|
+
system_prompt: list[SystemContentBlockTypeDef] = []
|
|
305
307
|
bedrock_messages: list[MessageUnionTypeDef] = []
|
|
306
308
|
for m in messages:
|
|
307
309
|
if isinstance(m, ModelRequest):
|
|
308
310
|
for part in m.parts:
|
|
309
311
|
if isinstance(part, SystemPromptPart):
|
|
310
|
-
system_prompt
|
|
312
|
+
system_prompt.append({'text': part.content})
|
|
311
313
|
elif isinstance(part, UserPromptPart):
|
|
312
314
|
bedrock_messages.extend(await self._map_user_prompt(part))
|
|
313
315
|
elif isinstance(part, ToolReturnPart):
|
|
@@ -118,7 +118,7 @@ class CohereModel(Model):
|
|
|
118
118
|
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
|
|
119
119
|
created using the other parameters.
|
|
120
120
|
"""
|
|
121
|
-
self._model_name
|
|
121
|
+
self._model_name = model_name
|
|
122
122
|
|
|
123
123
|
if isinstance(provider, str):
|
|
124
124
|
provider = infer_provider(provider)
|
|
@@ -163,6 +163,7 @@ class CohereModel(Model):
|
|
|
163
163
|
messages=cohere_messages,
|
|
164
164
|
tools=tools or OMIT,
|
|
165
165
|
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
166
|
+
stop_sequences=model_settings.get('stop_sequences', OMIT),
|
|
166
167
|
temperature=model_settings.get('temperature', OMIT),
|
|
167
168
|
p=model_settings.get('top_p', OMIT),
|
|
168
169
|
seed=model_settings.get('seed', OMIT),
|
|
@@ -31,7 +31,7 @@ from ..messages import (
|
|
|
31
31
|
from ..providers import Provider, infer_provider
|
|
32
32
|
from ..settings import ModelSettings
|
|
33
33
|
from ..tools import ToolDefinition
|
|
34
|
-
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
|
|
34
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, get_user_agent
|
|
35
35
|
|
|
36
36
|
try:
|
|
37
37
|
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
@@ -208,6 +208,7 @@ class GroqModel(Model):
|
|
|
208
208
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
209
209
|
tools=tools or NOT_GIVEN,
|
|
210
210
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
211
|
+
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
211
212
|
stream=stream,
|
|
212
213
|
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
213
214
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
@@ -217,6 +218,7 @@ class GroqModel(Model):
|
|
|
217
218
|
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
218
219
|
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
219
220
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
221
|
+
extra_headers={'User-Agent': get_user_agent()},
|
|
220
222
|
)
|
|
221
223
|
except APIStatusError as e:
|
|
222
224
|
if (status_code := e.status_code) >= 400:
|
|
@@ -39,6 +39,7 @@ from . import (
|
|
|
39
39
|
ModelRequestParameters,
|
|
40
40
|
StreamedResponse,
|
|
41
41
|
check_allow_model_requests,
|
|
42
|
+
get_user_agent,
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
try:
|
|
@@ -199,6 +200,8 @@ class MistralModel(Model):
|
|
|
199
200
|
top_p=model_settings.get('top_p', 1),
|
|
200
201
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
201
202
|
random_seed=model_settings.get('seed', UNSET),
|
|
203
|
+
stop=model_settings.get('stop_sequences', None),
|
|
204
|
+
http_headers={'User-Agent': get_user_agent()},
|
|
202
205
|
)
|
|
203
206
|
except SDKError as e:
|
|
204
207
|
if (status_code := e.status_code) >= 400:
|
|
@@ -236,6 +239,8 @@ class MistralModel(Model):
|
|
|
236
239
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
237
240
|
presence_penalty=model_settings.get('presence_penalty'),
|
|
238
241
|
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
242
|
+
stop=model_settings.get('stop_sequences', None),
|
|
243
|
+
http_headers={'User-Agent': get_user_agent()},
|
|
239
244
|
)
|
|
240
245
|
|
|
241
246
|
elif model_request_parameters.result_tools:
|
|
@@ -249,6 +254,7 @@ class MistralModel(Model):
|
|
|
249
254
|
messages=mistral_messages,
|
|
250
255
|
response_format={'type': 'json_object'},
|
|
251
256
|
stream=True,
|
|
257
|
+
http_headers={'User-Agent': get_user_agent()},
|
|
252
258
|
)
|
|
253
259
|
|
|
254
260
|
else:
|
|
@@ -257,6 +263,7 @@ class MistralModel(Model):
|
|
|
257
263
|
model=str(self._model_name),
|
|
258
264
|
messages=mistral_messages,
|
|
259
265
|
stream=True,
|
|
266
|
+
http_headers={'User-Agent': get_user_agent()},
|
|
260
267
|
)
|
|
261
268
|
assert response, 'A unexpected empty response from Mistral.'
|
|
262
269
|
return response
|
|
@@ -39,6 +39,7 @@ from . import (
|
|
|
39
39
|
StreamedResponse,
|
|
40
40
|
cached_async_http_client,
|
|
41
41
|
check_allow_model_requests,
|
|
42
|
+
get_user_agent,
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
try:
|
|
@@ -271,6 +272,7 @@ class OpenAIModel(Model):
|
|
|
271
272
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
272
273
|
stream=stream,
|
|
273
274
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
275
|
+
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
274
276
|
max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
275
277
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
276
278
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
@@ -281,6 +283,7 @@ class OpenAIModel(Model):
|
|
|
281
283
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
282
284
|
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
283
285
|
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
286
|
+
extra_headers={'User-Agent': get_user_agent()},
|
|
284
287
|
)
|
|
285
288
|
except APIStatusError as e:
|
|
286
289
|
if (status_code := e.status_code) >= 400:
|
|
@@ -611,7 +614,8 @@ class OpenAIResponsesModel(Model):
|
|
|
611
614
|
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
|
|
612
615
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
613
616
|
reasoning=reasoning,
|
|
614
|
-
user=model_settings.get('
|
|
617
|
+
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
618
|
+
extra_headers={'User-Agent': get_user_agent()},
|
|
615
619
|
)
|
|
616
620
|
except APIStatusError as e:
|
|
617
621
|
if (status_code := e.status_code) >= 400:
|
|
@@ -37,6 +37,9 @@ class WrapperModel(Model):
|
|
|
37
37
|
async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
|
|
38
38
|
yield response_stream
|
|
39
39
|
|
|
40
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
41
|
+
return self.wrapped.customize_request_parameters(model_request_parameters)
|
|
42
|
+
|
|
40
43
|
@property
|
|
41
44
|
def model_name(self) -> str:
|
|
42
45
|
return self.wrapped.model_name
|
|
@@ -52,6 +52,10 @@ def infer_provider(provider: str) -> Provider[Any]:
|
|
|
52
52
|
from .deepseek import DeepSeekProvider
|
|
53
53
|
|
|
54
54
|
return DeepSeekProvider()
|
|
55
|
+
elif provider == 'azure':
|
|
56
|
+
from .azure import AzureProvider
|
|
57
|
+
|
|
58
|
+
return AzureProvider()
|
|
55
59
|
elif provider == 'google-vertex':
|
|
56
60
|
from .google_vertex import GoogleVertexProvider
|
|
57
61
|
|
|
@@ -87,9 +87,9 @@ class AzureProvider(Provider[AsyncOpenAI]):
|
|
|
87
87
|
'Must provide one of the `azure_endpoint` argument or the `AZURE_OPENAI_ENDPOINT` environment variable'
|
|
88
88
|
)
|
|
89
89
|
|
|
90
|
-
if not api_key and '
|
|
90
|
+
if not api_key and 'AZURE_OPENAI_API_KEY' not in os.environ: # pragma: no cover
|
|
91
91
|
raise UserError(
|
|
92
|
-
'Must provide one of the `api_key` argument or the `
|
|
92
|
+
'Must provide one of the `api_key` argument or the `AZURE_OPENAI_API_KEY` environment variable'
|
|
93
93
|
)
|
|
94
94
|
|
|
95
95
|
if not api_version and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover
|
|
@@ -128,6 +128,19 @@ class ModelSettings(TypedDict, total=False):
|
|
|
128
128
|
* Groq
|
|
129
129
|
"""
|
|
130
130
|
|
|
131
|
+
stop_sequences: list[str]
|
|
132
|
+
"""Sequences that will cause the model to stop generating.
|
|
133
|
+
|
|
134
|
+
Supported by:
|
|
135
|
+
|
|
136
|
+
* OpenAI
|
|
137
|
+
* Anthropic
|
|
138
|
+
* Bedrock
|
|
139
|
+
* Mistral
|
|
140
|
+
* Groq
|
|
141
|
+
* Cohere
|
|
142
|
+
"""
|
|
143
|
+
|
|
131
144
|
|
|
132
145
|
def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
|
|
133
146
|
"""Merge two sets of model settings, preferring the overrides.
|
|
@@ -48,7 +48,7 @@ class RunContext(Generic[AgentDepsT]):
|
|
|
48
48
|
"""The model used in this run."""
|
|
49
49
|
usage: Usage
|
|
50
50
|
"""LLM usage associated with the run."""
|
|
51
|
-
prompt: str | Sequence[_messages.UserContent]
|
|
51
|
+
prompt: str | Sequence[_messages.UserContent] | None
|
|
52
52
|
"""The original user prompt passed to the run."""
|
|
53
53
|
messages: list[_messages.ModelMessage] = field(default_factory=list)
|
|
54
54
|
"""Messages exchanged in the conversation so far."""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|