pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +323 -156
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +7 -5
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +32 -28
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +82 -51
- pydantic_ai/agent/__init__.py +70 -9
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +4 -2
- pydantic_ai/durable_exec/temporal/_agent.py +93 -11
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +15 -27
- pydantic_ai/messages.py +149 -42
- pydantic_ai/models/__init__.py +6 -4
- pydantic_ai/models/anthropic.py +9 -16
- pydantic_ai/models/bedrock.py +50 -56
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +12 -13
- pydantic_ai/models/google.py +18 -4
- pydantic_ai/models/groq.py +126 -38
- pydantic_ai/models/huggingface.py +4 -4
- pydantic_ai/models/instrumented.py +35 -16
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +6 -6
- pydantic_ai/models/openai.py +35 -40
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/result.py +144 -41
- pydantic_ai/retries.py +52 -31
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +127 -23
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +58 -21
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +44 -8
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
- pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
- pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,10 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
|
+
from pydantic import BaseModel, Json, ValidationError
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
13
|
+
from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
14
|
+
|
|
12
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
16
|
from .._run_context import RunContext
|
|
14
17
|
from .._thinking_part import split_content_into_text_and_thinking
|
|
@@ -48,7 +51,7 @@ from . import (
|
|
|
48
51
|
)
|
|
49
52
|
|
|
50
53
|
try:
|
|
51
|
-
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
54
|
+
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
|
|
52
55
|
from groq.types import chat
|
|
53
56
|
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
54
57
|
except ImportError as _import_error:
|
|
@@ -88,7 +91,7 @@ PreviewGroqModelNames = Literal[
|
|
|
88
91
|
]
|
|
89
92
|
"""Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
|
|
90
93
|
|
|
91
|
-
GroqModelName =
|
|
94
|
+
GroqModelName = str | ProductionGroqModelNames | PreviewGroqModelNames
|
|
92
95
|
"""Possible Groq model names.
|
|
93
96
|
|
|
94
97
|
Since Groq supports a variety of models and the list changes frequencly, we explicitly list the named models as of 2025-03-31
|
|
@@ -169,9 +172,24 @@ class GroqModel(Model):
|
|
|
169
172
|
model_request_parameters: ModelRequestParameters,
|
|
170
173
|
) -> ModelResponse:
|
|
171
174
|
check_allow_model_requests()
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
+
try:
|
|
176
|
+
response = await self._completions_create(
|
|
177
|
+
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
178
|
+
)
|
|
179
|
+
except ModelHTTPError as e:
|
|
180
|
+
if isinstance(e.body, dict): # pragma: no branch
|
|
181
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
182
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
|
|
183
|
+
try:
|
|
184
|
+
error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
|
|
185
|
+
tool_call_part = ToolCallPart(
|
|
186
|
+
tool_name=error.error.failed_generation.name,
|
|
187
|
+
args=error.error.failed_generation.arguments,
|
|
188
|
+
)
|
|
189
|
+
return ModelResponse(parts=[tool_call_part])
|
|
190
|
+
except ValidationError:
|
|
191
|
+
pass
|
|
192
|
+
raise
|
|
175
193
|
model_response = self._process_response(response)
|
|
176
194
|
return model_response
|
|
177
195
|
|
|
@@ -228,6 +246,18 @@ class GroqModel(Model):
|
|
|
228
246
|
|
|
229
247
|
groq_messages = self._map_messages(messages)
|
|
230
248
|
|
|
249
|
+
response_format: chat.completion_create_params.ResponseFormat | None = None
|
|
250
|
+
if model_request_parameters.output_mode == 'native':
|
|
251
|
+
output_object = model_request_parameters.output_object
|
|
252
|
+
assert output_object is not None
|
|
253
|
+
response_format = self._map_json_schema(output_object)
|
|
254
|
+
elif (
|
|
255
|
+
model_request_parameters.output_mode == 'prompted'
|
|
256
|
+
and not tools
|
|
257
|
+
and self.profile.supports_json_object_output
|
|
258
|
+
): # pragma: no branch
|
|
259
|
+
response_format = {'type': 'json_object'}
|
|
260
|
+
|
|
231
261
|
try:
|
|
232
262
|
extra_headers = model_settings.get('extra_headers', {})
|
|
233
263
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
@@ -240,6 +270,7 @@ class GroqModel(Model):
|
|
|
240
270
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
241
271
|
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
242
272
|
stream=stream,
|
|
273
|
+
response_format=response_format or NOT_GIVEN,
|
|
243
274
|
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
244
275
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
245
276
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
@@ -285,7 +316,7 @@ class GroqModel(Model):
|
|
|
285
316
|
for c in choice.message.tool_calls:
|
|
286
317
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
287
318
|
return ModelResponse(
|
|
288
|
-
items,
|
|
319
|
+
parts=items,
|
|
289
320
|
usage=_map_usage(response),
|
|
290
321
|
model_name=response.model,
|
|
291
322
|
timestamp=timestamp,
|
|
@@ -347,7 +378,7 @@ class GroqModel(Model):
|
|
|
347
378
|
elif isinstance(item, ThinkingPart):
|
|
348
379
|
# Skip thinking parts when mapping to Groq messages
|
|
349
380
|
continue
|
|
350
|
-
elif isinstance(item,
|
|
381
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
351
382
|
# This is currently never returned from groq
|
|
352
383
|
pass
|
|
353
384
|
else:
|
|
@@ -385,6 +416,19 @@ class GroqModel(Model):
|
|
|
385
416
|
},
|
|
386
417
|
}
|
|
387
418
|
|
|
419
|
+
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
|
|
420
|
+
response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
|
|
421
|
+
'type': 'json_schema',
|
|
422
|
+
'json_schema': {
|
|
423
|
+
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
|
|
424
|
+
'schema': o.json_schema,
|
|
425
|
+
'strict': o.strict,
|
|
426
|
+
},
|
|
427
|
+
}
|
|
428
|
+
if o.description: # pragma: no branch
|
|
429
|
+
response_format_param['json_schema']['description'] = o.description
|
|
430
|
+
return response_format_param
|
|
431
|
+
|
|
388
432
|
@classmethod
|
|
389
433
|
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
390
434
|
for part in message.parts:
|
|
@@ -449,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
449
493
|
_provider_name: str
|
|
450
494
|
|
|
451
495
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
496
|
+
try:
|
|
497
|
+
async for chunk in self._response:
|
|
498
|
+
self._usage += _map_usage(chunk)
|
|
499
|
+
|
|
500
|
+
try:
|
|
501
|
+
choice = chunk.choices[0]
|
|
502
|
+
except IndexError:
|
|
503
|
+
continue
|
|
504
|
+
|
|
505
|
+
# Handle the text part of the response
|
|
506
|
+
content = choice.delta.content
|
|
507
|
+
if content is not None:
|
|
508
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
509
|
+
vendor_part_id='content',
|
|
510
|
+
content=content,
|
|
511
|
+
thinking_tags=self._model_profile.thinking_tags,
|
|
512
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
513
|
+
)
|
|
514
|
+
if maybe_event is not None: # pragma: no branch
|
|
515
|
+
yield maybe_event
|
|
516
|
+
|
|
517
|
+
# Handle the tool calls
|
|
518
|
+
for dtc in choice.delta.tool_calls or []:
|
|
519
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
520
|
+
vendor_part_id=dtc.index,
|
|
521
|
+
tool_name=dtc.function and dtc.function.name,
|
|
522
|
+
args=dtc.function and dtc.function.arguments,
|
|
523
|
+
tool_call_id=dtc.id,
|
|
524
|
+
)
|
|
525
|
+
if maybe_event is not None:
|
|
526
|
+
yield maybe_event
|
|
527
|
+
except APIError as e:
|
|
528
|
+
if isinstance(e.body, dict): # pragma: no branch
|
|
529
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
530
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call
|
|
531
|
+
try:
|
|
532
|
+
error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
|
|
533
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
534
|
+
vendor_part_id='tool_use_failed',
|
|
535
|
+
tool_name=error.failed_generation.name,
|
|
536
|
+
args=error.failed_generation.arguments,
|
|
537
|
+
)
|
|
538
|
+
return
|
|
539
|
+
except ValidationError as e: # pragma: no cover
|
|
540
|
+
pass
|
|
541
|
+
raise # pragma: no cover
|
|
482
542
|
|
|
483
543
|
@property
|
|
484
544
|
def model_name(self) -> GroqModelName:
|
|
@@ -510,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
510
570
|
input_tokens=response_usage.prompt_tokens,
|
|
511
571
|
output_tokens=response_usage.completion_tokens,
|
|
512
572
|
)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
class _GroqToolUseFailedGeneration(BaseModel):
|
|
576
|
+
name: str
|
|
577
|
+
arguments: dict[str, Any]
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class _GroqToolUseFailedInnerError(BaseModel):
|
|
581
|
+
message: str
|
|
582
|
+
type: Literal['invalid_request_error']
|
|
583
|
+
code: Literal['tool_use_failed']
|
|
584
|
+
failed_generation: Json[_GroqToolUseFailedGeneration]
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class _GroqToolUseFailedError(BaseModel):
|
|
588
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
589
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
|
|
590
|
+
# Example payload from `exception.body`:
|
|
591
|
+
# {
|
|
592
|
+
# 'error': {
|
|
593
|
+
# 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
|
|
594
|
+
# 'type': 'invalid_request_error',
|
|
595
|
+
# 'code': 'tool_use_failed',
|
|
596
|
+
# 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
|
|
597
|
+
# }
|
|
598
|
+
# }
|
|
599
|
+
|
|
600
|
+
error: _GroqToolUseFailedInnerError
|
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
@@ -88,7 +88,7 @@ LatestHuggingFaceModelNames = Literal[
|
|
|
88
88
|
"""Latest Hugging Face models."""
|
|
89
89
|
|
|
90
90
|
|
|
91
|
-
HuggingFaceModelName =
|
|
91
|
+
HuggingFaceModelName = str | LatestHuggingFaceModelNames
|
|
92
92
|
"""Possible Hugging Face model names.
|
|
93
93
|
|
|
94
94
|
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
@@ -267,7 +267,7 @@ class HuggingFaceModel(Model):
|
|
|
267
267
|
for c in tool_calls:
|
|
268
268
|
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
269
269
|
return ModelResponse(
|
|
270
|
-
items,
|
|
270
|
+
parts=items,
|
|
271
271
|
usage=_map_usage(response),
|
|
272
272
|
model_name=response.model,
|
|
273
273
|
timestamp=timestamp,
|
|
@@ -320,7 +320,7 @@ class HuggingFaceModel(Model):
|
|
|
320
320
|
# please open an issue. The below code is the code to send thinking to the provider.
|
|
321
321
|
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
322
322
|
pass
|
|
323
|
-
elif isinstance(item,
|
|
323
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
324
324
|
# This is currently never returned from huggingface
|
|
325
325
|
pass
|
|
326
326
|
else:
|
|
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
4
|
import json
|
|
5
|
-
|
|
5
|
+
import warnings
|
|
6
|
+
from collections.abc import AsyncIterator, Callable, Iterator, Mapping
|
|
6
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
7
8
|
from dataclasses import dataclass, field
|
|
8
|
-
from typing import Any,
|
|
9
|
+
from typing import Any, Literal, cast
|
|
9
10
|
from urllib.parse import urlparse
|
|
10
11
|
|
|
11
12
|
from opentelemetry._events import (
|
|
@@ -93,36 +94,41 @@ class InstrumentationSettings:
|
|
|
93
94
|
def __init__(
|
|
94
95
|
self,
|
|
95
96
|
*,
|
|
96
|
-
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
97
97
|
tracer_provider: TracerProvider | None = None,
|
|
98
98
|
meter_provider: MeterProvider | None = None,
|
|
99
|
-
event_logger_provider: EventLoggerProvider | None = None,
|
|
100
99
|
include_binary_content: bool = True,
|
|
101
100
|
include_content: bool = True,
|
|
102
|
-
version: Literal[1, 2] =
|
|
101
|
+
version: Literal[1, 2] = 2,
|
|
102
|
+
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
103
|
+
event_logger_provider: EventLoggerProvider | None = None,
|
|
103
104
|
):
|
|
104
105
|
"""Create instrumentation options.
|
|
105
106
|
|
|
106
107
|
Args:
|
|
107
|
-
event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes.
|
|
108
|
-
If `'logs'`, events are emitted as OpenTelemetry log-based events.
|
|
109
108
|
tracer_provider: The OpenTelemetry tracer provider to use.
|
|
110
109
|
If not provided, the global tracer provider is used.
|
|
111
110
|
Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
|
|
112
111
|
meter_provider: The OpenTelemetry meter provider to use.
|
|
113
112
|
If not provided, the global meter provider is used.
|
|
114
113
|
Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
|
|
115
|
-
event_logger_provider: The OpenTelemetry event logger provider to use.
|
|
116
|
-
If not provided, the global event logger provider is used.
|
|
117
|
-
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
118
|
-
This is only used if `event_mode='logs'`.
|
|
119
114
|
include_binary_content: Whether to include binary content in the instrumentation events.
|
|
120
115
|
include_content: Whether to include prompts, completions, and tool call arguments and responses
|
|
121
116
|
in the instrumentation events.
|
|
122
|
-
version: Version of the data format.
|
|
123
|
-
Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
|
|
124
|
-
|
|
125
|
-
|
|
117
|
+
version: Version of the data format. This is unrelated to the Pydantic AI package version.
|
|
118
|
+
Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
|
|
119
|
+
and will be removed in a future release.
|
|
120
|
+
The parameters `event_mode` and `event_logger_provider` are only relevant for version 1.
|
|
121
|
+
Version 2 uses the newer OpenTelemetry GenAI spec and stores messages in the following attributes:
|
|
122
|
+
- `gen_ai.system_instructions` for instructions passed to the agent.
|
|
123
|
+
- `gen_ai.input.messages` and `gen_ai.output.messages` on model request spans.
|
|
124
|
+
- `pydantic_ai.all_messages` on agent run spans.
|
|
125
|
+
event_mode: The mode for emitting events in version 1.
|
|
126
|
+
If `'attributes'`, events are attached to the span as attributes.
|
|
127
|
+
If `'logs'`, events are emitted as OpenTelemetry log-based events.
|
|
128
|
+
event_logger_provider: The OpenTelemetry event logger provider to use.
|
|
129
|
+
If not provided, the global event logger provider is used.
|
|
130
|
+
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
131
|
+
This is only used if `event_mode='logs'` and `version=1`.
|
|
126
132
|
"""
|
|
127
133
|
from pydantic_ai import __version__
|
|
128
134
|
|
|
@@ -136,6 +142,14 @@ class InstrumentationSettings:
|
|
|
136
142
|
self.event_mode = event_mode
|
|
137
143
|
self.include_binary_content = include_binary_content
|
|
138
144
|
self.include_content = include_content
|
|
145
|
+
|
|
146
|
+
if event_mode == 'logs' and version != 1:
|
|
147
|
+
warnings.warn(
|
|
148
|
+
'event_mode is only relevant for version=1 which is deprecated and will be removed in a future release.',
|
|
149
|
+
stacklevel=2,
|
|
150
|
+
)
|
|
151
|
+
version = 1
|
|
152
|
+
|
|
139
153
|
self.version = version
|
|
140
154
|
|
|
141
155
|
# As specified in the OpenTelemetry GenAI metrics spec:
|
|
@@ -366,7 +380,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
366
380
|
|
|
367
381
|
if model_settings:
|
|
368
382
|
for key in MODEL_SETTING_ATTRIBUTES:
|
|
369
|
-
if isinstance(value := model_settings.get(key),
|
|
383
|
+
if isinstance(value := model_settings.get(key), float | int):
|
|
370
384
|
attributes[f'gen_ai.request.{key}'] = value
|
|
371
385
|
|
|
372
386
|
record_metrics: Callable[[], None] | None = None
|
|
@@ -406,10 +420,15 @@ class InstrumentedModel(WrapperModel):
|
|
|
406
420
|
return
|
|
407
421
|
|
|
408
422
|
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
423
|
+
try:
|
|
424
|
+
cost_attributes = {'operation.cost': float(response.cost().total_price)}
|
|
425
|
+
except LookupError:
|
|
426
|
+
cost_attributes = {}
|
|
409
427
|
span.set_attributes(
|
|
410
428
|
{
|
|
411
429
|
**response.usage.opentelemetry_attributes(),
|
|
412
430
|
'gen_ai.response.model': response_model,
|
|
431
|
+
**cost_attributes,
|
|
413
432
|
}
|
|
414
433
|
)
|
|
415
434
|
span.update_name(f'{operation} {request_model}')
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import AsyncIterator
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
|
-
from dataclasses import dataclass
|
|
5
|
+
from dataclasses import KW_ONLY, dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
from .. import _mcp, exceptions
|
|
@@ -36,6 +36,8 @@ class MCPSamplingModel(Model):
|
|
|
36
36
|
session: ServerSession
|
|
37
37
|
"""The MCP server session to use for sampling."""
|
|
38
38
|
|
|
39
|
+
_: KW_ONLY
|
|
40
|
+
|
|
39
41
|
default_max_tokens: int = 16_384
|
|
40
42
|
"""Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens].
|
|
41
43
|
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast
|
|
9
9
|
|
|
10
10
|
import pydantic_core
|
|
11
11
|
from httpx import Timeout
|
|
@@ -90,7 +90,7 @@ LatestMistralModelNames = Literal[
|
|
|
90
90
|
]
|
|
91
91
|
"""Latest Mistral models."""
|
|
92
92
|
|
|
93
|
-
MistralModelName =
|
|
93
|
+
MistralModelName = str | LatestMistralModelNames
|
|
94
94
|
"""Possible Mistral model names.
|
|
95
95
|
|
|
96
96
|
Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
|
|
@@ -117,7 +117,7 @@ class MistralModel(Model):
|
|
|
117
117
|
"""
|
|
118
118
|
|
|
119
119
|
client: Mistral = field(repr=False)
|
|
120
|
-
json_mode_schema_prompt: str
|
|
120
|
+
json_mode_schema_prompt: str
|
|
121
121
|
|
|
122
122
|
_model_name: MistralModelName = field(repr=False)
|
|
123
123
|
_provider: Provider[Mistral] = field(repr=False)
|
|
@@ -348,7 +348,7 @@ class MistralModel(Model):
|
|
|
348
348
|
parts.append(tool)
|
|
349
349
|
|
|
350
350
|
return ModelResponse(
|
|
351
|
-
parts,
|
|
351
|
+
parts=parts,
|
|
352
352
|
usage=_map_usage(response),
|
|
353
353
|
model_name=response.model,
|
|
354
354
|
timestamp=timestamp,
|
|
@@ -515,7 +515,7 @@ class MistralModel(Model):
|
|
|
515
515
|
pass
|
|
516
516
|
elif isinstance(part, ToolCallPart):
|
|
517
517
|
tool_calls.append(self._map_tool_call(part))
|
|
518
|
-
elif isinstance(part,
|
|
518
|
+
elif isinstance(part, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
519
519
|
# This is currently never returned from mistral
|
|
520
520
|
pass
|
|
521
521
|
else:
|
|
@@ -576,7 +576,7 @@ class MistralModel(Model):
|
|
|
576
576
|
return MistralUserMessage(content=content)
|
|
577
577
|
|
|
578
578
|
|
|
579
|
-
MistralToolCallId =
|
|
579
|
+
MistralToolCallId = str | None
|
|
580
580
|
|
|
581
581
|
|
|
582
582
|
@dataclass
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
|
-
from typing import Any, Literal,
|
|
9
|
+
from typing import Any, Literal, cast, overload
|
|
10
10
|
|
|
11
11
|
from pydantic import ValidationError
|
|
12
12
|
from typing_extensions import assert_never, deprecated
|
|
@@ -90,7 +90,7 @@ __all__ = (
|
|
|
90
90
|
'OpenAIModelName',
|
|
91
91
|
)
|
|
92
92
|
|
|
93
|
-
OpenAIModelName =
|
|
93
|
+
OpenAIModelName = str | AllModels
|
|
94
94
|
"""
|
|
95
95
|
Possible OpenAI model names.
|
|
96
96
|
|
|
@@ -225,6 +225,7 @@ class OpenAIChatModel(Model):
|
|
|
225
225
|
'openrouter',
|
|
226
226
|
'together',
|
|
227
227
|
'vercel',
|
|
228
|
+
'litellm',
|
|
228
229
|
]
|
|
229
230
|
| Provider[AsyncOpenAI] = 'openai',
|
|
230
231
|
profile: ModelProfileSpec | None = None,
|
|
@@ -252,6 +253,7 @@ class OpenAIChatModel(Model):
|
|
|
252
253
|
'openrouter',
|
|
253
254
|
'together',
|
|
254
255
|
'vercel',
|
|
256
|
+
'litellm',
|
|
255
257
|
]
|
|
256
258
|
| Provider[AsyncOpenAI] = 'openai',
|
|
257
259
|
profile: ModelProfileSpec | None = None,
|
|
@@ -278,6 +280,7 @@ class OpenAIChatModel(Model):
|
|
|
278
280
|
'openrouter',
|
|
279
281
|
'together',
|
|
280
282
|
'vercel',
|
|
283
|
+
'litellm',
|
|
281
284
|
]
|
|
282
285
|
| Provider[AsyncOpenAI] = 'openai',
|
|
283
286
|
profile: ModelProfileSpec | None = None,
|
|
@@ -409,13 +412,6 @@ class OpenAIChatModel(Model):
|
|
|
409
412
|
for setting in unsupported_model_settings:
|
|
410
413
|
model_settings.pop(setting, None)
|
|
411
414
|
|
|
412
|
-
# TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
|
|
413
|
-
sampling_settings = (
|
|
414
|
-
model_settings
|
|
415
|
-
if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
|
|
416
|
-
else OpenAIChatModelSettings()
|
|
417
|
-
)
|
|
418
|
-
|
|
419
415
|
try:
|
|
420
416
|
extra_headers = model_settings.get('extra_headers', {})
|
|
421
417
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
@@ -437,13 +433,13 @@ class OpenAIChatModel(Model):
|
|
|
437
433
|
web_search_options=web_search_options or NOT_GIVEN,
|
|
438
434
|
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
|
|
439
435
|
prediction=model_settings.get('openai_prediction', NOT_GIVEN),
|
|
440
|
-
temperature=
|
|
441
|
-
top_p=
|
|
442
|
-
presence_penalty=
|
|
443
|
-
frequency_penalty=
|
|
444
|
-
logit_bias=
|
|
445
|
-
logprobs=
|
|
446
|
-
top_logprobs=
|
|
436
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
437
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
438
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
439
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
440
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
441
|
+
logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
|
|
442
|
+
top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
|
|
447
443
|
extra_headers=extra_headers,
|
|
448
444
|
extra_body=model_settings.get('extra_body'),
|
|
449
445
|
)
|
|
@@ -512,7 +508,7 @@ class OpenAIChatModel(Model):
|
|
|
512
508
|
part.tool_call_id = _guard_tool_call_id(part)
|
|
513
509
|
items.append(part)
|
|
514
510
|
return ModelResponse(
|
|
515
|
-
items,
|
|
511
|
+
parts=items,
|
|
516
512
|
usage=_map_usage(response),
|
|
517
513
|
model_name=response.model,
|
|
518
514
|
timestamp=timestamp,
|
|
@@ -582,7 +578,7 @@ class OpenAIChatModel(Model):
|
|
|
582
578
|
elif isinstance(item, ToolCallPart):
|
|
583
579
|
tool_calls.append(self._map_tool_call(item))
|
|
584
580
|
# OpenAI doesn't return built-in tool calls
|
|
585
|
-
elif isinstance(item,
|
|
581
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
586
582
|
pass
|
|
587
583
|
else:
|
|
588
584
|
assert_never(item)
|
|
@@ -613,7 +609,7 @@ class OpenAIChatModel(Model):
|
|
|
613
609
|
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
|
|
614
610
|
response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
|
|
615
611
|
'type': 'json_schema',
|
|
616
|
-
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema
|
|
612
|
+
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
|
|
617
613
|
}
|
|
618
614
|
if o.description:
|
|
619
615
|
response_format_param['json_schema']['description'] = o.description
|
|
@@ -828,7 +824,7 @@ class OpenAIResponsesModel(Model):
|
|
|
828
824
|
elif item.type == 'function_call':
|
|
829
825
|
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
|
|
830
826
|
return ModelResponse(
|
|
831
|
-
items,
|
|
827
|
+
parts=items,
|
|
832
828
|
usage=_map_usage(response),
|
|
833
829
|
model_name=response.model,
|
|
834
830
|
provider_response_id=response.id,
|
|
@@ -918,11 +914,9 @@ class OpenAIResponsesModel(Model):
|
|
|
918
914
|
text = text or {}
|
|
919
915
|
text['verbosity'] = verbosity
|
|
920
916
|
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
else OpenAIResponsesModelSettings()
|
|
925
|
-
)
|
|
917
|
+
unsupported_model_settings = OpenAIModelProfile.from_profile(self.profile).openai_unsupported_model_settings
|
|
918
|
+
for setting in unsupported_model_settings:
|
|
919
|
+
model_settings.pop(setting, None)
|
|
926
920
|
|
|
927
921
|
try:
|
|
928
922
|
extra_headers = model_settings.get('extra_headers', {})
|
|
@@ -936,8 +930,8 @@ class OpenAIResponsesModel(Model):
|
|
|
936
930
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
937
931
|
max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
938
932
|
stream=stream,
|
|
939
|
-
temperature=
|
|
940
|
-
top_p=
|
|
933
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
934
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
941
935
|
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
|
|
942
936
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
943
937
|
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
|
|
@@ -1049,7 +1043,7 @@ class OpenAIResponsesModel(Model):
|
|
|
1049
1043
|
elif isinstance(item, ToolCallPart):
|
|
1050
1044
|
openai_messages.append(self._map_tool_call(item))
|
|
1051
1045
|
# OpenAI doesn't return built-in tool calls
|
|
1052
|
-
elif isinstance(item,
|
|
1046
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
|
|
1053
1047
|
pass
|
|
1054
1048
|
elif isinstance(item, ThinkingPart):
|
|
1055
1049
|
# NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
|
|
@@ -1180,6 +1174,10 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1180
1174
|
except IndexError:
|
|
1181
1175
|
continue
|
|
1182
1176
|
|
|
1177
|
+
# When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
|
|
1178
|
+
if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
|
|
1179
|
+
continue
|
|
1180
|
+
|
|
1183
1181
|
# Handle the text part of the response
|
|
1184
1182
|
content = choice.delta.content
|
|
1185
1183
|
if content is not None:
|
|
@@ -1279,12 +1277,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1279
1277
|
tool_call_id=chunk.item.call_id,
|
|
1280
1278
|
)
|
|
1281
1279
|
elif isinstance(chunk.item, responses.ResponseReasoningItem):
|
|
1282
|
-
|
|
1283
|
-
yield self._parts_manager.handle_thinking_delta(
|
|
1284
|
-
vendor_part_id=chunk.item.id,
|
|
1285
|
-
content=content,
|
|
1286
|
-
signature=chunk.item.id,
|
|
1287
|
-
)
|
|
1280
|
+
pass
|
|
1288
1281
|
elif isinstance(chunk.item, responses.ResponseOutputMessage):
|
|
1289
1282
|
pass
|
|
1290
1283
|
elif isinstance(chunk.item, responses.ResponseFunctionWebSearch):
|
|
@@ -1300,7 +1293,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1300
1293
|
pass
|
|
1301
1294
|
|
|
1302
1295
|
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
|
|
1303
|
-
|
|
1296
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
1297
|
+
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
|
|
1298
|
+
content=chunk.part.text,
|
|
1299
|
+
id=chunk.item_id,
|
|
1300
|
+
)
|
|
1304
1301
|
|
|
1305
1302
|
elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
|
|
1306
1303
|
pass # there's nothing we need to do here
|
|
@@ -1310,9 +1307,9 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1310
1307
|
|
|
1311
1308
|
elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
|
|
1312
1309
|
yield self._parts_manager.handle_thinking_delta(
|
|
1313
|
-
vendor_part_id=chunk.item_id,
|
|
1310
|
+
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
|
|
1314
1311
|
content=chunk.delta,
|
|
1315
|
-
|
|
1312
|
+
id=chunk.item_id,
|
|
1316
1313
|
)
|
|
1317
1314
|
|
|
1318
1315
|
# TODO(Marcelo): We should support annotations in the future.
|
|
@@ -1320,9 +1317,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1320
1317
|
pass # there's nothing we need to do here
|
|
1321
1318
|
|
|
1322
1319
|
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
|
|
1323
|
-
maybe_event = self._parts_manager.handle_text_delta(
|
|
1324
|
-
vendor_part_id=chunk.content_index, content=chunk.delta
|
|
1325
|
-
)
|
|
1320
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=chunk.item_id, content=chunk.delta)
|
|
1326
1321
|
if maybe_event is not None: # pragma: no branch
|
|
1327
1322
|
yield maybe_event
|
|
1328
1323
|
|