pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +323 -156
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +7 -5
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +32 -28
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +82 -51
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +93 -11
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +15 -27
- pydantic_ai/messages.py +156 -44
- pydantic_ai/models/__init__.py +20 -7
- pydantic_ai/models/anthropic.py +10 -17
- pydantic_ai/models/bedrock.py +55 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +13 -14
- pydantic_ai/models/google.py +19 -5
- pydantic_ai/models/groq.py +127 -39
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +49 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +37 -42
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +52 -31
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +127 -23
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +58 -21
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +44 -8
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
- pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,10 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
|
+
from pydantic import BaseModel, Json, ValidationError
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
13
|
+
from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
14
|
+
|
|
12
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
16
|
from .._run_context import RunContext
|
|
14
17
|
from .._thinking_part import split_content_into_text_and_thinking
|
|
@@ -48,7 +51,7 @@ from . import (
|
|
|
48
51
|
)
|
|
49
52
|
|
|
50
53
|
try:
|
|
51
|
-
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
54
|
+
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
|
|
52
55
|
from groq.types import chat
|
|
53
56
|
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
54
57
|
except ImportError as _import_error:
|
|
@@ -88,7 +91,7 @@ PreviewGroqModelNames = Literal[
|
|
|
88
91
|
]
|
|
89
92
|
"""Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
|
|
90
93
|
|
|
91
|
-
GroqModelName =
|
|
94
|
+
GroqModelName = str | ProductionGroqModelNames | PreviewGroqModelNames
|
|
92
95
|
"""Possible Groq model names.
|
|
93
96
|
|
|
94
97
|
Since Groq supports a variety of models and the list changes frequencly, we explicitly list the named models as of 2025-03-31
|
|
@@ -169,9 +172,24 @@ class GroqModel(Model):
|
|
|
169
172
|
model_request_parameters: ModelRequestParameters,
|
|
170
173
|
) -> ModelResponse:
|
|
171
174
|
check_allow_model_requests()
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
+
try:
|
|
176
|
+
response = await self._completions_create(
|
|
177
|
+
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
178
|
+
)
|
|
179
|
+
except ModelHTTPError as e:
|
|
180
|
+
if isinstance(e.body, dict): # pragma: no branch
|
|
181
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
182
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
|
|
183
|
+
try:
|
|
184
|
+
error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
|
|
185
|
+
tool_call_part = ToolCallPart(
|
|
186
|
+
tool_name=error.error.failed_generation.name,
|
|
187
|
+
args=error.error.failed_generation.arguments,
|
|
188
|
+
)
|
|
189
|
+
return ModelResponse(parts=[tool_call_part])
|
|
190
|
+
except ValidationError:
|
|
191
|
+
pass
|
|
192
|
+
raise
|
|
175
193
|
model_response = self._process_response(response)
|
|
176
194
|
return model_response
|
|
177
195
|
|
|
@@ -228,6 +246,18 @@ class GroqModel(Model):
|
|
|
228
246
|
|
|
229
247
|
groq_messages = self._map_messages(messages)
|
|
230
248
|
|
|
249
|
+
response_format: chat.completion_create_params.ResponseFormat | None = None
|
|
250
|
+
if model_request_parameters.output_mode == 'native':
|
|
251
|
+
output_object = model_request_parameters.output_object
|
|
252
|
+
assert output_object is not None
|
|
253
|
+
response_format = self._map_json_schema(output_object)
|
|
254
|
+
elif (
|
|
255
|
+
model_request_parameters.output_mode == 'prompted'
|
|
256
|
+
and not tools
|
|
257
|
+
and self.profile.supports_json_object_output
|
|
258
|
+
): # pragma: no branch
|
|
259
|
+
response_format = {'type': 'json_object'}
|
|
260
|
+
|
|
231
261
|
try:
|
|
232
262
|
extra_headers = model_settings.get('extra_headers', {})
|
|
233
263
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
@@ -240,6 +270,7 @@ class GroqModel(Model):
|
|
|
240
270
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
241
271
|
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
242
272
|
stream=stream,
|
|
273
|
+
response_format=response_format or NOT_GIVEN,
|
|
243
274
|
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
244
275
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
245
276
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
@@ -285,11 +316,11 @@ class GroqModel(Model):
|
|
|
285
316
|
for c in choice.message.tool_calls:
|
|
286
317
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
287
318
|
return ModelResponse(
|
|
288
|
-
items,
|
|
319
|
+
parts=items,
|
|
289
320
|
usage=_map_usage(response),
|
|
290
321
|
model_name=response.model,
|
|
291
322
|
timestamp=timestamp,
|
|
292
|
-
|
|
323
|
+
provider_response_id=response.id,
|
|
293
324
|
provider_name=self._provider.name,
|
|
294
325
|
)
|
|
295
326
|
|
|
@@ -347,7 +378,7 @@ class GroqModel(Model):
|
|
|
347
378
|
elif isinstance(item, ThinkingPart):
|
|
348
379
|
# Skip thinking parts when mapping to Groq messages
|
|
349
380
|
continue
|
|
350
|
-
elif isinstance(item,
|
|
381
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
351
382
|
# This is currently never returned from groq
|
|
352
383
|
pass
|
|
353
384
|
else:
|
|
@@ -385,6 +416,19 @@ class GroqModel(Model):
|
|
|
385
416
|
},
|
|
386
417
|
}
|
|
387
418
|
|
|
419
|
+
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
|
|
420
|
+
response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
|
|
421
|
+
'type': 'json_schema',
|
|
422
|
+
'json_schema': {
|
|
423
|
+
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
|
|
424
|
+
'schema': o.json_schema,
|
|
425
|
+
'strict': o.strict,
|
|
426
|
+
},
|
|
427
|
+
}
|
|
428
|
+
if o.description: # pragma: no branch
|
|
429
|
+
response_format_param['json_schema']['description'] = o.description
|
|
430
|
+
return response_format_param
|
|
431
|
+
|
|
388
432
|
@classmethod
|
|
389
433
|
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
390
434
|
for part in message.parts:
|
|
@@ -449,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
449
493
|
_provider_name: str
|
|
450
494
|
|
|
451
495
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
496
|
+
try:
|
|
497
|
+
async for chunk in self._response:
|
|
498
|
+
self._usage += _map_usage(chunk)
|
|
499
|
+
|
|
500
|
+
try:
|
|
501
|
+
choice = chunk.choices[0]
|
|
502
|
+
except IndexError:
|
|
503
|
+
continue
|
|
504
|
+
|
|
505
|
+
# Handle the text part of the response
|
|
506
|
+
content = choice.delta.content
|
|
507
|
+
if content is not None:
|
|
508
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
509
|
+
vendor_part_id='content',
|
|
510
|
+
content=content,
|
|
511
|
+
thinking_tags=self._model_profile.thinking_tags,
|
|
512
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
513
|
+
)
|
|
514
|
+
if maybe_event is not None: # pragma: no branch
|
|
515
|
+
yield maybe_event
|
|
516
|
+
|
|
517
|
+
# Handle the tool calls
|
|
518
|
+
for dtc in choice.delta.tool_calls or []:
|
|
519
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
520
|
+
vendor_part_id=dtc.index,
|
|
521
|
+
tool_name=dtc.function and dtc.function.name,
|
|
522
|
+
args=dtc.function and dtc.function.arguments,
|
|
523
|
+
tool_call_id=dtc.id,
|
|
524
|
+
)
|
|
525
|
+
if maybe_event is not None:
|
|
526
|
+
yield maybe_event
|
|
527
|
+
except APIError as e:
|
|
528
|
+
if isinstance(e.body, dict): # pragma: no branch
|
|
529
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
530
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call
|
|
531
|
+
try:
|
|
532
|
+
error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
|
|
533
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
534
|
+
vendor_part_id='tool_use_failed',
|
|
535
|
+
tool_name=error.failed_generation.name,
|
|
536
|
+
args=error.failed_generation.arguments,
|
|
537
|
+
)
|
|
538
|
+
return
|
|
539
|
+
except ValidationError as e: # pragma: no cover
|
|
540
|
+
pass
|
|
541
|
+
raise # pragma: no cover
|
|
482
542
|
|
|
483
543
|
@property
|
|
484
544
|
def model_name(self) -> GroqModelName:
|
|
@@ -510,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
510
570
|
input_tokens=response_usage.prompt_tokens,
|
|
511
571
|
output_tokens=response_usage.completion_tokens,
|
|
512
572
|
)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
class _GroqToolUseFailedGeneration(BaseModel):
|
|
576
|
+
name: str
|
|
577
|
+
arguments: dict[str, Any]
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class _GroqToolUseFailedInnerError(BaseModel):
|
|
581
|
+
message: str
|
|
582
|
+
type: Literal['invalid_request_error']
|
|
583
|
+
code: Literal['tool_use_failed']
|
|
584
|
+
failed_generation: Json[_GroqToolUseFailedGeneration]
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class _GroqToolUseFailedError(BaseModel):
|
|
588
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
589
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
|
|
590
|
+
# Example payload from `exception.body`:
|
|
591
|
+
# {
|
|
592
|
+
# 'error': {
|
|
593
|
+
# 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
|
|
594
|
+
# 'type': 'invalid_request_error',
|
|
595
|
+
# 'code': 'tool_use_failed',
|
|
596
|
+
# 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
|
|
597
|
+
# }
|
|
598
|
+
# }
|
|
599
|
+
|
|
600
|
+
error: _GroqToolUseFailedInnerError
|
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
@@ -88,7 +88,7 @@ LatestHuggingFaceModelNames = Literal[
|
|
|
88
88
|
"""Latest Hugging Face models."""
|
|
89
89
|
|
|
90
90
|
|
|
91
|
-
HuggingFaceModelName =
|
|
91
|
+
HuggingFaceModelName = str | LatestHuggingFaceModelNames
|
|
92
92
|
"""Possible Hugging Face model names.
|
|
93
93
|
|
|
94
94
|
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
@@ -267,11 +267,11 @@ class HuggingFaceModel(Model):
|
|
|
267
267
|
for c in tool_calls:
|
|
268
268
|
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
269
269
|
return ModelResponse(
|
|
270
|
-
items,
|
|
270
|
+
parts=items,
|
|
271
271
|
usage=_map_usage(response),
|
|
272
272
|
model_name=response.model,
|
|
273
273
|
timestamp=timestamp,
|
|
274
|
-
|
|
274
|
+
provider_response_id=response.id,
|
|
275
275
|
provider_name=self._provider.name,
|
|
276
276
|
)
|
|
277
277
|
|
|
@@ -320,7 +320,7 @@ class HuggingFaceModel(Model):
|
|
|
320
320
|
# please open an issue. The below code is the code to send thinking to the provider.
|
|
321
321
|
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
322
322
|
pass
|
|
323
|
-
elif isinstance(item,
|
|
323
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
324
324
|
# This is currently never returned from huggingface
|
|
325
325
|
pass
|
|
326
326
|
else:
|
|
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
4
|
import json
|
|
5
|
-
|
|
5
|
+
import warnings
|
|
6
|
+
from collections.abc import AsyncIterator, Callable, Iterator, Mapping
|
|
6
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
7
8
|
from dataclasses import dataclass, field
|
|
8
|
-
from typing import Any,
|
|
9
|
+
from typing import Any, Literal, cast
|
|
9
10
|
from urllib.parse import urlparse
|
|
10
11
|
|
|
11
12
|
from opentelemetry._events import (
|
|
@@ -93,36 +94,41 @@ class InstrumentationSettings:
|
|
|
93
94
|
def __init__(
|
|
94
95
|
self,
|
|
95
96
|
*,
|
|
96
|
-
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
97
97
|
tracer_provider: TracerProvider | None = None,
|
|
98
98
|
meter_provider: MeterProvider | None = None,
|
|
99
|
-
event_logger_provider: EventLoggerProvider | None = None,
|
|
100
99
|
include_binary_content: bool = True,
|
|
101
100
|
include_content: bool = True,
|
|
102
|
-
version: Literal[1, 2] =
|
|
101
|
+
version: Literal[1, 2] = 2,
|
|
102
|
+
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
103
|
+
event_logger_provider: EventLoggerProvider | None = None,
|
|
103
104
|
):
|
|
104
105
|
"""Create instrumentation options.
|
|
105
106
|
|
|
106
107
|
Args:
|
|
107
|
-
event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes.
|
|
108
|
-
If `'logs'`, events are emitted as OpenTelemetry log-based events.
|
|
109
108
|
tracer_provider: The OpenTelemetry tracer provider to use.
|
|
110
109
|
If not provided, the global tracer provider is used.
|
|
111
110
|
Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
|
|
112
111
|
meter_provider: The OpenTelemetry meter provider to use.
|
|
113
112
|
If not provided, the global meter provider is used.
|
|
114
113
|
Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
|
|
115
|
-
event_logger_provider: The OpenTelemetry event logger provider to use.
|
|
116
|
-
If not provided, the global event logger provider is used.
|
|
117
|
-
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
118
|
-
This is only used if `event_mode='logs'`.
|
|
119
114
|
include_binary_content: Whether to include binary content in the instrumentation events.
|
|
120
115
|
include_content: Whether to include prompts, completions, and tool call arguments and responses
|
|
121
116
|
in the instrumentation events.
|
|
122
|
-
version: Version of the data format.
|
|
123
|
-
Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
|
|
124
|
-
|
|
125
|
-
|
|
117
|
+
version: Version of the data format. This is unrelated to the Pydantic AI package version.
|
|
118
|
+
Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
|
|
119
|
+
and will be removed in a future release.
|
|
120
|
+
The parameters `event_mode` and `event_logger_provider` are only relevant for version 1.
|
|
121
|
+
Version 2 uses the newer OpenTelemetry GenAI spec and stores messages in the following attributes:
|
|
122
|
+
- `gen_ai.system_instructions` for instructions passed to the agent.
|
|
123
|
+
- `gen_ai.input.messages` and `gen_ai.output.messages` on model request spans.
|
|
124
|
+
- `pydantic_ai.all_messages` on agent run spans.
|
|
125
|
+
event_mode: The mode for emitting events in version 1.
|
|
126
|
+
If `'attributes'`, events are attached to the span as attributes.
|
|
127
|
+
If `'logs'`, events are emitted as OpenTelemetry log-based events.
|
|
128
|
+
event_logger_provider: The OpenTelemetry event logger provider to use.
|
|
129
|
+
If not provided, the global event logger provider is used.
|
|
130
|
+
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
131
|
+
This is only used if `event_mode='logs'` and `version=1`.
|
|
126
132
|
"""
|
|
127
133
|
from pydantic_ai import __version__
|
|
128
134
|
|
|
@@ -136,6 +142,14 @@ class InstrumentationSettings:
|
|
|
136
142
|
self.event_mode = event_mode
|
|
137
143
|
self.include_binary_content = include_binary_content
|
|
138
144
|
self.include_content = include_content
|
|
145
|
+
|
|
146
|
+
if event_mode == 'logs' and version != 1:
|
|
147
|
+
warnings.warn(
|
|
148
|
+
'event_mode is only relevant for version=1 which is deprecated and will be removed in a future release.',
|
|
149
|
+
stacklevel=2,
|
|
150
|
+
)
|
|
151
|
+
version = 1
|
|
152
|
+
|
|
139
153
|
self.version = version
|
|
140
154
|
|
|
141
155
|
# As specified in the OpenTelemetry GenAI metrics spec:
|
|
@@ -236,27 +250,36 @@ class InstrumentationSettings:
|
|
|
236
250
|
if response.provider_details and 'finish_reason' in response.provider_details:
|
|
237
251
|
output_message['finish_reason'] = response.provider_details['finish_reason']
|
|
238
252
|
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
253
|
+
system_instructions_attributes = self.system_instructions_attributes(instructions)
|
|
239
254
|
attributes = {
|
|
240
255
|
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
241
256
|
'gen_ai.output.messages': json.dumps([output_message]),
|
|
257
|
+
**system_instructions_attributes,
|
|
242
258
|
'logfire.json_schema': json.dumps(
|
|
243
259
|
{
|
|
244
260
|
'type': 'object',
|
|
245
261
|
'properties': {
|
|
246
262
|
'gen_ai.input.messages': {'type': 'array'},
|
|
247
263
|
'gen_ai.output.messages': {'type': 'array'},
|
|
248
|
-
**(
|
|
264
|
+
**(
|
|
265
|
+
{'gen_ai.system_instructions': {'type': 'array'}}
|
|
266
|
+
if system_instructions_attributes
|
|
267
|
+
else {}
|
|
268
|
+
),
|
|
249
269
|
'model_request_parameters': {'type': 'object'},
|
|
250
270
|
},
|
|
251
271
|
}
|
|
252
272
|
),
|
|
253
273
|
}
|
|
254
|
-
if instructions is not None:
|
|
255
|
-
attributes['gen_ai.system_instructions'] = json.dumps(
|
|
256
|
-
[_otel_messages.TextPart(type='text', content=instructions)]
|
|
257
|
-
)
|
|
258
274
|
span.set_attributes(attributes)
|
|
259
275
|
|
|
276
|
+
def system_instructions_attributes(self, instructions: str | None) -> dict[str, str]:
|
|
277
|
+
if instructions and self.include_content:
|
|
278
|
+
return {
|
|
279
|
+
'gen_ai.system_instructions': json.dumps([_otel_messages.TextPart(type='text', content=instructions)]),
|
|
280
|
+
}
|
|
281
|
+
return {}
|
|
282
|
+
|
|
260
283
|
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
261
284
|
if self.event_mode == 'logs':
|
|
262
285
|
for event in events:
|
|
@@ -357,7 +380,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
357
380
|
|
|
358
381
|
if model_settings:
|
|
359
382
|
for key in MODEL_SETTING_ATTRIBUTES:
|
|
360
|
-
if isinstance(value := model_settings.get(key),
|
|
383
|
+
if isinstance(value := model_settings.get(key), float | int):
|
|
361
384
|
attributes[f'gen_ai.request.{key}'] = value
|
|
362
385
|
|
|
363
386
|
record_metrics: Callable[[], None] | None = None
|
|
@@ -397,10 +420,15 @@ class InstrumentedModel(WrapperModel):
|
|
|
397
420
|
return
|
|
398
421
|
|
|
399
422
|
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
423
|
+
try:
|
|
424
|
+
cost_attributes = {'operation.cost': float(response.cost().total_price)}
|
|
425
|
+
except LookupError:
|
|
426
|
+
cost_attributes = {}
|
|
400
427
|
span.set_attributes(
|
|
401
428
|
{
|
|
402
429
|
**response.usage.opentelemetry_attributes(),
|
|
403
430
|
'gen_ai.response.model': response_model,
|
|
431
|
+
**cost_attributes,
|
|
404
432
|
}
|
|
405
433
|
)
|
|
406
434
|
span.update_name(f'{operation} {request_model}')
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import AsyncIterator
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
|
-
from dataclasses import dataclass
|
|
5
|
+
from dataclasses import KW_ONLY, dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
from .. import _mcp, exceptions
|
|
@@ -36,6 +36,8 @@ class MCPSamplingModel(Model):
|
|
|
36
36
|
session: ServerSession
|
|
37
37
|
"""The MCP server session to use for sampling."""
|
|
38
38
|
|
|
39
|
+
_: KW_ONLY
|
|
40
|
+
|
|
39
41
|
default_max_tokens: int = 16_384
|
|
40
42
|
"""Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens].
|
|
41
43
|
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast
|
|
9
9
|
|
|
10
10
|
import pydantic_core
|
|
11
11
|
from httpx import Timeout
|
|
@@ -79,7 +79,7 @@ try:
|
|
|
79
79
|
from mistralai.models.usermessage import UserMessage as MistralUserMessage
|
|
80
80
|
from mistralai.types.basemodel import Unset as MistralUnset
|
|
81
81
|
from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
|
|
82
|
-
except ImportError as e:
|
|
82
|
+
except ImportError as e:
|
|
83
83
|
raise ImportError(
|
|
84
84
|
'Please install `mistral` to use the Mistral model, '
|
|
85
85
|
'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
|
|
@@ -90,7 +90,7 @@ LatestMistralModelNames = Literal[
|
|
|
90
90
|
]
|
|
91
91
|
"""Latest Mistral models."""
|
|
92
92
|
|
|
93
|
-
MistralModelName =
|
|
93
|
+
MistralModelName = str | LatestMistralModelNames
|
|
94
94
|
"""Possible Mistral model names.
|
|
95
95
|
|
|
96
96
|
Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
|
|
@@ -117,7 +117,7 @@ class MistralModel(Model):
|
|
|
117
117
|
"""
|
|
118
118
|
|
|
119
119
|
client: Mistral = field(repr=False)
|
|
120
|
-
json_mode_schema_prompt: str
|
|
120
|
+
json_mode_schema_prompt: str
|
|
121
121
|
|
|
122
122
|
_model_name: MistralModelName = field(repr=False)
|
|
123
123
|
_provider: Provider[Mistral] = field(repr=False)
|
|
@@ -348,11 +348,11 @@ class MistralModel(Model):
|
|
|
348
348
|
parts.append(tool)
|
|
349
349
|
|
|
350
350
|
return ModelResponse(
|
|
351
|
-
parts,
|
|
351
|
+
parts=parts,
|
|
352
352
|
usage=_map_usage(response),
|
|
353
353
|
model_name=response.model,
|
|
354
354
|
timestamp=timestamp,
|
|
355
|
-
|
|
355
|
+
provider_response_id=response.id,
|
|
356
356
|
provider_name=self._provider.name,
|
|
357
357
|
)
|
|
358
358
|
|
|
@@ -515,7 +515,7 @@ class MistralModel(Model):
|
|
|
515
515
|
pass
|
|
516
516
|
elif isinstance(part, ToolCallPart):
|
|
517
517
|
tool_calls.append(self._map_tool_call(part))
|
|
518
|
-
elif isinstance(part,
|
|
518
|
+
elif isinstance(part, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
519
519
|
# This is currently never returned from mistral
|
|
520
520
|
pass
|
|
521
521
|
else:
|
|
@@ -576,7 +576,7 @@ class MistralModel(Model):
|
|
|
576
576
|
return MistralUserMessage(content=content)
|
|
577
577
|
|
|
578
578
|
|
|
579
|
-
MistralToolCallId =
|
|
579
|
+
MistralToolCallId = str | None
|
|
580
580
|
|
|
581
581
|
|
|
582
582
|
@dataclass
|