pydantic-ai-slim 0.0.6a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,389 @@
1
+ """Utilities for testing apps built with PydanticAI."""
2
+
3
+ from __future__ import annotations as _annotations
4
+
5
+ import re
6
+ import string
7
+ from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
8
+ from contextlib import asynccontextmanager
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ from typing import Any, Literal
12
+
13
+ import pydantic_core
14
+
15
+ from .. import _utils
16
+ from ..messages import (
17
+ Message,
18
+ ModelAnyResponse,
19
+ ModelStructuredResponse,
20
+ ModelTextResponse,
21
+ RetryPrompt,
22
+ ToolCall,
23
+ ToolReturn,
24
+ )
25
+ from ..result import Cost
26
+ from . import (
27
+ AbstractToolDefinition,
28
+ AgentModel,
29
+ EitherStreamedResponse,
30
+ Model,
31
+ StreamStructuredResponse,
32
+ StreamTextResponse,
33
+ )
34
+
35
+
36
+ class UnSetType:
37
+ def __repr__(self):
38
+ return 'UnSet'
39
+
40
+
41
+ UnSet = UnSetType()
42
+
43
+
44
+ @dataclass
45
+ class TestModel(Model):
46
+ """A model specifically for testing purposes.
47
+
48
+ This will (by default) call all retrievers in the agent model, then return a tool response if possible,
49
+ otherwise a plain response.
50
+
51
+ How useful this function will be is unknown, it may be useless, it may require significant changes to be useful.
52
+
53
+ Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those
54
+ of the base class.
55
+ """
56
+
57
+ # NOTE: Avoid test discovery by pytest.
58
+ __test__ = False
59
+
60
+ call_retrievers: list[str] | Literal['all'] = 'all'
61
+ """List of retrievers to call. If `'all'`, all retrievers will be called."""
62
+ custom_result_text: str | None = None
63
+ """If set, this text is return as the final result."""
64
+ custom_result_args: Any | None = None
65
+ """If set, these args will be passed to the result tool."""
66
+ seed: int = 0
67
+ """Seed for generating random data."""
68
+ # these fields are set when the model is called by the agent
69
+ agent_model_retrievers: Mapping[str, AbstractToolDefinition] | None = field(default=None, init=False)
70
+ agent_model_allow_text_result: bool | None = field(default=None, init=False)
71
+ agent_model_result_tools: list[AbstractToolDefinition] | None = field(default=None, init=False)
72
+
73
+ async def agent_model(
74
+ self,
75
+ retrievers: Mapping[str, AbstractToolDefinition],
76
+ allow_text_result: bool,
77
+ result_tools: Sequence[AbstractToolDefinition] | None,
78
+ ) -> AgentModel:
79
+ self.agent_model_retrievers = retrievers
80
+ self.agent_model_allow_text_result = allow_text_result
81
+ self.agent_model_result_tools = list(result_tools) if result_tools is not None else None
82
+
83
+ if self.call_retrievers == 'all':
84
+ retriever_calls = [(r.name, r) for r in retrievers.values()]
85
+ else:
86
+ retrievers_to_call = (retrievers[name] for name in self.call_retrievers)
87
+ retriever_calls = [(r.name, r) for r in retrievers_to_call]
88
+
89
+ if self.custom_result_text is not None:
90
+ assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
91
+ assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
92
+ result: _utils.Either[str | None, Any | None] = _utils.Either(left=self.custom_result_text)
93
+ elif self.custom_result_args is not None:
94
+ assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
95
+ result_tool = result_tools[0]
96
+
97
+ if k := result_tool.outer_typed_dict_key:
98
+ result = _utils.Either(right={k: self.custom_result_args})
99
+ else:
100
+ result = _utils.Either(right=self.custom_result_args)
101
+ elif allow_text_result:
102
+ result = _utils.Either(left=None)
103
+ elif result_tools is not None:
104
+ result = _utils.Either(right=None)
105
+ else:
106
+ result = _utils.Either(left=None)
107
+ return TestAgentModel(retriever_calls, result, self.agent_model_result_tools, self.seed)
108
+
109
+ def name(self) -> str:
110
+ return 'test-model'
111
+
112
+
113
+ @dataclass
114
+ class TestAgentModel(AgentModel):
115
+ """Implementation of `AgentModel` for testing purposes."""
116
+
117
+ # NOTE: Avoid test discovery by pytest.
118
+ __test__ = False
119
+
120
+ retriever_calls: list[tuple[str, AbstractToolDefinition]]
121
+ # left means the text is plain text; right means it's a function call
122
+ result: _utils.Either[str | None, Any | None]
123
+ result_tools: list[AbstractToolDefinition] | None
124
+ seed: int
125
+ step: int = 0
126
+ last_message_count: int = 0
127
+
128
+ async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
129
+ return self._request(messages), Cost()
130
+
131
+ @asynccontextmanager
132
+ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
133
+ msg = self._request(messages)
134
+ cost = Cost()
135
+ if isinstance(msg, ModelTextResponse):
136
+ yield TestStreamTextResponse(msg.content, cost)
137
+ else:
138
+ yield TestStreamStructuredResponse(msg, cost)
139
+
140
+ def gen_retriever_args(self, tool_def: AbstractToolDefinition) -> Any:
141
+ return _JsonSchemaTestData(tool_def.json_schema, self.seed).generate()
142
+
143
+ def _request(self, messages: list[Message]) -> ModelAnyResponse:
144
+ if self.step == 0 and self.retriever_calls:
145
+ calls = [ToolCall.from_object(name, self.gen_retriever_args(args)) for name, args in self.retriever_calls]
146
+ self.step += 1
147
+ self.last_message_count = len(messages)
148
+ return ModelStructuredResponse(calls=calls)
149
+
150
+ new_messages = messages[self.last_message_count :]
151
+ self.last_message_count = len(messages)
152
+ new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)}
153
+ if new_retry_names:
154
+ calls = [
155
+ ToolCall.from_object(name, self.gen_retriever_args(args))
156
+ for name, args in self.retriever_calls
157
+ if name in new_retry_names
158
+ ]
159
+ self.step += 1
160
+ return ModelStructuredResponse(calls=calls)
161
+ else:
162
+ if response_text := self.result.left:
163
+ self.step += 1
164
+ if response_text.value is None:
165
+ # build up details of retriever responses
166
+ output: dict[str, Any] = {}
167
+ for message in messages:
168
+ if isinstance(message, ToolReturn):
169
+ output[message.tool_name] = message.content
170
+ if output:
171
+ return ModelTextResponse(content=pydantic_core.to_json(output).decode())
172
+ else:
173
+ return ModelTextResponse(content='success (no retriever calls)')
174
+ else:
175
+ return ModelTextResponse(content=response_text.value)
176
+ else:
177
+ assert self.result_tools is not None, 'No result tools provided'
178
+ custom_result_args = self.result.right
179
+ result_tool = self.result_tools[self.seed % len(self.result_tools)]
180
+ if custom_result_args is not None:
181
+ self.step += 1
182
+ return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, custom_result_args)])
183
+ else:
184
+ response_args = self.gen_retriever_args(result_tool)
185
+ self.step += 1
186
+ return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, response_args)])
187
+
188
+
189
+ @dataclass
190
+ class TestStreamTextResponse(StreamTextResponse):
191
+ _text: str
192
+ _cost: Cost
193
+ _iter: Iterator[str] = field(init=False)
194
+ _timestamp: datetime = field(default_factory=_utils.now_utc)
195
+ _buffer: list[str] = field(default_factory=list, init=False)
196
+
197
+ def __post_init__(self):
198
+ *words, last_word = self._text.split(' ')
199
+ words = [f'{word} ' for word in words]
200
+ words.append(last_word)
201
+ if len(words) == 1 and len(self._text) > 2:
202
+ mid = len(self._text) // 2
203
+ words = [self._text[:mid], self._text[mid:]]
204
+ self._iter = iter(words)
205
+
206
+ async def __anext__(self) -> None:
207
+ self._buffer.append(_utils.sync_anext(self._iter))
208
+
209
+ def get(self, *, final: bool = False) -> Iterable[str]:
210
+ yield from self._buffer
211
+ self._buffer.clear()
212
+
213
+ def cost(self) -> Cost:
214
+ return self._cost
215
+
216
+ def timestamp(self) -> datetime:
217
+ return self._timestamp
218
+
219
+
220
+ @dataclass
221
+ class TestStreamStructuredResponse(StreamStructuredResponse):
222
+ _structured_response: ModelStructuredResponse
223
+ _cost: Cost
224
+ _iter: Iterator[None] = field(default_factory=lambda: iter([None]))
225
+ _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
226
+
227
+ async def __anext__(self) -> None:
228
+ return _utils.sync_anext(self._iter)
229
+
230
+ def get(self, *, final: bool = False) -> ModelStructuredResponse:
231
+ return self._structured_response
232
+
233
+ def cost(self) -> Cost:
234
+ return self._cost
235
+
236
+ def timestamp(self) -> datetime:
237
+ return self._timestamp
238
+
239
+
240
+ _chars = string.ascii_letters + string.digits + string.punctuation
241
+
242
+
243
+ class _JsonSchemaTestData:
244
+ """Generate data that matches a JSON schema.
245
+
246
+ This tries to generate the minimal viable data for the schema.
247
+ """
248
+
249
+ def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0):
250
+ self.schema = schema
251
+ self.defs = schema.get('$defs', {})
252
+ self.seed = seed
253
+
254
+ def generate(self) -> Any:
255
+ """Generate data for the JSON schema."""
256
+ return self._gen_any(self.schema)
257
+
258
+ def _gen_any(self, schema: dict[str, Any]) -> Any:
259
+ """Generate data for any JSON Schema."""
260
+ if const := schema.get('const'):
261
+ return const
262
+ elif enum := schema.get('enum'):
263
+ return enum[self.seed % len(enum)]
264
+ elif examples := schema.get('examples'):
265
+ return examples[self.seed % len(examples)]
266
+ elif ref := schema.get('$ref'):
267
+ key = re.sub(r'^#/\$defs/', '', ref)
268
+ js_def = self.defs[key]
269
+ return self._gen_any(js_def)
270
+ elif any_of := schema.get('anyOf'):
271
+ return self._gen_any(any_of[self.seed % len(any_of)])
272
+
273
+ type_ = schema.get('type')
274
+ if type_ is None:
275
+ # if there's no type or ref, we can't generate anything
276
+ return self._char()
277
+ elif type_ == 'object':
278
+ return self._object_gen(schema)
279
+ elif type_ == 'string':
280
+ return self._str_gen(schema)
281
+ elif type_ == 'integer':
282
+ return self._int_gen(schema)
283
+ elif type_ == 'number':
284
+ return float(self._int_gen(schema))
285
+ elif type_ == 'boolean':
286
+ return self._bool_gen()
287
+ elif type_ == 'array':
288
+ return self._array_gen(schema)
289
+ elif type_ == 'null':
290
+ return None
291
+ else:
292
+ raise NotImplementedError(f'Unknown type: {type_}, please submit a PR to extend JsonSchemaTestData!')
293
+
294
+ def _object_gen(self, schema: dict[str, Any]) -> dict[str, Any]:
295
+ """Generate data for a JSON Schema object."""
296
+ required = set(schema.get('required', []))
297
+
298
+ data: dict[str, Any] = {}
299
+ if properties := schema.get('properties'):
300
+ for key, value in properties.items():
301
+ if key in required:
302
+ data[key] = self._gen_any(value)
303
+
304
+ if addition_props := schema.get('additionalProperties'):
305
+ add_prop_key = 'additionalProperty'
306
+ while add_prop_key in data:
307
+ add_prop_key += '_'
308
+ if addition_props is True:
309
+ data[add_prop_key] = self._char()
310
+ else:
311
+ data[add_prop_key] = self._gen_any(addition_props)
312
+
313
+ return data
314
+
315
+ def _str_gen(self, schema: dict[str, Any]) -> str:
316
+ """Generate a string from a JSON Schema string."""
317
+ min_len = schema.get('minLength')
318
+ if min_len is not None:
319
+ return self._char() * min_len
320
+
321
+ if schema.get('maxLength') == 0:
322
+ return ''
323
+ else:
324
+ return self._char()
325
+
326
+ def _int_gen(self, schema: dict[str, Any]) -> int:
327
+ """Generate an integer from a JSON Schema integer."""
328
+ maximum = schema.get('maximum')
329
+ if maximum is None:
330
+ exc_max = schema.get('exclusiveMaximum')
331
+ if exc_max is not None:
332
+ maximum = exc_max - 1
333
+
334
+ minimum = schema.get('minimum')
335
+ if minimum is None:
336
+ exc_min = schema.get('exclusiveMinimum')
337
+ if exc_min is not None:
338
+ minimum = exc_min + 1
339
+
340
+ if minimum is not None and maximum is not None:
341
+ return minimum + self.seed % (maximum - minimum)
342
+ elif minimum is not None:
343
+ return minimum + self.seed
344
+ elif maximum is not None:
345
+ return maximum - self.seed
346
+ else:
347
+ return self.seed
348
+
349
+ def _bool_gen(self) -> bool:
350
+ """Generate a boolean from a JSON Schema boolean."""
351
+ return bool(self.seed % 2)
352
+
353
+ def _array_gen(self, schema: dict[str, Any]) -> list[Any]:
354
+ """Generate an array from a JSON Schema array."""
355
+ data: list[Any] = []
356
+ unique_items = schema.get('uniqueItems')
357
+ if prefix_items := schema.get('prefixItems'):
358
+ for item in prefix_items:
359
+ data.append(self._gen_any(item))
360
+ if unique_items:
361
+ self.seed += 1
362
+
363
+ items_schema = schema.get('items', {})
364
+ min_items = schema.get('minItems', 0)
365
+ if min_items > len(data):
366
+ for _ in range(min_items - len(data)):
367
+ data.append(self._gen_any(items_schema))
368
+ if unique_items:
369
+ self.seed += 1
370
+ elif items_schema:
371
+ # if there is an `items` schema, add an item unless it would break `maxItems` rule
372
+ max_items = schema.get('maxItems')
373
+ if max_items is None or max_items > len(data):
374
+ data.append(self._gen_any(items_schema))
375
+ if unique_items:
376
+ self.seed += 1
377
+
378
+ return data
379
+
380
+ def _char(self) -> str:
381
+ """Generate a character on the same principle as Excel columns, e.g. a-z, aa-az..."""
382
+ chars = len(_chars)
383
+ s = ''
384
+ rem = self.seed // chars
385
+ while rem > 0:
386
+ s += _chars[(rem - 1) % chars]
387
+ rem //= chars
388
+ s += _chars[self.seed % chars]
389
+ return s
@@ -0,0 +1,306 @@
1
+ """Custom interface to the `*-aiplatform.googleapis.com` API for Gemini models.
2
+
3
+ This model uses [`GeminiAgentModel`][pydantic_ai.models.gemini.GeminiAgentModel] with just the URL and auth method
4
+ changed from the default `GeminiModel`, it relies on the VertexAI
5
+ [`generateContent` function endpoint](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent)
6
+ and `streamGenerateContent` function endpoints
7
+ having the same schemas as the equivalent [Gemini endpoints][pydantic_ai.models.gemini.GeminiModel].
8
+
9
+ There are four advantages of using this API over the `generativelanguage.googleapis.com` API which
10
+ [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] uses, and one big disadvantage.
11
+
12
+ Advantages:
13
+
14
+ 1. The VertexAI API seems to be less flakey, less likely to occasionally return a 503 response.
15
+ 2. You can
16
+ [purchase provisioned throughput](https://cloud.google.com/vertex-ai/generative-ai/docs/provisioned-throughput#purchase-provisioned-throughput)
17
+ with VertexAI.
18
+ 3. If you're running PydanticAI inside GCP, you don't need to set up authentication, it should "just work".
19
+ 4. You can decide which region to use, which might be important from a regulatory perspective,
20
+ and might improve latency.
21
+
22
+ Disadvantage:
23
+
24
+ 1. When authorization doesn't just work, it's much more painful to set up than an API key.
25
+
26
+ ## Example Usage
27
+
28
+ With the default google project already configured in your environment:
29
+
30
+ ```py title="vertex_example_env.py"
31
+ from pydantic_ai import Agent
32
+ from pydantic_ai.models.vertexai import VertexAIModel
33
+
34
+ model = VertexAIModel('gemini-1.5-flash')
35
+ agent = Agent(model)
36
+ result = agent.run_sync('Tell me a joke.')
37
+ print(result.data)
38
+ #> Did you hear about the toothpaste scandal? They called it Colgate.
39
+ ```
40
+
41
+ Or using a service account JSON file:
42
+
43
+ ```py title="vertex_example_service_account.py"
44
+ from pydantic_ai import Agent
45
+ from pydantic_ai.models.vertexai import VertexAIModel
46
+
47
+ model = VertexAIModel(
48
+ 'gemini-1.5-flash',
49
+ service_account_file='path/to/service-account.json',
50
+ )
51
+ agent = Agent(model)
52
+ result = agent.run_sync('Tell me a joke.')
53
+ print(result.data)
54
+ #> Did you hear about the toothpaste scandal? They called it Colgate.
55
+ ```
56
+ """
57
+
58
+ from __future__ import annotations as _annotations
59
+
60
+ from collections.abc import Mapping, Sequence
61
+ from dataclasses import dataclass, field
62
+ from datetime import datetime, timedelta
63
+ from pathlib import Path
64
+ from typing import Literal
65
+
66
+ from httpx import AsyncClient as AsyncHTTPClient
67
+
68
+ from .._utils import run_in_executor
69
+ from ..exceptions import UserError
70
+ from . import AbstractToolDefinition, Model, cached_async_http_client
71
+ from .gemini import GeminiAgentModel, GeminiModelName
72
+
73
+ try:
74
+ import google.auth
75
+ from google.auth.credentials import Credentials as BaseCredentials
76
+ from google.auth.transport.requests import Request
77
+ from google.oauth2.service_account import Credentials as ServiceAccountCredentials
78
+ except ImportError as e:
79
+ raise ImportError(
80
+ 'Please install `google-auth` to use the VertexAI model, '
81
+ "you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
82
+ ) from e
83
+
84
+ VERTEX_AI_URL_TEMPLATE = (
85
+ 'https://{region}-aiplatform.googleapis.com/v1'
86
+ '/projects/{project_id}'
87
+ '/locations/{region}'
88
+ '/publishers/{model_publisher}'
89
+ '/models/{model}'
90
+ ':'
91
+ )
92
+ """URL template for Vertex AI.
93
+
94
+ See
95
+ [`generateContent` docs](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent)
96
+ and
97
+ [`streamGenerateContent` docs](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent)
98
+ for more information.
99
+
100
+ The template is used thus:
101
+
102
+ * `region` is substituted with the `region` argument,
103
+ see [available regions][pydantic_ai.models.vertexai.VertexAiRegion]
104
+ * `model_publisher` is substituted with the `model_publisher` argument
105
+ * `model` is substituted with the `model_name` argument
106
+ * `project_id` is substituted with the `project_id` from auth/credentials
107
+ * `function` (`generateContent` or `streamGenerateContent`) is added to the end of the URL
108
+ """
109
+
110
+
111
+ @dataclass(init=False)
112
+ class VertexAIModel(Model):
113
+ """A model that uses Gemini via the `*-aiplatform.googleapis.com` VertexAI API."""
114
+
115
+ model_name: GeminiModelName
116
+ service_account_file: Path | str | None
117
+ project_id: str | None
118
+ region: VertexAiRegion
119
+ model_publisher: Literal['google']
120
+ http_client: AsyncHTTPClient
121
+ url_template: str
122
+
123
+ auth: BearerTokenAuth | None
124
+ url: str | None
125
+
126
+ # TODO __init__ can be removed once we drop 3.9 and we can set kw_only correctly on the dataclass
127
+ def __init__(
128
+ self,
129
+ model_name: GeminiModelName,
130
+ *,
131
+ service_account_file: Path | str | None = None,
132
+ project_id: str | None = None,
133
+ region: VertexAiRegion = 'us-central1',
134
+ model_publisher: Literal['google'] = 'google',
135
+ http_client: AsyncHTTPClient | None = None,
136
+ url_template: str = VERTEX_AI_URL_TEMPLATE,
137
+ ):
138
+ """Initialize a Vertex AI Gemini model.
139
+
140
+ Args:
141
+ model_name: The name of the model to use. I couldn't find a list of supported Google models, in VertexAI
142
+ so for now this uses the same models as the [Gemini model][pydantic_ai.models.gemini.GeminiModel].
143
+ service_account_file: Path to a service account file.
144
+ If not provided, the default environment credentials will be used.
145
+ project_id: The project ID to use, if not provided it will be taken from the credentials.
146
+ region: The region to make requests to.
147
+ model_publisher: The model publisher to use, I couldn't find a good list of available publishers,
148
+ and from trial and error it seems non-google models don't work with the `generateContent` and
149
+ `streamGenerateContent` functions, hence only `google` is currently supported.
150
+ Please create an issue or PR if you know how to use other publishers.
151
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
152
+ url_template: URL template for Vertex AI, see
153
+ [`VERTEX_AI_URL_TEMPLATE` docs][pydantic_ai.models.vertexai.VERTEX_AI_URL_TEMPLATE]
154
+ for more information.
155
+ """
156
+ self.model_name = model_name
157
+ self.service_account_file = service_account_file
158
+ self.project_id = project_id
159
+ self.region = region
160
+ self.model_publisher = model_publisher
161
+ self.http_client = http_client or cached_async_http_client()
162
+ self.url_template = url_template
163
+
164
+ self.auth = None
165
+ self.url = None
166
+
167
+ async def agent_model(
168
+ self,
169
+ retrievers: Mapping[str, AbstractToolDefinition],
170
+ allow_text_result: bool,
171
+ result_tools: Sequence[AbstractToolDefinition] | None,
172
+ ) -> GeminiAgentModel:
173
+ url, auth = await self._ainit()
174
+ return GeminiAgentModel(
175
+ http_client=self.http_client,
176
+ model_name=self.model_name,
177
+ auth=auth,
178
+ url=url,
179
+ retrievers=retrievers,
180
+ allow_text_result=allow_text_result,
181
+ result_tools=result_tools,
182
+ )
183
+
184
+ async def _ainit(self) -> tuple[str, BearerTokenAuth]:
185
+ if self.url is not None and self.auth is not None:
186
+ return self.url, self.auth
187
+
188
+ if self.service_account_file is not None:
189
+ creds: BaseCredentials | ServiceAccountCredentials = _creds_from_file(self.service_account_file)
190
+ assert creds.project_id is None or isinstance(creds.project_id, str)
191
+ creds_project_id: str | None = creds.project_id
192
+ creds_source = 'service account file'
193
+ else:
194
+ creds, creds_project_id = await _async_google_auth()
195
+ creds_source = '`google.auth.default()`'
196
+
197
+ if self.project_id is None:
198
+ if creds_project_id is None:
199
+ raise UserError(f'No project_id provided and none found in {creds_source}')
200
+ project_id = creds_project_id
201
+ else:
202
+ if creds_project_id is not None and self.project_id != creds_project_id:
203
+ raise UserError(
204
+ f'The project_id you provided does not match the one from {creds_source}: '
205
+ f'{self.project_id!r} != {creds_project_id!r}'
206
+ )
207
+ project_id = self.project_id
208
+
209
+ self.url = url = self.url_template.format(
210
+ region=self.region,
211
+ project_id=project_id,
212
+ model_publisher=self.model_publisher,
213
+ model=self.model_name,
214
+ )
215
+ self.auth = auth = BearerTokenAuth(creds)
216
+ return url, auth
217
+
218
+ def name(self) -> str:
219
+ return f'vertexai:{self.model_name}'
220
+
221
+
222
+ # pyright: reportUnknownMemberType=false
223
+ def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
224
+ return ServiceAccountCredentials.from_service_account_file(
225
+ str(service_account_file), scopes=['https://www.googleapis.com/auth/cloud-platform']
226
+ )
227
+
228
+
229
+ # pyright: reportReturnType=false
230
+ # pyright: reportUnknownVariableType=false
231
+ # pyright: reportUnknownArgumentType=false
232
+ async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
233
+ return await run_in_executor(google.auth.default)
234
+
235
+
236
+ # default expiry is 3600 seconds
237
+ MAX_TOKEN_AGE = timedelta(seconds=3000)
238
+
239
+
240
+ @dataclass
241
+ class BearerTokenAuth:
242
+ credentials: BaseCredentials | ServiceAccountCredentials
243
+ token_created: datetime | None = field(default=None, init=False)
244
+
245
+ async def headers(self) -> dict[str, str]:
246
+ if self.credentials.token is None or self._token_expired():
247
+ await run_in_executor(self._refresh_token)
248
+ self.token_created = datetime.now()
249
+ return {'Authorization': f'Bearer {self.credentials.token}'}
250
+
251
+ def _token_expired(self) -> bool:
252
+ if self.token_created is None:
253
+ return True
254
+ else:
255
+ return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
256
+
257
+ def _refresh_token(self) -> str:
258
+ self.credentials.refresh(Request())
259
+ assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}'
260
+ return self.credentials.token
261
+
262
+
263
+ VertexAiRegion = Literal[
264
+ 'us-central1',
265
+ 'us-east1',
266
+ 'us-east4',
267
+ 'us-south1',
268
+ 'us-west1',
269
+ 'us-west2',
270
+ 'us-west3',
271
+ 'us-west4',
272
+ 'us-east5',
273
+ 'europe-central2',
274
+ 'europe-north1',
275
+ 'europe-southwest1',
276
+ 'europe-west1',
277
+ 'europe-west2',
278
+ 'europe-west3',
279
+ 'europe-west4',
280
+ 'europe-west6',
281
+ 'europe-west8',
282
+ 'europe-west9',
283
+ 'europe-west12',
284
+ 'africa-south1',
285
+ 'asia-east1',
286
+ 'asia-east2',
287
+ 'asia-northeast1',
288
+ 'asia-northeast2',
289
+ 'asia-northeast3',
290
+ 'asia-south1',
291
+ 'asia-southeast1',
292
+ 'asia-southeast2',
293
+ 'australia-southeast1',
294
+ 'australia-southeast2',
295
+ 'me-central1',
296
+ 'me-central2',
297
+ 'me-west1',
298
+ 'northamerica-northeast1',
299
+ 'northamerica-northeast2',
300
+ 'southamerica-east1',
301
+ 'southamerica-west1',
302
+ ]
303
+ """Regions available for Vertex AI.
304
+
305
+ More details [here](https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints).
306
+ """
pydantic_ai/py.typed ADDED
File without changes