pydantic-ai-slim 0.0.6a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +8 -0
- pydantic_ai/_griffe.py +128 -0
- pydantic_ai/_pydantic.py +216 -0
- pydantic_ai/_result.py +258 -0
- pydantic_ai/_retriever.py +114 -0
- pydantic_ai/_system_prompt.py +33 -0
- pydantic_ai/_utils.py +247 -0
- pydantic_ai/agent.py +795 -0
- pydantic_ai/dependencies.py +83 -0
- pydantic_ai/exceptions.py +56 -0
- pydantic_ai/messages.py +205 -0
- pydantic_ai/models/__init__.py +300 -0
- pydantic_ai/models/function.py +268 -0
- pydantic_ai/models/gemini.py +720 -0
- pydantic_ai/models/groq.py +400 -0
- pydantic_ai/models/openai.py +379 -0
- pydantic_ai/models/test.py +389 -0
- pydantic_ai/models/vertexai.py +306 -0
- pydantic_ai/py.typed +0 -0
- pydantic_ai/result.py +314 -0
- pydantic_ai_slim-0.0.6a1.dist-info/METADATA +49 -0
- pydantic_ai_slim-0.0.6a1.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.6a1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
"""Utilities for testing apps built with PydanticAI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations as _annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
import string
|
|
7
|
+
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
|
|
8
|
+
from contextlib import asynccontextmanager
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Any, Literal
|
|
12
|
+
|
|
13
|
+
import pydantic_core
|
|
14
|
+
|
|
15
|
+
from .. import _utils
|
|
16
|
+
from ..messages import (
|
|
17
|
+
Message,
|
|
18
|
+
ModelAnyResponse,
|
|
19
|
+
ModelStructuredResponse,
|
|
20
|
+
ModelTextResponse,
|
|
21
|
+
RetryPrompt,
|
|
22
|
+
ToolCall,
|
|
23
|
+
ToolReturn,
|
|
24
|
+
)
|
|
25
|
+
from ..result import Cost
|
|
26
|
+
from . import (
|
|
27
|
+
AbstractToolDefinition,
|
|
28
|
+
AgentModel,
|
|
29
|
+
EitherStreamedResponse,
|
|
30
|
+
Model,
|
|
31
|
+
StreamStructuredResponse,
|
|
32
|
+
StreamTextResponse,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UnSetType:
|
|
37
|
+
def __repr__(self):
|
|
38
|
+
return 'UnSet'
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
UnSet = UnSetType()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class TestModel(Model):
|
|
46
|
+
"""A model specifically for testing purposes.
|
|
47
|
+
|
|
48
|
+
This will (by default) call all retrievers in the agent model, then return a tool response if possible,
|
|
49
|
+
otherwise a plain response.
|
|
50
|
+
|
|
51
|
+
How useful this function will be is unknown, it may be useless, it may require significant changes to be useful.
|
|
52
|
+
|
|
53
|
+
Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those
|
|
54
|
+
of the base class.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# NOTE: Avoid test discovery by pytest.
|
|
58
|
+
__test__ = False
|
|
59
|
+
|
|
60
|
+
call_retrievers: list[str] | Literal['all'] = 'all'
|
|
61
|
+
"""List of retrievers to call. If `'all'`, all retrievers will be called."""
|
|
62
|
+
custom_result_text: str | None = None
|
|
63
|
+
"""If set, this text is return as the final result."""
|
|
64
|
+
custom_result_args: Any | None = None
|
|
65
|
+
"""If set, these args will be passed to the result tool."""
|
|
66
|
+
seed: int = 0
|
|
67
|
+
"""Seed for generating random data."""
|
|
68
|
+
# these fields are set when the model is called by the agent
|
|
69
|
+
agent_model_retrievers: Mapping[str, AbstractToolDefinition] | None = field(default=None, init=False)
|
|
70
|
+
agent_model_allow_text_result: bool | None = field(default=None, init=False)
|
|
71
|
+
agent_model_result_tools: list[AbstractToolDefinition] | None = field(default=None, init=False)
|
|
72
|
+
|
|
73
|
+
async def agent_model(
|
|
74
|
+
self,
|
|
75
|
+
retrievers: Mapping[str, AbstractToolDefinition],
|
|
76
|
+
allow_text_result: bool,
|
|
77
|
+
result_tools: Sequence[AbstractToolDefinition] | None,
|
|
78
|
+
) -> AgentModel:
|
|
79
|
+
self.agent_model_retrievers = retrievers
|
|
80
|
+
self.agent_model_allow_text_result = allow_text_result
|
|
81
|
+
self.agent_model_result_tools = list(result_tools) if result_tools is not None else None
|
|
82
|
+
|
|
83
|
+
if self.call_retrievers == 'all':
|
|
84
|
+
retriever_calls = [(r.name, r) for r in retrievers.values()]
|
|
85
|
+
else:
|
|
86
|
+
retrievers_to_call = (retrievers[name] for name in self.call_retrievers)
|
|
87
|
+
retriever_calls = [(r.name, r) for r in retrievers_to_call]
|
|
88
|
+
|
|
89
|
+
if self.custom_result_text is not None:
|
|
90
|
+
assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
|
|
91
|
+
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
92
|
+
result: _utils.Either[str | None, Any | None] = _utils.Either(left=self.custom_result_text)
|
|
93
|
+
elif self.custom_result_args is not None:
|
|
94
|
+
assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
|
|
95
|
+
result_tool = result_tools[0]
|
|
96
|
+
|
|
97
|
+
if k := result_tool.outer_typed_dict_key:
|
|
98
|
+
result = _utils.Either(right={k: self.custom_result_args})
|
|
99
|
+
else:
|
|
100
|
+
result = _utils.Either(right=self.custom_result_args)
|
|
101
|
+
elif allow_text_result:
|
|
102
|
+
result = _utils.Either(left=None)
|
|
103
|
+
elif result_tools is not None:
|
|
104
|
+
result = _utils.Either(right=None)
|
|
105
|
+
else:
|
|
106
|
+
result = _utils.Either(left=None)
|
|
107
|
+
return TestAgentModel(retriever_calls, result, self.agent_model_result_tools, self.seed)
|
|
108
|
+
|
|
109
|
+
def name(self) -> str:
|
|
110
|
+
return 'test-model'
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class TestAgentModel(AgentModel):
|
|
115
|
+
"""Implementation of `AgentModel` for testing purposes."""
|
|
116
|
+
|
|
117
|
+
# NOTE: Avoid test discovery by pytest.
|
|
118
|
+
__test__ = False
|
|
119
|
+
|
|
120
|
+
retriever_calls: list[tuple[str, AbstractToolDefinition]]
|
|
121
|
+
# left means the text is plain text; right means it's a function call
|
|
122
|
+
result: _utils.Either[str | None, Any | None]
|
|
123
|
+
result_tools: list[AbstractToolDefinition] | None
|
|
124
|
+
seed: int
|
|
125
|
+
step: int = 0
|
|
126
|
+
last_message_count: int = 0
|
|
127
|
+
|
|
128
|
+
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
|
|
129
|
+
return self._request(messages), Cost()
|
|
130
|
+
|
|
131
|
+
@asynccontextmanager
|
|
132
|
+
async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
|
|
133
|
+
msg = self._request(messages)
|
|
134
|
+
cost = Cost()
|
|
135
|
+
if isinstance(msg, ModelTextResponse):
|
|
136
|
+
yield TestStreamTextResponse(msg.content, cost)
|
|
137
|
+
else:
|
|
138
|
+
yield TestStreamStructuredResponse(msg, cost)
|
|
139
|
+
|
|
140
|
+
def gen_retriever_args(self, tool_def: AbstractToolDefinition) -> Any:
|
|
141
|
+
return _JsonSchemaTestData(tool_def.json_schema, self.seed).generate()
|
|
142
|
+
|
|
143
|
+
def _request(self, messages: list[Message]) -> ModelAnyResponse:
|
|
144
|
+
if self.step == 0 and self.retriever_calls:
|
|
145
|
+
calls = [ToolCall.from_object(name, self.gen_retriever_args(args)) for name, args in self.retriever_calls]
|
|
146
|
+
self.step += 1
|
|
147
|
+
self.last_message_count = len(messages)
|
|
148
|
+
return ModelStructuredResponse(calls=calls)
|
|
149
|
+
|
|
150
|
+
new_messages = messages[self.last_message_count :]
|
|
151
|
+
self.last_message_count = len(messages)
|
|
152
|
+
new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)}
|
|
153
|
+
if new_retry_names:
|
|
154
|
+
calls = [
|
|
155
|
+
ToolCall.from_object(name, self.gen_retriever_args(args))
|
|
156
|
+
for name, args in self.retriever_calls
|
|
157
|
+
if name in new_retry_names
|
|
158
|
+
]
|
|
159
|
+
self.step += 1
|
|
160
|
+
return ModelStructuredResponse(calls=calls)
|
|
161
|
+
else:
|
|
162
|
+
if response_text := self.result.left:
|
|
163
|
+
self.step += 1
|
|
164
|
+
if response_text.value is None:
|
|
165
|
+
# build up details of retriever responses
|
|
166
|
+
output: dict[str, Any] = {}
|
|
167
|
+
for message in messages:
|
|
168
|
+
if isinstance(message, ToolReturn):
|
|
169
|
+
output[message.tool_name] = message.content
|
|
170
|
+
if output:
|
|
171
|
+
return ModelTextResponse(content=pydantic_core.to_json(output).decode())
|
|
172
|
+
else:
|
|
173
|
+
return ModelTextResponse(content='success (no retriever calls)')
|
|
174
|
+
else:
|
|
175
|
+
return ModelTextResponse(content=response_text.value)
|
|
176
|
+
else:
|
|
177
|
+
assert self.result_tools is not None, 'No result tools provided'
|
|
178
|
+
custom_result_args = self.result.right
|
|
179
|
+
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
180
|
+
if custom_result_args is not None:
|
|
181
|
+
self.step += 1
|
|
182
|
+
return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, custom_result_args)])
|
|
183
|
+
else:
|
|
184
|
+
response_args = self.gen_retriever_args(result_tool)
|
|
185
|
+
self.step += 1
|
|
186
|
+
return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, response_args)])
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@dataclass
|
|
190
|
+
class TestStreamTextResponse(StreamTextResponse):
|
|
191
|
+
_text: str
|
|
192
|
+
_cost: Cost
|
|
193
|
+
_iter: Iterator[str] = field(init=False)
|
|
194
|
+
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
195
|
+
_buffer: list[str] = field(default_factory=list, init=False)
|
|
196
|
+
|
|
197
|
+
def __post_init__(self):
|
|
198
|
+
*words, last_word = self._text.split(' ')
|
|
199
|
+
words = [f'{word} ' for word in words]
|
|
200
|
+
words.append(last_word)
|
|
201
|
+
if len(words) == 1 and len(self._text) > 2:
|
|
202
|
+
mid = len(self._text) // 2
|
|
203
|
+
words = [self._text[:mid], self._text[mid:]]
|
|
204
|
+
self._iter = iter(words)
|
|
205
|
+
|
|
206
|
+
async def __anext__(self) -> None:
|
|
207
|
+
self._buffer.append(_utils.sync_anext(self._iter))
|
|
208
|
+
|
|
209
|
+
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
210
|
+
yield from self._buffer
|
|
211
|
+
self._buffer.clear()
|
|
212
|
+
|
|
213
|
+
def cost(self) -> Cost:
|
|
214
|
+
return self._cost
|
|
215
|
+
|
|
216
|
+
def timestamp(self) -> datetime:
|
|
217
|
+
return self._timestamp
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@dataclass
|
|
221
|
+
class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
222
|
+
_structured_response: ModelStructuredResponse
|
|
223
|
+
_cost: Cost
|
|
224
|
+
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
|
225
|
+
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
226
|
+
|
|
227
|
+
async def __anext__(self) -> None:
|
|
228
|
+
return _utils.sync_anext(self._iter)
|
|
229
|
+
|
|
230
|
+
def get(self, *, final: bool = False) -> ModelStructuredResponse:
|
|
231
|
+
return self._structured_response
|
|
232
|
+
|
|
233
|
+
def cost(self) -> Cost:
|
|
234
|
+
return self._cost
|
|
235
|
+
|
|
236
|
+
def timestamp(self) -> datetime:
|
|
237
|
+
return self._timestamp
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
_chars = string.ascii_letters + string.digits + string.punctuation
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class _JsonSchemaTestData:
|
|
244
|
+
"""Generate data that matches a JSON schema.
|
|
245
|
+
|
|
246
|
+
This tries to generate the minimal viable data for the schema.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0):
|
|
250
|
+
self.schema = schema
|
|
251
|
+
self.defs = schema.get('$defs', {})
|
|
252
|
+
self.seed = seed
|
|
253
|
+
|
|
254
|
+
def generate(self) -> Any:
|
|
255
|
+
"""Generate data for the JSON schema."""
|
|
256
|
+
return self._gen_any(self.schema)
|
|
257
|
+
|
|
258
|
+
def _gen_any(self, schema: dict[str, Any]) -> Any:
|
|
259
|
+
"""Generate data for any JSON Schema."""
|
|
260
|
+
if const := schema.get('const'):
|
|
261
|
+
return const
|
|
262
|
+
elif enum := schema.get('enum'):
|
|
263
|
+
return enum[self.seed % len(enum)]
|
|
264
|
+
elif examples := schema.get('examples'):
|
|
265
|
+
return examples[self.seed % len(examples)]
|
|
266
|
+
elif ref := schema.get('$ref'):
|
|
267
|
+
key = re.sub(r'^#/\$defs/', '', ref)
|
|
268
|
+
js_def = self.defs[key]
|
|
269
|
+
return self._gen_any(js_def)
|
|
270
|
+
elif any_of := schema.get('anyOf'):
|
|
271
|
+
return self._gen_any(any_of[self.seed % len(any_of)])
|
|
272
|
+
|
|
273
|
+
type_ = schema.get('type')
|
|
274
|
+
if type_ is None:
|
|
275
|
+
# if there's no type or ref, we can't generate anything
|
|
276
|
+
return self._char()
|
|
277
|
+
elif type_ == 'object':
|
|
278
|
+
return self._object_gen(schema)
|
|
279
|
+
elif type_ == 'string':
|
|
280
|
+
return self._str_gen(schema)
|
|
281
|
+
elif type_ == 'integer':
|
|
282
|
+
return self._int_gen(schema)
|
|
283
|
+
elif type_ == 'number':
|
|
284
|
+
return float(self._int_gen(schema))
|
|
285
|
+
elif type_ == 'boolean':
|
|
286
|
+
return self._bool_gen()
|
|
287
|
+
elif type_ == 'array':
|
|
288
|
+
return self._array_gen(schema)
|
|
289
|
+
elif type_ == 'null':
|
|
290
|
+
return None
|
|
291
|
+
else:
|
|
292
|
+
raise NotImplementedError(f'Unknown type: {type_}, please submit a PR to extend JsonSchemaTestData!')
|
|
293
|
+
|
|
294
|
+
def _object_gen(self, schema: dict[str, Any]) -> dict[str, Any]:
|
|
295
|
+
"""Generate data for a JSON Schema object."""
|
|
296
|
+
required = set(schema.get('required', []))
|
|
297
|
+
|
|
298
|
+
data: dict[str, Any] = {}
|
|
299
|
+
if properties := schema.get('properties'):
|
|
300
|
+
for key, value in properties.items():
|
|
301
|
+
if key in required:
|
|
302
|
+
data[key] = self._gen_any(value)
|
|
303
|
+
|
|
304
|
+
if addition_props := schema.get('additionalProperties'):
|
|
305
|
+
add_prop_key = 'additionalProperty'
|
|
306
|
+
while add_prop_key in data:
|
|
307
|
+
add_prop_key += '_'
|
|
308
|
+
if addition_props is True:
|
|
309
|
+
data[add_prop_key] = self._char()
|
|
310
|
+
else:
|
|
311
|
+
data[add_prop_key] = self._gen_any(addition_props)
|
|
312
|
+
|
|
313
|
+
return data
|
|
314
|
+
|
|
315
|
+
def _str_gen(self, schema: dict[str, Any]) -> str:
|
|
316
|
+
"""Generate a string from a JSON Schema string."""
|
|
317
|
+
min_len = schema.get('minLength')
|
|
318
|
+
if min_len is not None:
|
|
319
|
+
return self._char() * min_len
|
|
320
|
+
|
|
321
|
+
if schema.get('maxLength') == 0:
|
|
322
|
+
return ''
|
|
323
|
+
else:
|
|
324
|
+
return self._char()
|
|
325
|
+
|
|
326
|
+
def _int_gen(self, schema: dict[str, Any]) -> int:
|
|
327
|
+
"""Generate an integer from a JSON Schema integer."""
|
|
328
|
+
maximum = schema.get('maximum')
|
|
329
|
+
if maximum is None:
|
|
330
|
+
exc_max = schema.get('exclusiveMaximum')
|
|
331
|
+
if exc_max is not None:
|
|
332
|
+
maximum = exc_max - 1
|
|
333
|
+
|
|
334
|
+
minimum = schema.get('minimum')
|
|
335
|
+
if minimum is None:
|
|
336
|
+
exc_min = schema.get('exclusiveMinimum')
|
|
337
|
+
if exc_min is not None:
|
|
338
|
+
minimum = exc_min + 1
|
|
339
|
+
|
|
340
|
+
if minimum is not None and maximum is not None:
|
|
341
|
+
return minimum + self.seed % (maximum - minimum)
|
|
342
|
+
elif minimum is not None:
|
|
343
|
+
return minimum + self.seed
|
|
344
|
+
elif maximum is not None:
|
|
345
|
+
return maximum - self.seed
|
|
346
|
+
else:
|
|
347
|
+
return self.seed
|
|
348
|
+
|
|
349
|
+
def _bool_gen(self) -> bool:
|
|
350
|
+
"""Generate a boolean from a JSON Schema boolean."""
|
|
351
|
+
return bool(self.seed % 2)
|
|
352
|
+
|
|
353
|
+
def _array_gen(self, schema: dict[str, Any]) -> list[Any]:
|
|
354
|
+
"""Generate an array from a JSON Schema array."""
|
|
355
|
+
data: list[Any] = []
|
|
356
|
+
unique_items = schema.get('uniqueItems')
|
|
357
|
+
if prefix_items := schema.get('prefixItems'):
|
|
358
|
+
for item in prefix_items:
|
|
359
|
+
data.append(self._gen_any(item))
|
|
360
|
+
if unique_items:
|
|
361
|
+
self.seed += 1
|
|
362
|
+
|
|
363
|
+
items_schema = schema.get('items', {})
|
|
364
|
+
min_items = schema.get('minItems', 0)
|
|
365
|
+
if min_items > len(data):
|
|
366
|
+
for _ in range(min_items - len(data)):
|
|
367
|
+
data.append(self._gen_any(items_schema))
|
|
368
|
+
if unique_items:
|
|
369
|
+
self.seed += 1
|
|
370
|
+
elif items_schema:
|
|
371
|
+
# if there is an `items` schema, add an item unless it would break `maxItems` rule
|
|
372
|
+
max_items = schema.get('maxItems')
|
|
373
|
+
if max_items is None or max_items > len(data):
|
|
374
|
+
data.append(self._gen_any(items_schema))
|
|
375
|
+
if unique_items:
|
|
376
|
+
self.seed += 1
|
|
377
|
+
|
|
378
|
+
return data
|
|
379
|
+
|
|
380
|
+
def _char(self) -> str:
|
|
381
|
+
"""Generate a character on the same principle as Excel columns, e.g. a-z, aa-az..."""
|
|
382
|
+
chars = len(_chars)
|
|
383
|
+
s = ''
|
|
384
|
+
rem = self.seed // chars
|
|
385
|
+
while rem > 0:
|
|
386
|
+
s += _chars[(rem - 1) % chars]
|
|
387
|
+
rem //= chars
|
|
388
|
+
s += _chars[self.seed % chars]
|
|
389
|
+
return s
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
"""Custom interface to the `*-aiplatform.googleapis.com` API for Gemini models.
|
|
2
|
+
|
|
3
|
+
This model uses [`GeminiAgentModel`][pydantic_ai.models.gemini.GeminiAgentModel] with just the URL and auth method
|
|
4
|
+
changed from the default `GeminiModel`, it relies on the VertexAI
|
|
5
|
+
[`generateContent` function endpoint](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent)
|
|
6
|
+
and `streamGenerateContent` function endpoints
|
|
7
|
+
having the same schemas as the equivalent [Gemini endpoints][pydantic_ai.models.gemini.GeminiModel].
|
|
8
|
+
|
|
9
|
+
There are four advantages of using this API over the `generativelanguage.googleapis.com` API which
|
|
10
|
+
[`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] uses, and one big disadvantage.
|
|
11
|
+
|
|
12
|
+
Advantages:
|
|
13
|
+
|
|
14
|
+
1. The VertexAI API seems to be less flakey, less likely to occasionally return a 503 response.
|
|
15
|
+
2. You can
|
|
16
|
+
[purchase provisioned throughput](https://cloud.google.com/vertex-ai/generative-ai/docs/provisioned-throughput#purchase-provisioned-throughput)
|
|
17
|
+
with VertexAI.
|
|
18
|
+
3. If you're running PydanticAI inside GCP, you don't need to set up authentication, it should "just work".
|
|
19
|
+
4. You can decide which region to use, which might be important from a regulatory perspective,
|
|
20
|
+
and might improve latency.
|
|
21
|
+
|
|
22
|
+
Disadvantage:
|
|
23
|
+
|
|
24
|
+
1. When authorization doesn't just work, it's much more painful to set up than an API key.
|
|
25
|
+
|
|
26
|
+
## Example Usage
|
|
27
|
+
|
|
28
|
+
With the default google project already configured in your environment:
|
|
29
|
+
|
|
30
|
+
```py title="vertex_example_env.py"
|
|
31
|
+
from pydantic_ai import Agent
|
|
32
|
+
from pydantic_ai.models.vertexai import VertexAIModel
|
|
33
|
+
|
|
34
|
+
model = VertexAIModel('gemini-1.5-flash')
|
|
35
|
+
agent = Agent(model)
|
|
36
|
+
result = agent.run_sync('Tell me a joke.')
|
|
37
|
+
print(result.data)
|
|
38
|
+
#> Did you hear about the toothpaste scandal? They called it Colgate.
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Or using a service account JSON file:
|
|
42
|
+
|
|
43
|
+
```py title="vertex_example_service_account.py"
|
|
44
|
+
from pydantic_ai import Agent
|
|
45
|
+
from pydantic_ai.models.vertexai import VertexAIModel
|
|
46
|
+
|
|
47
|
+
model = VertexAIModel(
|
|
48
|
+
'gemini-1.5-flash',
|
|
49
|
+
service_account_file='path/to/service-account.json',
|
|
50
|
+
)
|
|
51
|
+
agent = Agent(model)
|
|
52
|
+
result = agent.run_sync('Tell me a joke.')
|
|
53
|
+
print(result.data)
|
|
54
|
+
#> Did you hear about the toothpaste scandal? They called it Colgate.
|
|
55
|
+
```
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
from __future__ import annotations as _annotations
|
|
59
|
+
|
|
60
|
+
from collections.abc import Mapping, Sequence
|
|
61
|
+
from dataclasses import dataclass, field
|
|
62
|
+
from datetime import datetime, timedelta
|
|
63
|
+
from pathlib import Path
|
|
64
|
+
from typing import Literal
|
|
65
|
+
|
|
66
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
67
|
+
|
|
68
|
+
from .._utils import run_in_executor
|
|
69
|
+
from ..exceptions import UserError
|
|
70
|
+
from . import AbstractToolDefinition, Model, cached_async_http_client
|
|
71
|
+
from .gemini import GeminiAgentModel, GeminiModelName
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
import google.auth
|
|
75
|
+
from google.auth.credentials import Credentials as BaseCredentials
|
|
76
|
+
from google.auth.transport.requests import Request
|
|
77
|
+
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
|
78
|
+
except ImportError as e:
|
|
79
|
+
raise ImportError(
|
|
80
|
+
'Please install `google-auth` to use the VertexAI model, '
|
|
81
|
+
"you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
|
|
82
|
+
) from e
|
|
83
|
+
|
|
84
|
+
VERTEX_AI_URL_TEMPLATE = (
|
|
85
|
+
'https://{region}-aiplatform.googleapis.com/v1'
|
|
86
|
+
'/projects/{project_id}'
|
|
87
|
+
'/locations/{region}'
|
|
88
|
+
'/publishers/{model_publisher}'
|
|
89
|
+
'/models/{model}'
|
|
90
|
+
':'
|
|
91
|
+
)
|
|
92
|
+
"""URL template for Vertex AI.
|
|
93
|
+
|
|
94
|
+
See
|
|
95
|
+
[`generateContent` docs](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent)
|
|
96
|
+
and
|
|
97
|
+
[`streamGenerateContent` docs](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent)
|
|
98
|
+
for more information.
|
|
99
|
+
|
|
100
|
+
The template is used thus:
|
|
101
|
+
|
|
102
|
+
* `region` is substituted with the `region` argument,
|
|
103
|
+
see [available regions][pydantic_ai.models.vertexai.VertexAiRegion]
|
|
104
|
+
* `model_publisher` is substituted with the `model_publisher` argument
|
|
105
|
+
* `model` is substituted with the `model_name` argument
|
|
106
|
+
* `project_id` is substituted with the `project_id` from auth/credentials
|
|
107
|
+
* `function` (`generateContent` or `streamGenerateContent`) is added to the end of the URL
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@dataclass(init=False)
|
|
112
|
+
class VertexAIModel(Model):
|
|
113
|
+
"""A model that uses Gemini via the `*-aiplatform.googleapis.com` VertexAI API."""
|
|
114
|
+
|
|
115
|
+
model_name: GeminiModelName
|
|
116
|
+
service_account_file: Path | str | None
|
|
117
|
+
project_id: str | None
|
|
118
|
+
region: VertexAiRegion
|
|
119
|
+
model_publisher: Literal['google']
|
|
120
|
+
http_client: AsyncHTTPClient
|
|
121
|
+
url_template: str
|
|
122
|
+
|
|
123
|
+
auth: BearerTokenAuth | None
|
|
124
|
+
url: str | None
|
|
125
|
+
|
|
126
|
+
# TODO __init__ can be removed once we drop 3.9 and we can set kw_only correctly on the dataclass
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
model_name: GeminiModelName,
|
|
130
|
+
*,
|
|
131
|
+
service_account_file: Path | str | None = None,
|
|
132
|
+
project_id: str | None = None,
|
|
133
|
+
region: VertexAiRegion = 'us-central1',
|
|
134
|
+
model_publisher: Literal['google'] = 'google',
|
|
135
|
+
http_client: AsyncHTTPClient | None = None,
|
|
136
|
+
url_template: str = VERTEX_AI_URL_TEMPLATE,
|
|
137
|
+
):
|
|
138
|
+
"""Initialize a Vertex AI Gemini model.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
model_name: The name of the model to use. I couldn't find a list of supported Google models, in VertexAI
|
|
142
|
+
so for now this uses the same models as the [Gemini model][pydantic_ai.models.gemini.GeminiModel].
|
|
143
|
+
service_account_file: Path to a service account file.
|
|
144
|
+
If not provided, the default environment credentials will be used.
|
|
145
|
+
project_id: The project ID to use, if not provided it will be taken from the credentials.
|
|
146
|
+
region: The region to make requests to.
|
|
147
|
+
model_publisher: The model publisher to use, I couldn't find a good list of available publishers,
|
|
148
|
+
and from trial and error it seems non-google models don't work with the `generateContent` and
|
|
149
|
+
`streamGenerateContent` functions, hence only `google` is currently supported.
|
|
150
|
+
Please create an issue or PR if you know how to use other publishers.
|
|
151
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
152
|
+
url_template: URL template for Vertex AI, see
|
|
153
|
+
[`VERTEX_AI_URL_TEMPLATE` docs][pydantic_ai.models.vertexai.VERTEX_AI_URL_TEMPLATE]
|
|
154
|
+
for more information.
|
|
155
|
+
"""
|
|
156
|
+
self.model_name = model_name
|
|
157
|
+
self.service_account_file = service_account_file
|
|
158
|
+
self.project_id = project_id
|
|
159
|
+
self.region = region
|
|
160
|
+
self.model_publisher = model_publisher
|
|
161
|
+
self.http_client = http_client or cached_async_http_client()
|
|
162
|
+
self.url_template = url_template
|
|
163
|
+
|
|
164
|
+
self.auth = None
|
|
165
|
+
self.url = None
|
|
166
|
+
|
|
167
|
+
async def agent_model(
|
|
168
|
+
self,
|
|
169
|
+
retrievers: Mapping[str, AbstractToolDefinition],
|
|
170
|
+
allow_text_result: bool,
|
|
171
|
+
result_tools: Sequence[AbstractToolDefinition] | None,
|
|
172
|
+
) -> GeminiAgentModel:
|
|
173
|
+
url, auth = await self._ainit()
|
|
174
|
+
return GeminiAgentModel(
|
|
175
|
+
http_client=self.http_client,
|
|
176
|
+
model_name=self.model_name,
|
|
177
|
+
auth=auth,
|
|
178
|
+
url=url,
|
|
179
|
+
retrievers=retrievers,
|
|
180
|
+
allow_text_result=allow_text_result,
|
|
181
|
+
result_tools=result_tools,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
async def _ainit(self) -> tuple[str, BearerTokenAuth]:
|
|
185
|
+
if self.url is not None and self.auth is not None:
|
|
186
|
+
return self.url, self.auth
|
|
187
|
+
|
|
188
|
+
if self.service_account_file is not None:
|
|
189
|
+
creds: BaseCredentials | ServiceAccountCredentials = _creds_from_file(self.service_account_file)
|
|
190
|
+
assert creds.project_id is None or isinstance(creds.project_id, str)
|
|
191
|
+
creds_project_id: str | None = creds.project_id
|
|
192
|
+
creds_source = 'service account file'
|
|
193
|
+
else:
|
|
194
|
+
creds, creds_project_id = await _async_google_auth()
|
|
195
|
+
creds_source = '`google.auth.default()`'
|
|
196
|
+
|
|
197
|
+
if self.project_id is None:
|
|
198
|
+
if creds_project_id is None:
|
|
199
|
+
raise UserError(f'No project_id provided and none found in {creds_source}')
|
|
200
|
+
project_id = creds_project_id
|
|
201
|
+
else:
|
|
202
|
+
if creds_project_id is not None and self.project_id != creds_project_id:
|
|
203
|
+
raise UserError(
|
|
204
|
+
f'The project_id you provided does not match the one from {creds_source}: '
|
|
205
|
+
f'{self.project_id!r} != {creds_project_id!r}'
|
|
206
|
+
)
|
|
207
|
+
project_id = self.project_id
|
|
208
|
+
|
|
209
|
+
self.url = url = self.url_template.format(
|
|
210
|
+
region=self.region,
|
|
211
|
+
project_id=project_id,
|
|
212
|
+
model_publisher=self.model_publisher,
|
|
213
|
+
model=self.model_name,
|
|
214
|
+
)
|
|
215
|
+
self.auth = auth = BearerTokenAuth(creds)
|
|
216
|
+
return url, auth
|
|
217
|
+
|
|
218
|
+
def name(self) -> str:
|
|
219
|
+
return f'vertexai:{self.model_name}'
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# pyright: reportUnknownMemberType=false
|
|
223
|
+
def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
|
|
224
|
+
return ServiceAccountCredentials.from_service_account_file(
|
|
225
|
+
str(service_account_file), scopes=['https://www.googleapis.com/auth/cloud-platform']
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# pyright: reportReturnType=false
|
|
230
|
+
# pyright: reportUnknownVariableType=false
|
|
231
|
+
# pyright: reportUnknownArgumentType=false
|
|
232
|
+
async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
|
|
233
|
+
return await run_in_executor(google.auth.default)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# default expiry is 3600 seconds
|
|
237
|
+
MAX_TOKEN_AGE = timedelta(seconds=3000)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@dataclass
|
|
241
|
+
class BearerTokenAuth:
|
|
242
|
+
credentials: BaseCredentials | ServiceAccountCredentials
|
|
243
|
+
token_created: datetime | None = field(default=None, init=False)
|
|
244
|
+
|
|
245
|
+
async def headers(self) -> dict[str, str]:
|
|
246
|
+
if self.credentials.token is None or self._token_expired():
|
|
247
|
+
await run_in_executor(self._refresh_token)
|
|
248
|
+
self.token_created = datetime.now()
|
|
249
|
+
return {'Authorization': f'Bearer {self.credentials.token}'}
|
|
250
|
+
|
|
251
|
+
def _token_expired(self) -> bool:
|
|
252
|
+
if self.token_created is None:
|
|
253
|
+
return True
|
|
254
|
+
else:
|
|
255
|
+
return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
|
|
256
|
+
|
|
257
|
+
def _refresh_token(self) -> str:
|
|
258
|
+
self.credentials.refresh(Request())
|
|
259
|
+
assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}'
|
|
260
|
+
return self.credentials.token
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
VertexAiRegion = Literal[
|
|
264
|
+
'us-central1',
|
|
265
|
+
'us-east1',
|
|
266
|
+
'us-east4',
|
|
267
|
+
'us-south1',
|
|
268
|
+
'us-west1',
|
|
269
|
+
'us-west2',
|
|
270
|
+
'us-west3',
|
|
271
|
+
'us-west4',
|
|
272
|
+
'us-east5',
|
|
273
|
+
'europe-central2',
|
|
274
|
+
'europe-north1',
|
|
275
|
+
'europe-southwest1',
|
|
276
|
+
'europe-west1',
|
|
277
|
+
'europe-west2',
|
|
278
|
+
'europe-west3',
|
|
279
|
+
'europe-west4',
|
|
280
|
+
'europe-west6',
|
|
281
|
+
'europe-west8',
|
|
282
|
+
'europe-west9',
|
|
283
|
+
'europe-west12',
|
|
284
|
+
'africa-south1',
|
|
285
|
+
'asia-east1',
|
|
286
|
+
'asia-east2',
|
|
287
|
+
'asia-northeast1',
|
|
288
|
+
'asia-northeast2',
|
|
289
|
+
'asia-northeast3',
|
|
290
|
+
'asia-south1',
|
|
291
|
+
'asia-southeast1',
|
|
292
|
+
'asia-southeast2',
|
|
293
|
+
'australia-southeast1',
|
|
294
|
+
'australia-southeast2',
|
|
295
|
+
'me-central1',
|
|
296
|
+
'me-central2',
|
|
297
|
+
'me-west1',
|
|
298
|
+
'northamerica-northeast1',
|
|
299
|
+
'northamerica-northeast2',
|
|
300
|
+
'southamerica-east1',
|
|
301
|
+
'southamerica-west1',
|
|
302
|
+
]
|
|
303
|
+
"""Regions available for Vertex AI.
|
|
304
|
+
|
|
305
|
+
More details [here](https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints).
|
|
306
|
+
"""
|
pydantic_ai/py.typed
ADDED
|
File without changes
|