pydantic-ai-slim 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -1
- pydantic_ai/_agent_graph.py +256 -346
- pydantic_ai/_utils.py +1 -1
- pydantic_ai/agent.py +574 -149
- pydantic_ai/messages.py +31 -0
- pydantic_ai/models/__init__.py +29 -13
- pydantic_ai/models/anthropic.py +60 -50
- pydantic_ai/models/cohere.py +11 -1
- pydantic_ai/models/function.py +21 -3
- pydantic_ai/models/gemini.py +40 -3
- pydantic_ai/models/groq.py +19 -1
- pydantic_ai/models/instrumented.py +225 -0
- pydantic_ai/models/mistral.py +19 -4
- pydantic_ai/models/openai.py +23 -7
- pydantic_ai/models/test.py +24 -7
- pydantic_ai/models/vertexai.py +10 -0
- pydantic_ai/models/wrapper.py +45 -0
- pydantic_ai/result.py +107 -145
- {pydantic_ai_slim-0.0.23.dist-info → pydantic_ai_slim-0.0.25.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.25.dist-info/RECORD +32 -0
- pydantic_ai_slim-0.0.23.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.23.dist-info → pydantic_ai_slim-0.0.25.dist-info}/WHEEL +0 -0
pydantic_ai/messages.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import uuid
|
|
3
4
|
from dataclasses import dataclass, field, replace
|
|
4
5
|
from datetime import datetime
|
|
5
6
|
from typing import Annotated, Any, Literal, Union, cast, overload
|
|
@@ -445,3 +446,33 @@ class PartDeltaEvent:
|
|
|
445
446
|
|
|
446
447
|
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
|
|
447
448
|
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@dataclass
|
|
452
|
+
class FunctionToolCallEvent:
|
|
453
|
+
"""An event indicating the start to a call to a function tool."""
|
|
454
|
+
|
|
455
|
+
part: ToolCallPart
|
|
456
|
+
"""The (function) tool call to make."""
|
|
457
|
+
call_id: str = field(init=False)
|
|
458
|
+
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
|
|
459
|
+
event_kind: Literal['function_tool_call'] = 'function_tool_call'
|
|
460
|
+
"""Event type identifier, used as a discriminator."""
|
|
461
|
+
|
|
462
|
+
def __post_init__(self):
|
|
463
|
+
self.call_id = self.part.tool_call_id or str(uuid.uuid4())
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@dataclass
|
|
467
|
+
class FunctionToolResultEvent:
|
|
468
|
+
"""An event indicating the result of a function tool call."""
|
|
469
|
+
|
|
470
|
+
result: ToolReturnPart | RetryPromptPart
|
|
471
|
+
"""The result of the call to the function tool."""
|
|
472
|
+
call_id: str
|
|
473
|
+
"""An ID used to match the result to its original call."""
|
|
474
|
+
event_kind: Literal['function_tool_result'] = 'function_tool_result'
|
|
475
|
+
"""Event type identifier, used as a discriminator."""
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')]
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -54,6 +54,8 @@ KnownModelName = Literal[
|
|
|
54
54
|
'google-gla:gemini-2.0-flash-exp',
|
|
55
55
|
'google-gla:gemini-2.0-flash-thinking-exp-01-21',
|
|
56
56
|
'google-gla:gemini-exp-1206',
|
|
57
|
+
'google-gla:gemini-2.0-flash',
|
|
58
|
+
'google-gla:gemini-2.0-flash-lite-preview-02-05',
|
|
57
59
|
'google-vertex:gemini-1.0-pro',
|
|
58
60
|
'google-vertex:gemini-1.5-flash',
|
|
59
61
|
'google-vertex:gemini-1.5-flash-8b',
|
|
@@ -61,6 +63,8 @@ KnownModelName = Literal[
|
|
|
61
63
|
'google-vertex:gemini-2.0-flash-exp',
|
|
62
64
|
'google-vertex:gemini-2.0-flash-thinking-exp-01-21',
|
|
63
65
|
'google-vertex:gemini-exp-1206',
|
|
66
|
+
'google-vertex:gemini-2.0-flash',
|
|
67
|
+
'google-vertex:gemini-2.0-flash-lite-preview-02-05',
|
|
64
68
|
'gpt-3.5-turbo',
|
|
65
69
|
'gpt-3.5-turbo-0125',
|
|
66
70
|
'gpt-3.5-turbo-0301',
|
|
@@ -173,9 +177,6 @@ class ModelRequestParameters:
|
|
|
173
177
|
class Model(ABC):
|
|
174
178
|
"""Abstract class for a model."""
|
|
175
179
|
|
|
176
|
-
_model_name: str
|
|
177
|
-
_system: str | None
|
|
178
|
-
|
|
179
180
|
@abstractmethod
|
|
180
181
|
async def request(
|
|
181
182
|
self,
|
|
@@ -201,24 +202,25 @@ class Model(ABC):
|
|
|
201
202
|
yield # pragma: no cover
|
|
202
203
|
|
|
203
204
|
@property
|
|
205
|
+
@abstractmethod
|
|
204
206
|
def model_name(self) -> str:
|
|
205
207
|
"""The model name."""
|
|
206
|
-
|
|
208
|
+
raise NotImplementedError()
|
|
207
209
|
|
|
208
210
|
@property
|
|
211
|
+
@abstractmethod
|
|
209
212
|
def system(self) -> str | None:
|
|
210
213
|
"""The system / model provider, ex: openai."""
|
|
211
|
-
|
|
214
|
+
raise NotImplementedError()
|
|
212
215
|
|
|
213
216
|
|
|
214
217
|
@dataclass
|
|
215
218
|
class StreamedResponse(ABC):
|
|
216
219
|
"""Streamed response from an LLM when calling a tool."""
|
|
217
220
|
|
|
218
|
-
_model_name: str
|
|
219
|
-
_usage: Usage = field(default_factory=Usage, init=False)
|
|
220
221
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
221
222
|
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
223
|
+
_usage: Usage = field(default_factory=Usage, init=False)
|
|
222
224
|
|
|
223
225
|
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
224
226
|
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
|
|
@@ -232,6 +234,8 @@ class StreamedResponse(ABC):
|
|
|
232
234
|
|
|
233
235
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
234
236
|
pydantic_ai-format events.
|
|
237
|
+
|
|
238
|
+
It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
|
|
235
239
|
"""
|
|
236
240
|
raise NotImplementedError()
|
|
237
241
|
# noinspection PyUnreachableCode
|
|
@@ -240,17 +244,20 @@ class StreamedResponse(ABC):
|
|
|
240
244
|
def get(self) -> ModelResponse:
|
|
241
245
|
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
|
|
242
246
|
return ModelResponse(
|
|
243
|
-
parts=self._parts_manager.get_parts(), model_name=self.
|
|
247
|
+
parts=self._parts_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp
|
|
244
248
|
)
|
|
245
249
|
|
|
246
|
-
def model_name(self) -> str:
|
|
247
|
-
"""Get the model name of the response."""
|
|
248
|
-
return self._model_name
|
|
249
|
-
|
|
250
250
|
def usage(self) -> Usage:
|
|
251
251
|
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
252
252
|
return self._usage
|
|
253
253
|
|
|
254
|
+
@property
|
|
255
|
+
@abstractmethod
|
|
256
|
+
def model_name(self) -> str:
|
|
257
|
+
"""Get the model name of the response."""
|
|
258
|
+
raise NotImplementedError()
|
|
259
|
+
|
|
260
|
+
@property
|
|
254
261
|
@abstractmethod
|
|
255
262
|
def timestamp(self) -> datetime:
|
|
256
263
|
"""Get the timestamp of the response."""
|
|
@@ -357,7 +364,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
357
364
|
raise UserError(f'Unknown model: {model}')
|
|
358
365
|
|
|
359
366
|
|
|
360
|
-
@cache
|
|
361
367
|
def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
362
368
|
"""Cached HTTPX async client so multiple agents and calls can share the same client.
|
|
363
369
|
|
|
@@ -368,6 +374,16 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.Asyn
|
|
|
368
374
|
The default timeouts match those of OpenAI,
|
|
369
375
|
see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
|
|
370
376
|
"""
|
|
377
|
+
client = _cached_async_http_client(timeout=timeout, connect=connect)
|
|
378
|
+
if client.is_closed:
|
|
379
|
+
# This happens if the context manager is used, so we need to create a new client.
|
|
380
|
+
_cached_async_http_client.cache_clear()
|
|
381
|
+
client = _cached_async_http_client(timeout=timeout, connect=connect)
|
|
382
|
+
return client
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@cache
|
|
386
|
+
def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
371
387
|
return httpx.AsyncClient(
|
|
372
388
|
timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
|
373
389
|
headers={'User-Agent': get_user_agent()},
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -162,6 +162,16 @@ class AnthropicModel(Model):
|
|
|
162
162
|
async with response:
|
|
163
163
|
yield await self._process_streamed_response(response)
|
|
164
164
|
|
|
165
|
+
@property
|
|
166
|
+
def model_name(self) -> AnthropicModelName:
|
|
167
|
+
"""The model name."""
|
|
168
|
+
return self._model_name
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def system(self) -> str | None:
|
|
172
|
+
"""The system / model provider."""
|
|
173
|
+
return self._system
|
|
174
|
+
|
|
165
175
|
@overload
|
|
166
176
|
async def _messages_create(
|
|
167
177
|
self,
|
|
@@ -236,7 +246,7 @@ class AnthropicModel(Model):
|
|
|
236
246
|
)
|
|
237
247
|
)
|
|
238
248
|
|
|
239
|
-
return ModelResponse(items, model_name=
|
|
249
|
+
return ModelResponse(items, model_name=response.model)
|
|
240
250
|
|
|
241
251
|
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
242
252
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
@@ -262,64 +272,56 @@ class AnthropicModel(Model):
|
|
|
262
272
|
anthropic_messages: list[MessageParam] = []
|
|
263
273
|
for m in messages:
|
|
264
274
|
if isinstance(m, ModelRequest):
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
content=part.model_response_str(),
|
|
279
|
-
is_error=False,
|
|
280
|
-
)
|
|
281
|
-
],
|
|
282
|
-
)
|
|
275
|
+
user_content_params: list[ToolResultBlockParam | TextBlockParam] = []
|
|
276
|
+
for request_part in m.parts:
|
|
277
|
+
if isinstance(request_part, SystemPromptPart):
|
|
278
|
+
system_prompt += request_part.content
|
|
279
|
+
elif isinstance(request_part, UserPromptPart):
|
|
280
|
+
text_block_param = TextBlockParam(type='text', text=request_part.content)
|
|
281
|
+
user_content_params.append(text_block_param)
|
|
282
|
+
elif isinstance(request_part, ToolReturnPart):
|
|
283
|
+
tool_result_block_param = ToolResultBlockParam(
|
|
284
|
+
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
|
|
285
|
+
type='tool_result',
|
|
286
|
+
content=request_part.model_response_str(),
|
|
287
|
+
is_error=False,
|
|
283
288
|
)
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
289
|
+
user_content_params.append(tool_result_block_param)
|
|
290
|
+
elif isinstance(request_part, RetryPromptPart):
|
|
291
|
+
if request_part.tool_name is None:
|
|
292
|
+
retry_param = TextBlockParam(type='text', text=request_part.model_response())
|
|
287
293
|
else:
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
294
|
-
type='tool_result',
|
|
295
|
-
content=part.model_response(),
|
|
296
|
-
is_error=True,
|
|
297
|
-
),
|
|
298
|
-
],
|
|
299
|
-
)
|
|
294
|
+
retry_param = ToolResultBlockParam(
|
|
295
|
+
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
|
|
296
|
+
type='tool_result',
|
|
297
|
+
content=request_part.model_response(),
|
|
298
|
+
is_error=True,
|
|
300
299
|
)
|
|
300
|
+
user_content_params.append(retry_param)
|
|
301
|
+
anthropic_messages.append(
|
|
302
|
+
MessageParam(
|
|
303
|
+
role='user',
|
|
304
|
+
content=user_content_params,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
301
307
|
elif isinstance(m, ModelResponse):
|
|
302
|
-
|
|
303
|
-
for
|
|
304
|
-
if isinstance(
|
|
305
|
-
|
|
308
|
+
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
|
|
309
|
+
for response_part in m.parts:
|
|
310
|
+
if isinstance(response_part, TextPart):
|
|
311
|
+
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
|
|
306
312
|
else:
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
313
|
+
tool_use_block_param = ToolUseBlockParam(
|
|
314
|
+
id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
|
|
315
|
+
type='tool_use',
|
|
316
|
+
name=response_part.tool_name,
|
|
317
|
+
input=response_part.args_as_dict(),
|
|
318
|
+
)
|
|
319
|
+
assistant_content_params.append(tool_use_block_param)
|
|
320
|
+
anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
|
|
310
321
|
else:
|
|
311
322
|
assert_never(m)
|
|
312
323
|
return system_prompt, anthropic_messages
|
|
313
324
|
|
|
314
|
-
@staticmethod
|
|
315
|
-
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
316
|
-
return ToolUseBlockParam(
|
|
317
|
-
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
318
|
-
type='tool_use',
|
|
319
|
-
name=t.tool_name,
|
|
320
|
-
input=t.args_as_dict(),
|
|
321
|
-
)
|
|
322
|
-
|
|
323
325
|
@staticmethod
|
|
324
326
|
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
325
327
|
return {
|
|
@@ -362,6 +364,7 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
|
|
|
362
364
|
class AnthropicStreamedResponse(StreamedResponse):
|
|
363
365
|
"""Implementation of `StreamedResponse` for Anthropic models."""
|
|
364
366
|
|
|
367
|
+
_model_name: AnthropicModelName
|
|
365
368
|
_response: AsyncIterable[RawMessageStreamEvent]
|
|
366
369
|
_timestamp: datetime
|
|
367
370
|
|
|
@@ -414,5 +417,12 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
414
417
|
elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
|
|
415
418
|
current_block = None
|
|
416
419
|
|
|
420
|
+
@property
|
|
421
|
+
def model_name(self) -> AnthropicModelName:
|
|
422
|
+
"""Get the model name of the response."""
|
|
423
|
+
return self._model_name
|
|
424
|
+
|
|
425
|
+
@property
|
|
417
426
|
def timestamp(self) -> datetime:
|
|
427
|
+
"""Get the timestamp of the response."""
|
|
418
428
|
return self._timestamp
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -124,7 +124,7 @@ class CohereModel(Model):
|
|
|
124
124
|
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
125
125
|
self.client = cohere_client
|
|
126
126
|
else:
|
|
127
|
-
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
|
|
127
|
+
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
|
|
128
128
|
|
|
129
129
|
async def request(
|
|
130
130
|
self,
|
|
@@ -136,6 +136,16 @@ class CohereModel(Model):
|
|
|
136
136
|
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
|
|
137
137
|
return self._process_response(response), _map_usage(response)
|
|
138
138
|
|
|
139
|
+
@property
|
|
140
|
+
def model_name(self) -> CohereModelName:
|
|
141
|
+
"""The model name."""
|
|
142
|
+
return self._model_name
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def system(self) -> str | None:
|
|
146
|
+
"""The system / model provider."""
|
|
147
|
+
return self._system
|
|
148
|
+
|
|
139
149
|
async def _chat(
|
|
140
150
|
self,
|
|
141
151
|
messages: list[ModelMessage],
|
pydantic_ai/models/function.py
CHANGED
|
@@ -109,9 +109,9 @@ class FunctionModel(Model):
|
|
|
109
109
|
model_settings,
|
|
110
110
|
)
|
|
111
111
|
|
|
112
|
-
assert (
|
|
113
|
-
|
|
114
|
-
)
|
|
112
|
+
assert self.stream_function is not None, (
|
|
113
|
+
'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
114
|
+
)
|
|
115
115
|
|
|
116
116
|
response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
|
|
117
117
|
|
|
@@ -121,6 +121,16 @@ class FunctionModel(Model):
|
|
|
121
121
|
|
|
122
122
|
yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
|
|
123
123
|
|
|
124
|
+
@property
|
|
125
|
+
def model_name(self) -> str:
|
|
126
|
+
"""The model name."""
|
|
127
|
+
return self._model_name
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def system(self) -> str | None:
|
|
131
|
+
"""The system / model provider."""
|
|
132
|
+
return self._system
|
|
133
|
+
|
|
124
134
|
|
|
125
135
|
@dataclass(frozen=True)
|
|
126
136
|
class AgentInfo:
|
|
@@ -178,6 +188,7 @@ E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
|
|
|
178
188
|
class FunctionStreamedResponse(StreamedResponse):
|
|
179
189
|
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
180
190
|
|
|
191
|
+
_model_name: str
|
|
181
192
|
_iter: AsyncIterator[str | DeltaToolCalls]
|
|
182
193
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
183
194
|
|
|
@@ -205,7 +216,14 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
205
216
|
if maybe_event is not None:
|
|
206
217
|
yield maybe_event
|
|
207
218
|
|
|
219
|
+
@property
|
|
220
|
+
def model_name(self) -> str:
|
|
221
|
+
"""Get the model name of the response."""
|
|
222
|
+
return self._model_name
|
|
223
|
+
|
|
224
|
+
@property
|
|
208
225
|
def timestamp(self) -> datetime:
|
|
226
|
+
"""Get the timestamp of the response."""
|
|
209
227
|
return self._timestamp
|
|
210
228
|
|
|
211
229
|
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -47,6 +47,8 @@ LatestGeminiModelNames = Literal[
|
|
|
47
47
|
'gemini-2.0-flash-exp',
|
|
48
48
|
'gemini-2.0-flash-thinking-exp-01-21',
|
|
49
49
|
'gemini-exp-1206',
|
|
50
|
+
'gemini-2.0-flash',
|
|
51
|
+
'gemini-2.0-flash-lite-preview-02-05',
|
|
50
52
|
]
|
|
51
53
|
"""Latest Gemini models."""
|
|
52
54
|
|
|
@@ -147,6 +149,16 @@ class GeminiModel(Model):
|
|
|
147
149
|
) as http_response:
|
|
148
150
|
yield await self._process_streamed_response(http_response)
|
|
149
151
|
|
|
152
|
+
@property
|
|
153
|
+
def model_name(self) -> GeminiModelName:
|
|
154
|
+
"""The model name."""
|
|
155
|
+
return self._model_name
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def system(self) -> str | None:
|
|
159
|
+
"""The system / model provider."""
|
|
160
|
+
return self._system
|
|
161
|
+
|
|
150
162
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
|
|
151
163
|
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
|
|
152
164
|
if model_request_parameters.result_tools:
|
|
@@ -231,7 +243,7 @@ class GeminiModel(Model):
|
|
|
231
243
|
else:
|
|
232
244
|
raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
|
|
233
245
|
parts = response['candidates'][0]['content']['parts']
|
|
234
|
-
return _process_response_from_parts(parts, model_name=self._model_name)
|
|
246
|
+
return _process_response_from_parts(parts, model_name=response.get('model_version', self._model_name))
|
|
235
247
|
|
|
236
248
|
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
|
|
237
249
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -242,7 +254,7 @@ class GeminiModel(Model):
|
|
|
242
254
|
async for chunk in aiter_bytes:
|
|
243
255
|
content.extend(chunk)
|
|
244
256
|
responses = _gemini_streamed_response_ta.validate_json(
|
|
245
|
-
content,
|
|
257
|
+
_ensure_decodeable(content),
|
|
246
258
|
experimental_allow_partial='trailing-strings',
|
|
247
259
|
)
|
|
248
260
|
if responses:
|
|
@@ -313,6 +325,7 @@ class ApiKeyAuth:
|
|
|
313
325
|
class GeminiStreamedResponse(StreamedResponse):
|
|
314
326
|
"""Implementation of `StreamedResponse` for the Gemini model."""
|
|
315
327
|
|
|
328
|
+
_model_name: GeminiModelName
|
|
316
329
|
_content: bytearray
|
|
317
330
|
_stream: AsyncIterator[bytes]
|
|
318
331
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
@@ -357,7 +370,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
357
370
|
self._content.extend(chunk)
|
|
358
371
|
|
|
359
372
|
gemini_responses = _gemini_streamed_response_ta.validate_json(
|
|
360
|
-
self._content,
|
|
373
|
+
_ensure_decodeable(self._content),
|
|
361
374
|
experimental_allow_partial='trailing-strings',
|
|
362
375
|
)
|
|
363
376
|
|
|
@@ -376,7 +389,14 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
376
389
|
self._usage += _metadata_as_usage(r)
|
|
377
390
|
yield r
|
|
378
391
|
|
|
392
|
+
@property
|
|
393
|
+
def model_name(self) -> GeminiModelName:
|
|
394
|
+
"""Get the model name of the response."""
|
|
395
|
+
return self._model_name
|
|
396
|
+
|
|
397
|
+
@property
|
|
379
398
|
def timestamp(self) -> datetime:
|
|
399
|
+
"""Get the timestamp of the response."""
|
|
380
400
|
return self._timestamp
|
|
381
401
|
|
|
382
402
|
|
|
@@ -608,6 +628,7 @@ class _GeminiResponse(TypedDict):
|
|
|
608
628
|
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
|
|
609
629
|
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
|
|
610
630
|
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
|
|
631
|
+
model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
|
|
611
632
|
|
|
612
633
|
|
|
613
634
|
class _GeminiCandidates(TypedDict):
|
|
@@ -753,3 +774,19 @@ class _GeminiJsonSchema:
|
|
|
753
774
|
|
|
754
775
|
if items_schema := schema.get('items'): # pragma: no branch
|
|
755
776
|
self._simplify(items_schema, refs_stack)
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def _ensure_decodeable(content: bytearray) -> bytearray:
|
|
780
|
+
"""Trim any invalid unicode point bytes off the end of a bytearray.
|
|
781
|
+
|
|
782
|
+
This is necessary before attempting to parse streaming JSON bytes.
|
|
783
|
+
|
|
784
|
+
This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved
|
|
785
|
+
"""
|
|
786
|
+
while True:
|
|
787
|
+
try:
|
|
788
|
+
content.decode()
|
|
789
|
+
except UnicodeDecodeError:
|
|
790
|
+
content = content[:-1] # this will definitely succeed before we run out of bytes
|
|
791
|
+
else:
|
|
792
|
+
return content
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -146,6 +146,16 @@ class GroqModel(Model):
|
|
|
146
146
|
async with response:
|
|
147
147
|
yield await self._process_streamed_response(response)
|
|
148
148
|
|
|
149
|
+
@property
|
|
150
|
+
def model_name(self) -> GroqModelName:
|
|
151
|
+
"""The model name."""
|
|
152
|
+
return self._model_name
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def system(self) -> str | None:
|
|
156
|
+
"""The system / model provider."""
|
|
157
|
+
return self._system
|
|
158
|
+
|
|
149
159
|
@overload
|
|
150
160
|
async def _completions_create(
|
|
151
161
|
self,
|
|
@@ -212,7 +222,7 @@ class GroqModel(Model):
|
|
|
212
222
|
if choice.message.tool_calls is not None:
|
|
213
223
|
for c in choice.message.tool_calls:
|
|
214
224
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
215
|
-
return ModelResponse(items, model_name=
|
|
225
|
+
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
216
226
|
|
|
217
227
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
218
228
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -305,6 +315,7 @@ class GroqModel(Model):
|
|
|
305
315
|
class GroqStreamedResponse(StreamedResponse):
|
|
306
316
|
"""Implementation of `StreamedResponse` for Groq models."""
|
|
307
317
|
|
|
318
|
+
_model_name: GroqModelName
|
|
308
319
|
_response: AsyncIterable[ChatCompletionChunk]
|
|
309
320
|
_timestamp: datetime
|
|
310
321
|
|
|
@@ -333,7 +344,14 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
333
344
|
if maybe_event is not None:
|
|
334
345
|
yield maybe_event
|
|
335
346
|
|
|
347
|
+
@property
|
|
348
|
+
def model_name(self) -> GroqModelName:
|
|
349
|
+
"""Get the model name of the response."""
|
|
350
|
+
return self._model_name
|
|
351
|
+
|
|
352
|
+
@property
|
|
336
353
|
def timestamp(self) -> datetime:
|
|
354
|
+
"""Get the timestamp of the response."""
|
|
337
355
|
return self._timestamp
|
|
338
356
|
|
|
339
357
|
|