pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -1
- pydantic_ai/_agent_graph.py +256 -346
- pydantic_ai/_utils.py +1 -1
- pydantic_ai/agent.py +572 -147
- pydantic_ai/messages.py +31 -0
- pydantic_ai/models/__init__.py +12 -1
- pydantic_ai/models/anthropic.py +41 -49
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +3 -3
- pydantic_ai/models/gemini.py +18 -2
- pydantic_ai/models/instrumented.py +225 -0
- pydantic_ai/models/mistral.py +0 -3
- pydantic_ai/models/openai.py +2 -5
- pydantic_ai/models/test.py +6 -6
- pydantic_ai/models/wrapper.py +45 -0
- pydantic_ai/result.py +106 -144
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.25.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.25.dist-info/RECORD +32 -0
- pydantic_ai_slim-0.0.24.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.25.dist-info}/WHEEL +0 -0
pydantic_ai/messages.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import uuid
|
|
3
4
|
from dataclasses import dataclass, field, replace
|
|
4
5
|
from datetime import datetime
|
|
5
6
|
from typing import Annotated, Any, Literal, Union, cast, overload
|
|
@@ -445,3 +446,33 @@ class PartDeltaEvent:
|
|
|
445
446
|
|
|
446
447
|
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
|
|
447
448
|
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@dataclass
|
|
452
|
+
class FunctionToolCallEvent:
|
|
453
|
+
"""An event indicating the start to a call to a function tool."""
|
|
454
|
+
|
|
455
|
+
part: ToolCallPart
|
|
456
|
+
"""The (function) tool call to make."""
|
|
457
|
+
call_id: str = field(init=False)
|
|
458
|
+
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
|
|
459
|
+
event_kind: Literal['function_tool_call'] = 'function_tool_call'
|
|
460
|
+
"""Event type identifier, used as a discriminator."""
|
|
461
|
+
|
|
462
|
+
def __post_init__(self):
|
|
463
|
+
self.call_id = self.part.tool_call_id or str(uuid.uuid4())
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@dataclass
|
|
467
|
+
class FunctionToolResultEvent:
|
|
468
|
+
"""An event indicating the result of a function tool call."""
|
|
469
|
+
|
|
470
|
+
result: ToolReturnPart | RetryPromptPart
|
|
471
|
+
"""The result of the call to the function tool."""
|
|
472
|
+
call_id: str
|
|
473
|
+
"""An ID used to match the result to its original call."""
|
|
474
|
+
event_kind: Literal['function_tool_result'] = 'function_tool_result'
|
|
475
|
+
"""Event type identifier, used as a discriminator."""
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')]
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -234,6 +234,8 @@ class StreamedResponse(ABC):
|
|
|
234
234
|
|
|
235
235
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
236
236
|
pydantic_ai-format events.
|
|
237
|
+
|
|
238
|
+
It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
|
|
237
239
|
"""
|
|
238
240
|
raise NotImplementedError()
|
|
239
241
|
# noinspection PyUnreachableCode
|
|
@@ -362,7 +364,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
362
364
|
raise UserError(f'Unknown model: {model}')
|
|
363
365
|
|
|
364
366
|
|
|
365
|
-
@cache
|
|
366
367
|
def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
367
368
|
"""Cached HTTPX async client so multiple agents and calls can share the same client.
|
|
368
369
|
|
|
@@ -373,6 +374,16 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.Asyn
|
|
|
373
374
|
The default timeouts match those of OpenAI,
|
|
374
375
|
see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
|
|
375
376
|
"""
|
|
377
|
+
client = _cached_async_http_client(timeout=timeout, connect=connect)
|
|
378
|
+
if client.is_closed:
|
|
379
|
+
# This happens if the context manager is used, so we need to create a new client.
|
|
380
|
+
_cached_async_http_client.cache_clear()
|
|
381
|
+
client = _cached_async_http_client(timeout=timeout, connect=connect)
|
|
382
|
+
return client
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@cache
|
|
386
|
+
def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
376
387
|
return httpx.AsyncClient(
|
|
377
388
|
timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
|
378
389
|
headers={'User-Agent': get_user_agent()},
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -272,64 +272,56 @@ class AnthropicModel(Model):
|
|
|
272
272
|
anthropic_messages: list[MessageParam] = []
|
|
273
273
|
for m in messages:
|
|
274
274
|
if isinstance(m, ModelRequest):
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
content=part.model_response_str(),
|
|
289
|
-
is_error=False,
|
|
290
|
-
)
|
|
291
|
-
],
|
|
292
|
-
)
|
|
275
|
+
user_content_params: list[ToolResultBlockParam | TextBlockParam] = []
|
|
276
|
+
for request_part in m.parts:
|
|
277
|
+
if isinstance(request_part, SystemPromptPart):
|
|
278
|
+
system_prompt += request_part.content
|
|
279
|
+
elif isinstance(request_part, UserPromptPart):
|
|
280
|
+
text_block_param = TextBlockParam(type='text', text=request_part.content)
|
|
281
|
+
user_content_params.append(text_block_param)
|
|
282
|
+
elif isinstance(request_part, ToolReturnPart):
|
|
283
|
+
tool_result_block_param = ToolResultBlockParam(
|
|
284
|
+
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
|
|
285
|
+
type='tool_result',
|
|
286
|
+
content=request_part.model_response_str(),
|
|
287
|
+
is_error=False,
|
|
293
288
|
)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
289
|
+
user_content_params.append(tool_result_block_param)
|
|
290
|
+
elif isinstance(request_part, RetryPromptPart):
|
|
291
|
+
if request_part.tool_name is None:
|
|
292
|
+
retry_param = TextBlockParam(type='text', text=request_part.model_response())
|
|
297
293
|
else:
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
304
|
-
type='tool_result',
|
|
305
|
-
content=part.model_response(),
|
|
306
|
-
is_error=True,
|
|
307
|
-
),
|
|
308
|
-
],
|
|
309
|
-
)
|
|
294
|
+
retry_param = ToolResultBlockParam(
|
|
295
|
+
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
|
|
296
|
+
type='tool_result',
|
|
297
|
+
content=request_part.model_response(),
|
|
298
|
+
is_error=True,
|
|
310
299
|
)
|
|
300
|
+
user_content_params.append(retry_param)
|
|
301
|
+
anthropic_messages.append(
|
|
302
|
+
MessageParam(
|
|
303
|
+
role='user',
|
|
304
|
+
content=user_content_params,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
311
307
|
elif isinstance(m, ModelResponse):
|
|
312
|
-
|
|
313
|
-
for
|
|
314
|
-
if isinstance(
|
|
315
|
-
|
|
308
|
+
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
|
|
309
|
+
for response_part in m.parts:
|
|
310
|
+
if isinstance(response_part, TextPart):
|
|
311
|
+
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
|
|
316
312
|
else:
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
313
|
+
tool_use_block_param = ToolUseBlockParam(
|
|
314
|
+
id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
|
|
315
|
+
type='tool_use',
|
|
316
|
+
name=response_part.tool_name,
|
|
317
|
+
input=response_part.args_as_dict(),
|
|
318
|
+
)
|
|
319
|
+
assistant_content_params.append(tool_use_block_param)
|
|
320
|
+
anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
|
|
320
321
|
else:
|
|
321
322
|
assert_never(m)
|
|
322
323
|
return system_prompt, anthropic_messages
|
|
323
324
|
|
|
324
|
-
@staticmethod
|
|
325
|
-
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
326
|
-
return ToolUseBlockParam(
|
|
327
|
-
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
328
|
-
type='tool_use',
|
|
329
|
-
name=t.tool_name,
|
|
330
|
-
input=t.args_as_dict(),
|
|
331
|
-
)
|
|
332
|
-
|
|
333
325
|
@staticmethod
|
|
334
326
|
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
335
327
|
return {
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -124,7 +124,7 @@ class CohereModel(Model):
|
|
|
124
124
|
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
125
125
|
self.client = cohere_client
|
|
126
126
|
else:
|
|
127
|
-
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
|
|
127
|
+
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
|
|
128
128
|
|
|
129
129
|
async def request(
|
|
130
130
|
self,
|
pydantic_ai/models/function.py
CHANGED
|
@@ -109,9 +109,9 @@ class FunctionModel(Model):
|
|
|
109
109
|
model_settings,
|
|
110
110
|
)
|
|
111
111
|
|
|
112
|
-
assert (
|
|
113
|
-
|
|
114
|
-
)
|
|
112
|
+
assert self.stream_function is not None, (
|
|
113
|
+
'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
114
|
+
)
|
|
115
115
|
|
|
116
116
|
response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
|
|
117
117
|
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -254,7 +254,7 @@ class GeminiModel(Model):
|
|
|
254
254
|
async for chunk in aiter_bytes:
|
|
255
255
|
content.extend(chunk)
|
|
256
256
|
responses = _gemini_streamed_response_ta.validate_json(
|
|
257
|
-
content,
|
|
257
|
+
_ensure_decodeable(content),
|
|
258
258
|
experimental_allow_partial='trailing-strings',
|
|
259
259
|
)
|
|
260
260
|
if responses:
|
|
@@ -370,7 +370,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
370
370
|
self._content.extend(chunk)
|
|
371
371
|
|
|
372
372
|
gemini_responses = _gemini_streamed_response_ta.validate_json(
|
|
373
|
-
self._content,
|
|
373
|
+
_ensure_decodeable(self._content),
|
|
374
374
|
experimental_allow_partial='trailing-strings',
|
|
375
375
|
)
|
|
376
376
|
|
|
@@ -774,3 +774,19 @@ class _GeminiJsonSchema:
|
|
|
774
774
|
|
|
775
775
|
if items_schema := schema.get('items'): # pragma: no branch
|
|
776
776
|
self._simplify(items_schema, refs_stack)
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def _ensure_decodeable(content: bytearray) -> bytearray:
|
|
780
|
+
"""Trim any invalid unicode point bytes off the end of a bytearray.
|
|
781
|
+
|
|
782
|
+
This is necessary before attempting to parse streaming JSON bytes.
|
|
783
|
+
|
|
784
|
+
This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved
|
|
785
|
+
"""
|
|
786
|
+
while True:
|
|
787
|
+
try:
|
|
788
|
+
content.decode()
|
|
789
|
+
except UnicodeDecodeError:
|
|
790
|
+
content = content[:-1] # this will definitely succeed before we run out of bytes
|
|
791
|
+
else:
|
|
792
|
+
return content
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator, Iterator
|
|
4
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Any, Callable, Literal
|
|
8
|
+
|
|
9
|
+
import logfire_api
|
|
10
|
+
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
|
|
11
|
+
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
|
|
12
|
+
|
|
13
|
+
from ..messages import (
|
|
14
|
+
ModelMessage,
|
|
15
|
+
ModelRequest,
|
|
16
|
+
ModelRequestPart,
|
|
17
|
+
ModelResponse,
|
|
18
|
+
RetryPromptPart,
|
|
19
|
+
SystemPromptPart,
|
|
20
|
+
TextPart,
|
|
21
|
+
ToolCallPart,
|
|
22
|
+
ToolReturnPart,
|
|
23
|
+
UserPromptPart,
|
|
24
|
+
)
|
|
25
|
+
from ..settings import ModelSettings
|
|
26
|
+
from ..usage import Usage
|
|
27
|
+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
|
|
28
|
+
from .wrapper import WrapperModel
|
|
29
|
+
|
|
30
|
+
MODEL_SETTING_ATTRIBUTES: tuple[
|
|
31
|
+
Literal[
|
|
32
|
+
'max_tokens',
|
|
33
|
+
'top_p',
|
|
34
|
+
'seed',
|
|
35
|
+
'temperature',
|
|
36
|
+
'presence_penalty',
|
|
37
|
+
'frequency_penalty',
|
|
38
|
+
],
|
|
39
|
+
...,
|
|
40
|
+
] = (
|
|
41
|
+
'max_tokens',
|
|
42
|
+
'top_p',
|
|
43
|
+
'seed',
|
|
44
|
+
'temperature',
|
|
45
|
+
'presence_penalty',
|
|
46
|
+
'frequency_penalty',
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
NOT_GIVEN = object()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class InstrumentedModel(WrapperModel):
|
|
54
|
+
"""Model which is instrumented with logfire."""
|
|
55
|
+
|
|
56
|
+
tracer: Tracer = field(repr=False)
|
|
57
|
+
event_logger: EventLogger = field(repr=False)
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
wrapped: Model | KnownModelName,
|
|
62
|
+
tracer_provider: TracerProvider | None = None,
|
|
63
|
+
event_logger_provider: EventLoggerProvider | None = None,
|
|
64
|
+
):
|
|
65
|
+
super().__init__(wrapped)
|
|
66
|
+
tracer_provider = tracer_provider or get_tracer_provider()
|
|
67
|
+
event_logger_provider = event_logger_provider or get_event_logger_provider()
|
|
68
|
+
self.tracer = tracer_provider.get_tracer('pydantic-ai')
|
|
69
|
+
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_logfire(
|
|
73
|
+
cls,
|
|
74
|
+
wrapped: Model | KnownModelName,
|
|
75
|
+
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
|
|
76
|
+
) -> InstrumentedModel:
|
|
77
|
+
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
|
|
78
|
+
event_provider = logfire_instance.config.get_event_logger_provider()
|
|
79
|
+
else:
|
|
80
|
+
event_provider = None
|
|
81
|
+
tracer_provider = logfire_instance.config.get_tracer_provider()
|
|
82
|
+
return cls(wrapped, tracer_provider, event_provider)
|
|
83
|
+
|
|
84
|
+
async def request(
|
|
85
|
+
self,
|
|
86
|
+
messages: list[ModelMessage],
|
|
87
|
+
model_settings: ModelSettings | None,
|
|
88
|
+
model_request_parameters: ModelRequestParameters,
|
|
89
|
+
) -> tuple[ModelResponse, Usage]:
|
|
90
|
+
with self._instrument(messages, model_settings) as finish:
|
|
91
|
+
response, usage = await super().request(messages, model_settings, model_request_parameters)
|
|
92
|
+
finish(response, usage)
|
|
93
|
+
return response, usage
|
|
94
|
+
|
|
95
|
+
@asynccontextmanager
|
|
96
|
+
async def request_stream(
|
|
97
|
+
self,
|
|
98
|
+
messages: list[ModelMessage],
|
|
99
|
+
model_settings: ModelSettings | None,
|
|
100
|
+
model_request_parameters: ModelRequestParameters,
|
|
101
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
102
|
+
with self._instrument(messages, model_settings) as finish:
|
|
103
|
+
response_stream: StreamedResponse | None = None
|
|
104
|
+
try:
|
|
105
|
+
async with super().request_stream(
|
|
106
|
+
messages, model_settings, model_request_parameters
|
|
107
|
+
) as response_stream:
|
|
108
|
+
yield response_stream
|
|
109
|
+
finally:
|
|
110
|
+
if response_stream:
|
|
111
|
+
finish(response_stream.get(), response_stream.usage())
|
|
112
|
+
|
|
113
|
+
@contextmanager
|
|
114
|
+
def _instrument(
|
|
115
|
+
self,
|
|
116
|
+
messages: list[ModelMessage],
|
|
117
|
+
model_settings: ModelSettings | None,
|
|
118
|
+
) -> Iterator[Callable[[ModelResponse, Usage], None]]:
|
|
119
|
+
operation = 'chat'
|
|
120
|
+
model_name = self.model_name
|
|
121
|
+
span_name = f'{operation} {model_name}'
|
|
122
|
+
system = getattr(self.wrapped, 'system', '') or self.wrapped.__class__.__name__.removesuffix('Model').lower()
|
|
123
|
+
system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
|
|
124
|
+
# TODO Missing attributes:
|
|
125
|
+
# - server.address: requires a Model.base_url abstract method or similar
|
|
126
|
+
# - server.port: to parse from the base_url
|
|
127
|
+
# - error.type: unclear if we should do something here or just always rely on span exceptions
|
|
128
|
+
# - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
|
|
129
|
+
attributes: dict[str, Any] = {
|
|
130
|
+
'gen_ai.operation.name': operation,
|
|
131
|
+
'gen_ai.system': system,
|
|
132
|
+
'gen_ai.request.model': model_name,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if model_settings:
|
|
136
|
+
for key in MODEL_SETTING_ATTRIBUTES:
|
|
137
|
+
if (value := model_settings.get(key, NOT_GIVEN)) is not NOT_GIVEN:
|
|
138
|
+
attributes[f'gen_ai.request.{key}'] = value
|
|
139
|
+
|
|
140
|
+
emit_event = partial(self._emit_event, system)
|
|
141
|
+
|
|
142
|
+
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
143
|
+
if span.is_recording():
|
|
144
|
+
for message in messages:
|
|
145
|
+
if isinstance(message, ModelRequest):
|
|
146
|
+
for part in message.parts:
|
|
147
|
+
event_name, body = _request_part_body(part)
|
|
148
|
+
if event_name:
|
|
149
|
+
emit_event(event_name, body)
|
|
150
|
+
elif isinstance(message, ModelResponse):
|
|
151
|
+
for body in _response_bodies(message):
|
|
152
|
+
emit_event('gen_ai.assistant.message', body)
|
|
153
|
+
|
|
154
|
+
def finish(response: ModelResponse, usage: Usage):
|
|
155
|
+
if not span.is_recording():
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
for response_body in _response_bodies(response):
|
|
159
|
+
if response_body:
|
|
160
|
+
emit_event(
|
|
161
|
+
'gen_ai.choice',
|
|
162
|
+
{
|
|
163
|
+
# TODO finish_reason
|
|
164
|
+
'index': 0,
|
|
165
|
+
'message': response_body,
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
span.set_attributes(
|
|
169
|
+
{
|
|
170
|
+
k: v
|
|
171
|
+
for k, v in {
|
|
172
|
+
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
|
|
173
|
+
# https://github.com/pydantic/pydantic-ai/issues/886
|
|
174
|
+
'gen_ai.response.model': response.model_name or model_name,
|
|
175
|
+
'gen_ai.usage.input_tokens': usage.request_tokens,
|
|
176
|
+
'gen_ai.usage.output_tokens': usage.response_tokens,
|
|
177
|
+
}.items()
|
|
178
|
+
if v is not None
|
|
179
|
+
}
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
yield finish
|
|
183
|
+
|
|
184
|
+
def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
|
|
185
|
+
self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
|
|
189
|
+
if isinstance(part, SystemPromptPart):
|
|
190
|
+
return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
|
|
191
|
+
elif isinstance(part, UserPromptPart):
|
|
192
|
+
return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
|
|
193
|
+
elif isinstance(part, ToolReturnPart):
|
|
194
|
+
return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
|
|
195
|
+
elif isinstance(part, RetryPromptPart):
|
|
196
|
+
if part.tool_name is None:
|
|
197
|
+
return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
|
|
198
|
+
else:
|
|
199
|
+
return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
|
|
200
|
+
else:
|
|
201
|
+
return '', {}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
|
|
205
|
+
body: dict[str, Any] = {'role': 'assistant'}
|
|
206
|
+
result = [body]
|
|
207
|
+
for part in message.parts:
|
|
208
|
+
if isinstance(part, ToolCallPart):
|
|
209
|
+
body.setdefault('tool_calls', []).append(
|
|
210
|
+
{
|
|
211
|
+
'id': part.tool_call_id,
|
|
212
|
+
'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
|
|
213
|
+
'function': {
|
|
214
|
+
'name': part.tool_name,
|
|
215
|
+
'arguments': part.args,
|
|
216
|
+
},
|
|
217
|
+
}
|
|
218
|
+
)
|
|
219
|
+
elif isinstance(part, TextPart):
|
|
220
|
+
if body.get('content'):
|
|
221
|
+
body = {'role': 'assistant'}
|
|
222
|
+
result.append(body)
|
|
223
|
+
body['content'] = part.content
|
|
224
|
+
|
|
225
|
+
return result
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -134,9 +134,6 @@ class MistralModel(Model):
|
|
|
134
134
|
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
|
|
135
135
|
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
|
|
136
136
|
|
|
137
|
-
def name(self) -> str:
|
|
138
|
-
return f'mistral:{self._model_name}'
|
|
139
|
-
|
|
140
137
|
async def request(
|
|
141
138
|
self,
|
|
142
139
|
messages: list[ModelMessage],
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -119,9 +119,9 @@ class OpenAIModel(Model):
|
|
|
119
119
|
"""
|
|
120
120
|
self._model_name = model_name
|
|
121
121
|
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
122
|
-
# openai compatible models do not always need an API key.
|
|
122
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
123
123
|
if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
|
|
124
|
-
api_key = ''
|
|
124
|
+
api_key = 'api-key-not-set'
|
|
125
125
|
|
|
126
126
|
if openai_client is not None:
|
|
127
127
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
@@ -135,9 +135,6 @@ class OpenAIModel(Model):
|
|
|
135
135
|
self.system_prompt_role = system_prompt_role
|
|
136
136
|
self._system = system
|
|
137
137
|
|
|
138
|
-
def name(self) -> str:
|
|
139
|
-
return f'openai:{self._model_name}'
|
|
140
|
-
|
|
141
138
|
async def request(
|
|
142
139
|
self,
|
|
143
140
|
messages: list[ModelMessage],
|
pydantic_ai/models/test.py
CHANGED
|
@@ -130,15 +130,15 @@ class TestModel(Model):
|
|
|
130
130
|
|
|
131
131
|
def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
|
|
132
132
|
if self.custom_result_text is not None:
|
|
133
|
-
assert (
|
|
134
|
-
|
|
135
|
-
)
|
|
133
|
+
assert model_request_parameters.allow_text_result, (
|
|
134
|
+
'Plain response not allowed, but `custom_result_text` is set.'
|
|
135
|
+
)
|
|
136
136
|
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
137
137
|
return _TextResult(self.custom_result_text)
|
|
138
138
|
elif self.custom_result_args is not None:
|
|
139
|
-
assert (
|
|
140
|
-
|
|
141
|
-
)
|
|
139
|
+
assert model_request_parameters.result_tools is not None, (
|
|
140
|
+
'No result tools provided, but `custom_result_args` is set.'
|
|
141
|
+
)
|
|
142
142
|
result_tool = model_request_parameters.result_tools[0]
|
|
143
143
|
|
|
144
144
|
if k := result_tool.outer_typed_dict_key:
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..messages import ModelMessage, ModelResponse
|
|
9
|
+
from ..settings import ModelSettings
|
|
10
|
+
from ..usage import Usage
|
|
11
|
+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(init=False)
|
|
15
|
+
class WrapperModel(Model):
|
|
16
|
+
"""Model which wraps another model."""
|
|
17
|
+
|
|
18
|
+
wrapped: Model
|
|
19
|
+
|
|
20
|
+
def __init__(self, wrapped: Model | KnownModelName):
|
|
21
|
+
self.wrapped = infer_model(wrapped)
|
|
22
|
+
|
|
23
|
+
async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
|
|
24
|
+
return await self.wrapped.request(*args, **kwargs)
|
|
25
|
+
|
|
26
|
+
@asynccontextmanager
|
|
27
|
+
async def request_stream(
|
|
28
|
+
self,
|
|
29
|
+
messages: list[ModelMessage],
|
|
30
|
+
model_settings: ModelSettings | None,
|
|
31
|
+
model_request_parameters: ModelRequestParameters,
|
|
32
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
33
|
+
async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
|
|
34
|
+
yield response_stream
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def model_name(self) -> str:
|
|
38
|
+
return self.wrapped.model_name
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def system(self) -> str | None:
|
|
42
|
+
return self.wrapped.system
|
|
43
|
+
|
|
44
|
+
def __getattr__(self, item: str):
|
|
45
|
+
return getattr(self.wrapped, item)
|