pydantic-ai-slim 0.0.47__py3-none-any.whl → 0.0.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_cli.py +82 -59
- pydantic_ai/_result.py +7 -3
- pydantic_ai/models/__init__.py +2 -0
- pydantic_ai/models/anthropic.py +5 -41
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/gemini.py +1 -0
- pydantic_ai/models/openai.py +461 -13
- pydantic_ai/tools.py +2 -2
- {pydantic_ai_slim-0.0.47.dist-info → pydantic_ai_slim-0.0.49.dist-info}/METADATA +3 -3
- {pydantic_ai_slim-0.0.47.dist-info → pydantic_ai_slim-0.0.49.dist-info}/RECORD +12 -12
- {pydantic_ai_slim-0.0.47.dist-info → pydantic_ai_slim-0.0.49.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.47.dist-info → pydantic_ai_slim-0.0.49.dist-info}/entry_points.txt +0 -0
pydantic_ai/_cli.py
CHANGED
|
@@ -3,19 +3,20 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import argparse
|
|
4
4
|
import asyncio
|
|
5
5
|
import sys
|
|
6
|
+
from asyncio import CancelledError
|
|
6
7
|
from collections.abc import Sequence
|
|
7
8
|
from contextlib import ExitStack
|
|
8
9
|
from datetime import datetime, timezone
|
|
9
10
|
from importlib.metadata import version
|
|
10
11
|
from pathlib import Path
|
|
11
|
-
from typing import cast
|
|
12
|
+
from typing import Any, cast
|
|
12
13
|
|
|
13
14
|
from typing_inspection.introspection import get_literal_values
|
|
14
15
|
|
|
15
16
|
from pydantic_ai.agent import Agent
|
|
16
17
|
from pydantic_ai.exceptions import UserError
|
|
17
18
|
from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta
|
|
18
|
-
from pydantic_ai.models import KnownModelName
|
|
19
|
+
from pydantic_ai.models import KnownModelName, infer_model
|
|
19
20
|
|
|
20
21
|
try:
|
|
21
22
|
import argcomplete
|
|
@@ -47,7 +48,7 @@ class SimpleCodeBlock(CodeBlock):
|
|
|
47
48
|
This avoids a background color which messes up copy-pasting and sets the language name as dim prefix and suffix.
|
|
48
49
|
"""
|
|
49
50
|
|
|
50
|
-
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
|
|
51
|
+
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
|
|
51
52
|
code = str(self.text).rstrip()
|
|
52
53
|
yield Text(self.lexer_name, style='dim')
|
|
53
54
|
yield Syntax(code, self.lexer_name, theme=self.theme, background_color='default', word_wrap=True)
|
|
@@ -57,7 +58,7 @@ class SimpleCodeBlock(CodeBlock):
|
|
|
57
58
|
class LeftHeading(Heading):
|
|
58
59
|
"""Customised headings in markdown to stop centering and prepend markdown style hashes."""
|
|
59
60
|
|
|
60
|
-
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
|
|
61
|
+
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
|
|
61
62
|
# note we use `Style(bold=True)` not `self.style_name` here to disable underlining which is ugly IMHO
|
|
62
63
|
yield Text(f'{"#" * int(self.tag[1:])} {self.text.plain}', style=Style(bold=True))
|
|
63
64
|
|
|
@@ -68,7 +69,21 @@ Markdown.elements.update(
|
|
|
68
69
|
)
|
|
69
70
|
|
|
70
71
|
|
|
71
|
-
|
|
72
|
+
cli_agent = Agent()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@cli_agent.system_prompt
|
|
76
|
+
def cli_system_prompt() -> str:
|
|
77
|
+
now_utc = datetime.now(timezone.utc)
|
|
78
|
+
tzinfo = now_utc.astimezone().tzinfo
|
|
79
|
+
tzname = tzinfo.tzname(now_utc) if tzinfo else ''
|
|
80
|
+
return f"""\
|
|
81
|
+
Help the user by responding to their request, the output should be concise and always written in markdown.
|
|
82
|
+
The current date and time is {datetime.now()} {tzname}.
|
|
83
|
+
The user is running {sys.platform}."""
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def cli(args_list: Sequence[str] | None = None) -> int:
|
|
72
87
|
parser = argparse.ArgumentParser(
|
|
73
88
|
prog='pai',
|
|
74
89
|
description=f"""\
|
|
@@ -124,18 +139,10 @@ Special prompt:
|
|
|
124
139
|
console.print(f' {model}', highlight=False)
|
|
125
140
|
return 0
|
|
126
141
|
|
|
127
|
-
now_utc = datetime.now(timezone.utc)
|
|
128
|
-
tzname = now_utc.astimezone().tzinfo.tzname(now_utc) # type: ignore
|
|
129
142
|
try:
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
Help the user by responding to their request, the output should be concise and always written in markdown.
|
|
134
|
-
The current date and time is {datetime.now()} {tzname}.
|
|
135
|
-
The user is running {sys.platform}.""",
|
|
136
|
-
)
|
|
137
|
-
except UserError:
|
|
138
|
-
console.print(f'[red]Invalid model "{args.model}"[/red]')
|
|
143
|
+
cli_agent.model = infer_model(args.model)
|
|
144
|
+
except UserError as e:
|
|
145
|
+
console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]')
|
|
139
146
|
return 1
|
|
140
147
|
|
|
141
148
|
stream = not args.no_stream
|
|
@@ -148,67 +155,44 @@ Special prompt:
|
|
|
148
155
|
|
|
149
156
|
if prompt := cast(str, args.prompt):
|
|
150
157
|
try:
|
|
151
|
-
asyncio.run(ask_agent(
|
|
158
|
+
asyncio.run(ask_agent(cli_agent, prompt, stream, console, code_theme))
|
|
152
159
|
except KeyboardInterrupt:
|
|
153
160
|
pass
|
|
154
161
|
return 0
|
|
155
162
|
|
|
156
163
|
history = Path.home() / '.pai-prompt-history.txt'
|
|
157
|
-
|
|
164
|
+
# doing this instead of `PromptSession[Any](history=` allows mocking of PromptSession in tests
|
|
165
|
+
session: PromptSession[Any] = PromptSession(history=FileHistory(str(history)))
|
|
166
|
+
try:
|
|
167
|
+
return asyncio.run(run_chat(session, stream, cli_agent, console, code_theme))
|
|
168
|
+
except KeyboardInterrupt: # pragma: no cover
|
|
169
|
+
return 0
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
async def run_chat(session: PromptSession[Any], stream: bool, agent: Agent, console: Console, code_theme: str) -> int:
|
|
158
173
|
multiline = False
|
|
159
174
|
messages: list[ModelMessage] = []
|
|
160
175
|
|
|
161
176
|
while True:
|
|
162
177
|
try:
|
|
163
178
|
auto_suggest = CustomAutoSuggest(['/markdown', '/multiline', '/exit'])
|
|
164
|
-
text =
|
|
165
|
-
except (KeyboardInterrupt, EOFError):
|
|
179
|
+
text = await session.prompt_async('pai ➤ ', auto_suggest=auto_suggest, multiline=multiline)
|
|
180
|
+
except (KeyboardInterrupt, EOFError): # pragma: no cover
|
|
166
181
|
return 0
|
|
167
182
|
|
|
168
183
|
if not text.strip():
|
|
169
184
|
continue
|
|
170
185
|
|
|
171
|
-
ident_prompt = text.lower().strip(
|
|
186
|
+
ident_prompt = text.lower().strip().replace(' ', '-')
|
|
172
187
|
if ident_prompt.startswith('/'):
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
except IndexError:
|
|
177
|
-
console.print('[dim]No markdown output available.[/dim]')
|
|
178
|
-
continue
|
|
179
|
-
console.print('[dim]Markdown output of last question:[/dim]\n')
|
|
180
|
-
for part in parts:
|
|
181
|
-
if part.part_kind == 'text':
|
|
182
|
-
console.print(
|
|
183
|
-
Syntax(
|
|
184
|
-
part.content,
|
|
185
|
-
lexer='markdown',
|
|
186
|
-
theme=code_theme,
|
|
187
|
-
word_wrap=True,
|
|
188
|
-
background_color='default',
|
|
189
|
-
)
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
elif ident_prompt == '/multiline':
|
|
193
|
-
multiline = not multiline
|
|
194
|
-
if multiline:
|
|
195
|
-
console.print(
|
|
196
|
-
'Enabling multiline mode. '
|
|
197
|
-
'[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
|
|
198
|
-
)
|
|
199
|
-
else:
|
|
200
|
-
console.print('Disabling multiline mode.')
|
|
201
|
-
elif ident_prompt == '/exit':
|
|
202
|
-
console.print('[dim]Exiting…[/dim]')
|
|
203
|
-
return 0
|
|
204
|
-
else:
|
|
205
|
-
console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]')
|
|
188
|
+
exit_value, multiline = handle_slash_command(ident_prompt, messages, multiline, console, code_theme)
|
|
189
|
+
if exit_value is not None:
|
|
190
|
+
return exit_value
|
|
206
191
|
else:
|
|
207
192
|
try:
|
|
208
|
-
messages =
|
|
209
|
-
except
|
|
193
|
+
messages = await ask_agent(agent, text, stream, console, code_theme, messages)
|
|
194
|
+
except CancelledError: # pragma: no cover
|
|
210
195
|
console.print('[dim]Interrupted[/dim]')
|
|
211
|
-
messages = []
|
|
212
196
|
|
|
213
197
|
|
|
214
198
|
async def ask_agent(
|
|
@@ -218,7 +202,7 @@ async def ask_agent(
|
|
|
218
202
|
console: Console,
|
|
219
203
|
code_theme: str,
|
|
220
204
|
messages: list[ModelMessage] | None = None,
|
|
221
|
-
) -> list[ModelMessage]:
|
|
205
|
+
) -> list[ModelMessage]:
|
|
222
206
|
status = Status('[dim]Working on it…[/dim]', console=console)
|
|
223
207
|
|
|
224
208
|
if not stream:
|
|
@@ -248,7 +232,7 @@ async def ask_agent(
|
|
|
248
232
|
|
|
249
233
|
|
|
250
234
|
class CustomAutoSuggest(AutoSuggestFromHistory):
|
|
251
|
-
def __init__(self, special_suggestions: list[str] | None = None):
|
|
235
|
+
def __init__(self, special_suggestions: list[str] | None = None):
|
|
252
236
|
super().__init__()
|
|
253
237
|
self.special_suggestions = special_suggestions or []
|
|
254
238
|
|
|
@@ -264,5 +248,44 @@ class CustomAutoSuggest(AutoSuggestFromHistory):
|
|
|
264
248
|
return suggestion
|
|
265
249
|
|
|
266
250
|
|
|
251
|
+
def handle_slash_command(
|
|
252
|
+
ident_prompt: str, messages: list[ModelMessage], multiline: bool, console: Console, code_theme: str
|
|
253
|
+
) -> tuple[int | None, bool]:
|
|
254
|
+
if ident_prompt == '/markdown':
|
|
255
|
+
try:
|
|
256
|
+
parts = messages[-1].parts
|
|
257
|
+
except IndexError:
|
|
258
|
+
console.print('[dim]No markdown output available.[/dim]')
|
|
259
|
+
else:
|
|
260
|
+
console.print('[dim]Markdown output of last question:[/dim]\n')
|
|
261
|
+
for part in parts:
|
|
262
|
+
if part.part_kind == 'text':
|
|
263
|
+
console.print(
|
|
264
|
+
Syntax(
|
|
265
|
+
part.content,
|
|
266
|
+
lexer='markdown',
|
|
267
|
+
theme=code_theme,
|
|
268
|
+
word_wrap=True,
|
|
269
|
+
background_color='default',
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
elif ident_prompt == '/multiline':
|
|
274
|
+
multiline = not multiline
|
|
275
|
+
if multiline:
|
|
276
|
+
console.print(
|
|
277
|
+
'Enabling multiline mode. [dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
console.print('Disabling multiline mode.')
|
|
281
|
+
return None, multiline
|
|
282
|
+
elif ident_prompt == '/exit':
|
|
283
|
+
console.print('[dim]Exiting…[/dim]')
|
|
284
|
+
return 0, multiline
|
|
285
|
+
else:
|
|
286
|
+
console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]')
|
|
287
|
+
return None, multiline
|
|
288
|
+
|
|
289
|
+
|
|
267
290
|
def app(): # pragma: no cover
|
|
268
291
|
sys.exit(cli())
|
pydantic_ai/_result.py
CHANGED
|
@@ -13,7 +13,7 @@ from typing_inspection.introspection import is_union_origin
|
|
|
13
13
|
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
15
|
from .result import ResultDataT, ResultDataT_inv, ResultValidatorFunc
|
|
16
|
-
from .tools import AgentDepsT, RunContext, ToolDefinition
|
|
16
|
+
from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition
|
|
17
17
|
|
|
18
18
|
T = TypeVar('T')
|
|
19
19
|
"""An invariant TypeVar."""
|
|
@@ -159,7 +159,9 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
159
159
|
self.type_adapter = TypeAdapter(response_type)
|
|
160
160
|
outer_typed_dict_key: str | None = None
|
|
161
161
|
# noinspection PyArgumentList
|
|
162
|
-
parameters_json_schema = _utils.check_object_json_schema(
|
|
162
|
+
parameters_json_schema = _utils.check_object_json_schema(
|
|
163
|
+
self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
|
|
164
|
+
)
|
|
163
165
|
else:
|
|
164
166
|
response_data_typed_dict = TypedDict( # noqa: UP013
|
|
165
167
|
'response_data_typed_dict',
|
|
@@ -168,7 +170,9 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
168
170
|
self.type_adapter = TypeAdapter(response_data_typed_dict)
|
|
169
171
|
outer_typed_dict_key = 'response'
|
|
170
172
|
# noinspection PyArgumentList
|
|
171
|
-
parameters_json_schema = _utils.check_object_json_schema(
|
|
173
|
+
parameters_json_schema = _utils.check_object_json_schema(
|
|
174
|
+
self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
|
|
175
|
+
)
|
|
172
176
|
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
|
|
173
177
|
parameters_json_schema.pop('title')
|
|
174
178
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -106,6 +106,7 @@ KnownModelName = TypeAliasType(
|
|
|
106
106
|
'google-gla:gemini-2.0-flash',
|
|
107
107
|
'google-gla:gemini-2.0-flash-lite-preview-02-05',
|
|
108
108
|
'google-gla:gemini-2.0-pro-exp-02-05',
|
|
109
|
+
'google-gla:gemini-2.5-pro-exp-03-25',
|
|
109
110
|
'google-vertex:gemini-1.0-pro',
|
|
110
111
|
'google-vertex:gemini-1.5-flash',
|
|
111
112
|
'google-vertex:gemini-1.5-flash-8b',
|
|
@@ -116,6 +117,7 @@ KnownModelName = TypeAliasType(
|
|
|
116
117
|
'google-vertex:gemini-2.0-flash',
|
|
117
118
|
'google-vertex:gemini-2.0-flash-lite-preview-02-05',
|
|
118
119
|
'google-vertex:gemini-2.0-pro-exp-02-05',
|
|
120
|
+
'google-vertex:gemini-2.5-pro-exp-03-25',
|
|
119
121
|
'gpt-3.5-turbo',
|
|
120
122
|
'gpt-3.5-turbo-0125',
|
|
121
123
|
'gpt-3.5-turbo-0301',
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import base64
|
|
4
3
|
import io
|
|
5
4
|
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
@@ -9,7 +8,6 @@ from datetime import datetime, timezone
|
|
|
9
8
|
from json import JSONDecodeError, loads as json_loads
|
|
10
9
|
from typing import Any, Literal, Union, cast, overload
|
|
11
10
|
|
|
12
|
-
from anthropic.types import DocumentBlockParam
|
|
13
11
|
from typing_extensions import assert_never
|
|
14
12
|
|
|
15
13
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
@@ -40,6 +38,7 @@ try:
|
|
|
40
38
|
from anthropic.types import (
|
|
41
39
|
Base64PDFSourceParam,
|
|
42
40
|
ContentBlock,
|
|
41
|
+
DocumentBlockParam,
|
|
43
42
|
ImageBlockParam,
|
|
44
43
|
Message as AnthropicMessage,
|
|
45
44
|
MessageParam,
|
|
@@ -354,48 +353,13 @@ class AnthropicModel(Model):
|
|
|
354
353
|
else:
|
|
355
354
|
raise RuntimeError('Only images and PDFs are supported for binary content')
|
|
356
355
|
elif isinstance(item, ImageUrl):
|
|
357
|
-
|
|
358
|
-
response = await cached_async_http_client().get(item.url)
|
|
359
|
-
response.raise_for_status()
|
|
360
|
-
yield ImageBlockParam(
|
|
361
|
-
source={
|
|
362
|
-
'data': io.BytesIO(response.content),
|
|
363
|
-
'media_type': item.media_type,
|
|
364
|
-
'type': 'base64',
|
|
365
|
-
},
|
|
366
|
-
type='image',
|
|
367
|
-
)
|
|
368
|
-
except ValueError:
|
|
369
|
-
# Download the file if can't find the mime type.
|
|
370
|
-
client = cached_async_http_client()
|
|
371
|
-
response = await client.get(item.url, follow_redirects=True)
|
|
372
|
-
response.raise_for_status()
|
|
373
|
-
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
374
|
-
if (mime_type := response.headers['Content-Type']) in (
|
|
375
|
-
'image/jpeg',
|
|
376
|
-
'image/png',
|
|
377
|
-
'image/gif',
|
|
378
|
-
'image/webp',
|
|
379
|
-
):
|
|
380
|
-
yield ImageBlockParam(
|
|
381
|
-
source={'data': base64_encoded, 'media_type': mime_type, 'type': 'base64'},
|
|
382
|
-
type='image',
|
|
383
|
-
)
|
|
384
|
-
else: # pragma: no cover
|
|
385
|
-
raise RuntimeError(f'Unsupported image type: {mime_type}')
|
|
356
|
+
yield ImageBlockParam(source={'type': 'url', 'url': item.url}, type='image')
|
|
386
357
|
elif isinstance(item, DocumentUrl):
|
|
387
|
-
response = await cached_async_http_client().get(item.url)
|
|
388
|
-
response.raise_for_status()
|
|
389
358
|
if item.media_type == 'application/pdf':
|
|
390
|
-
yield DocumentBlockParam(
|
|
391
|
-
source=Base64PDFSourceParam(
|
|
392
|
-
data=io.BytesIO(response.content),
|
|
393
|
-
media_type=item.media_type,
|
|
394
|
-
type='base64',
|
|
395
|
-
),
|
|
396
|
-
type='document',
|
|
397
|
-
)
|
|
359
|
+
yield DocumentBlockParam(source={'url': item.url, 'type': 'url'}, type='document')
|
|
398
360
|
elif item.media_type == 'text/plain':
|
|
361
|
+
response = await cached_async_http_client().get(item.url)
|
|
362
|
+
response.raise_for_status()
|
|
399
363
|
yield DocumentBlockParam(
|
|
400
364
|
source=PlainTextSourceParam(data=response.text, media_type=item.media_type, type='text'),
|
|
401
365
|
type='document',
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -5,7 +5,6 @@ from dataclasses import dataclass, field
|
|
|
5
5
|
from itertools import chain
|
|
6
6
|
from typing import Literal, Union, cast
|
|
7
7
|
|
|
8
|
-
from cohere import TextAssistantMessageContentItem
|
|
9
8
|
from typing_extensions import assert_never
|
|
10
9
|
|
|
11
10
|
from .. import ModelHTTPError, result
|
|
@@ -38,6 +37,7 @@ try:
|
|
|
38
37
|
ChatMessageV2,
|
|
39
38
|
ChatResponse,
|
|
40
39
|
SystemChatMessageV2,
|
|
40
|
+
TextAssistantMessageContentItem,
|
|
41
41
|
ToolCallV2,
|
|
42
42
|
ToolCallV2Function,
|
|
43
43
|
ToolChatMessageV2,
|
pydantic_ai/models/gemini.py
CHANGED
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
|
|
4
|
+
import warnings
|
|
5
|
+
from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
8
|
from datetime import datetime, timezone
|
|
@@ -41,8 +42,8 @@ from . import (
|
|
|
41
42
|
)
|
|
42
43
|
|
|
43
44
|
try:
|
|
44
|
-
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
|
|
45
|
-
from openai.types import ChatModel, chat
|
|
45
|
+
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
|
|
46
|
+
from openai.types import ChatModel, chat, responses
|
|
46
47
|
from openai.types.chat import (
|
|
47
48
|
ChatCompletionChunk,
|
|
48
49
|
ChatCompletionContentPartImageParam,
|
|
@@ -52,12 +53,24 @@ try:
|
|
|
52
53
|
)
|
|
53
54
|
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
54
55
|
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
|
|
56
|
+
from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam
|
|
57
|
+
from openai.types.responses.response_input_param import FunctionCallOutput, Message
|
|
58
|
+
from openai.types.shared import ReasoningEffort
|
|
59
|
+
from openai.types.shared_params import Reasoning
|
|
55
60
|
except ImportError as _import_error:
|
|
56
61
|
raise ImportError(
|
|
57
62
|
'Please install `openai` to use the OpenAI model, '
|
|
58
63
|
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
59
64
|
) from _import_error
|
|
60
65
|
|
|
66
|
+
__all__ = (
|
|
67
|
+
'OpenAIModel',
|
|
68
|
+
'OpenAIResponsesModel',
|
|
69
|
+
'OpenAIModelSettings',
|
|
70
|
+
'OpenAIResponsesModelSettings',
|
|
71
|
+
'OpenAIModelName',
|
|
72
|
+
)
|
|
73
|
+
|
|
61
74
|
OpenAIModelName = Union[str, ChatModel]
|
|
62
75
|
"""
|
|
63
76
|
Possible OpenAI model names.
|
|
@@ -79,9 +92,9 @@ class OpenAIModelSettings(ModelSettings, total=False):
|
|
|
79
92
|
ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
80
93
|
"""
|
|
81
94
|
|
|
82
|
-
openai_reasoning_effort:
|
|
83
|
-
"""
|
|
84
|
-
|
|
95
|
+
openai_reasoning_effort: ReasoningEffort
|
|
96
|
+
"""Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
|
|
97
|
+
|
|
85
98
|
Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
|
|
86
99
|
result in faster responses and fewer tokens used on reasoning in a response.
|
|
87
100
|
"""
|
|
@@ -93,6 +106,40 @@ class OpenAIModelSettings(ModelSettings, total=False):
|
|
|
93
106
|
"""
|
|
94
107
|
|
|
95
108
|
|
|
109
|
+
class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False):
|
|
110
|
+
"""Settings used for an OpenAI Responses model request.
|
|
111
|
+
|
|
112
|
+
ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
openai_builtin_tools: Sequence[FileSearchToolParam | WebSearchToolParam | ComputerToolParam]
|
|
116
|
+
"""The provided OpenAI built-in tools to use.
|
|
117
|
+
|
|
118
|
+
See [OpenAI's built-in tools](https://platform.openai.com/docs/guides/tools?api-mode=responses) for more details.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
openai_reasoning_generate_summary: Literal['detailed', 'concise']
|
|
122
|
+
"""A summary of the reasoning performed by the model.
|
|
123
|
+
|
|
124
|
+
This can be useful for debugging and understanding the model's reasoning process.
|
|
125
|
+
One of `concise` or `detailed`.
|
|
126
|
+
|
|
127
|
+
Check the [OpenAI Computer use documentation](https://platform.openai.com/docs/guides/tools-computer-use#1-send-a-request-to-the-model)
|
|
128
|
+
for more details.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
openai_truncation: Literal['disabled', 'auto']
|
|
132
|
+
"""The truncation strategy to use for the model response.
|
|
133
|
+
|
|
134
|
+
It can be either:
|
|
135
|
+
- `disabled` (default): If a model response will exceed the context window size for a model, the
|
|
136
|
+
request will fail with a 400 error.
|
|
137
|
+
- `auto`: If the context of this response and previous ones exceeds the model's context window size,
|
|
138
|
+
the model will truncate the response to fit the context window by dropping input items in the
|
|
139
|
+
middle of the conversation.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
|
|
96
143
|
@dataclass(init=False)
|
|
97
144
|
class OpenAIModel(Model):
|
|
98
145
|
"""A model that uses the OpenAI API.
|
|
@@ -178,8 +225,7 @@ class OpenAIModel(Model):
|
|
|
178
225
|
stream: Literal[True],
|
|
179
226
|
model_settings: OpenAIModelSettings,
|
|
180
227
|
model_request_parameters: ModelRequestParameters,
|
|
181
|
-
) -> AsyncStream[ChatCompletionChunk]:
|
|
182
|
-
pass
|
|
228
|
+
) -> AsyncStream[ChatCompletionChunk]: ...
|
|
183
229
|
|
|
184
230
|
@overload
|
|
185
231
|
async def _completions_create(
|
|
@@ -188,8 +234,7 @@ class OpenAIModel(Model):
|
|
|
188
234
|
stream: Literal[False],
|
|
189
235
|
model_settings: OpenAIModelSettings,
|
|
190
236
|
model_request_parameters: ModelRequestParameters,
|
|
191
|
-
) -> chat.ChatCompletion:
|
|
192
|
-
pass
|
|
237
|
+
) -> chat.ChatCompletion: ...
|
|
193
238
|
|
|
194
239
|
async def _completions_create(
|
|
195
240
|
self,
|
|
@@ -248,7 +293,7 @@ class OpenAIModel(Model):
|
|
|
248
293
|
items.append(TextPart(choice.message.content))
|
|
249
294
|
if choice.message.tool_calls is not None:
|
|
250
295
|
for c in choice.message.tool_calls:
|
|
251
|
-
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
296
|
+
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
252
297
|
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
253
298
|
|
|
254
299
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
@@ -399,6 +444,318 @@ class OpenAIModel(Model):
|
|
|
399
444
|
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
400
445
|
|
|
401
446
|
|
|
447
|
+
@dataclass(init=False)
|
|
448
|
+
class OpenAIResponsesModel(Model):
|
|
449
|
+
"""A model that uses the OpenAI Responses API.
|
|
450
|
+
|
|
451
|
+
The [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) is the
|
|
452
|
+
new API for OpenAI models.
|
|
453
|
+
|
|
454
|
+
The Responses API has built-in tools, that you can use instead of building your own:
|
|
455
|
+
|
|
456
|
+
- [Web search](https://platform.openai.com/docs/guides/tools-web-search)
|
|
457
|
+
- [File search](https://platform.openai.com/docs/guides/tools-file-search)
|
|
458
|
+
- [Computer use](https://platform.openai.com/docs/guides/tools-computer-use)
|
|
459
|
+
|
|
460
|
+
Use the `openai_builtin_tools` setting to add these tools to your model.
|
|
461
|
+
|
|
462
|
+
If you are interested in the differences between the Responses API and the Chat Completions API,
|
|
463
|
+
see the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions).
|
|
464
|
+
"""
|
|
465
|
+
|
|
466
|
+
client: AsyncOpenAI = field(repr=False)
|
|
467
|
+
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
468
|
+
|
|
469
|
+
_model_name: OpenAIModelName = field(repr=False)
|
|
470
|
+
_system: str = field(default='openai', repr=False)
|
|
471
|
+
|
|
472
|
+
def __init__(
|
|
473
|
+
self,
|
|
474
|
+
model_name: OpenAIModelName,
|
|
475
|
+
*,
|
|
476
|
+
provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
|
|
477
|
+
):
|
|
478
|
+
"""Initialize an OpenAI Responses model.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
model_name: The name of the OpenAI model to use.
|
|
482
|
+
provider: The provider to use. Defaults to `'openai'`.
|
|
483
|
+
"""
|
|
484
|
+
self._model_name = model_name
|
|
485
|
+
if isinstance(provider, str):
|
|
486
|
+
provider = infer_provider(provider)
|
|
487
|
+
self.client = provider.client
|
|
488
|
+
|
|
489
|
+
@property
|
|
490
|
+
def model_name(self) -> OpenAIModelName:
|
|
491
|
+
"""The model name."""
|
|
492
|
+
return self._model_name
|
|
493
|
+
|
|
494
|
+
@property
|
|
495
|
+
def system(self) -> str:
|
|
496
|
+
"""The system / model provider."""
|
|
497
|
+
return self._system
|
|
498
|
+
|
|
499
|
+
async def request(
|
|
500
|
+
self,
|
|
501
|
+
messages: list[ModelRequest | ModelResponse],
|
|
502
|
+
model_settings: ModelSettings | None,
|
|
503
|
+
model_request_parameters: ModelRequestParameters,
|
|
504
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
505
|
+
check_allow_model_requests()
|
|
506
|
+
response = await self._responses_create(
|
|
507
|
+
messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
|
|
508
|
+
)
|
|
509
|
+
return self._process_response(response), _map_usage(response)
|
|
510
|
+
|
|
511
|
+
@asynccontextmanager
|
|
512
|
+
async def request_stream(
|
|
513
|
+
self,
|
|
514
|
+
messages: list[ModelMessage],
|
|
515
|
+
model_settings: ModelSettings | None,
|
|
516
|
+
model_request_parameters: ModelRequestParameters,
|
|
517
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
518
|
+
check_allow_model_requests()
|
|
519
|
+
response = await self._responses_create(
|
|
520
|
+
messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
|
|
521
|
+
)
|
|
522
|
+
async with response:
|
|
523
|
+
yield await self._process_streamed_response(response)
|
|
524
|
+
|
|
525
|
+
def _process_response(self, response: responses.Response) -> ModelResponse:
|
|
526
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
527
|
+
timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
|
|
528
|
+
items: list[ModelResponsePart] = []
|
|
529
|
+
items.append(TextPart(response.output_text))
|
|
530
|
+
for item in response.output:
|
|
531
|
+
if item.type == 'function_call':
|
|
532
|
+
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
|
|
533
|
+
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
534
|
+
|
|
535
|
+
async def _process_streamed_response(
|
|
536
|
+
self, response: AsyncStream[responses.ResponseStreamEvent]
|
|
537
|
+
) -> OpenAIResponsesStreamedResponse:
|
|
538
|
+
"""Process a streamed response, and prepare a streaming response to return."""
|
|
539
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
540
|
+
first_chunk = await peekable_response.peek()
|
|
541
|
+
if isinstance(first_chunk, _utils.Unset): # pragma: no cover
|
|
542
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
543
|
+
|
|
544
|
+
assert isinstance(first_chunk, responses.ResponseCreatedEvent)
|
|
545
|
+
return OpenAIResponsesStreamedResponse(
|
|
546
|
+
_model_name=self._model_name,
|
|
547
|
+
_response=peekable_response,
|
|
548
|
+
_timestamp=datetime.fromtimestamp(first_chunk.response.created_at, tz=timezone.utc),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
@overload
|
|
552
|
+
async def _responses_create(
|
|
553
|
+
self,
|
|
554
|
+
messages: list[ModelRequest | ModelResponse],
|
|
555
|
+
stream: Literal[False],
|
|
556
|
+
model_settings: OpenAIResponsesModelSettings,
|
|
557
|
+
model_request_parameters: ModelRequestParameters,
|
|
558
|
+
) -> responses.Response: ...
|
|
559
|
+
|
|
560
|
+
@overload
|
|
561
|
+
async def _responses_create(
|
|
562
|
+
self,
|
|
563
|
+
messages: list[ModelRequest | ModelResponse],
|
|
564
|
+
stream: Literal[True],
|
|
565
|
+
model_settings: OpenAIResponsesModelSettings,
|
|
566
|
+
model_request_parameters: ModelRequestParameters,
|
|
567
|
+
) -> AsyncStream[responses.ResponseStreamEvent]: ...
|
|
568
|
+
|
|
569
|
+
async def _responses_create(
|
|
570
|
+
self,
|
|
571
|
+
messages: list[ModelRequest | ModelResponse],
|
|
572
|
+
stream: bool,
|
|
573
|
+
model_settings: OpenAIResponsesModelSettings,
|
|
574
|
+
model_request_parameters: ModelRequestParameters,
|
|
575
|
+
) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]:
|
|
576
|
+
tools = self._get_tools(model_request_parameters)
|
|
577
|
+
tools = list(model_settings.get('openai_builtin_tools', [])) + tools
|
|
578
|
+
|
|
579
|
+
# standalone function to make it easier to override
|
|
580
|
+
if not tools:
|
|
581
|
+
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
582
|
+
elif not model_request_parameters.allow_text_result:
|
|
583
|
+
tool_choice = 'required'
|
|
584
|
+
else:
|
|
585
|
+
tool_choice = 'auto'
|
|
586
|
+
|
|
587
|
+
system_prompt, openai_messages = await self._map_message(messages)
|
|
588
|
+
reasoning = self._get_reasoning(model_settings)
|
|
589
|
+
|
|
590
|
+
try:
|
|
591
|
+
return await self.client.responses.create(
|
|
592
|
+
input=openai_messages,
|
|
593
|
+
model=self._model_name,
|
|
594
|
+
instructions=system_prompt,
|
|
595
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
596
|
+
tools=tools or NOT_GIVEN,
|
|
597
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
598
|
+
max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
599
|
+
stream=stream,
|
|
600
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
601
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
602
|
+
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
|
|
603
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
604
|
+
reasoning=reasoning,
|
|
605
|
+
user=model_settings.get('user', NOT_GIVEN),
|
|
606
|
+
)
|
|
607
|
+
except APIStatusError as e:
|
|
608
|
+
if (status_code := e.status_code) >= 400:
|
|
609
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
610
|
+
raise
|
|
611
|
+
|
|
612
|
+
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
|
|
613
|
+
reasoning_effort = model_settings.get('openai_reasoning_effort', None)
|
|
614
|
+
reasoning_generate_summary = model_settings.get('openai_reasoning_generate_summary', None)
|
|
615
|
+
|
|
616
|
+
if reasoning_effort is None and reasoning_generate_summary is None:
|
|
617
|
+
return NOT_GIVEN
|
|
618
|
+
return Reasoning(effort=reasoning_effort, generate_summary=reasoning_generate_summary)
|
|
619
|
+
|
|
620
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
|
|
621
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
622
|
+
if model_request_parameters.result_tools:
|
|
623
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
624
|
+
return tools
|
|
625
|
+
|
|
626
|
+
@staticmethod
|
|
627
|
+
def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam:
|
|
628
|
+
return {
|
|
629
|
+
'name': f.name,
|
|
630
|
+
'parameters': f.parameters_json_schema,
|
|
631
|
+
'type': 'function',
|
|
632
|
+
'description': f.description,
|
|
633
|
+
'strict': True,
|
|
634
|
+
}
|
|
635
|
+
|
|
636
|
+
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[responses.ResponseInputItemParam]]:
|
|
637
|
+
"""Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
|
|
638
|
+
system_prompt: str = ''
|
|
639
|
+
openai_messages: list[responses.ResponseInputItemParam] = []
|
|
640
|
+
for message in messages:
|
|
641
|
+
if isinstance(message, ModelRequest):
|
|
642
|
+
for part in message.parts:
|
|
643
|
+
if isinstance(part, SystemPromptPart):
|
|
644
|
+
system_prompt += part.content
|
|
645
|
+
elif isinstance(part, UserPromptPart):
|
|
646
|
+
openai_messages.append(await self._map_user_prompt(part))
|
|
647
|
+
elif isinstance(part, ToolReturnPart):
|
|
648
|
+
openai_messages.append(
|
|
649
|
+
FunctionCallOutput(
|
|
650
|
+
type='function_call_output',
|
|
651
|
+
call_id=_guard_tool_call_id(t=part),
|
|
652
|
+
output=part.model_response_str(),
|
|
653
|
+
)
|
|
654
|
+
)
|
|
655
|
+
elif isinstance(part, RetryPromptPart):
|
|
656
|
+
# TODO(Marcelo): How do we test this conditional branch?
|
|
657
|
+
if part.tool_name is None: # pragma: no cover
|
|
658
|
+
openai_messages.append(
|
|
659
|
+
Message(role='user', content=[{'type': 'input_text', 'text': part.model_response()}])
|
|
660
|
+
)
|
|
661
|
+
else:
|
|
662
|
+
openai_messages.append(
|
|
663
|
+
FunctionCallOutput(
|
|
664
|
+
type='function_call_output',
|
|
665
|
+
call_id=_guard_tool_call_id(t=part),
|
|
666
|
+
output=part.model_response(),
|
|
667
|
+
)
|
|
668
|
+
)
|
|
669
|
+
else:
|
|
670
|
+
assert_never(part)
|
|
671
|
+
elif isinstance(message, ModelResponse):
|
|
672
|
+
for item in message.parts:
|
|
673
|
+
if isinstance(item, TextPart):
|
|
674
|
+
openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content))
|
|
675
|
+
elif isinstance(item, ToolCallPart):
|
|
676
|
+
openai_messages.append(self._map_tool_call(item))
|
|
677
|
+
else:
|
|
678
|
+
assert_never(item)
|
|
679
|
+
else:
|
|
680
|
+
assert_never(message)
|
|
681
|
+
return system_prompt, openai_messages
|
|
682
|
+
|
|
683
|
+
@staticmethod
|
|
684
|
+
def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
|
|
685
|
+
return responses.ResponseFunctionToolCallParam(
|
|
686
|
+
arguments=t.args_as_json_str(),
|
|
687
|
+
call_id=_guard_tool_call_id(t=t),
|
|
688
|
+
name=t.tool_name,
|
|
689
|
+
type='function_call',
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
@staticmethod
|
|
693
|
+
async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam:
|
|
694
|
+
content: str | list[responses.ResponseInputContentParam]
|
|
695
|
+
if isinstance(part.content, str):
|
|
696
|
+
content = part.content
|
|
697
|
+
else:
|
|
698
|
+
content = []
|
|
699
|
+
for item in part.content:
|
|
700
|
+
if isinstance(item, str):
|
|
701
|
+
content.append(responses.ResponseInputTextParam(text=item, type='input_text'))
|
|
702
|
+
elif isinstance(item, BinaryContent):
|
|
703
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
704
|
+
if item.is_image:
|
|
705
|
+
content.append(
|
|
706
|
+
responses.ResponseInputImageParam(
|
|
707
|
+
image_url=f'data:{item.media_type};base64,{base64_encoded}',
|
|
708
|
+
type='input_image',
|
|
709
|
+
detail='auto',
|
|
710
|
+
)
|
|
711
|
+
)
|
|
712
|
+
elif item.is_document:
|
|
713
|
+
content.append(
|
|
714
|
+
responses.ResponseInputFileParam(
|
|
715
|
+
type='input_file',
|
|
716
|
+
file_data=f'data:{item.media_type};base64,{base64_encoded}',
|
|
717
|
+
# NOTE: Type wise it's not necessary to include the filename, but it's required by the
|
|
718
|
+
# API itself. If we add empty string, the server sends a 500 error - which OpenAI needs
|
|
719
|
+
# to fix. In any case, we add a placeholder name.
|
|
720
|
+
filename=f'filename.{item.format}',
|
|
721
|
+
)
|
|
722
|
+
)
|
|
723
|
+
elif item.is_audio:
|
|
724
|
+
raise NotImplementedError('Audio as binary content is not supported for OpenAI Responses API.')
|
|
725
|
+
else: # pragma: no cover
|
|
726
|
+
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
|
|
727
|
+
elif isinstance(item, ImageUrl):
|
|
728
|
+
content.append(
|
|
729
|
+
responses.ResponseInputImageParam(image_url=item.url, type='input_image', detail='auto')
|
|
730
|
+
)
|
|
731
|
+
elif isinstance(item, AudioUrl): # pragma: no cover
|
|
732
|
+
client = cached_async_http_client()
|
|
733
|
+
response = await client.get(item.url)
|
|
734
|
+
response.raise_for_status()
|
|
735
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
736
|
+
content.append(
|
|
737
|
+
responses.ResponseInputFileParam(
|
|
738
|
+
type='input_file',
|
|
739
|
+
file_data=f'data:{item.media_type};base64,{base64_encoded}',
|
|
740
|
+
)
|
|
741
|
+
)
|
|
742
|
+
elif isinstance(item, DocumentUrl): # pragma: no cover
|
|
743
|
+
client = cached_async_http_client()
|
|
744
|
+
response = await client.get(item.url)
|
|
745
|
+
response.raise_for_status()
|
|
746
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
747
|
+
content.append(
|
|
748
|
+
responses.ResponseInputFileParam(
|
|
749
|
+
type='input_file',
|
|
750
|
+
file_data=f'data:{item.media_type};base64,{base64_encoded}',
|
|
751
|
+
filename=f'filename.{item.format}',
|
|
752
|
+
)
|
|
753
|
+
)
|
|
754
|
+
else:
|
|
755
|
+
assert_never(item)
|
|
756
|
+
return responses.EasyInputMessageParam(role='user', content=content)
|
|
757
|
+
|
|
758
|
+
|
|
402
759
|
@dataclass
|
|
403
760
|
class OpenAIStreamedResponse(StreamedResponse):
|
|
404
761
|
"""Implementation of `StreamedResponse` for OpenAI models."""
|
|
@@ -442,12 +799,103 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
442
799
|
return self._timestamp
|
|
443
800
|
|
|
444
801
|
|
|
445
|
-
|
|
802
|
+
@dataclass
|
|
803
|
+
class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
804
|
+
"""Implementation of `StreamedResponse` for OpenAI Responses API."""
|
|
805
|
+
|
|
806
|
+
_model_name: OpenAIModelName
|
|
807
|
+
_response: AsyncIterable[responses.ResponseStreamEvent]
|
|
808
|
+
_timestamp: datetime
|
|
809
|
+
|
|
810
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
811
|
+
async for chunk in self._response:
|
|
812
|
+
if isinstance(chunk, responses.ResponseCompletedEvent):
|
|
813
|
+
self._usage += _map_usage(chunk.response)
|
|
814
|
+
|
|
815
|
+
elif isinstance(chunk, responses.ResponseContentPartAddedEvent):
|
|
816
|
+
pass # there's nothing we need to do here
|
|
817
|
+
|
|
818
|
+
elif isinstance(chunk, responses.ResponseContentPartDoneEvent):
|
|
819
|
+
pass # there's nothing we need to do here
|
|
820
|
+
|
|
821
|
+
elif isinstance(chunk, responses.ResponseCreatedEvent):
|
|
822
|
+
pass # there's nothing we need to do here
|
|
823
|
+
|
|
824
|
+
elif isinstance(chunk, responses.ResponseFailedEvent): # pragma: no cover
|
|
825
|
+
self._usage += _map_usage(chunk.response)
|
|
826
|
+
|
|
827
|
+
elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDeltaEvent):
|
|
828
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
829
|
+
vendor_part_id=chunk.item_id,
|
|
830
|
+
tool_name=None,
|
|
831
|
+
args=chunk.delta,
|
|
832
|
+
tool_call_id=chunk.item_id,
|
|
833
|
+
)
|
|
834
|
+
if maybe_event is not None:
|
|
835
|
+
yield maybe_event
|
|
836
|
+
|
|
837
|
+
elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDoneEvent):
|
|
838
|
+
pass # there's nothing we need to do here
|
|
839
|
+
|
|
840
|
+
elif isinstance(chunk, responses.ResponseIncompleteEvent): # pragma: no cover
|
|
841
|
+
self._usage += _map_usage(chunk.response)
|
|
842
|
+
|
|
843
|
+
elif isinstance(chunk, responses.ResponseInProgressEvent):
|
|
844
|
+
self._usage += _map_usage(chunk.response)
|
|
845
|
+
|
|
846
|
+
elif isinstance(chunk, responses.ResponseOutputItemAddedEvent):
|
|
847
|
+
if isinstance(chunk.item, responses.ResponseFunctionToolCall):
|
|
848
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
849
|
+
vendor_part_id=chunk.item.id,
|
|
850
|
+
tool_name=chunk.item.name,
|
|
851
|
+
args=chunk.item.arguments,
|
|
852
|
+
tool_call_id=chunk.item.id,
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
elif isinstance(chunk, responses.ResponseOutputItemDoneEvent):
|
|
856
|
+
# NOTE: We only need this if the tool call deltas don't include the final info.
|
|
857
|
+
pass
|
|
858
|
+
|
|
859
|
+
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
|
|
860
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta)
|
|
861
|
+
|
|
862
|
+
elif isinstance(chunk, responses.ResponseTextDoneEvent):
|
|
863
|
+
pass # there's nothing we need to do here
|
|
864
|
+
|
|
865
|
+
else: # pragma: no cover
|
|
866
|
+
warnings.warn(
|
|
867
|
+
f'Handling of this event type is not yet implemented. Please report on our GitHub: {chunk}',
|
|
868
|
+
UserWarning,
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
@property
|
|
872
|
+
def model_name(self) -> OpenAIModelName:
|
|
873
|
+
"""Get the model name of the response."""
|
|
874
|
+
return self._model_name
|
|
875
|
+
|
|
876
|
+
@property
|
|
877
|
+
def timestamp(self) -> datetime:
|
|
878
|
+
"""Get the timestamp of the response."""
|
|
879
|
+
return self._timestamp
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.Usage:
|
|
446
883
|
response_usage = response.usage
|
|
447
884
|
if response_usage is None:
|
|
448
885
|
return usage.Usage()
|
|
449
|
-
|
|
886
|
+
elif isinstance(response_usage, responses.ResponseUsage):
|
|
450
887
|
details: dict[str, int] = {}
|
|
888
|
+
return usage.Usage(
|
|
889
|
+
request_tokens=response_usage.input_tokens,
|
|
890
|
+
response_tokens=response_usage.output_tokens,
|
|
891
|
+
total_tokens=response_usage.total_tokens,
|
|
892
|
+
details={
|
|
893
|
+
'reasoning_tokens': response_usage.output_tokens_details.reasoning_tokens,
|
|
894
|
+
'cached_tokens': response_usage.input_tokens_details.cached_tokens,
|
|
895
|
+
},
|
|
896
|
+
)
|
|
897
|
+
else:
|
|
898
|
+
details = {}
|
|
451
899
|
if response_usage.completion_tokens_details is not None:
|
|
452
900
|
details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
|
|
453
901
|
if response_usage.prompt_tokens_details is not None:
|
pydantic_ai/tools.py
CHANGED
|
@@ -149,8 +149,8 @@ class GenerateToolJsonSchema(GenerateJsonSchema):
|
|
|
149
149
|
def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
|
|
150
150
|
s = super().typed_dict_schema(schema)
|
|
151
151
|
total = schema.get('total')
|
|
152
|
-
if total is
|
|
153
|
-
s['additionalProperties'] =
|
|
152
|
+
if 'additionalProperties' not in s and (total is True or total is None):
|
|
153
|
+
s['additionalProperties'] = False
|
|
154
154
|
return s
|
|
155
155
|
|
|
156
156
|
def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.49
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.49
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
47
|
Provides-Extra: evals
|
|
48
|
-
Requires-Dist: pydantic-evals==0.0.
|
|
48
|
+
Requires-Dist: pydantic-evals==0.0.49; extra == 'evals'
|
|
49
49
|
Provides-Extra: groq
|
|
50
50
|
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
51
51
|
Provides-Extra: logfire
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
pydantic_ai/__init__.py,sha256=5or1fE25gmemJGCznkFHC4VMeNT7vTLU6BiGxkmSA2A,959
|
|
2
2
|
pydantic_ai/__main__.py,sha256=AW8FzscUWPFtIrQBG0QExLxTQehKtt5FnFVnOT200OE,122
|
|
3
3
|
pydantic_ai/_agent_graph.py,sha256=MWNNug-bGJcgOVonkkD38Yz3X4R0XfUxzFJ9_fpNnyQ,32532
|
|
4
|
-
pydantic_ai/_cli.py,sha256=
|
|
4
|
+
pydantic_ai/_cli.py,sha256=TU6Kqlu4aDmgrwmNAnnGPG5mSpTZPjtOtyhq582BCec,10953
|
|
5
5
|
pydantic_ai/_griffe.py,sha256=Sf_DisE9k2TA0VFeVIK2nf1oOct5MygW86PBCACJkFA,5244
|
|
6
6
|
pydantic_ai/_parts_manager.py,sha256=HIi6eth7z2g0tOn6iQYc633xMqy4d_xZ8vwka8J8150,12016
|
|
7
7
|
pydantic_ai/_pydantic.py,sha256=12hX5hON88meO1QxbWrEPXSvr6RTNgr6ubKY6KRwab4,8890
|
|
8
|
-
pydantic_ai/_result.py,sha256=
|
|
8
|
+
pydantic_ai/_result.py,sha256=9cDWMiXv3ef_qPywv02SAH_8r3SS5FVfxm5iAYDI_is,10375
|
|
9
9
|
pydantic_ai/_system_prompt.py,sha256=602c2jyle2R_SesOrITBDETZqsLk4BZ8Cbo8yEhmx04,1120
|
|
10
10
|
pydantic_ai/_utils.py,sha256=VXbR6FG6WQU_TTN1NMmBDbZ6e6DT7PDiUiBHFHf348U,9703
|
|
11
11
|
pydantic_ai/agent.py,sha256=qBUE9uPgomSj09sHe6kQYxjXGMVsRs3nkPn5neRVlYg,69285
|
|
@@ -16,22 +16,22 @@ pydantic_ai/messages.py,sha256=IQS6vabH72yhvTQY1ciQxRqJDHpGMMR0MKiO5xcJ0SE,27112
|
|
|
16
16
|
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
pydantic_ai/result.py,sha256=LXKxRzy_rGMkdZ8xJ7yknPP3wGZtGNeZl-gh5opXbaQ,22542
|
|
18
18
|
pydantic_ai/settings.py,sha256=q__Hordc4dypesNxpy_cBT5rFdSiEY-rQt9G6zfyFaM,3101
|
|
19
|
-
pydantic_ai/tools.py,sha256=
|
|
19
|
+
pydantic_ai/tools.py,sha256=S7pg9cc7OjIkACXAOjE5RUz7VkKXapH4GciP7VVyFc8,15791
|
|
20
20
|
pydantic_ai/usage.py,sha256=9sqoIv_RVVUhKXQScTDqUJc074gifsuSzc9_NOt7C3g,5394
|
|
21
21
|
pydantic_ai/common_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
22
|
pydantic_ai/common_tools/duckduckgo.py,sha256=Iw8Dl2YQ28S483mzfa8CXs-dc-ujS8un085R2O6oOEw,2241
|
|
23
23
|
pydantic_ai/common_tools/tavily.py,sha256=h8deBDrpG-8BGzydM_zXs7z1ASrhdVvUxL4-CAbncBo,2589
|
|
24
|
-
pydantic_ai/models/__init__.py,sha256=
|
|
25
|
-
pydantic_ai/models/anthropic.py,sha256=
|
|
24
|
+
pydantic_ai/models/__init__.py,sha256=0f_9smleJgn3Jgmm96gsW7lbyQIJ6-5chXxClLqzfHk,18367
|
|
25
|
+
pydantic_ai/models/anthropic.py,sha256=bqx-yQH4TyLoC9Jn5m7IegRsHXn0h2vv9aOkO0elfu0,19957
|
|
26
26
|
pydantic_ai/models/bedrock.py,sha256=EmnBwxx_FzKMM8xAed_1aTmDwASWo26GA6Q8Z92waVc,20706
|
|
27
|
-
pydantic_ai/models/cohere.py,sha256=
|
|
27
|
+
pydantic_ai/models/cohere.py,sha256=UTRCl3sa5-ptceg7k9WqLMIhNo5oFExUPMpLLi-A46c,11328
|
|
28
28
|
pydantic_ai/models/fallback.py,sha256=y0bYXM3DfzJNAsyyMzclt33lzZazL-5_hwdgc33gfuM,4876
|
|
29
29
|
pydantic_ai/models/function.py,sha256=HUSgPB3mKVfYI0OSJJJJRiQN-yeewjYIbrtrPfsvlgI,11365
|
|
30
|
-
pydantic_ai/models/gemini.py,sha256=
|
|
30
|
+
pydantic_ai/models/gemini.py,sha256=0N5ccI5HsCLaPd2-jxareQvrAePKSALOiRmV0ufxIcc,33336
|
|
31
31
|
pydantic_ai/models/groq.py,sha256=IbO2jMNC5yiYRuUJf-j4blpPvoQDromQNBKJyPjs2A4,16518
|
|
32
32
|
pydantic_ai/models/instrumented.py,sha256=ErFRDiOehOYlJBp4mSNj7yEIMtMqjlGcamEAwgW_Il4,11163
|
|
33
33
|
pydantic_ai/models/mistral.py,sha256=8pdG9oRW6Dx7H5P88ZgiRIZbkGkPUfEateuyvzoULOE,27439
|
|
34
|
-
pydantic_ai/models/openai.py,sha256=
|
|
34
|
+
pydantic_ai/models/openai.py,sha256=asLsbtyQwvAtLjdJs4B87qE9ylqWJmWCIoWs2xvvjM0,40675
|
|
35
35
|
pydantic_ai/models/test.py,sha256=qQ8ZIaVRdbJv-tKGu6lrdakVAhOsTlyf68TFWyGwOWE,16861
|
|
36
36
|
pydantic_ai/models/wrapper.py,sha256=ff6JPTuIv9C_6Zo4kyYIO7Cn0VI1uSICz1v1aKUyeOc,1506
|
|
37
37
|
pydantic_ai/providers/__init__.py,sha256=lsJn3BStrPMMAFWEkCYPyfMj3fEVfaeS2xllnvE6Gdk,2489
|
|
@@ -45,7 +45,7 @@ pydantic_ai/providers/google_vertex.py,sha256=WAwPxKTARVzs8DFs2veEUOJSur0krDOo9-
|
|
|
45
45
|
pydantic_ai/providers/groq.py,sha256=DoY6qkfhuemuKB5JXhUkqG-3t1HQkxwSXoE_kHQIAK0,2788
|
|
46
46
|
pydantic_ai/providers/mistral.py,sha256=fcR1uSwORo0jtevX7-wOjvcfT8ojMAaKY81uN5uYymM,2661
|
|
47
47
|
pydantic_ai/providers/openai.py,sha256=ePF-QWwLkGkSE5w245gTTDVR3VoTIUqFoIhQ0TAoUiA,2866
|
|
48
|
-
pydantic_ai_slim-0.0.
|
|
49
|
-
pydantic_ai_slim-0.0.
|
|
50
|
-
pydantic_ai_slim-0.0.
|
|
51
|
-
pydantic_ai_slim-0.0.
|
|
48
|
+
pydantic_ai_slim-0.0.49.dist-info/METADATA,sha256=lj5N1y3Dff7zJqTUQJy-z-kRBa7Fh3odZRuDW_GqOFc,3555
|
|
49
|
+
pydantic_ai_slim-0.0.49.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
50
|
+
pydantic_ai_slim-0.0.49.dist-info/entry_points.txt,sha256=KxQSmlMS8GMTkwTsl4_q9a5nJvBjj3HWeXx688wLrKg,45
|
|
51
|
+
pydantic_ai_slim-0.0.49.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|