pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +29 -35
- pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
- pydantic_ai/_output.py +265 -118
- pydantic_ai/agent.py +15 -15
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +39 -3
- pydantic_ai/models/anthropic.py +4 -0
- pydantic_ai/models/bedrock.py +43 -16
- pydantic_ai/models/cohere.py +4 -0
- pydantic_ai/models/gemini.py +68 -108
- pydantic_ai/models/google.py +45 -110
- pydantic_ai/models/groq.py +17 -2
- pydantic_ai/models/mistral.py +4 -0
- pydantic_ai/models/openai.py +22 -157
- pydantic_ai/profiles/__init__.py +39 -0
- pydantic_ai/{models → profiles}/_json_schema.py +23 -2
- pydantic_ai/profiles/amazon.py +9 -0
- pydantic_ai/profiles/anthropic.py +8 -0
- pydantic_ai/profiles/cohere.py +8 -0
- pydantic_ai/profiles/deepseek.py +8 -0
- pydantic_ai/profiles/google.py +100 -0
- pydantic_ai/profiles/grok.py +8 -0
- pydantic_ai/profiles/meta.py +9 -0
- pydantic_ai/profiles/mistral.py +8 -0
- pydantic_ai/profiles/openai.py +144 -0
- pydantic_ai/profiles/qwen.py +9 -0
- pydantic_ai/providers/__init__.py +18 -0
- pydantic_ai/providers/anthropic.py +5 -0
- pydantic_ai/providers/azure.py +34 -0
- pydantic_ai/providers/bedrock.py +60 -1
- pydantic_ai/providers/cohere.py +5 -0
- pydantic_ai/providers/deepseek.py +12 -0
- pydantic_ai/providers/fireworks.py +99 -0
- pydantic_ai/providers/google.py +5 -0
- pydantic_ai/providers/google_gla.py +5 -0
- pydantic_ai/providers/google_vertex.py +5 -0
- pydantic_ai/providers/grok.py +82 -0
- pydantic_ai/providers/groq.py +25 -0
- pydantic_ai/providers/mistral.py +5 -0
- pydantic_ai/providers/openai.py +5 -0
- pydantic_ai/providers/openrouter.py +36 -0
- pydantic_ai/providers/together.py +96 -0
- pydantic_ai/result.py +34 -103
- pydantic_ai/tools.py +28 -58
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/METADATA +4 -4
- pydantic_ai_slim-0.2.12.dist-info/RECORD +73 -0
- pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from pydantic_ai.exceptions import UserError
|
|
10
|
+
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.profiles import ModelProfile
|
|
12
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
13
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
14
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
15
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
16
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
|
|
17
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
18
|
+
from pydantic_ai.providers import Provider
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from openai import AsyncOpenAI
|
|
22
|
+
except ImportError as _import_error: # pragma: no cover
|
|
23
|
+
raise ImportError(
|
|
24
|
+
'Please install the `openai` package to use the Fireworks AI provider, '
|
|
25
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
26
|
+
) from _import_error
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FireworksProvider(Provider[AsyncOpenAI]):
|
|
30
|
+
"""Provider for Fireworks AI API."""
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def name(self) -> str:
|
|
34
|
+
return 'fireworks'
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def base_url(self) -> str:
|
|
38
|
+
return 'https://api.fireworks.ai/inference/v1'
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def client(self) -> AsyncOpenAI:
|
|
42
|
+
return self._client
|
|
43
|
+
|
|
44
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
45
|
+
prefix_to_profile = {
|
|
46
|
+
'llama': meta_model_profile,
|
|
47
|
+
'qwen': qwen_model_profile,
|
|
48
|
+
'deepseek': deepseek_model_profile,
|
|
49
|
+
'mistral': mistral_model_profile,
|
|
50
|
+
'gemma': google_model_profile,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
prefix = 'accounts/fireworks/models/'
|
|
54
|
+
|
|
55
|
+
profile = None
|
|
56
|
+
if model_name.startswith(prefix):
|
|
57
|
+
model_name = model_name[len(prefix) :]
|
|
58
|
+
for provider, profile_func in prefix_to_profile.items():
|
|
59
|
+
if model_name.startswith(provider):
|
|
60
|
+
profile = profile_func(model_name)
|
|
61
|
+
break
|
|
62
|
+
|
|
63
|
+
# As the Fireworks API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
|
|
64
|
+
# unless json_schema_transformer is set explicitly
|
|
65
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
66
|
+
|
|
67
|
+
@overload
|
|
68
|
+
def __init__(self) -> None: ...
|
|
69
|
+
|
|
70
|
+
@overload
|
|
71
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
72
|
+
|
|
73
|
+
@overload
|
|
74
|
+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
|
|
75
|
+
|
|
76
|
+
@overload
|
|
77
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
*,
|
|
82
|
+
api_key: str | None = None,
|
|
83
|
+
openai_client: AsyncOpenAI | None = None,
|
|
84
|
+
http_client: AsyncHTTPClient | None = None,
|
|
85
|
+
) -> None:
|
|
86
|
+
api_key = api_key or os.getenv('FIREWORKS_API_KEY')
|
|
87
|
+
if not api_key and openai_client is None:
|
|
88
|
+
raise UserError(
|
|
89
|
+
'Set the `FIREWORKS_API_KEY` environment variable or pass it via `FireworksProvider(api_key=...)`'
|
|
90
|
+
'to use the Fireworks AI provider.'
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if openai_client is not None:
|
|
94
|
+
self._client = openai_client
|
|
95
|
+
elif http_client is not None:
|
|
96
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
97
|
+
else:
|
|
98
|
+
http_client = cached_async_http_client(provider='fireworks')
|
|
99
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
pydantic_ai/providers/google.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Literal, overload
|
|
|
5
5
|
|
|
6
6
|
from pydantic_ai.exceptions import UserError
|
|
7
7
|
from pydantic_ai.models import get_user_agent
|
|
8
|
+
from pydantic_ai.profiles import ModelProfile
|
|
9
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
8
10
|
from pydantic_ai.providers import Provider
|
|
9
11
|
|
|
10
12
|
try:
|
|
@@ -32,6 +34,9 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
32
34
|
def client(self) -> genai.Client:
|
|
33
35
|
return self._client
|
|
34
36
|
|
|
37
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
38
|
+
return google_model_profile(model_name)
|
|
39
|
+
|
|
35
40
|
@overload
|
|
36
41
|
def __init__(self, *, api_key: str) -> None: ...
|
|
37
42
|
|
|
@@ -6,6 +6,8 @@ import httpx
|
|
|
6
6
|
|
|
7
7
|
from pydantic_ai.exceptions import UserError
|
|
8
8
|
from pydantic_ai.models import cached_async_http_client
|
|
9
|
+
from pydantic_ai.profiles import ModelProfile
|
|
10
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
9
11
|
from pydantic_ai.providers import Provider
|
|
10
12
|
|
|
11
13
|
|
|
@@ -24,6 +26,9 @@ class GoogleGLAProvider(Provider[httpx.AsyncClient]):
|
|
|
24
26
|
def client(self) -> httpx.AsyncClient:
|
|
25
27
|
return self._client
|
|
26
28
|
|
|
29
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
30
|
+
return google_model_profile(model_name)
|
|
31
|
+
|
|
27
32
|
def __init__(self, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None:
|
|
28
33
|
"""Create a new Google GLA provider.
|
|
29
34
|
|
|
@@ -10,6 +10,8 @@ import httpx
|
|
|
10
10
|
|
|
11
11
|
from pydantic_ai.exceptions import UserError
|
|
12
12
|
from pydantic_ai.models import cached_async_http_client
|
|
13
|
+
from pydantic_ai.profiles import ModelProfile
|
|
14
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
13
15
|
from pydantic_ai.providers import Provider
|
|
14
16
|
|
|
15
17
|
try:
|
|
@@ -47,6 +49,9 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
|
47
49
|
def client(self) -> httpx.AsyncClient:
|
|
48
50
|
return self._client
|
|
49
51
|
|
|
52
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
53
|
+
return google_model_profile(model_name)
|
|
54
|
+
|
|
50
55
|
@overload
|
|
51
56
|
def __init__(
|
|
52
57
|
self,
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from pydantic_ai.exceptions import UserError
|
|
10
|
+
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.profiles import ModelProfile
|
|
12
|
+
from pydantic_ai.profiles.grok import grok_model_profile
|
|
13
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
|
|
14
|
+
from pydantic_ai.providers import Provider
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from openai import AsyncOpenAI
|
|
18
|
+
except ImportError as _import_error: # pragma: no cover
|
|
19
|
+
raise ImportError(
|
|
20
|
+
'Please install the `openai` package to use the Grok provider, '
|
|
21
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
22
|
+
) from _import_error
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GrokProvider(Provider[AsyncOpenAI]):
|
|
26
|
+
"""Provider for Grok API."""
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def name(self) -> str:
|
|
30
|
+
return 'grok'
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def base_url(self) -> str:
|
|
34
|
+
return 'https://api.x.ai/v1'
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def client(self) -> AsyncOpenAI:
|
|
38
|
+
return self._client
|
|
39
|
+
|
|
40
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
41
|
+
profile = grok_model_profile(model_name)
|
|
42
|
+
|
|
43
|
+
# As the Grok API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
|
|
44
|
+
# unless json_schema_transformer is set explicitly.
|
|
45
|
+
# Also, Grok does not support strict tool definitions: https://github.com/pydantic/pydantic-ai/issues/1846
|
|
46
|
+
return OpenAIModelProfile(
|
|
47
|
+
json_schema_transformer=OpenAIJsonSchemaTransformer, openai_supports_strict_tool_definition=False
|
|
48
|
+
).update(profile)
|
|
49
|
+
|
|
50
|
+
@overload
|
|
51
|
+
def __init__(self) -> None: ...
|
|
52
|
+
|
|
53
|
+
@overload
|
|
54
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
55
|
+
|
|
56
|
+
@overload
|
|
57
|
+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
|
|
58
|
+
|
|
59
|
+
@overload
|
|
60
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
*,
|
|
65
|
+
api_key: str | None = None,
|
|
66
|
+
openai_client: AsyncOpenAI | None = None,
|
|
67
|
+
http_client: AsyncHTTPClient | None = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
api_key = api_key or os.getenv('GROK_API_KEY')
|
|
70
|
+
if not api_key and openai_client is None:
|
|
71
|
+
raise UserError(
|
|
72
|
+
'Set the `GROK_API_KEY` environment variable or pass it via `GrokProvider(api_key=...)`'
|
|
73
|
+
'to use the Grok provider.'
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if openai_client is not None:
|
|
77
|
+
self._client = openai_client
|
|
78
|
+
elif http_client is not None:
|
|
79
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
80
|
+
else:
|
|
81
|
+
http_client = cached_async_http_client(provider='grok')
|
|
82
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
pydantic_ai/providers/groq.py
CHANGED
|
@@ -7,6 +7,12 @@ from httpx import AsyncClient as AsyncHTTPClient
|
|
|
7
7
|
|
|
8
8
|
from pydantic_ai.exceptions import UserError
|
|
9
9
|
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
from pydantic_ai.profiles import ModelProfile
|
|
11
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
12
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
13
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
14
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
15
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
10
16
|
from pydantic_ai.providers import Provider
|
|
11
17
|
|
|
12
18
|
try:
|
|
@@ -33,6 +39,25 @@ class GroqProvider(Provider[AsyncGroq]):
|
|
|
33
39
|
def client(self) -> AsyncGroq:
|
|
34
40
|
return self._client
|
|
35
41
|
|
|
42
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
43
|
+
prefix_to_profile = {
|
|
44
|
+
'llama': meta_model_profile,
|
|
45
|
+
'meta-llama/': meta_model_profile,
|
|
46
|
+
'gemma': google_model_profile,
|
|
47
|
+
'qwen': qwen_model_profile,
|
|
48
|
+
'deepseek': deepseek_model_profile,
|
|
49
|
+
'mistral': mistral_model_profile,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
for prefix, profile_func in prefix_to_profile.items():
|
|
53
|
+
model_name = model_name.lower()
|
|
54
|
+
if model_name.startswith(prefix):
|
|
55
|
+
if prefix.endswith('/'):
|
|
56
|
+
model_name = model_name[len(prefix) :]
|
|
57
|
+
return profile_func(model_name)
|
|
58
|
+
|
|
59
|
+
return None
|
|
60
|
+
|
|
36
61
|
@overload
|
|
37
62
|
def __init__(self, *, groq_client: AsyncGroq | None = None) -> None: ...
|
|
38
63
|
|
pydantic_ai/providers/mistral.py
CHANGED
|
@@ -7,6 +7,8 @@ from httpx import AsyncClient as AsyncHTTPClient
|
|
|
7
7
|
|
|
8
8
|
from pydantic_ai.exceptions import UserError
|
|
9
9
|
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
from pydantic_ai.profiles import ModelProfile
|
|
11
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
10
12
|
from pydantic_ai.providers import Provider
|
|
11
13
|
|
|
12
14
|
try:
|
|
@@ -33,6 +35,9 @@ class MistralProvider(Provider[Mistral]):
|
|
|
33
35
|
def client(self) -> Mistral:
|
|
34
36
|
return self._client
|
|
35
37
|
|
|
38
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
39
|
+
return mistral_model_profile(model_name)
|
|
40
|
+
|
|
36
41
|
@overload
|
|
37
42
|
def __init__(self, *, mistral_client: Mistral | None = None) -> None: ...
|
|
38
43
|
|
pydantic_ai/providers/openai.py
CHANGED
|
@@ -5,6 +5,8 @@ import os
|
|
|
5
5
|
import httpx
|
|
6
6
|
|
|
7
7
|
from pydantic_ai.models import cached_async_http_client
|
|
8
|
+
from pydantic_ai.profiles import ModelProfile
|
|
9
|
+
from pydantic_ai.profiles.openai import openai_model_profile
|
|
8
10
|
from pydantic_ai.providers import Provider
|
|
9
11
|
|
|
10
12
|
try:
|
|
@@ -31,6 +33,9 @@ class OpenAIProvider(Provider[AsyncOpenAI]):
|
|
|
31
33
|
def client(self) -> AsyncOpenAI:
|
|
32
34
|
return self._client
|
|
33
35
|
|
|
36
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
37
|
+
return openai_model_profile(model_name)
|
|
38
|
+
|
|
34
39
|
def __init__(
|
|
35
40
|
self,
|
|
36
41
|
base_url: str | None = None,
|
|
@@ -8,6 +8,17 @@ from openai import AsyncOpenAI
|
|
|
8
8
|
|
|
9
9
|
from pydantic_ai.exceptions import UserError
|
|
10
10
|
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.profiles import ModelProfile
|
|
12
|
+
from pydantic_ai.profiles.amazon import amazon_model_profile
|
|
13
|
+
from pydantic_ai.profiles.anthropic import anthropic_model_profile
|
|
14
|
+
from pydantic_ai.profiles.cohere import cohere_model_profile
|
|
15
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
16
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
17
|
+
from pydantic_ai.profiles.grok import grok_model_profile
|
|
18
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
19
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
20
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
|
|
21
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
11
22
|
from pydantic_ai.providers import Provider
|
|
12
23
|
|
|
13
24
|
try:
|
|
@@ -34,6 +45,31 @@ class OpenRouterProvider(Provider[AsyncOpenAI]):
|
|
|
34
45
|
def client(self) -> AsyncOpenAI:
|
|
35
46
|
return self._client
|
|
36
47
|
|
|
48
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
49
|
+
provider_to_profile = {
|
|
50
|
+
'google': google_model_profile,
|
|
51
|
+
'openai': openai_model_profile,
|
|
52
|
+
'anthropic': anthropic_model_profile,
|
|
53
|
+
'mistralai': mistral_model_profile,
|
|
54
|
+
'qwen': qwen_model_profile,
|
|
55
|
+
'x-ai': grok_model_profile,
|
|
56
|
+
'cohere': cohere_model_profile,
|
|
57
|
+
'amazon': amazon_model_profile,
|
|
58
|
+
'deepseek': deepseek_model_profile,
|
|
59
|
+
'meta-llama': meta_model_profile,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
profile = None
|
|
63
|
+
|
|
64
|
+
provider, model_name = model_name.split('/', 1)
|
|
65
|
+
if provider in provider_to_profile:
|
|
66
|
+
model_name, *_ = model_name.split(':', 1) # drop tags
|
|
67
|
+
profile = provider_to_profile[provider](model_name)
|
|
68
|
+
|
|
69
|
+
# As OpenRouterProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
|
|
70
|
+
# we need to maintain that behavior unless json_schema_transformer is set explicitly
|
|
71
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
72
|
+
|
|
37
73
|
@overload
|
|
38
74
|
def __init__(self) -> None: ...
|
|
39
75
|
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from pydantic_ai.exceptions import UserError
|
|
10
|
+
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.profiles import ModelProfile
|
|
12
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
13
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
14
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
15
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
16
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
|
|
17
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
18
|
+
from pydantic_ai.providers import Provider
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from openai import AsyncOpenAI
|
|
22
|
+
except ImportError as _import_error: # pragma: no cover
|
|
23
|
+
raise ImportError(
|
|
24
|
+
'Please install the `openai` package to use the Together AI provider, '
|
|
25
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
26
|
+
) from _import_error
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TogetherProvider(Provider[AsyncOpenAI]):
|
|
30
|
+
"""Provider for Together AI API."""
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def name(self) -> str:
|
|
34
|
+
return 'together'
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def base_url(self) -> str:
|
|
38
|
+
return 'https://api.together.xyz/v1'
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def client(self) -> AsyncOpenAI:
|
|
42
|
+
return self._client
|
|
43
|
+
|
|
44
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
45
|
+
provider_to_profile = {
|
|
46
|
+
'deepseek-ai': deepseek_model_profile,
|
|
47
|
+
'google': google_model_profile,
|
|
48
|
+
'qwen': qwen_model_profile,
|
|
49
|
+
'meta-llama': meta_model_profile,
|
|
50
|
+
'mistralai': mistral_model_profile,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
profile = None
|
|
54
|
+
|
|
55
|
+
model_name = model_name.lower()
|
|
56
|
+
provider, model_name = model_name.split('/', 1)
|
|
57
|
+
if provider in provider_to_profile:
|
|
58
|
+
profile = provider_to_profile[provider](model_name)
|
|
59
|
+
|
|
60
|
+
# As the Together API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
|
|
61
|
+
# unless json_schema_transformer is set explicitly
|
|
62
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
63
|
+
|
|
64
|
+
@overload
|
|
65
|
+
def __init__(self) -> None: ...
|
|
66
|
+
|
|
67
|
+
@overload
|
|
68
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
69
|
+
|
|
70
|
+
@overload
|
|
71
|
+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
|
|
72
|
+
|
|
73
|
+
@overload
|
|
74
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
*,
|
|
79
|
+
api_key: str | None = None,
|
|
80
|
+
openai_client: AsyncOpenAI | None = None,
|
|
81
|
+
http_client: AsyncHTTPClient | None = None,
|
|
82
|
+
) -> None:
|
|
83
|
+
api_key = api_key or os.getenv('TOGETHER_API_KEY')
|
|
84
|
+
if not api_key and openai_client is None:
|
|
85
|
+
raise UserError(
|
|
86
|
+
'Set the `TOGETHER_API_KEY` environment variable or pass it via `TogetherProvider(api_key=...)`'
|
|
87
|
+
'to use the Together AI provider.'
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if openai_client is not None:
|
|
91
|
+
self._client = openai_client
|
|
92
|
+
elif http_client is not None:
|
|
93
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
94
|
+
else:
|
|
95
|
+
http_client = cached_async_http_client(provider='together')
|
|
96
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
pydantic_ai/result.py
CHANGED
|
@@ -5,100 +5,35 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
|
|
5
5
|
from copy import copy
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Generic, cast
|
|
9
9
|
|
|
10
10
|
from typing_extensions import TypeVar, assert_type, deprecated, overload
|
|
11
11
|
|
|
12
|
-
from . import _utils, exceptions, messages as _messages, models
|
|
12
|
+
from . import _output, _utils, exceptions, messages as _messages, models
|
|
13
|
+
from ._output import (
|
|
14
|
+
OutputDataT,
|
|
15
|
+
OutputDataT_inv,
|
|
16
|
+
OutputSchema,
|
|
17
|
+
OutputValidator,
|
|
18
|
+
OutputValidatorFunc,
|
|
19
|
+
ToolOutput,
|
|
20
|
+
)
|
|
13
21
|
from .messages import AgentStreamEvent, FinalResultEvent
|
|
14
22
|
from .tools import AgentDepsT, RunContext
|
|
15
23
|
from .usage import Usage, UsageLimits
|
|
16
24
|
|
|
17
|
-
if TYPE_CHECKING:
|
|
18
|
-
from . import _output
|
|
19
|
-
|
|
20
25
|
__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'OutputValidatorFunc'
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
T = TypeVar('T')
|
|
24
29
|
"""An invariant TypeVar."""
|
|
25
|
-
OutputDataT_inv = TypeVar('OutputDataT_inv', default=str)
|
|
26
|
-
"""
|
|
27
|
-
An invariant type variable for the result data of a model.
|
|
28
|
-
|
|
29
|
-
We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used
|
|
30
|
-
in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types
|
|
31
|
-
possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and
|
|
32
|
-
changing it would have negative consequences for the ergonomics of the library.
|
|
33
|
-
|
|
34
|
-
At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would
|
|
35
|
-
resolve these potential variance issues.
|
|
36
|
-
"""
|
|
37
|
-
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
|
|
38
|
-
"""Covariant type variable for the result data type of a run."""
|
|
39
|
-
|
|
40
|
-
OutputValidatorFunc = Union[
|
|
41
|
-
Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv],
|
|
42
|
-
Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]],
|
|
43
|
-
Callable[[OutputDataT_inv], OutputDataT_inv],
|
|
44
|
-
Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]],
|
|
45
|
-
]
|
|
46
|
-
"""
|
|
47
|
-
A function that always takes and returns the same type of data (which is the result type of an agent run), and:
|
|
48
|
-
|
|
49
|
-
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
|
|
50
|
-
* may or may not be async
|
|
51
|
-
|
|
52
|
-
Usage `OutputValidatorFunc[AgentDepsT, T]`.
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
@dataclass(init=False)
|
|
59
|
-
class ToolOutput(Generic[OutputDataT]):
|
|
60
|
-
"""Marker class to use tools for structured outputs, and customize the tool."""
|
|
61
|
-
|
|
62
|
-
output_type: type[OutputDataT]
|
|
63
|
-
# TODO: Add `output_call` support, for calling a function to get the output
|
|
64
|
-
# output_call: Callable[..., OutputDataT] | None
|
|
65
|
-
name: str
|
|
66
|
-
description: str | None
|
|
67
|
-
max_retries: int | None
|
|
68
|
-
strict: bool | None
|
|
69
|
-
|
|
70
|
-
def __init__(
|
|
71
|
-
self,
|
|
72
|
-
*,
|
|
73
|
-
type_: type[OutputDataT],
|
|
74
|
-
# call: Callable[..., OutputDataT] | None = None,
|
|
75
|
-
name: str = 'final_result',
|
|
76
|
-
description: str | None = None,
|
|
77
|
-
max_retries: int | None = None,
|
|
78
|
-
strict: bool | None = None,
|
|
79
|
-
):
|
|
80
|
-
self.output_type = type_
|
|
81
|
-
self.name = name
|
|
82
|
-
self.description = description
|
|
83
|
-
self.max_retries = max_retries
|
|
84
|
-
self.strict = strict
|
|
85
|
-
|
|
86
|
-
# TODO: add support for call and make type_ optional, with the following logic:
|
|
87
|
-
# if type_ is None and call is None:
|
|
88
|
-
# raise ValueError('Either type_ or call must be provided')
|
|
89
|
-
# if call is not None:
|
|
90
|
-
# if type_ is None:
|
|
91
|
-
# type_ = get_type_hints(call).get('return')
|
|
92
|
-
# if type_ is None:
|
|
93
|
-
# raise ValueError('Unable to determine type_ from call signature; please provide it explicitly')
|
|
94
|
-
# self.output_call = call
|
|
95
30
|
|
|
96
31
|
|
|
97
32
|
@dataclass
|
|
98
33
|
class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
99
34
|
_raw_stream_response: models.StreamedResponse
|
|
100
|
-
_output_schema:
|
|
101
|
-
_output_validators: list[
|
|
35
|
+
_output_schema: OutputSchema[OutputDataT] | None
|
|
36
|
+
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
|
|
102
37
|
_run_ctx: RunContext[AgentDepsT]
|
|
103
38
|
_usage_limits: UsageLimits | None
|
|
104
39
|
|
|
@@ -144,6 +79,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
144
79
|
self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False
|
|
145
80
|
) -> OutputDataT:
|
|
146
81
|
"""Validate a structured result message."""
|
|
82
|
+
call = None
|
|
147
83
|
if self._output_schema is not None and output_tool_name is not None:
|
|
148
84
|
match = self._output_schema.find_named_tool(message.parts, output_tool_name)
|
|
149
85
|
if match is None:
|
|
@@ -152,21 +88,17 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
152
88
|
)
|
|
153
89
|
|
|
154
90
|
call, output_tool = match
|
|
155
|
-
result_data = output_tool.
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
159
|
-
return result_data
|
|
91
|
+
result_data = await output_tool.process(
|
|
92
|
+
call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
93
|
+
)
|
|
160
94
|
else:
|
|
161
95
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
# Since there is no output tool, we can assume that str is compatible with OutputDataT
|
|
169
|
-
return cast(OutputDataT, text)
|
|
96
|
+
# The following cast is safe because we know `str` is an allowed output type
|
|
97
|
+
result_data = cast(OutputDataT, text)
|
|
98
|
+
|
|
99
|
+
for validator in self._output_validators:
|
|
100
|
+
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
101
|
+
return result_data
|
|
170
102
|
|
|
171
103
|
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
172
104
|
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
@@ -180,7 +112,6 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
180
112
|
|
|
181
113
|
async def aiter():
|
|
182
114
|
output_schema = self._output_schema
|
|
183
|
-
allow_text_output = output_schema is None or output_schema.allow_text_output
|
|
184
115
|
|
|
185
116
|
def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
|
|
186
117
|
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
|
|
@@ -192,7 +123,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
192
123
|
return _messages.FinalResultEvent(
|
|
193
124
|
tool_name=call.tool_name, tool_call_id=call.tool_call_id
|
|
194
125
|
)
|
|
195
|
-
elif allow_text_output: # pragma: no branch
|
|
126
|
+
elif _output.allow_text_output(output_schema): # pragma: no branch
|
|
196
127
|
assert_type(e, _messages.PartStartEvent)
|
|
197
128
|
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
198
129
|
|
|
@@ -224,9 +155,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
224
155
|
|
|
225
156
|
_usage_limits: UsageLimits | None
|
|
226
157
|
_stream_response: models.StreamedResponse
|
|
227
|
-
_output_schema:
|
|
158
|
+
_output_schema: OutputSchema[OutputDataT] | None
|
|
228
159
|
_run_ctx: RunContext[AgentDepsT]
|
|
229
|
-
_output_validators: list[
|
|
160
|
+
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
|
|
230
161
|
_output_tool_name: str | None
|
|
231
162
|
_on_complete: Callable[[], Awaitable[None]]
|
|
232
163
|
|
|
@@ -458,6 +389,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
458
389
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
459
390
|
) -> OutputDataT:
|
|
460
391
|
"""Validate a structured result message."""
|
|
392
|
+
call = None
|
|
461
393
|
if self._output_schema is not None and self._output_tool_name is not None:
|
|
462
394
|
match = self._output_schema.find_named_tool(message.parts, self._output_tool_name)
|
|
463
395
|
if match is None:
|
|
@@ -466,17 +398,16 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
466
398
|
)
|
|
467
399
|
|
|
468
400
|
call, output_tool = match
|
|
469
|
-
result_data = output_tool.
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover
|
|
473
|
-
return result_data
|
|
401
|
+
result_data = await output_tool.process(
|
|
402
|
+
call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
403
|
+
)
|
|
474
404
|
else:
|
|
475
405
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
406
|
+
result_data = cast(OutputDataT, text)
|
|
407
|
+
|
|
408
|
+
for validator in self._output_validators:
|
|
409
|
+
result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover
|
|
410
|
+
return result_data
|
|
480
411
|
|
|
481
412
|
async def _validate_text_output(self, text: str) -> str:
|
|
482
413
|
for validator in self._output_validators:
|