pydantic-ai-slim 0.0.43__py3-none-any.whl → 0.0.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -6,7 +6,9 @@ from typing import overload
6
6
  import httpx
7
7
  from openai import AsyncOpenAI
8
8
 
9
+ from pydantic_ai.exceptions import UserError
9
10
  from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.providers import Provider
10
12
 
11
13
  try:
12
14
  from openai import AsyncAzureOpenAI
@@ -17,9 +19,6 @@ except ImportError as _import_error: # pragma: no cover
17
19
  ) from _import_error
18
20
 
19
21
 
20
- from . import Provider
21
-
22
-
23
22
  class AzureProvider(Provider[AsyncOpenAI]):
24
23
  """Provider for Azure OpenAI API.
25
24
 
@@ -83,18 +82,18 @@ class AzureProvider(Provider[AsyncOpenAI]):
83
82
  self._client = openai_client
84
83
  else:
85
84
  azure_endpoint = azure_endpoint or os.getenv('AZURE_OPENAI_ENDPOINT')
86
- if azure_endpoint is None: # pragma: no cover
87
- raise ValueError(
85
+ if not azure_endpoint: # pragma: no cover
86
+ raise UserError(
88
87
  'Must provide one of the `azure_endpoint` argument or the `AZURE_OPENAI_ENDPOINT` environment variable'
89
88
  )
90
89
 
91
- if api_key is None and 'OPENAI_API_KEY' not in os.environ: # pragma: no cover
92
- raise ValueError(
90
+ if not api_key and 'OPENAI_API_KEY' not in os.environ: # pragma: no cover
91
+ raise UserError(
93
92
  'Must provide one of the `api_key` argument or the `OPENAI_API_KEY` environment variable'
94
93
  )
95
94
 
96
- if api_version is None and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover
97
- raise ValueError(
95
+ if not api_version and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover
96
+ raise UserError(
98
97
  'Must provide one of the `api_version` argument or the `OPENAI_API_VERSION` environment variable'
99
98
  )
100
99
 
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from typing import overload
4
4
 
5
+ from pydantic_ai.exceptions import UserError
5
6
  from pydantic_ai.providers import Provider
6
7
 
7
8
  try:
@@ -73,4 +74,4 @@ class BedrockProvider(Provider[BaseClient]):
73
74
  region_name=region_name,
74
75
  )
75
76
  except NoRegionError as exc: # pragma: no cover
76
- raise ValueError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
77
+ raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+
5
+ from httpx import AsyncClient as AsyncHTTPClient
6
+
7
+ from pydantic_ai.exceptions import UserError
8
+ from pydantic_ai.models import cached_async_http_client
9
+ from pydantic_ai.providers import Provider
10
+
11
+ try:
12
+ from cohere import AsyncClientV2
13
+ except ImportError as _import_error: # pragma: no cover
14
+ raise ImportError(
15
+ 'Please install the `cohere` package to use the Cohere provider, '
16
+ 'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`'
17
+ ) from _import_error
18
+
19
+
20
+ class CohereProvider(Provider[AsyncClientV2]):
21
+ """Provider for Cohere API."""
22
+
23
+ @property
24
+ def name(self) -> str:
25
+ return 'cohere'
26
+
27
+ @property
28
+ def base_url(self) -> str:
29
+ client_wrapper = self.client._client_wrapper # type: ignore
30
+ return str(client_wrapper.get_base_url())
31
+
32
+ @property
33
+ def client(self) -> AsyncClientV2:
34
+ return self._client
35
+
36
+ def __init__(
37
+ self,
38
+ *,
39
+ api_key: str | None = None,
40
+ cohere_client: AsyncClientV2 | None = None,
41
+ http_client: AsyncHTTPClient | None = None,
42
+ ) -> None:
43
+ """Create a new Cohere provider.
44
+
45
+ Args:
46
+ api_key: The API key to use for authentication, if not provided, the `CO_API_KEY` environment variable
47
+ will be used if available.
48
+ cohere_client: An existing
49
+ [AsyncClientV2](https://github.com/cohere-ai/cohere-python)
50
+ client to use. If provided, `api_key` and `http_client` must be `None`.
51
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
52
+ """
53
+ if cohere_client is not None:
54
+ assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
55
+ assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
56
+ self._client = cohere_client
57
+ else:
58
+ api_key = api_key or os.environ.get('CO_API_KEY')
59
+ if not api_key:
60
+ raise UserError(
61
+ 'Set the `CO_API_KEY` environment variable or pass it via `CohereProvider(api_key=...)`'
62
+ 'to use the Cohere provider.'
63
+ )
64
+
65
+ base_url = os.environ.get('CO_BASE_URL')
66
+ if http_client is not None:
67
+ self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url)
68
+ else:
69
+ self._client = AsyncClientV2(
70
+ api_key=api_key, httpx_client=cached_async_http_client(), base_url=base_url
71
+ )
@@ -6,7 +6,9 @@ from typing import overload
6
6
  from httpx import AsyncClient as AsyncHTTPClient
7
7
  from openai import AsyncOpenAI
8
8
 
9
+ from pydantic_ai.exceptions import UserError
9
10
  from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.providers import Provider
10
12
 
11
13
  try:
12
14
  from openai import AsyncOpenAI
@@ -16,8 +18,6 @@ except ImportError as _import_error: # pragma: no cover
16
18
  'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
17
19
  ) from _import_error
18
20
 
19
- from . import Provider
20
-
21
21
 
22
22
  class DeepSeekProvider(Provider[AsyncOpenAI]):
23
23
  """Provider for DeepSeek API."""
@@ -54,8 +54,8 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
54
54
  http_client: AsyncHTTPClient | None = None,
55
55
  ) -> None:
56
56
  api_key = api_key or os.getenv('DEEPSEEK_API_KEY')
57
- if api_key is None and openai_client is None:
58
- raise ValueError(
57
+ if not api_key and openai_client is None:
58
+ raise UserError(
59
59
  'Set the `DEEPSEEK_API_KEY` environment variable or pass it via `DeepSeekProvider(api_key=...)`'
60
60
  'to use the DeepSeek provider.'
61
61
  )
@@ -4,6 +4,7 @@ import os
4
4
 
5
5
  import httpx
6
6
 
7
+ from pydantic_ai.exceptions import UserError
7
8
  from pydantic_ai.models import cached_async_http_client
8
9
  from pydantic_ai.providers import Provider
9
10
 
@@ -32,8 +33,8 @@ class GoogleGLAProvider(Provider[httpx.AsyncClient]):
32
33
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
33
34
  """
34
35
  api_key = api_key or os.environ.get('GEMINI_API_KEY')
35
- if api_key is None:
36
- raise ValueError(
36
+ if not api_key:
37
+ raise UserError(
37
38
  'Set the `GEMINI_API_KEY` environment variable or pass it via `GoogleGLAProvider(api_key=...)`'
38
39
  'to use the Google GLA provider.'
39
40
  )
@@ -9,9 +9,8 @@ import anyio.to_thread
9
9
  import httpx
10
10
 
11
11
  from pydantic_ai.exceptions import UserError
12
-
13
- from ..models import cached_async_http_client
14
- from . import Provider
12
+ from pydantic_ai.models import cached_async_http_client
13
+ from pydantic_ai.providers import Provider
15
14
 
16
15
  try:
17
16
  import google.auth
@@ -5,7 +5,9 @@ from typing import overload
5
5
 
6
6
  from httpx import AsyncClient as AsyncHTTPClient
7
7
 
8
+ from pydantic_ai.exceptions import UserError
8
9
  from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.providers import Provider
9
11
 
10
12
  try:
11
13
  from groq import AsyncGroq
@@ -16,9 +18,6 @@ except ImportError as _import_error: # pragma: no cover
16
18
  ) from _import_error
17
19
 
18
20
 
19
- from . import Provider
20
-
21
-
22
21
  class GroqProvider(Provider[AsyncGroq]):
23
22
  """Provider for Groq API."""
24
23
 
@@ -64,8 +63,8 @@ class GroqProvider(Provider[AsyncGroq]):
64
63
  else:
65
64
  api_key = api_key or os.environ.get('GROQ_API_KEY')
66
65
 
67
- if api_key is None:
68
- raise ValueError(
66
+ if not api_key:
67
+ raise UserError(
69
68
  'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`'
70
69
  'to use the Groq provider.'
71
70
  )
@@ -5,7 +5,9 @@ from typing import overload
5
5
 
6
6
  from httpx import AsyncClient as AsyncHTTPClient
7
7
 
8
+ from pydantic_ai.exceptions import UserError
8
9
  from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.providers import Provider
9
11
 
10
12
  try:
11
13
  from mistralai import Mistral
@@ -16,9 +18,6 @@ except ImportError as e: # pragma: no cover
16
18
  ) from e
17
19
 
18
20
 
19
- from . import Provider
20
-
21
-
22
21
  class MistralProvider(Provider[Mistral]):
23
22
  """Provider for Mistral API."""
24
23
 
@@ -62,8 +61,8 @@ class MistralProvider(Provider[Mistral]):
62
61
  else:
63
62
  api_key = api_key or os.environ.get('MISTRAL_API_KEY')
64
63
 
65
- if api_key is None:
66
- raise ValueError(
64
+ if not api_key:
65
+ raise UserError(
67
66
  'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
68
67
  'to use the Mistral provider.'
69
68
  )
@@ -5,6 +5,7 @@ import os
5
5
  import httpx
6
6
 
7
7
  from pydantic_ai.models import cached_async_http_client
8
+ from pydantic_ai.providers import Provider
8
9
 
9
10
  try:
10
11
  from openai import AsyncOpenAI
@@ -15,9 +16,6 @@ except ImportError as _import_error: # pragma: no cover
15
16
  ) from _import_error
16
17
 
17
18
 
18
- from . import Provider
19
-
20
-
21
19
  class OpenAIProvider(Provider[AsyncOpenAI]):
22
20
  """Provider for OpenAI API."""
23
21
 
@@ -27,7 +25,7 @@ class OpenAIProvider(Provider[AsyncOpenAI]):
27
25
 
28
26
  @property
29
27
  def base_url(self) -> str:
30
- return self._base_url
28
+ return str(self.client.base_url)
31
29
 
32
30
  @property
33
31
  def client(self) -> AsyncOpenAI:
@@ -52,10 +50,9 @@ class OpenAIProvider(Provider[AsyncOpenAI]):
52
50
  client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
53
51
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
54
52
  """
55
- self._base_url = base_url or os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1')
56
53
  # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
57
54
  # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
58
- if api_key is None and 'OPENAI_API_KEY' not in os.environ and openai_client is None:
55
+ if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
59
56
  api_key = 'api-key-not-set'
60
57
 
61
58
  if openai_client is not None:
@@ -64,6 +61,6 @@ class OpenAIProvider(Provider[AsyncOpenAI]):
64
61
  assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
65
62
  self._client = openai_client
66
63
  elif http_client is not None:
67
- self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
64
+ self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
68
65
  else:
69
- self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
66
+ self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.43
3
+ Version: 0.0.45
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.43
32
+ Requires-Dist: pydantic-graph==0.0.45
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -53,7 +53,7 @@ Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
53
53
  Provides-Extra: mistral
54
54
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
55
55
  Provides-Extra: openai
56
- Requires-Dist: openai>=1.65.1; extra == 'openai'
56
+ Requires-Dist: openai>=1.67.0; extra == 'openai'
57
57
  Provides-Extra: tavily
58
58
  Requires-Dist: tavily-python>=0.5.0; extra == 'tavily'
59
59
  Provides-Extra: vertexai
@@ -0,0 +1,50 @@
1
+ pydantic_ai/__init__.py,sha256=5or1fE25gmemJGCznkFHC4VMeNT7vTLU6BiGxkmSA2A,959
2
+ pydantic_ai/_agent_graph.py,sha256=aZHgDDEL0kYl7G0LAAPrrf4UufqB_FUN8s4PlnVqA-o,32557
3
+ pydantic_ai/_cli.py,sha256=YrNi4vodEH-o5MlfG4CemdkdJl5kQx56nyCMp2QqRMw,8610
4
+ pydantic_ai/_griffe.py,sha256=Sf_DisE9k2TA0VFeVIK2nf1oOct5MygW86PBCACJkFA,5244
5
+ pydantic_ai/_parts_manager.py,sha256=HIi6eth7z2g0tOn6iQYc633xMqy4d_xZ8vwka8J8150,12016
6
+ pydantic_ai/_pydantic.py,sha256=12hX5hON88meO1QxbWrEPXSvr6RTNgr6ubKY6KRwab4,8890
7
+ pydantic_ai/_result.py,sha256=SlxqR-AKWzDoc7cRRN2jmIZ7pCv3DKzaP-dnZW-e7us,10117
8
+ pydantic_ai/_system_prompt.py,sha256=602c2jyle2R_SesOrITBDETZqsLk4BZ8Cbo8yEhmx04,1120
9
+ pydantic_ai/_utils.py,sha256=s_cVIKiJk1wkLhXDRxxWZGd1QgXFey6HYJ0OGU8Kezs,9657
10
+ pydantic_ai/agent.py,sha256=QF3MQWKjglbtGVWLHSAizN0gDrC1d0B09BKNF03dSow,69491
11
+ pydantic_ai/exceptions.py,sha256=gvbFsFkAzSXOo_d1nfjy09kDHUGv1j5q70Uk-wKYGi8,3167
12
+ pydantic_ai/format_as_xml.py,sha256=QE7eMlg5-YUMw1_2kcI3h0uKYPZZyGkgXFDtfZTMeeI,4480
13
+ pydantic_ai/mcp.py,sha256=d6odfllUQ94ROnkZ1MeEMg5L23rKhOnEs_JVWkWAC-Y,7068
14
+ pydantic_ai/messages.py,sha256=KQXG8BLnQtZzgU4ykOJT6LIs03Vem2eb0VqMyXmFt7I,27067
15
+ pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ pydantic_ai/result.py,sha256=LXKxRzy_rGMkdZ8xJ7yknPP3wGZtGNeZl-gh5opXbaQ,22542
17
+ pydantic_ai/settings.py,sha256=q__Hordc4dypesNxpy_cBT5rFdSiEY-rQt9G6zfyFaM,3101
18
+ pydantic_ai/tools.py,sha256=ImFy3V4fw_tvh5h8FZFpCrgnzlmoHCo-Y_9EjMwOSWc,14393
19
+ pydantic_ai/usage.py,sha256=9sqoIv_RVVUhKXQScTDqUJc074gifsuSzc9_NOt7C3g,5394
20
+ pydantic_ai/common_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ pydantic_ai/common_tools/duckduckgo.py,sha256=Iw8Dl2YQ28S483mzfa8CXs-dc-ujS8un085R2O6oOEw,2241
22
+ pydantic_ai/common_tools/tavily.py,sha256=h8deBDrpG-8BGzydM_zXs7z1ASrhdVvUxL4-CAbncBo,2589
23
+ pydantic_ai/models/__init__.py,sha256=HSjLtlnhtYGXHql46WK_V01lmc9PO7X66Dkf-plCafo,16671
24
+ pydantic_ai/models/anthropic.py,sha256=O7gHDqM6V4zQLwjg5xyB6zDCAUXGaXcWtyXjV6vHLu0,21640
25
+ pydantic_ai/models/bedrock.py,sha256=Fh5kNnH0H_OyKOUzSlVJYhc0K_wz3mrOH5Y4-oS6lmU,20532
26
+ pydantic_ai/models/cohere.py,sha256=ogu97Strxsp6wzlNqT22SfgvltTeP1lLC5S86d037Ic,11249
27
+ pydantic_ai/models/fallback.py,sha256=y0bYXM3DfzJNAsyyMzclt33lzZazL-5_hwdgc33gfuM,4876
28
+ pydantic_ai/models/function.py,sha256=HUSgPB3mKVfYI0OSJJJJRiQN-yeewjYIbrtrPfsvlgI,11365
29
+ pydantic_ai/models/gemini.py,sha256=ox9WoqWgZ7Q-xTRJBv9loTy9P49uAwvJwCPoiqljhPM,33215
30
+ pydantic_ai/models/groq.py,sha256=H-7Eu61EOxKoIPjI6wvofrA2PxSQxhd-BJOHT1p7KiA,15862
31
+ pydantic_ai/models/instrumented.py,sha256=FMEcQ8RnASD7bR8Ol5a16W6yTq1P1F8FojgBIUXdt3w,10962
32
+ pydantic_ai/models/mistral.py,sha256=0bA5vRGXOJulY7r-8jUT8gicAuTSg_sSN0riDP_j9oY,27243
33
+ pydantic_ai/models/openai.py,sha256=W_J_pnoJLSYKTfs_4G9SpFzLPdD2trIDFqtNVSmJ1D4,20060
34
+ pydantic_ai/models/test.py,sha256=qQ8ZIaVRdbJv-tKGu6lrdakVAhOsTlyf68TFWyGwOWE,16861
35
+ pydantic_ai/models/wrapper.py,sha256=ff6JPTuIv9C_6Zo4kyYIO7Cn0VI1uSICz1v1aKUyeOc,1506
36
+ pydantic_ai/providers/__init__.py,sha256=lsJn3BStrPMMAFWEkCYPyfMj3fEVfaeS2xllnvE6Gdk,2489
37
+ pydantic_ai/providers/anthropic.py,sha256=RfYpsKMZxUqE1_PbfJi3JCVmYelN-cwtC5vmw-PmIIA,2750
38
+ pydantic_ai/providers/azure.py,sha256=M1QYzoLGBg23V8eXo4e7xSNJgoDGvLeMFgiYD-3iTNc,4197
39
+ pydantic_ai/providers/bedrock.py,sha256=lSfK0mDqrmWLxzDKvtiY_nN2J3S_GGSPMRLJYeyvLrQ,2544
40
+ pydantic_ai/providers/cohere.py,sha256=qk5fu0ao1EjLVgvDOGWjpGw-rE5iEggeyEgm6tou9O4,2646
41
+ pydantic_ai/providers/deepseek.py,sha256=q_-ybngI2IvRl_sKlkARKuKwYCNC2lElIT4eo43UCf8,2136
42
+ pydantic_ai/providers/google_gla.py,sha256=vrzzf8BkGLfJKy4GO3Ywswhf-b9OPr54O6LvBKVSwko,1579
43
+ pydantic_ai/providers/google_vertex.py,sha256=TrMUzNwEJpFdE9lMuTwiQzGA-Qvf8TCOBrztgwk8u0c,9111
44
+ pydantic_ai/providers/groq.py,sha256=BM0hNwSmf4gTuQoH2iiT96HApdqskEafeK_Yq9V_L6c,2769
45
+ pydantic_ai/providers/mistral.py,sha256=oap2wjzc20byC8W0O6qaRUkMo7fFpibwhRvqEXAFI3g,2601
46
+ pydantic_ai/providers/openai.py,sha256=gKbq12z4rzwDmTvWnpMULcyP5bCSagTQKApK9ukROOA,2811
47
+ pydantic_ai_slim-0.0.45.dist-info/METADATA,sha256=TuKSPs8TaUefJ9drnwrQqZ2CmDqXkK1SXbYJ_iwYX38,3436
48
+ pydantic_ai_slim-0.0.45.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
49
+ pydantic_ai_slim-0.0.45.dist-info/entry_points.txt,sha256=KxQSmlMS8GMTkwTsl4_q9a5nJvBjj3HWeXx688wLrKg,45
50
+ pydantic_ai_slim-0.0.45.dist-info/RECORD,,
@@ -1,260 +0,0 @@
1
- from __future__ import annotations as _annotations
2
-
3
- import warnings
4
- from collections.abc import AsyncIterator
5
- from contextlib import asynccontextmanager
6
- from dataclasses import dataclass, field
7
- from datetime import datetime, timedelta
8
- from pathlib import Path
9
- from typing import Literal
10
-
11
- from httpx import AsyncClient as AsyncHTTPClient
12
- from typing_extensions import deprecated
13
-
14
- from .. import usage
15
- from .._utils import run_in_executor
16
- from ..exceptions import UserError
17
- from ..messages import ModelMessage, ModelResponse
18
- from ..settings import ModelSettings
19
- from . import ModelRequestParameters, StreamedResponse, cached_async_http_client
20
- from .gemini import GeminiModel, GeminiModelName
21
-
22
- try:
23
- import google.auth
24
- from google.auth.credentials import Credentials as BaseCredentials
25
- from google.auth.transport.requests import Request
26
- from google.oauth2.service_account import Credentials as ServiceAccountCredentials
27
- except ImportError as _import_error:
28
- raise ImportError(
29
- 'Please install `google-auth` to use the VertexAI model, '
30
- 'you can use the `vertexai` optional group — `pip install "pydantic-ai-slim[vertexai]"`'
31
- ) from _import_error
32
-
33
- VERTEX_AI_URL_TEMPLATE = (
34
- 'https://{region}-aiplatform.googleapis.com/v1'
35
- '/projects/{project_id}'
36
- '/locations/{region}'
37
- '/publishers/{model_publisher}'
38
- '/models/{model}'
39
- ':'
40
- )
41
- """URL template for Vertex AI.
42
-
43
- See
44
- [`generateContent` docs](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent)
45
- and
46
- [`streamGenerateContent` docs](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent)
47
- for more information.
48
-
49
- The template is used thus:
50
-
51
- * `region` is substituted with the `region` argument,
52
- see [available regions][pydantic_ai.models.vertexai.VertexAiRegion]
53
- * `model_publisher` is substituted with the `model_publisher` argument
54
- * `model` is substituted with the `model_name` argument
55
- * `project_id` is substituted with the `project_id` from auth/credentials
56
- * `function` (`generateContent` or `streamGenerateContent`) is added to the end of the URL
57
- """
58
-
59
-
60
- @deprecated('Please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.')
61
- @dataclass(init=False)
62
- class VertexAIModel(GeminiModel):
63
- """A model that uses Gemini via the `*-aiplatform.googleapis.com` VertexAI API."""
64
-
65
- service_account_file: Path | str | None
66
- project_id: str | None
67
- region: VertexAiRegion
68
- model_publisher: Literal['google']
69
- url_template: str
70
-
71
- _model_name: GeminiModelName = field(repr=False)
72
- _system: str = field(default='vertex_ai', repr=False)
73
-
74
- # TODO __init__ can be removed once we drop 3.9 and we can set kw_only correctly on the dataclass
75
- def __init__(
76
- self,
77
- model_name: GeminiModelName,
78
- *,
79
- service_account_file: Path | str | None = None,
80
- project_id: str | None = None,
81
- region: VertexAiRegion = 'us-central1',
82
- model_publisher: Literal['google'] = 'google',
83
- http_client: AsyncHTTPClient | None = None,
84
- url_template: str = VERTEX_AI_URL_TEMPLATE,
85
- ):
86
- """Initialize a Vertex AI Gemini model.
87
-
88
- Args:
89
- model_name: The name of the model to use. I couldn't find a list of supported Google models, in VertexAI
90
- so for now this uses the same models as the [Gemini model][pydantic_ai.models.gemini.GeminiModel].
91
- service_account_file: Path to a service account file.
92
- If not provided, the default environment credentials will be used.
93
- project_id: The project ID to use, if not provided it will be taken from the credentials.
94
- region: The region to make requests to.
95
- model_publisher: The model publisher to use, I couldn't find a good list of available publishers,
96
- and from trial and error it seems non-google models don't work with the `generateContent` and
97
- `streamGenerateContent` functions, hence only `google` is currently supported.
98
- Please create an issue or PR if you know how to use other publishers.
99
- http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
100
- url_template: URL template for Vertex AI, see
101
- [`VERTEX_AI_URL_TEMPLATE` docs][pydantic_ai.models.vertexai.VERTEX_AI_URL_TEMPLATE]
102
- for more information.
103
- """
104
- self._model_name = model_name
105
- self.service_account_file = service_account_file
106
- self.project_id = project_id
107
- self.region = region
108
- self.model_publisher = model_publisher
109
- self.client = http_client or cached_async_http_client()
110
- self.url_template = url_template
111
-
112
- self._auth = None
113
- self._url = None
114
- warnings.warn(
115
- 'VertexAIModel is deprecated, please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.',
116
- DeprecationWarning,
117
- )
118
- self._provider = None
119
-
120
- async def ainit(self) -> None:
121
- """Initialize the model, setting the URL and auth.
122
-
123
- This will raise an error if authentication fails.
124
- """
125
- if self._url is not None and self._auth is not None:
126
- return
127
-
128
- if self.service_account_file is not None:
129
- creds: BaseCredentials | ServiceAccountCredentials = _creds_from_file(self.service_account_file)
130
- assert creds.project_id is None or isinstance(creds.project_id, str)
131
- creds_project_id: str | None = creds.project_id
132
- creds_source = 'service account file'
133
- else:
134
- creds, creds_project_id = await _async_google_auth()
135
- creds_source = '`google.auth.default()`'
136
-
137
- if self.project_id is None:
138
- if creds_project_id is None:
139
- raise UserError(f'No project_id provided and none found in {creds_source}')
140
- project_id = creds_project_id
141
- else:
142
- project_id = self.project_id
143
-
144
- self._url = self.url_template.format(
145
- region=self.region,
146
- project_id=project_id,
147
- model_publisher=self.model_publisher,
148
- model=self._model_name,
149
- )
150
- self._auth = BearerTokenAuth(creds)
151
-
152
- async def request(
153
- self,
154
- messages: list[ModelMessage],
155
- model_settings: ModelSettings | None,
156
- model_request_parameters: ModelRequestParameters,
157
- ) -> tuple[ModelResponse, usage.Usage]:
158
- await self.ainit()
159
- return await super().request(messages, model_settings, model_request_parameters)
160
-
161
- @asynccontextmanager
162
- async def request_stream(
163
- self,
164
- messages: list[ModelMessage],
165
- model_settings: ModelSettings | None,
166
- model_request_parameters: ModelRequestParameters,
167
- ) -> AsyncIterator[StreamedResponse]:
168
- await self.ainit()
169
- async with super().request_stream(messages, model_settings, model_request_parameters) as value:
170
- yield value
171
-
172
- @property
173
- def model_name(self) -> GeminiModelName:
174
- """The model name."""
175
- return self._model_name
176
-
177
- @property
178
- def system(self) -> str:
179
- """The system / model provider."""
180
- return self._system
181
-
182
-
183
- # pyright: reportUnknownMemberType=false
184
- def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
185
- return ServiceAccountCredentials.from_service_account_file(
186
- str(service_account_file), scopes=['https://www.googleapis.com/auth/cloud-platform']
187
- )
188
-
189
-
190
- # pyright: reportReturnType=false
191
- # pyright: reportUnknownVariableType=false
192
- # pyright: reportUnknownArgumentType=false
193
- async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
194
- return await run_in_executor(google.auth.default, scopes=['https://www.googleapis.com/auth/cloud-platform'])
195
-
196
-
197
- # default expiry is 3600 seconds
198
- MAX_TOKEN_AGE = timedelta(seconds=3000)
199
-
200
-
201
- @dataclass
202
- class BearerTokenAuth:
203
- """Authentication using a bearer token generated by google-auth."""
204
-
205
- credentials: BaseCredentials | ServiceAccountCredentials
206
- token_created: datetime | None = field(default=None, init=False)
207
-
208
- async def headers(self) -> dict[str, str]:
209
- if self.credentials.token is None or self._token_expired():
210
- await run_in_executor(self._refresh_token)
211
- self.token_created = datetime.now()
212
- return {'Authorization': f'Bearer {self.credentials.token}'}
213
-
214
- def _token_expired(self) -> bool:
215
- if self.token_created is None:
216
- return True
217
- else:
218
- return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
219
-
220
- def _refresh_token(self) -> str:
221
- self.credentials.refresh(Request())
222
- assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}'
223
- return self.credentials.token
224
-
225
-
226
- VertexAiRegion = Literal[
227
- 'asia-east1',
228
- 'asia-east2',
229
- 'asia-northeast1',
230
- 'asia-northeast3',
231
- 'asia-south1',
232
- 'asia-southeast1',
233
- 'australia-southeast1',
234
- 'europe-central2',
235
- 'europe-north1',
236
- 'europe-southwest1',
237
- 'europe-west1',
238
- 'europe-west2',
239
- 'europe-west3',
240
- 'europe-west4',
241
- 'europe-west6',
242
- 'europe-west8',
243
- 'europe-west9',
244
- 'me-central1',
245
- 'me-central2',
246
- 'me-west1',
247
- 'northamerica-northeast1',
248
- 'southamerica-east1',
249
- 'us-central1',
250
- 'us-east1',
251
- 'us-east4',
252
- 'us-east5',
253
- 'us-south1',
254
- 'us-west1',
255
- 'us-west4',
256
- ]
257
- """Regions available for Vertex AI.
258
-
259
- More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations).
260
- """