pydantic-ai-slim 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. pydantic_ai/_agent_graph.py +29 -35
  2. pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
  3. pydantic_ai/_output.py +265 -118
  4. pydantic_ai/agent.py +15 -15
  5. pydantic_ai/mcp.py +1 -1
  6. pydantic_ai/messages.py +2 -2
  7. pydantic_ai/models/__init__.py +39 -3
  8. pydantic_ai/models/anthropic.py +6 -1
  9. pydantic_ai/models/bedrock.py +43 -16
  10. pydantic_ai/models/cohere.py +4 -0
  11. pydantic_ai/models/gemini.py +68 -108
  12. pydantic_ai/models/google.py +45 -110
  13. pydantic_ai/models/groq.py +17 -2
  14. pydantic_ai/models/mistral.py +4 -0
  15. pydantic_ai/models/openai.py +22 -157
  16. pydantic_ai/profiles/__init__.py +39 -0
  17. pydantic_ai/{models → profiles}/_json_schema.py +23 -2
  18. pydantic_ai/profiles/amazon.py +9 -0
  19. pydantic_ai/profiles/anthropic.py +8 -0
  20. pydantic_ai/profiles/cohere.py +8 -0
  21. pydantic_ai/profiles/deepseek.py +8 -0
  22. pydantic_ai/profiles/google.py +100 -0
  23. pydantic_ai/profiles/grok.py +8 -0
  24. pydantic_ai/profiles/meta.py +9 -0
  25. pydantic_ai/profiles/mistral.py +8 -0
  26. pydantic_ai/profiles/openai.py +144 -0
  27. pydantic_ai/profiles/qwen.py +9 -0
  28. pydantic_ai/providers/__init__.py +18 -0
  29. pydantic_ai/providers/anthropic.py +5 -0
  30. pydantic_ai/providers/azure.py +34 -0
  31. pydantic_ai/providers/bedrock.py +60 -1
  32. pydantic_ai/providers/cohere.py +5 -0
  33. pydantic_ai/providers/deepseek.py +12 -0
  34. pydantic_ai/providers/fireworks.py +99 -0
  35. pydantic_ai/providers/google.py +5 -0
  36. pydantic_ai/providers/google_gla.py +5 -0
  37. pydantic_ai/providers/google_vertex.py +5 -0
  38. pydantic_ai/providers/grok.py +82 -0
  39. pydantic_ai/providers/groq.py +25 -0
  40. pydantic_ai/providers/mistral.py +5 -0
  41. pydantic_ai/providers/openai.py +5 -0
  42. pydantic_ai/providers/openrouter.py +36 -0
  43. pydantic_ai/providers/together.py +96 -0
  44. pydantic_ai/result.py +34 -103
  45. pydantic_ai/tools.py +28 -58
  46. {pydantic_ai_slim-0.2.10.dist-info → pydantic_ai_slim-0.2.12.dist-info}/METADATA +5 -5
  47. pydantic_ai_slim-0.2.12.dist-info/RECORD +73 -0
  48. pydantic_ai_slim-0.2.10.dist-info/RECORD +0 -59
  49. {pydantic_ai_slim-0.2.10.dist-info → pydantic_ai_slim-0.2.12.dist-info}/WHEEL +0 -0
  50. {pydantic_ai_slim-0.2.10.dist-info → pydantic_ai_slim-0.2.12.dist-info}/entry_points.txt +0 -0
  51. {pydantic_ai_slim-0.2.10.dist-info → pydantic_ai_slim-0.2.12.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+ from openai import AsyncOpenAI
8
+
9
+ from pydantic_ai.exceptions import UserError
10
+ from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles import ModelProfile
12
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
13
+ from pydantic_ai.profiles.google import google_model_profile
14
+ from pydantic_ai.profiles.meta import meta_model_profile
15
+ from pydantic_ai.profiles.mistral import mistral_model_profile
16
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
17
+ from pydantic_ai.profiles.qwen import qwen_model_profile
18
+ from pydantic_ai.providers import Provider
19
+
20
+ try:
21
+ from openai import AsyncOpenAI
22
+ except ImportError as _import_error: # pragma: no cover
23
+ raise ImportError(
24
+ 'Please install the `openai` package to use the Fireworks AI provider, '
25
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
26
+ ) from _import_error
27
+
28
+
29
+ class FireworksProvider(Provider[AsyncOpenAI]):
30
+ """Provider for Fireworks AI API."""
31
+
32
+ @property
33
+ def name(self) -> str:
34
+ return 'fireworks'
35
+
36
+ @property
37
+ def base_url(self) -> str:
38
+ return 'https://api.fireworks.ai/inference/v1'
39
+
40
+ @property
41
+ def client(self) -> AsyncOpenAI:
42
+ return self._client
43
+
44
+ def model_profile(self, model_name: str) -> ModelProfile | None:
45
+ prefix_to_profile = {
46
+ 'llama': meta_model_profile,
47
+ 'qwen': qwen_model_profile,
48
+ 'deepseek': deepseek_model_profile,
49
+ 'mistral': mistral_model_profile,
50
+ 'gemma': google_model_profile,
51
+ }
52
+
53
+ prefix = 'accounts/fireworks/models/'
54
+
55
+ profile = None
56
+ if model_name.startswith(prefix):
57
+ model_name = model_name[len(prefix) :]
58
+ for provider, profile_func in prefix_to_profile.items():
59
+ if model_name.startswith(provider):
60
+ profile = profile_func(model_name)
61
+ break
62
+
63
+ # As the Fireworks API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
64
+ # unless json_schema_transformer is set explicitly
65
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
66
+
67
+ @overload
68
+ def __init__(self) -> None: ...
69
+
70
+ @overload
71
+ def __init__(self, *, api_key: str) -> None: ...
72
+
73
+ @overload
74
+ def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
75
+
76
+ @overload
77
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
78
+
79
+ def __init__(
80
+ self,
81
+ *,
82
+ api_key: str | None = None,
83
+ openai_client: AsyncOpenAI | None = None,
84
+ http_client: AsyncHTTPClient | None = None,
85
+ ) -> None:
86
+ api_key = api_key or os.getenv('FIREWORKS_API_KEY')
87
+ if not api_key and openai_client is None:
88
+ raise UserError(
89
+ 'Set the `FIREWORKS_API_KEY` environment variable or pass it via `FireworksProvider(api_key=...)`'
90
+ 'to use the Fireworks AI provider.'
91
+ )
92
+
93
+ if openai_client is not None:
94
+ self._client = openai_client
95
+ elif http_client is not None:
96
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
97
+ else:
98
+ http_client = cached_async_http_client(provider='fireworks')
99
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
@@ -5,6 +5,8 @@ from typing import Literal, overload
5
5
 
6
6
  from pydantic_ai.exceptions import UserError
7
7
  from pydantic_ai.models import get_user_agent
8
+ from pydantic_ai.profiles import ModelProfile
9
+ from pydantic_ai.profiles.google import google_model_profile
8
10
  from pydantic_ai.providers import Provider
9
11
 
10
12
  try:
@@ -32,6 +34,9 @@ class GoogleProvider(Provider[genai.Client]):
32
34
  def client(self) -> genai.Client:
33
35
  return self._client
34
36
 
37
+ def model_profile(self, model_name: str) -> ModelProfile | None:
38
+ return google_model_profile(model_name)
39
+
35
40
  @overload
36
41
  def __init__(self, *, api_key: str) -> None: ...
37
42
 
@@ -6,6 +6,8 @@ import httpx
6
6
 
7
7
  from pydantic_ai.exceptions import UserError
8
8
  from pydantic_ai.models import cached_async_http_client
9
+ from pydantic_ai.profiles import ModelProfile
10
+ from pydantic_ai.profiles.google import google_model_profile
9
11
  from pydantic_ai.providers import Provider
10
12
 
11
13
 
@@ -24,6 +26,9 @@ class GoogleGLAProvider(Provider[httpx.AsyncClient]):
24
26
  def client(self) -> httpx.AsyncClient:
25
27
  return self._client
26
28
 
29
+ def model_profile(self, model_name: str) -> ModelProfile | None:
30
+ return google_model_profile(model_name)
31
+
27
32
  def __init__(self, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None:
28
33
  """Create a new Google GLA provider.
29
34
 
@@ -10,6 +10,8 @@ import httpx
10
10
 
11
11
  from pydantic_ai.exceptions import UserError
12
12
  from pydantic_ai.models import cached_async_http_client
13
+ from pydantic_ai.profiles import ModelProfile
14
+ from pydantic_ai.profiles.google import google_model_profile
13
15
  from pydantic_ai.providers import Provider
14
16
 
15
17
  try:
@@ -47,6 +49,9 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
47
49
  def client(self) -> httpx.AsyncClient:
48
50
  return self._client
49
51
 
52
+ def model_profile(self, model_name: str) -> ModelProfile | None:
53
+ return google_model_profile(model_name)
54
+
50
55
  @overload
51
56
  def __init__(
52
57
  self,
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+ from openai import AsyncOpenAI
8
+
9
+ from pydantic_ai.exceptions import UserError
10
+ from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles import ModelProfile
12
+ from pydantic_ai.profiles.grok import grok_model_profile
13
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
14
+ from pydantic_ai.providers import Provider
15
+
16
+ try:
17
+ from openai import AsyncOpenAI
18
+ except ImportError as _import_error: # pragma: no cover
19
+ raise ImportError(
20
+ 'Please install the `openai` package to use the Grok provider, '
21
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
22
+ ) from _import_error
23
+
24
+
25
+ class GrokProvider(Provider[AsyncOpenAI]):
26
+ """Provider for Grok API."""
27
+
28
+ @property
29
+ def name(self) -> str:
30
+ return 'grok'
31
+
32
+ @property
33
+ def base_url(self) -> str:
34
+ return 'https://api.x.ai/v1'
35
+
36
+ @property
37
+ def client(self) -> AsyncOpenAI:
38
+ return self._client
39
+
40
+ def model_profile(self, model_name: str) -> ModelProfile | None:
41
+ profile = grok_model_profile(model_name)
42
+
43
+ # As the Grok API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
44
+ # unless json_schema_transformer is set explicitly.
45
+ # Also, Grok does not support strict tool definitions: https://github.com/pydantic/pydantic-ai/issues/1846
46
+ return OpenAIModelProfile(
47
+ json_schema_transformer=OpenAIJsonSchemaTransformer, openai_supports_strict_tool_definition=False
48
+ ).update(profile)
49
+
50
+ @overload
51
+ def __init__(self) -> None: ...
52
+
53
+ @overload
54
+ def __init__(self, *, api_key: str) -> None: ...
55
+
56
+ @overload
57
+ def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
58
+
59
+ @overload
60
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
61
+
62
+ def __init__(
63
+ self,
64
+ *,
65
+ api_key: str | None = None,
66
+ openai_client: AsyncOpenAI | None = None,
67
+ http_client: AsyncHTTPClient | None = None,
68
+ ) -> None:
69
+ api_key = api_key or os.getenv('GROK_API_KEY')
70
+ if not api_key and openai_client is None:
71
+ raise UserError(
72
+ 'Set the `GROK_API_KEY` environment variable or pass it via `GrokProvider(api_key=...)`'
73
+ 'to use the Grok provider.'
74
+ )
75
+
76
+ if openai_client is not None:
77
+ self._client = openai_client
78
+ elif http_client is not None:
79
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
80
+ else:
81
+ http_client = cached_async_http_client(provider='grok')
82
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
@@ -7,6 +7,12 @@ from httpx import AsyncClient as AsyncHTTPClient
7
7
 
8
8
  from pydantic_ai.exceptions import UserError
9
9
  from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.profiles import ModelProfile
11
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
12
+ from pydantic_ai.profiles.google import google_model_profile
13
+ from pydantic_ai.profiles.meta import meta_model_profile
14
+ from pydantic_ai.profiles.mistral import mistral_model_profile
15
+ from pydantic_ai.profiles.qwen import qwen_model_profile
10
16
  from pydantic_ai.providers import Provider
11
17
 
12
18
  try:
@@ -33,6 +39,25 @@ class GroqProvider(Provider[AsyncGroq]):
33
39
  def client(self) -> AsyncGroq:
34
40
  return self._client
35
41
 
42
+ def model_profile(self, model_name: str) -> ModelProfile | None:
43
+ prefix_to_profile = {
44
+ 'llama': meta_model_profile,
45
+ 'meta-llama/': meta_model_profile,
46
+ 'gemma': google_model_profile,
47
+ 'qwen': qwen_model_profile,
48
+ 'deepseek': deepseek_model_profile,
49
+ 'mistral': mistral_model_profile,
50
+ }
51
+
52
+ for prefix, profile_func in prefix_to_profile.items():
53
+ model_name = model_name.lower()
54
+ if model_name.startswith(prefix):
55
+ if prefix.endswith('/'):
56
+ model_name = model_name[len(prefix) :]
57
+ return profile_func(model_name)
58
+
59
+ return None
60
+
36
61
  @overload
37
62
  def __init__(self, *, groq_client: AsyncGroq | None = None) -> None: ...
38
63
 
@@ -7,6 +7,8 @@ from httpx import AsyncClient as AsyncHTTPClient
7
7
 
8
8
  from pydantic_ai.exceptions import UserError
9
9
  from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.profiles import ModelProfile
11
+ from pydantic_ai.profiles.mistral import mistral_model_profile
10
12
  from pydantic_ai.providers import Provider
11
13
 
12
14
  try:
@@ -33,6 +35,9 @@ class MistralProvider(Provider[Mistral]):
33
35
  def client(self) -> Mistral:
34
36
  return self._client
35
37
 
38
+ def model_profile(self, model_name: str) -> ModelProfile | None:
39
+ return mistral_model_profile(model_name)
40
+
36
41
  @overload
37
42
  def __init__(self, *, mistral_client: Mistral | None = None) -> None: ...
38
43
 
@@ -5,6 +5,8 @@ import os
5
5
  import httpx
6
6
 
7
7
  from pydantic_ai.models import cached_async_http_client
8
+ from pydantic_ai.profiles import ModelProfile
9
+ from pydantic_ai.profiles.openai import openai_model_profile
8
10
  from pydantic_ai.providers import Provider
9
11
 
10
12
  try:
@@ -31,6 +33,9 @@ class OpenAIProvider(Provider[AsyncOpenAI]):
31
33
  def client(self) -> AsyncOpenAI:
32
34
  return self._client
33
35
 
36
+ def model_profile(self, model_name: str) -> ModelProfile | None:
37
+ return openai_model_profile(model_name)
38
+
34
39
  def __init__(
35
40
  self,
36
41
  base_url: str | None = None,
@@ -8,6 +8,17 @@ from openai import AsyncOpenAI
8
8
 
9
9
  from pydantic_ai.exceptions import UserError
10
10
  from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles import ModelProfile
12
+ from pydantic_ai.profiles.amazon import amazon_model_profile
13
+ from pydantic_ai.profiles.anthropic import anthropic_model_profile
14
+ from pydantic_ai.profiles.cohere import cohere_model_profile
15
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
16
+ from pydantic_ai.profiles.google import google_model_profile
17
+ from pydantic_ai.profiles.grok import grok_model_profile
18
+ from pydantic_ai.profiles.meta import meta_model_profile
19
+ from pydantic_ai.profiles.mistral import mistral_model_profile
20
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
21
+ from pydantic_ai.profiles.qwen import qwen_model_profile
11
22
  from pydantic_ai.providers import Provider
12
23
 
13
24
  try:
@@ -34,6 +45,31 @@ class OpenRouterProvider(Provider[AsyncOpenAI]):
34
45
  def client(self) -> AsyncOpenAI:
35
46
  return self._client
36
47
 
48
+ def model_profile(self, model_name: str) -> ModelProfile | None:
49
+ provider_to_profile = {
50
+ 'google': google_model_profile,
51
+ 'openai': openai_model_profile,
52
+ 'anthropic': anthropic_model_profile,
53
+ 'mistralai': mistral_model_profile,
54
+ 'qwen': qwen_model_profile,
55
+ 'x-ai': grok_model_profile,
56
+ 'cohere': cohere_model_profile,
57
+ 'amazon': amazon_model_profile,
58
+ 'deepseek': deepseek_model_profile,
59
+ 'meta-llama': meta_model_profile,
60
+ }
61
+
62
+ profile = None
63
+
64
+ provider, model_name = model_name.split('/', 1)
65
+ if provider in provider_to_profile:
66
+ model_name, *_ = model_name.split(':', 1) # drop tags
67
+ profile = provider_to_profile[provider](model_name)
68
+
69
+ # As OpenRouterProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
70
+ # we need to maintain that behavior unless json_schema_transformer is set explicitly
71
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
72
+
37
73
  @overload
38
74
  def __init__(self) -> None: ...
39
75
 
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+ from openai import AsyncOpenAI
8
+
9
+ from pydantic_ai.exceptions import UserError
10
+ from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles import ModelProfile
12
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
13
+ from pydantic_ai.profiles.google import google_model_profile
14
+ from pydantic_ai.profiles.meta import meta_model_profile
15
+ from pydantic_ai.profiles.mistral import mistral_model_profile
16
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
17
+ from pydantic_ai.profiles.qwen import qwen_model_profile
18
+ from pydantic_ai.providers import Provider
19
+
20
+ try:
21
+ from openai import AsyncOpenAI
22
+ except ImportError as _import_error: # pragma: no cover
23
+ raise ImportError(
24
+ 'Please install the `openai` package to use the Together AI provider, '
25
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
26
+ ) from _import_error
27
+
28
+
29
+ class TogetherProvider(Provider[AsyncOpenAI]):
30
+ """Provider for Together AI API."""
31
+
32
+ @property
33
+ def name(self) -> str:
34
+ return 'together'
35
+
36
+ @property
37
+ def base_url(self) -> str:
38
+ return 'https://api.together.xyz/v1'
39
+
40
+ @property
41
+ def client(self) -> AsyncOpenAI:
42
+ return self._client
43
+
44
+ def model_profile(self, model_name: str) -> ModelProfile | None:
45
+ provider_to_profile = {
46
+ 'deepseek-ai': deepseek_model_profile,
47
+ 'google': google_model_profile,
48
+ 'qwen': qwen_model_profile,
49
+ 'meta-llama': meta_model_profile,
50
+ 'mistralai': mistral_model_profile,
51
+ }
52
+
53
+ profile = None
54
+
55
+ model_name = model_name.lower()
56
+ provider, model_name = model_name.split('/', 1)
57
+ if provider in provider_to_profile:
58
+ profile = provider_to_profile[provider](model_name)
59
+
60
+ # As the Together API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
61
+ # unless json_schema_transformer is set explicitly
62
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
63
+
64
+ @overload
65
+ def __init__(self) -> None: ...
66
+
67
+ @overload
68
+ def __init__(self, *, api_key: str) -> None: ...
69
+
70
+ @overload
71
+ def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
72
+
73
+ @overload
74
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
75
+
76
+ def __init__(
77
+ self,
78
+ *,
79
+ api_key: str | None = None,
80
+ openai_client: AsyncOpenAI | None = None,
81
+ http_client: AsyncHTTPClient | None = None,
82
+ ) -> None:
83
+ api_key = api_key or os.getenv('TOGETHER_API_KEY')
84
+ if not api_key and openai_client is None:
85
+ raise UserError(
86
+ 'Set the `TOGETHER_API_KEY` environment variable or pass it via `TogetherProvider(api_key=...)`'
87
+ 'to use the Together AI provider.'
88
+ )
89
+
90
+ if openai_client is not None:
91
+ self._client = openai_client
92
+ elif http_client is not None:
93
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
94
+ else:
95
+ http_client = cached_async_http_client(provider='together')
96
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
pydantic_ai/result.py CHANGED
@@ -5,100 +5,35 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
5
5
  from copy import copy
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import TYPE_CHECKING, Generic, Union, cast
8
+ from typing import Generic, cast
9
9
 
10
10
  from typing_extensions import TypeVar, assert_type, deprecated, overload
11
11
 
12
- from . import _utils, exceptions, messages as _messages, models
12
+ from . import _output, _utils, exceptions, messages as _messages, models
13
+ from ._output import (
14
+ OutputDataT,
15
+ OutputDataT_inv,
16
+ OutputSchema,
17
+ OutputValidator,
18
+ OutputValidatorFunc,
19
+ ToolOutput,
20
+ )
13
21
  from .messages import AgentStreamEvent, FinalResultEvent
14
22
  from .tools import AgentDepsT, RunContext
15
23
  from .usage import Usage, UsageLimits
16
24
 
17
- if TYPE_CHECKING:
18
- from . import _output
19
-
20
25
  __all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'OutputValidatorFunc'
21
26
 
22
27
 
23
28
  T = TypeVar('T')
24
29
  """An invariant TypeVar."""
25
- OutputDataT_inv = TypeVar('OutputDataT_inv', default=str)
26
- """
27
- An invariant type variable for the result data of a model.
28
-
29
- We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used
30
- in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types
31
- possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and
32
- changing it would have negative consequences for the ergonomics of the library.
33
-
34
- At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would
35
- resolve these potential variance issues.
36
- """
37
- OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
38
- """Covariant type variable for the result data type of a run."""
39
-
40
- OutputValidatorFunc = Union[
41
- Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv],
42
- Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]],
43
- Callable[[OutputDataT_inv], OutputDataT_inv],
44
- Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]],
45
- ]
46
- """
47
- A function that always takes and returns the same type of data (which is the result type of an agent run), and:
48
-
49
- * may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
50
- * may or may not be async
51
-
52
- Usage `OutputValidatorFunc[AgentDepsT, T]`.
53
- """
54
-
55
- DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
56
-
57
-
58
- @dataclass(init=False)
59
- class ToolOutput(Generic[OutputDataT]):
60
- """Marker class to use tools for structured outputs, and customize the tool."""
61
-
62
- output_type: type[OutputDataT]
63
- # TODO: Add `output_call` support, for calling a function to get the output
64
- # output_call: Callable[..., OutputDataT] | None
65
- name: str
66
- description: str | None
67
- max_retries: int | None
68
- strict: bool | None
69
-
70
- def __init__(
71
- self,
72
- *,
73
- type_: type[OutputDataT],
74
- # call: Callable[..., OutputDataT] | None = None,
75
- name: str = 'final_result',
76
- description: str | None = None,
77
- max_retries: int | None = None,
78
- strict: bool | None = None,
79
- ):
80
- self.output_type = type_
81
- self.name = name
82
- self.description = description
83
- self.max_retries = max_retries
84
- self.strict = strict
85
-
86
- # TODO: add support for call and make type_ optional, with the following logic:
87
- # if type_ is None and call is None:
88
- # raise ValueError('Either type_ or call must be provided')
89
- # if call is not None:
90
- # if type_ is None:
91
- # type_ = get_type_hints(call).get('return')
92
- # if type_ is None:
93
- # raise ValueError('Unable to determine type_ from call signature; please provide it explicitly')
94
- # self.output_call = call
95
30
 
96
31
 
97
32
  @dataclass
98
33
  class AgentStream(Generic[AgentDepsT, OutputDataT]):
99
34
  _raw_stream_response: models.StreamedResponse
100
- _output_schema: _output.OutputSchema[OutputDataT] | None
101
- _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]]
35
+ _output_schema: OutputSchema[OutputDataT] | None
36
+ _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
102
37
  _run_ctx: RunContext[AgentDepsT]
103
38
  _usage_limits: UsageLimits | None
104
39
 
@@ -144,6 +79,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
144
79
  self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False
145
80
  ) -> OutputDataT:
146
81
  """Validate a structured result message."""
82
+ call = None
147
83
  if self._output_schema is not None and output_tool_name is not None:
148
84
  match = self._output_schema.find_named_tool(message.parts, output_tool_name)
149
85
  if match is None:
@@ -152,21 +88,17 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
152
88
  )
153
89
 
154
90
  call, output_tool = match
155
- result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
156
-
157
- for validator in self._output_validators:
158
- result_data = await validator.validate(result_data, call, self._run_ctx)
159
- return result_data
91
+ result_data = await output_tool.process(
92
+ call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
93
+ )
160
94
  else:
161
95
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
162
- for validator in self._output_validators:
163
- text = await validator.validate(
164
- text,
165
- None,
166
- self._run_ctx,
167
- )
168
- # Since there is no output tool, we can assume that str is compatible with OutputDataT
169
- return cast(OutputDataT, text)
96
+ # The following cast is safe because we know `str` is an allowed output type
97
+ result_data = cast(OutputDataT, text)
98
+
99
+ for validator in self._output_validators:
100
+ result_data = await validator.validate(result_data, call, self._run_ctx)
101
+ return result_data
170
102
 
171
103
  def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
172
104
  """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
@@ -180,7 +112,6 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
180
112
 
181
113
  async def aiter():
182
114
  output_schema = self._output_schema
183
- allow_text_output = output_schema is None or output_schema.allow_text_output
184
115
 
185
116
  def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
186
117
  """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
@@ -192,7 +123,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
192
123
  return _messages.FinalResultEvent(
193
124
  tool_name=call.tool_name, tool_call_id=call.tool_call_id
194
125
  )
195
- elif allow_text_output: # pragma: no branch
126
+ elif _output.allow_text_output(output_schema): # pragma: no branch
196
127
  assert_type(e, _messages.PartStartEvent)
197
128
  return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
198
129
 
@@ -224,9 +155,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
224
155
 
225
156
  _usage_limits: UsageLimits | None
226
157
  _stream_response: models.StreamedResponse
227
- _output_schema: _output.OutputSchema[OutputDataT] | None
158
+ _output_schema: OutputSchema[OutputDataT] | None
228
159
  _run_ctx: RunContext[AgentDepsT]
229
- _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]]
160
+ _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
230
161
  _output_tool_name: str | None
231
162
  _on_complete: Callable[[], Awaitable[None]]
232
163
 
@@ -458,6 +389,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
458
389
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
459
390
  ) -> OutputDataT:
460
391
  """Validate a structured result message."""
392
+ call = None
461
393
  if self._output_schema is not None and self._output_tool_name is not None:
462
394
  match = self._output_schema.find_named_tool(message.parts, self._output_tool_name)
463
395
  if match is None:
@@ -466,17 +398,16 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
466
398
  )
467
399
 
468
400
  call, output_tool = match
469
- result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
470
-
471
- for validator in self._output_validators:
472
- result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover
473
- return result_data
401
+ result_data = await output_tool.process(
402
+ call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
403
+ )
474
404
  else:
475
405
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
476
- for validator in self._output_validators:
477
- text = await validator.validate(text, None, self._run_ctx) # pragma: no cover
478
- # Since there is no output tool, we can assume that str is compatible with OutputDataT
479
- return cast(OutputDataT, text)
406
+ result_data = cast(OutputDataT, text)
407
+
408
+ for validator in self._output_validators:
409
+ result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover
410
+ return result_data
480
411
 
481
412
  async def _validate_text_output(self, text: str) -> str:
482
413
  for validator in self._output_validators: