pydantic-ai-slim 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +6 -4
- pydantic_ai/_result.py +18 -22
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +11 -6
- pydantic_ai/agent.py +156 -69
- pydantic_ai/messages.py +5 -2
- pydantic_ai/models/__init__.py +30 -37
- pydantic_ai/models/function.py +8 -14
- pydantic_ai/models/gemini.py +11 -10
- pydantic_ai/models/groq.py +31 -34
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +43 -38
- pydantic_ai/models/test.py +70 -49
- pydantic_ai/models/vertexai.py +7 -6
- pydantic_ai/tools.py +119 -34
- {pydantic_ai_slim-0.0.10.dist-info → pydantic_ai_slim-0.0.12.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.10.dist-info/RECORD +0 -22
- {pydantic_ai_slim-0.0.10.dist-info → pydantic_ai_slim-0.0.12.dist-info}/WHEEL +0 -0
pydantic_ai/models/function.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
@@ -14,14 +14,8 @@ from typing_extensions import TypeAlias, assert_never, overload
|
|
|
14
14
|
|
|
15
15
|
from .. import _utils, result
|
|
16
16
|
from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
|
|
17
|
-
from
|
|
18
|
-
|
|
19
|
-
AgentModel,
|
|
20
|
-
EitherStreamedResponse,
|
|
21
|
-
Model,
|
|
22
|
-
StreamStructuredResponse,
|
|
23
|
-
StreamTextResponse,
|
|
24
|
-
)
|
|
17
|
+
from ..tools import ToolDefinition
|
|
18
|
+
from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
|
|
25
19
|
|
|
26
20
|
|
|
27
21
|
@dataclass(init=False)
|
|
@@ -59,11 +53,11 @@ class FunctionModel(Model):
|
|
|
59
53
|
|
|
60
54
|
async def agent_model(
|
|
61
55
|
self,
|
|
62
|
-
|
|
56
|
+
*,
|
|
57
|
+
function_tools: list[ToolDefinition],
|
|
63
58
|
allow_text_result: bool,
|
|
64
|
-
result_tools:
|
|
59
|
+
result_tools: list[ToolDefinition],
|
|
65
60
|
) -> AgentModel:
|
|
66
|
-
result_tools = list(result_tools) if result_tools is not None else None
|
|
67
61
|
return FunctionAgentModel(
|
|
68
62
|
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
|
|
69
63
|
)
|
|
@@ -84,7 +78,7 @@ class AgentInfo:
|
|
|
84
78
|
This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel].
|
|
85
79
|
"""
|
|
86
80
|
|
|
87
|
-
function_tools:
|
|
81
|
+
function_tools: list[ToolDefinition]
|
|
88
82
|
"""The function tools available on this agent.
|
|
89
83
|
|
|
90
84
|
These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and
|
|
@@ -92,7 +86,7 @@ class AgentInfo:
|
|
|
92
86
|
"""
|
|
93
87
|
allow_text_result: bool
|
|
94
88
|
"""Whether a plain text result is allowed."""
|
|
95
|
-
result_tools: list[
|
|
89
|
+
result_tools: list[ToolDefinition]
|
|
96
90
|
"""The tools that can called as the final result of the run."""
|
|
97
91
|
|
|
98
92
|
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
@@ -25,8 +25,8 @@ from ..messages import (
|
|
|
25
25
|
ToolCall,
|
|
26
26
|
ToolReturn,
|
|
27
27
|
)
|
|
28
|
+
from ..tools import ToolDefinition
|
|
28
29
|
from . import (
|
|
29
|
-
AbstractToolDefinition,
|
|
30
30
|
AgentModel,
|
|
31
31
|
EitherStreamedResponse,
|
|
32
32
|
Model,
|
|
@@ -90,9 +90,10 @@ class GeminiModel(Model):
|
|
|
90
90
|
|
|
91
91
|
async def agent_model(
|
|
92
92
|
self,
|
|
93
|
-
|
|
93
|
+
*,
|
|
94
|
+
function_tools: list[ToolDefinition],
|
|
94
95
|
allow_text_result: bool,
|
|
95
|
-
result_tools:
|
|
96
|
+
result_tools: list[ToolDefinition],
|
|
96
97
|
) -> GeminiAgentModel:
|
|
97
98
|
return GeminiAgentModel(
|
|
98
99
|
http_client=self.http_client,
|
|
@@ -142,13 +143,13 @@ class GeminiAgentModel(AgentModel):
|
|
|
142
143
|
model_name: GeminiModelName,
|
|
143
144
|
auth: AuthProtocol,
|
|
144
145
|
url: str,
|
|
145
|
-
function_tools:
|
|
146
|
+
function_tools: list[ToolDefinition],
|
|
146
147
|
allow_text_result: bool,
|
|
147
|
-
result_tools:
|
|
148
|
+
result_tools: list[ToolDefinition],
|
|
148
149
|
):
|
|
149
150
|
check_allow_model_requests()
|
|
150
|
-
tools = [_function_from_abstract_tool(t) for t in function_tools
|
|
151
|
-
if result_tools
|
|
151
|
+
tools = [_function_from_abstract_tool(t) for t in function_tools]
|
|
152
|
+
if result_tools:
|
|
152
153
|
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
153
154
|
|
|
154
155
|
if allow_text_result:
|
|
@@ -504,8 +505,8 @@ class _GeminiFunction(TypedDict):
|
|
|
504
505
|
"""
|
|
505
506
|
|
|
506
507
|
|
|
507
|
-
def _function_from_abstract_tool(tool:
|
|
508
|
-
json_schema = _GeminiJsonSchema(tool.
|
|
508
|
+
def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
|
|
509
|
+
json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
|
|
509
510
|
f = _GeminiFunction(
|
|
510
511
|
name=tool.name,
|
|
511
512
|
description=tool.description,
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
@@ -21,8 +21,8 @@ from ..messages import (
|
|
|
21
21
|
ToolReturn,
|
|
22
22
|
)
|
|
23
23
|
from ..result import Cost
|
|
24
|
+
from ..tools import ToolDefinition
|
|
24
25
|
from . import (
|
|
25
|
-
AbstractToolDefinition,
|
|
26
26
|
AgentModel,
|
|
27
27
|
EitherStreamedResponse,
|
|
28
28
|
Model,
|
|
@@ -37,11 +37,11 @@ try:
|
|
|
37
37
|
from groq.types import chat
|
|
38
38
|
from groq.types.chat import ChatCompletion, ChatCompletionChunk
|
|
39
39
|
from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
40
|
-
except ImportError as
|
|
40
|
+
except ImportError as _import_error:
|
|
41
41
|
raise ImportError(
|
|
42
42
|
'Please install `groq` to use the Groq model, '
|
|
43
43
|
"you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
|
|
44
|
-
) from
|
|
44
|
+
) from _import_error
|
|
45
45
|
|
|
46
46
|
GroqModelName = Literal[
|
|
47
47
|
'llama-3.1-70b-versatile',
|
|
@@ -109,13 +109,14 @@ class GroqModel(Model):
|
|
|
109
109
|
|
|
110
110
|
async def agent_model(
|
|
111
111
|
self,
|
|
112
|
-
|
|
112
|
+
*,
|
|
113
|
+
function_tools: list[ToolDefinition],
|
|
113
114
|
allow_text_result: bool,
|
|
114
|
-
result_tools:
|
|
115
|
+
result_tools: list[ToolDefinition],
|
|
115
116
|
) -> AgentModel:
|
|
116
117
|
check_allow_model_requests()
|
|
117
|
-
tools = [self._map_tool_definition(r) for r in function_tools
|
|
118
|
-
if result_tools
|
|
118
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
119
|
+
if result_tools:
|
|
119
120
|
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
120
121
|
return GroqAgentModel(
|
|
121
122
|
self.client,
|
|
@@ -128,13 +129,13 @@ class GroqModel(Model):
|
|
|
128
129
|
return f'groq:{self.model_name}'
|
|
129
130
|
|
|
130
131
|
@staticmethod
|
|
131
|
-
def _map_tool_definition(f:
|
|
132
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
132
133
|
return {
|
|
133
134
|
'type': 'function',
|
|
134
135
|
'function': {
|
|
135
136
|
'name': f.name,
|
|
136
137
|
'description': f.description,
|
|
137
|
-
'parameters': f.
|
|
138
|
+
'parameters': f.parameters_json_schema,
|
|
138
139
|
},
|
|
139
140
|
}
|
|
140
141
|
|
|
@@ -208,33 +209,29 @@ class GroqAgentModel(AgentModel):
|
|
|
208
209
|
@staticmethod
|
|
209
210
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
210
211
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
|
|
216
|
-
delta = first_chunk.choices[0].delta
|
|
217
|
-
start_cost = _map_cost(first_chunk)
|
|
218
|
-
|
|
219
|
-
# the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
|
|
220
|
-
while delta.tool_calls is None and delta.content is None:
|
|
212
|
+
timestamp: datetime | None = None
|
|
213
|
+
start_cost = Cost()
|
|
214
|
+
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
215
|
+
while True:
|
|
221
216
|
try:
|
|
222
|
-
|
|
217
|
+
chunk = await response.__anext__()
|
|
223
218
|
except StopAsyncIteration as e:
|
|
224
219
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
225
|
-
|
|
226
|
-
start_cost += _map_cost(
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
220
|
+
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
221
|
+
start_cost += _map_cost(chunk)
|
|
222
|
+
|
|
223
|
+
if chunk.choices:
|
|
224
|
+
delta = chunk.choices[0].delta
|
|
225
|
+
|
|
226
|
+
if delta.content is not None:
|
|
227
|
+
return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
|
|
228
|
+
elif delta.tool_calls is not None:
|
|
229
|
+
return GroqStreamStructuredResponse(
|
|
230
|
+
response,
|
|
231
|
+
{c.index: c for c in delta.tool_calls},
|
|
232
|
+
timestamp,
|
|
233
|
+
start_cost,
|
|
234
|
+
)
|
|
238
235
|
|
|
239
236
|
@staticmethod
|
|
240
237
|
def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal, Union
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
|
|
8
|
+
from ..tools import ToolDefinition
|
|
9
|
+
from . import (
|
|
10
|
+
AgentModel,
|
|
11
|
+
Model,
|
|
12
|
+
cached_async_http_client,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from openai import AsyncOpenAI
|
|
17
|
+
except ImportError as e:
|
|
18
|
+
raise ImportError(
|
|
19
|
+
'Please install `openai` to use the OpenAI model, '
|
|
20
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
|
|
21
|
+
) from e
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from .openai import OpenAIModel
|
|
25
|
+
|
|
26
|
+
CommonOllamaModelNames = Literal[
|
|
27
|
+
'codellama',
|
|
28
|
+
'gemma',
|
|
29
|
+
'gemma2',
|
|
30
|
+
'llama3',
|
|
31
|
+
'llama3.1',
|
|
32
|
+
'llama3.2',
|
|
33
|
+
'llama3.2-vision',
|
|
34
|
+
'llama3.3',
|
|
35
|
+
'mistral',
|
|
36
|
+
'mistral-nemo',
|
|
37
|
+
'mixtral',
|
|
38
|
+
'phi3',
|
|
39
|
+
'qwq',
|
|
40
|
+
'qwen',
|
|
41
|
+
'qwen2',
|
|
42
|
+
'qwen2.5',
|
|
43
|
+
'starcoder2',
|
|
44
|
+
]
|
|
45
|
+
"""This contains just the most common ollama models.
|
|
46
|
+
|
|
47
|
+
For a full list see [ollama.com/library](https://ollama.com/library).
|
|
48
|
+
"""
|
|
49
|
+
OllamaModelName = Union[CommonOllamaModelNames, str]
|
|
50
|
+
"""Possible ollama models.
|
|
51
|
+
|
|
52
|
+
Since Ollama supports hundreds of models, we explicitly list the most models but
|
|
53
|
+
allow any name in the type hints.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(init=False)
|
|
58
|
+
class OllamaModel(Model):
|
|
59
|
+
"""A model that implements Ollama using the OpenAI API.
|
|
60
|
+
|
|
61
|
+
Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the Ollama server.
|
|
62
|
+
|
|
63
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
model_name: OllamaModelName
|
|
67
|
+
openai_model: OpenAIModel
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
model_name: OllamaModelName,
|
|
72
|
+
*,
|
|
73
|
+
base_url: str | None = 'http://localhost:11434/v1/',
|
|
74
|
+
openai_client: AsyncOpenAI | None = None,
|
|
75
|
+
http_client: AsyncHTTPClient | None = None,
|
|
76
|
+
):
|
|
77
|
+
"""Initialize an Ollama model.
|
|
78
|
+
|
|
79
|
+
Ollama has built-in compatability for the OpenAI chat completions API ([source](https://ollama.com/blog/openai-compatibility)), so we reuse the
|
|
80
|
+
[`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] here.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
|
|
84
|
+
You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
|
|
85
|
+
base_url: The base url for the ollama requests. The default value is the ollama default
|
|
86
|
+
openai_client: An existing
|
|
87
|
+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
88
|
+
client to use, if provided, `base_url` and `http_client` must be `None`.
|
|
89
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
90
|
+
"""
|
|
91
|
+
self.model_name = model_name
|
|
92
|
+
if openai_client is not None:
|
|
93
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
94
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
95
|
+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client)
|
|
96
|
+
else:
|
|
97
|
+
# API key is not required for ollama but a value is required to create the client
|
|
98
|
+
http_client_ = http_client or cached_async_http_client()
|
|
99
|
+
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client_)
|
|
100
|
+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
|
|
101
|
+
|
|
102
|
+
async def agent_model(
|
|
103
|
+
self,
|
|
104
|
+
*,
|
|
105
|
+
function_tools: list[ToolDefinition],
|
|
106
|
+
allow_text_result: bool,
|
|
107
|
+
result_tools: list[ToolDefinition],
|
|
108
|
+
) -> AgentModel:
|
|
109
|
+
return await self.openai_model.agent_model(
|
|
110
|
+
function_tools=function_tools,
|
|
111
|
+
allow_text_result=allow_text_result,
|
|
112
|
+
result_tools=result_tools,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def name(self) -> str:
|
|
116
|
+
return f'ollama:{self.model_name}'
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
|
-
from typing import Literal, overload
|
|
7
|
+
from typing import Literal, Union, overload
|
|
8
8
|
|
|
9
9
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
10
|
from typing_extensions import assert_never
|
|
@@ -21,8 +21,8 @@ from ..messages import (
|
|
|
21
21
|
ToolReturn,
|
|
22
22
|
)
|
|
23
23
|
from ..result import Cost
|
|
24
|
+
from ..tools import ToolDefinition
|
|
24
25
|
from . import (
|
|
25
|
-
AbstractToolDefinition,
|
|
26
26
|
AgentModel,
|
|
27
27
|
EitherStreamedResponse,
|
|
28
28
|
Model,
|
|
@@ -37,11 +37,17 @@ try:
|
|
|
37
37
|
from openai.types import ChatModel, chat
|
|
38
38
|
from openai.types.chat import ChatCompletionChunk
|
|
39
39
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
40
|
-
except ImportError as
|
|
40
|
+
except ImportError as _import_error:
|
|
41
41
|
raise ImportError(
|
|
42
42
|
'Please install `openai` to use the OpenAI model, '
|
|
43
43
|
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
|
|
44
|
-
) from
|
|
44
|
+
) from _import_error
|
|
45
|
+
|
|
46
|
+
OpenAIModelName = Union[ChatModel, str]
|
|
47
|
+
"""
|
|
48
|
+
Using this more broad type for the model name instead of the ChatModel definition
|
|
49
|
+
allows this model to be used more easily with other model types (ie, Ollama)
|
|
50
|
+
"""
|
|
45
51
|
|
|
46
52
|
|
|
47
53
|
@dataclass(init=False)
|
|
@@ -53,12 +59,12 @@ class OpenAIModel(Model):
|
|
|
53
59
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
54
60
|
"""
|
|
55
61
|
|
|
56
|
-
model_name:
|
|
62
|
+
model_name: OpenAIModelName
|
|
57
63
|
client: AsyncOpenAI = field(repr=False)
|
|
58
64
|
|
|
59
65
|
def __init__(
|
|
60
66
|
self,
|
|
61
|
-
model_name:
|
|
67
|
+
model_name: OpenAIModelName,
|
|
62
68
|
*,
|
|
63
69
|
api_key: str | None = None,
|
|
64
70
|
openai_client: AsyncOpenAI | None = None,
|
|
@@ -77,7 +83,7 @@ class OpenAIModel(Model):
|
|
|
77
83
|
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
78
84
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
79
85
|
"""
|
|
80
|
-
self.model_name:
|
|
86
|
+
self.model_name: OpenAIModelName = model_name
|
|
81
87
|
if openai_client is not None:
|
|
82
88
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
83
89
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
@@ -89,13 +95,14 @@ class OpenAIModel(Model):
|
|
|
89
95
|
|
|
90
96
|
async def agent_model(
|
|
91
97
|
self,
|
|
92
|
-
|
|
98
|
+
*,
|
|
99
|
+
function_tools: list[ToolDefinition],
|
|
93
100
|
allow_text_result: bool,
|
|
94
|
-
result_tools:
|
|
101
|
+
result_tools: list[ToolDefinition],
|
|
95
102
|
) -> AgentModel:
|
|
96
103
|
check_allow_model_requests()
|
|
97
|
-
tools = [self._map_tool_definition(r) for r in function_tools
|
|
98
|
-
if result_tools
|
|
104
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
105
|
+
if result_tools:
|
|
99
106
|
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
100
107
|
return OpenAIAgentModel(
|
|
101
108
|
self.client,
|
|
@@ -108,13 +115,13 @@ class OpenAIModel(Model):
|
|
|
108
115
|
return f'openai:{self.model_name}'
|
|
109
116
|
|
|
110
117
|
@staticmethod
|
|
111
|
-
def _map_tool_definition(f:
|
|
118
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
112
119
|
return {
|
|
113
120
|
'type': 'function',
|
|
114
121
|
'function': {
|
|
115
122
|
'name': f.name,
|
|
116
123
|
'description': f.description,
|
|
117
|
-
'parameters': f.
|
|
124
|
+
'parameters': f.parameters_json_schema,
|
|
118
125
|
},
|
|
119
126
|
}
|
|
120
127
|
|
|
@@ -124,7 +131,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
124
131
|
"""Implementation of `AgentModel` for OpenAI models."""
|
|
125
132
|
|
|
126
133
|
client: AsyncOpenAI
|
|
127
|
-
model_name:
|
|
134
|
+
model_name: OpenAIModelName
|
|
128
135
|
allow_text_result: bool
|
|
129
136
|
tools: list[chat.ChatCompletionToolParam]
|
|
130
137
|
|
|
@@ -188,33 +195,31 @@ class OpenAIAgentModel(AgentModel):
|
|
|
188
195
|
@staticmethod
|
|
189
196
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
190
197
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
|
|
196
|
-
delta = first_chunk.choices[0].delta
|
|
197
|
-
start_cost = _map_cost(first_chunk)
|
|
198
|
-
|
|
199
|
-
# the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
|
|
200
|
-
while delta.tool_calls is None and delta.content is None:
|
|
198
|
+
timestamp: datetime | None = None
|
|
199
|
+
start_cost = Cost()
|
|
200
|
+
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
201
|
+
while True:
|
|
201
202
|
try:
|
|
202
|
-
|
|
203
|
+
chunk = await response.__anext__()
|
|
203
204
|
except StopAsyncIteration as e:
|
|
204
205
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
205
|
-
delta = next_chunk.choices[0].delta
|
|
206
|
-
start_cost += _map_cost(next_chunk)
|
|
207
206
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
207
|
+
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
208
|
+
start_cost += _map_cost(chunk)
|
|
209
|
+
|
|
210
|
+
if chunk.choices:
|
|
211
|
+
delta = chunk.choices[0].delta
|
|
212
|
+
|
|
213
|
+
if delta.content is not None:
|
|
214
|
+
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
|
|
215
|
+
elif delta.tool_calls is not None:
|
|
216
|
+
return OpenAIStreamStructuredResponse(
|
|
217
|
+
response,
|
|
218
|
+
{c.index: c for c in delta.tool_calls},
|
|
219
|
+
timestamp,
|
|
220
|
+
start_cost,
|
|
221
|
+
)
|
|
222
|
+
# else continue until we get either delta.content or delta.tool_calls
|
|
218
223
|
|
|
219
224
|
@staticmethod
|
|
220
225
|
def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
|