pydantic-ai-slim 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -1
- pydantic_ai/_agent_graph.py +256 -346
- pydantic_ai/_utils.py +1 -1
- pydantic_ai/agent.py +574 -149
- pydantic_ai/messages.py +31 -0
- pydantic_ai/models/__init__.py +29 -13
- pydantic_ai/models/anthropic.py +60 -50
- pydantic_ai/models/cohere.py +11 -1
- pydantic_ai/models/function.py +21 -3
- pydantic_ai/models/gemini.py +40 -3
- pydantic_ai/models/groq.py +19 -1
- pydantic_ai/models/instrumented.py +225 -0
- pydantic_ai/models/mistral.py +19 -4
- pydantic_ai/models/openai.py +23 -7
- pydantic_ai/models/test.py +24 -7
- pydantic_ai/models/vertexai.py +10 -0
- pydantic_ai/models/wrapper.py +45 -0
- pydantic_ai/result.py +107 -145
- {pydantic_ai_slim-0.0.23.dist-info → pydantic_ai_slim-0.0.25.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.25.dist-info/RECORD +32 -0
- pydantic_ai_slim-0.0.23.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.23.dist-info → pydantic_ai_slim-0.0.25.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator, Iterator
|
|
4
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Any, Callable, Literal
|
|
8
|
+
|
|
9
|
+
import logfire_api
|
|
10
|
+
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
|
|
11
|
+
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
|
|
12
|
+
|
|
13
|
+
from ..messages import (
|
|
14
|
+
ModelMessage,
|
|
15
|
+
ModelRequest,
|
|
16
|
+
ModelRequestPart,
|
|
17
|
+
ModelResponse,
|
|
18
|
+
RetryPromptPart,
|
|
19
|
+
SystemPromptPart,
|
|
20
|
+
TextPart,
|
|
21
|
+
ToolCallPart,
|
|
22
|
+
ToolReturnPart,
|
|
23
|
+
UserPromptPart,
|
|
24
|
+
)
|
|
25
|
+
from ..settings import ModelSettings
|
|
26
|
+
from ..usage import Usage
|
|
27
|
+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
|
|
28
|
+
from .wrapper import WrapperModel
|
|
29
|
+
|
|
30
|
+
MODEL_SETTING_ATTRIBUTES: tuple[
|
|
31
|
+
Literal[
|
|
32
|
+
'max_tokens',
|
|
33
|
+
'top_p',
|
|
34
|
+
'seed',
|
|
35
|
+
'temperature',
|
|
36
|
+
'presence_penalty',
|
|
37
|
+
'frequency_penalty',
|
|
38
|
+
],
|
|
39
|
+
...,
|
|
40
|
+
] = (
|
|
41
|
+
'max_tokens',
|
|
42
|
+
'top_p',
|
|
43
|
+
'seed',
|
|
44
|
+
'temperature',
|
|
45
|
+
'presence_penalty',
|
|
46
|
+
'frequency_penalty',
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
NOT_GIVEN = object()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class InstrumentedModel(WrapperModel):
|
|
54
|
+
"""Model which is instrumented with logfire."""
|
|
55
|
+
|
|
56
|
+
tracer: Tracer = field(repr=False)
|
|
57
|
+
event_logger: EventLogger = field(repr=False)
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
wrapped: Model | KnownModelName,
|
|
62
|
+
tracer_provider: TracerProvider | None = None,
|
|
63
|
+
event_logger_provider: EventLoggerProvider | None = None,
|
|
64
|
+
):
|
|
65
|
+
super().__init__(wrapped)
|
|
66
|
+
tracer_provider = tracer_provider or get_tracer_provider()
|
|
67
|
+
event_logger_provider = event_logger_provider or get_event_logger_provider()
|
|
68
|
+
self.tracer = tracer_provider.get_tracer('pydantic-ai')
|
|
69
|
+
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_logfire(
|
|
73
|
+
cls,
|
|
74
|
+
wrapped: Model | KnownModelName,
|
|
75
|
+
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
|
|
76
|
+
) -> InstrumentedModel:
|
|
77
|
+
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
|
|
78
|
+
event_provider = logfire_instance.config.get_event_logger_provider()
|
|
79
|
+
else:
|
|
80
|
+
event_provider = None
|
|
81
|
+
tracer_provider = logfire_instance.config.get_tracer_provider()
|
|
82
|
+
return cls(wrapped, tracer_provider, event_provider)
|
|
83
|
+
|
|
84
|
+
async def request(
|
|
85
|
+
self,
|
|
86
|
+
messages: list[ModelMessage],
|
|
87
|
+
model_settings: ModelSettings | None,
|
|
88
|
+
model_request_parameters: ModelRequestParameters,
|
|
89
|
+
) -> tuple[ModelResponse, Usage]:
|
|
90
|
+
with self._instrument(messages, model_settings) as finish:
|
|
91
|
+
response, usage = await super().request(messages, model_settings, model_request_parameters)
|
|
92
|
+
finish(response, usage)
|
|
93
|
+
return response, usage
|
|
94
|
+
|
|
95
|
+
@asynccontextmanager
|
|
96
|
+
async def request_stream(
|
|
97
|
+
self,
|
|
98
|
+
messages: list[ModelMessage],
|
|
99
|
+
model_settings: ModelSettings | None,
|
|
100
|
+
model_request_parameters: ModelRequestParameters,
|
|
101
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
102
|
+
with self._instrument(messages, model_settings) as finish:
|
|
103
|
+
response_stream: StreamedResponse | None = None
|
|
104
|
+
try:
|
|
105
|
+
async with super().request_stream(
|
|
106
|
+
messages, model_settings, model_request_parameters
|
|
107
|
+
) as response_stream:
|
|
108
|
+
yield response_stream
|
|
109
|
+
finally:
|
|
110
|
+
if response_stream:
|
|
111
|
+
finish(response_stream.get(), response_stream.usage())
|
|
112
|
+
|
|
113
|
+
@contextmanager
|
|
114
|
+
def _instrument(
|
|
115
|
+
self,
|
|
116
|
+
messages: list[ModelMessage],
|
|
117
|
+
model_settings: ModelSettings | None,
|
|
118
|
+
) -> Iterator[Callable[[ModelResponse, Usage], None]]:
|
|
119
|
+
operation = 'chat'
|
|
120
|
+
model_name = self.model_name
|
|
121
|
+
span_name = f'{operation} {model_name}'
|
|
122
|
+
system = getattr(self.wrapped, 'system', '') or self.wrapped.__class__.__name__.removesuffix('Model').lower()
|
|
123
|
+
system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
|
|
124
|
+
# TODO Missing attributes:
|
|
125
|
+
# - server.address: requires a Model.base_url abstract method or similar
|
|
126
|
+
# - server.port: to parse from the base_url
|
|
127
|
+
# - error.type: unclear if we should do something here or just always rely on span exceptions
|
|
128
|
+
# - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
|
|
129
|
+
attributes: dict[str, Any] = {
|
|
130
|
+
'gen_ai.operation.name': operation,
|
|
131
|
+
'gen_ai.system': system,
|
|
132
|
+
'gen_ai.request.model': model_name,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if model_settings:
|
|
136
|
+
for key in MODEL_SETTING_ATTRIBUTES:
|
|
137
|
+
if (value := model_settings.get(key, NOT_GIVEN)) is not NOT_GIVEN:
|
|
138
|
+
attributes[f'gen_ai.request.{key}'] = value
|
|
139
|
+
|
|
140
|
+
emit_event = partial(self._emit_event, system)
|
|
141
|
+
|
|
142
|
+
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
143
|
+
if span.is_recording():
|
|
144
|
+
for message in messages:
|
|
145
|
+
if isinstance(message, ModelRequest):
|
|
146
|
+
for part in message.parts:
|
|
147
|
+
event_name, body = _request_part_body(part)
|
|
148
|
+
if event_name:
|
|
149
|
+
emit_event(event_name, body)
|
|
150
|
+
elif isinstance(message, ModelResponse):
|
|
151
|
+
for body in _response_bodies(message):
|
|
152
|
+
emit_event('gen_ai.assistant.message', body)
|
|
153
|
+
|
|
154
|
+
def finish(response: ModelResponse, usage: Usage):
|
|
155
|
+
if not span.is_recording():
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
for response_body in _response_bodies(response):
|
|
159
|
+
if response_body:
|
|
160
|
+
emit_event(
|
|
161
|
+
'gen_ai.choice',
|
|
162
|
+
{
|
|
163
|
+
# TODO finish_reason
|
|
164
|
+
'index': 0,
|
|
165
|
+
'message': response_body,
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
span.set_attributes(
|
|
169
|
+
{
|
|
170
|
+
k: v
|
|
171
|
+
for k, v in {
|
|
172
|
+
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
|
|
173
|
+
# https://github.com/pydantic/pydantic-ai/issues/886
|
|
174
|
+
'gen_ai.response.model': response.model_name or model_name,
|
|
175
|
+
'gen_ai.usage.input_tokens': usage.request_tokens,
|
|
176
|
+
'gen_ai.usage.output_tokens': usage.response_tokens,
|
|
177
|
+
}.items()
|
|
178
|
+
if v is not None
|
|
179
|
+
}
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
yield finish
|
|
183
|
+
|
|
184
|
+
def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
|
|
185
|
+
self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
|
|
189
|
+
if isinstance(part, SystemPromptPart):
|
|
190
|
+
return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
|
|
191
|
+
elif isinstance(part, UserPromptPart):
|
|
192
|
+
return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
|
|
193
|
+
elif isinstance(part, ToolReturnPart):
|
|
194
|
+
return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
|
|
195
|
+
elif isinstance(part, RetryPromptPart):
|
|
196
|
+
if part.tool_name is None:
|
|
197
|
+
return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
|
|
198
|
+
else:
|
|
199
|
+
return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
|
|
200
|
+
else:
|
|
201
|
+
return '', {}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
|
|
205
|
+
body: dict[str, Any] = {'role': 'assistant'}
|
|
206
|
+
result = [body]
|
|
207
|
+
for part in message.parts:
|
|
208
|
+
if isinstance(part, ToolCallPart):
|
|
209
|
+
body.setdefault('tool_calls', []).append(
|
|
210
|
+
{
|
|
211
|
+
'id': part.tool_call_id,
|
|
212
|
+
'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
|
|
213
|
+
'function': {
|
|
214
|
+
'name': part.tool_name,
|
|
215
|
+
'arguments': part.args,
|
|
216
|
+
},
|
|
217
|
+
}
|
|
218
|
+
)
|
|
219
|
+
elif isinstance(part, TextPart):
|
|
220
|
+
if body.get('content'):
|
|
221
|
+
body = {'role': 'assistant'}
|
|
222
|
+
result.append(body)
|
|
223
|
+
body['content'] = part.content
|
|
224
|
+
|
|
225
|
+
return result
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -134,9 +134,6 @@ class MistralModel(Model):
|
|
|
134
134
|
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
|
|
135
135
|
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
|
|
136
136
|
|
|
137
|
-
def name(self) -> str:
|
|
138
|
-
return f'mistral:{self._model_name}'
|
|
139
|
-
|
|
140
137
|
async def request(
|
|
141
138
|
self,
|
|
142
139
|
messages: list[ModelMessage],
|
|
@@ -165,6 +162,16 @@ class MistralModel(Model):
|
|
|
165
162
|
async with response:
|
|
166
163
|
yield await self._process_streamed_response(model_request_parameters.result_tools, response)
|
|
167
164
|
|
|
165
|
+
@property
|
|
166
|
+
def model_name(self) -> MistralModelName:
|
|
167
|
+
"""The model name."""
|
|
168
|
+
return self._model_name
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def system(self) -> str | None:
|
|
172
|
+
"""The system / model provider."""
|
|
173
|
+
return self._system
|
|
174
|
+
|
|
168
175
|
async def _completions_create(
|
|
169
176
|
self,
|
|
170
177
|
messages: list[ModelMessage],
|
|
@@ -296,7 +303,7 @@ class MistralModel(Model):
|
|
|
296
303
|
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
297
304
|
parts.append(tool)
|
|
298
305
|
|
|
299
|
-
return ModelResponse(parts, model_name=
|
|
306
|
+
return ModelResponse(parts, model_name=response.model, timestamp=timestamp)
|
|
300
307
|
|
|
301
308
|
async def _process_streamed_response(
|
|
302
309
|
self,
|
|
@@ -461,6 +468,7 @@ MistralToolCallId = Union[str, None]
|
|
|
461
468
|
class MistralStreamedResponse(StreamedResponse):
|
|
462
469
|
"""Implementation of `StreamedResponse` for Mistral models."""
|
|
463
470
|
|
|
471
|
+
_model_name: MistralModelName
|
|
464
472
|
_response: AsyncIterable[MistralCompletionEvent]
|
|
465
473
|
_timestamp: datetime
|
|
466
474
|
_result_tools: dict[str, ToolDefinition]
|
|
@@ -502,7 +510,14 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
502
510
|
vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
|
|
503
511
|
)
|
|
504
512
|
|
|
513
|
+
@property
|
|
514
|
+
def model_name(self) -> MistralModelName:
|
|
515
|
+
"""Get the model name of the response."""
|
|
516
|
+
return self._model_name
|
|
517
|
+
|
|
518
|
+
@property
|
|
505
519
|
def timestamp(self) -> datetime:
|
|
520
|
+
"""Get the timestamp of the response."""
|
|
506
521
|
return self._timestamp
|
|
507
522
|
|
|
508
523
|
@staticmethod
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -119,10 +119,11 @@ class OpenAIModel(Model):
|
|
|
119
119
|
"""
|
|
120
120
|
self._model_name = model_name
|
|
121
121
|
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
122
|
-
# openai compatible models do not always need an API key.
|
|
122
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
123
123
|
if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
|
|
124
|
-
api_key = ''
|
|
125
|
-
|
|
124
|
+
api_key = 'api-key-not-set'
|
|
125
|
+
|
|
126
|
+
if openai_client is not None:
|
|
126
127
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
127
128
|
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
128
129
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
@@ -134,9 +135,6 @@ class OpenAIModel(Model):
|
|
|
134
135
|
self.system_prompt_role = system_prompt_role
|
|
135
136
|
self._system = system
|
|
136
137
|
|
|
137
|
-
def name(self) -> str:
|
|
138
|
-
return f'openai:{self._model_name}'
|
|
139
|
-
|
|
140
138
|
async def request(
|
|
141
139
|
self,
|
|
142
140
|
messages: list[ModelMessage],
|
|
@@ -163,6 +161,16 @@ class OpenAIModel(Model):
|
|
|
163
161
|
async with response:
|
|
164
162
|
yield await self._process_streamed_response(response)
|
|
165
163
|
|
|
164
|
+
@property
|
|
165
|
+
def model_name(self) -> OpenAIModelName:
|
|
166
|
+
"""The model name."""
|
|
167
|
+
return self._model_name
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def system(self) -> str | None:
|
|
171
|
+
"""The system / model provider."""
|
|
172
|
+
return self._system
|
|
173
|
+
|
|
166
174
|
@overload
|
|
167
175
|
async def _completions_create(
|
|
168
176
|
self,
|
|
@@ -232,7 +240,7 @@ class OpenAIModel(Model):
|
|
|
232
240
|
if choice.message.tool_calls is not None:
|
|
233
241
|
for c in choice.message.tool_calls:
|
|
234
242
|
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
235
|
-
return ModelResponse(items, model_name=
|
|
243
|
+
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
236
244
|
|
|
237
245
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
238
246
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -331,6 +339,7 @@ class OpenAIModel(Model):
|
|
|
331
339
|
class OpenAIStreamedResponse(StreamedResponse):
|
|
332
340
|
"""Implementation of `StreamedResponse` for OpenAI models."""
|
|
333
341
|
|
|
342
|
+
_model_name: OpenAIModelName
|
|
334
343
|
_response: AsyncIterable[ChatCompletionChunk]
|
|
335
344
|
_timestamp: datetime
|
|
336
345
|
|
|
@@ -358,7 +367,14 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
358
367
|
if maybe_event is not None:
|
|
359
368
|
yield maybe_event
|
|
360
369
|
|
|
370
|
+
@property
|
|
371
|
+
def model_name(self) -> OpenAIModelName:
|
|
372
|
+
"""Get the model name of the response."""
|
|
373
|
+
return self._model_name
|
|
374
|
+
|
|
375
|
+
@property
|
|
361
376
|
def timestamp(self) -> datetime:
|
|
377
|
+
"""Get the timestamp of the response."""
|
|
362
378
|
return self._timestamp
|
|
363
379
|
|
|
364
380
|
|
pydantic_ai/models/test.py
CHANGED
|
@@ -107,6 +107,16 @@ class TestModel(Model):
|
|
|
107
107
|
_model_name=self._model_name, _structured_response=model_response, _messages=messages
|
|
108
108
|
)
|
|
109
109
|
|
|
110
|
+
@property
|
|
111
|
+
def model_name(self) -> str:
|
|
112
|
+
"""The model name."""
|
|
113
|
+
return self._model_name
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def system(self) -> str | None:
|
|
117
|
+
"""The system / model provider."""
|
|
118
|
+
return self._system
|
|
119
|
+
|
|
110
120
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
111
121
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
112
122
|
|
|
@@ -120,15 +130,15 @@ class TestModel(Model):
|
|
|
120
130
|
|
|
121
131
|
def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
|
|
122
132
|
if self.custom_result_text is not None:
|
|
123
|
-
assert (
|
|
124
|
-
|
|
125
|
-
)
|
|
133
|
+
assert model_request_parameters.allow_text_result, (
|
|
134
|
+
'Plain response not allowed, but `custom_result_text` is set.'
|
|
135
|
+
)
|
|
126
136
|
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
127
137
|
return _TextResult(self.custom_result_text)
|
|
128
138
|
elif self.custom_result_args is not None:
|
|
129
|
-
assert (
|
|
130
|
-
|
|
131
|
-
)
|
|
139
|
+
assert model_request_parameters.result_tools is not None, (
|
|
140
|
+
'No result tools provided, but `custom_result_args` is set.'
|
|
141
|
+
)
|
|
132
142
|
result_tool = model_request_parameters.result_tools[0]
|
|
133
143
|
|
|
134
144
|
if k := result_tool.outer_typed_dict_key:
|
|
@@ -221,9 +231,9 @@ class TestModel(Model):
|
|
|
221
231
|
class TestStreamedResponse(StreamedResponse):
|
|
222
232
|
"""A structured response that streams test data."""
|
|
223
233
|
|
|
234
|
+
_model_name: str
|
|
224
235
|
_structured_response: ModelResponse
|
|
225
236
|
_messages: InitVar[Iterable[ModelMessage]]
|
|
226
|
-
|
|
227
237
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
228
238
|
|
|
229
239
|
def __post_init__(self, _messages: Iterable[ModelMessage]):
|
|
@@ -249,7 +259,14 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
249
259
|
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
|
|
250
260
|
)
|
|
251
261
|
|
|
262
|
+
@property
|
|
263
|
+
def model_name(self) -> str:
|
|
264
|
+
"""Get the model name of the response."""
|
|
265
|
+
return self._model_name
|
|
266
|
+
|
|
267
|
+
@property
|
|
252
268
|
def timestamp(self) -> datetime:
|
|
269
|
+
"""Get the timestamp of the response."""
|
|
253
270
|
return self._timestamp
|
|
254
271
|
|
|
255
272
|
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -161,6 +161,16 @@ class VertexAIModel(GeminiModel):
|
|
|
161
161
|
async with super().request_stream(messages, model_settings, model_request_parameters) as value:
|
|
162
162
|
yield value
|
|
163
163
|
|
|
164
|
+
@property
|
|
165
|
+
def model_name(self) -> GeminiModelName:
|
|
166
|
+
"""The model name."""
|
|
167
|
+
return self._model_name
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def system(self) -> str | None:
|
|
171
|
+
"""The system / model provider."""
|
|
172
|
+
return self._system
|
|
173
|
+
|
|
164
174
|
|
|
165
175
|
# pyright: reportUnknownMemberType=false
|
|
166
176
|
def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..messages import ModelMessage, ModelResponse
|
|
9
|
+
from ..settings import ModelSettings
|
|
10
|
+
from ..usage import Usage
|
|
11
|
+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(init=False)
|
|
15
|
+
class WrapperModel(Model):
|
|
16
|
+
"""Model which wraps another model."""
|
|
17
|
+
|
|
18
|
+
wrapped: Model
|
|
19
|
+
|
|
20
|
+
def __init__(self, wrapped: Model | KnownModelName):
|
|
21
|
+
self.wrapped = infer_model(wrapped)
|
|
22
|
+
|
|
23
|
+
async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
|
|
24
|
+
return await self.wrapped.request(*args, **kwargs)
|
|
25
|
+
|
|
26
|
+
@asynccontextmanager
|
|
27
|
+
async def request_stream(
|
|
28
|
+
self,
|
|
29
|
+
messages: list[ModelMessage],
|
|
30
|
+
model_settings: ModelSettings | None,
|
|
31
|
+
model_request_parameters: ModelRequestParameters,
|
|
32
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
33
|
+
async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
|
|
34
|
+
yield response_stream
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def model_name(self) -> str:
|
|
38
|
+
return self.wrapped.model_name
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def system(self) -> str | None:
|
|
42
|
+
return self.wrapped.system
|
|
43
|
+
|
|
44
|
+
def __getattr__(self, item: str):
|
|
45
|
+
return getattr(self.wrapped, item)
|