pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/messages.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- import json
4
3
  from dataclasses import dataclass, field
5
4
  from datetime import datetime
6
5
  from typing import Annotated, Any, Literal, Union
@@ -74,6 +73,9 @@ class ToolReturn:
74
73
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
75
74
 
76
75
 
76
+ ErrorDetailsTa = _pydantic.LazyTypeAdapter(list[pydantic_core.ErrorDetails])
77
+
78
+
77
79
  @dataclass
78
80
  class RetryPrompt:
79
81
  """A message back to a model asking it to try again.
@@ -109,7 +111,8 @@ class RetryPrompt:
109
111
  if isinstance(self.content, str):
110
112
  description = self.content
111
113
  else:
112
- description = f'{len(self.content)} validation errors: {json.dumps(self.content, indent=2)}'
114
+ json_errors = ErrorDetailsTa.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
115
+ description = f'{len(self.content)} validation errors: {json_errors.decode()}'
113
116
  return f'{description}\n\nFix the errors and try again.'
114
117
 
115
118
 
@@ -7,11 +7,11 @@ specific LLM being used.
7
7
  from __future__ import annotations as _annotations
8
8
 
9
9
  from abc import ABC, abstractmethod
10
- from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
10
+ from collections.abc import AsyncIterator, Iterable, Iterator
11
11
  from contextlib import asynccontextmanager, contextmanager
12
12
  from datetime import datetime
13
13
  from functools import cache
14
- from typing import TYPE_CHECKING, Literal, Protocol, Union
14
+ from typing import TYPE_CHECKING, Literal, Union
15
15
 
16
16
  import httpx
17
17
 
@@ -19,8 +19,8 @@ from ..exceptions import UserError
19
19
  from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
20
20
 
21
21
  if TYPE_CHECKING:
22
- from .._utils import ObjectJsonSchema
23
22
  from ..result import Cost
23
+ from ..tools import ToolDefinition
24
24
 
25
25
 
26
26
  KnownModelName = Literal[
@@ -49,6 +49,23 @@ KnownModelName = Literal[
49
49
  'gemini-1.5-pro',
50
50
  'vertexai:gemini-1.5-flash',
51
51
  'vertexai:gemini-1.5-pro',
52
+ 'ollama:codellama',
53
+ 'ollama:gemma',
54
+ 'ollama:gemma2',
55
+ 'ollama:llama3',
56
+ 'ollama:llama3.1',
57
+ 'ollama:llama3.2',
58
+ 'ollama:llama3.2-vision',
59
+ 'ollama:llama3.3',
60
+ 'ollama:mistral',
61
+ 'ollama:mistral-nemo',
62
+ 'ollama:mixtral',
63
+ 'ollama:phi3',
64
+ 'ollama:qwq',
65
+ 'ollama:qwen',
66
+ 'ollama:qwen2',
67
+ 'ollama:qwen2.5',
68
+ 'ollama:starcoder2',
52
69
  'test',
53
70
  ]
54
71
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -63,11 +80,12 @@ class Model(ABC):
63
80
  @abstractmethod
64
81
  async def agent_model(
65
82
  self,
66
- function_tools: Mapping[str, AbstractToolDefinition],
83
+ *,
84
+ function_tools: list[ToolDefinition],
67
85
  allow_text_result: bool,
68
- result_tools: Sequence[AbstractToolDefinition] | None,
86
+ result_tools: list[ToolDefinition],
69
87
  ) -> AgentModel:
70
- """Create an agent model.
88
+ """Create an agent model, this is called for each step of an agent run.
71
89
 
72
90
  This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
73
91
 
@@ -87,7 +105,7 @@ class Model(ABC):
87
105
 
88
106
 
89
107
  class AgentModel(ABC):
90
- """Model configured for a specific agent."""
108
+ """Model configured for each step of an Agent run."""
91
109
 
92
110
  @abstractmethod
93
111
  async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
@@ -238,7 +256,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
238
256
  elif model.startswith('openai:'):
239
257
  from .openai import OpenAIModel
240
258
 
241
- return OpenAIModel(model[7:]) # pyright: ignore[reportArgumentType]
259
+ return OpenAIModel(model[7:])
242
260
  elif model.startswith('gemini'):
243
261
  from .gemini import GeminiModel
244
262
 
@@ -252,39 +270,14 @@ def infer_model(model: Model | KnownModelName) -> Model:
252
270
  from .vertexai import VertexAIModel
253
271
 
254
272
  return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
273
+ elif model.startswith('ollama:'):
274
+ from .ollama import OllamaModel
275
+
276
+ return OllamaModel(model[7:])
255
277
  else:
256
278
  raise UserError(f'Unknown model: {model}')
257
279
 
258
280
 
259
- class AbstractToolDefinition(Protocol):
260
- """Abstract definition of a function/tool.
261
-
262
- This is used for both function tools and result tools.
263
- """
264
-
265
- @property
266
- def name(self) -> str:
267
- """The name of the tool."""
268
- ...
269
-
270
- @property
271
- def description(self) -> str:
272
- """The description of the tool."""
273
- ...
274
-
275
- @property
276
- def json_schema(self) -> ObjectJsonSchema:
277
- """The JSON schema for the tool's arguments."""
278
- ...
279
-
280
- @property
281
- def outer_typed_dict_key(self) -> str | None:
282
- """The key in the outer [TypedDict] that wraps a result tool.
283
-
284
- This will only be set for result tools which don't have an `object` JSON schema.
285
- """
286
-
287
-
288
281
  @cache
289
282
  def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
290
283
  """Cached HTTPX async client so multiple agents and calls can share the same client.
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  import re
5
- from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence
5
+ from collections.abc import AsyncIterator, Awaitable, Iterable
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime
@@ -14,14 +14,8 @@ from typing_extensions import TypeAlias, assert_never, overload
14
14
 
15
15
  from .. import _utils, result
16
16
  from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
17
- from . import (
18
- AbstractToolDefinition,
19
- AgentModel,
20
- EitherStreamedResponse,
21
- Model,
22
- StreamStructuredResponse,
23
- StreamTextResponse,
24
- )
17
+ from ..tools import ToolDefinition
18
+ from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
25
19
 
26
20
 
27
21
  @dataclass(init=False)
@@ -59,11 +53,11 @@ class FunctionModel(Model):
59
53
 
60
54
  async def agent_model(
61
55
  self,
62
- function_tools: Mapping[str, AbstractToolDefinition],
56
+ *,
57
+ function_tools: list[ToolDefinition],
63
58
  allow_text_result: bool,
64
- result_tools: Sequence[AbstractToolDefinition] | None,
59
+ result_tools: list[ToolDefinition],
65
60
  ) -> AgentModel:
66
- result_tools = list(result_tools) if result_tools is not None else None
67
61
  return FunctionAgentModel(
68
62
  self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
69
63
  )
@@ -84,7 +78,7 @@ class AgentInfo:
84
78
  This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel].
85
79
  """
86
80
 
87
- function_tools: Mapping[str, AbstractToolDefinition]
81
+ function_tools: list[ToolDefinition]
88
82
  """The function tools available on this agent.
89
83
 
90
84
  These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and
@@ -92,7 +86,7 @@ class AgentInfo:
92
86
  """
93
87
  allow_text_result: bool
94
88
  """Whether a plain text result is allowed."""
95
- result_tools: list[AbstractToolDefinition] | None
89
+ result_tools: list[ToolDefinition]
96
90
  """The tools that can called as the final result of the run."""
97
91
 
98
92
 
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
4
  import re
5
- from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
5
+ from collections.abc import AsyncIterator, Iterable
6
6
  from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
@@ -25,8 +25,8 @@ from ..messages import (
25
25
  ToolCall,
26
26
  ToolReturn,
27
27
  )
28
+ from ..tools import ToolDefinition
28
29
  from . import (
29
- AbstractToolDefinition,
30
30
  AgentModel,
31
31
  EitherStreamedResponse,
32
32
  Model,
@@ -90,9 +90,10 @@ class GeminiModel(Model):
90
90
 
91
91
  async def agent_model(
92
92
  self,
93
- function_tools: Mapping[str, AbstractToolDefinition],
93
+ *,
94
+ function_tools: list[ToolDefinition],
94
95
  allow_text_result: bool,
95
- result_tools: Sequence[AbstractToolDefinition] | None,
96
+ result_tools: list[ToolDefinition],
96
97
  ) -> GeminiAgentModel:
97
98
  return GeminiAgentModel(
98
99
  http_client=self.http_client,
@@ -142,13 +143,13 @@ class GeminiAgentModel(AgentModel):
142
143
  model_name: GeminiModelName,
143
144
  auth: AuthProtocol,
144
145
  url: str,
145
- function_tools: Mapping[str, AbstractToolDefinition],
146
+ function_tools: list[ToolDefinition],
146
147
  allow_text_result: bool,
147
- result_tools: Sequence[AbstractToolDefinition] | None,
148
+ result_tools: list[ToolDefinition],
148
149
  ):
149
150
  check_allow_model_requests()
150
- tools = [_function_from_abstract_tool(t) for t in function_tools.values()]
151
- if result_tools is not None:
151
+ tools = [_function_from_abstract_tool(t) for t in function_tools]
152
+ if result_tools:
152
153
  tools += [_function_from_abstract_tool(t) for t in result_tools]
153
154
 
154
155
  if allow_text_result:
@@ -504,8 +505,8 @@ class _GeminiFunction(TypedDict):
504
505
  """
505
506
 
506
507
 
507
- def _function_from_abstract_tool(tool: AbstractToolDefinition) -> _GeminiFunction:
508
- json_schema = _GeminiJsonSchema(tool.json_schema).simplify()
508
+ def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
509
+ json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
509
510
  f = _GeminiFunction(
510
511
  name=tool.name,
511
512
  description=tool.description,
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
3
+ from collections.abc import AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
@@ -21,8 +21,8 @@ from ..messages import (
21
21
  ToolReturn,
22
22
  )
23
23
  from ..result import Cost
24
+ from ..tools import ToolDefinition
24
25
  from . import (
25
- AbstractToolDefinition,
26
26
  AgentModel,
27
27
  EitherStreamedResponse,
28
28
  Model,
@@ -37,11 +37,11 @@ try:
37
37
  from groq.types import chat
38
38
  from groq.types.chat import ChatCompletion, ChatCompletionChunk
39
39
  from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
40
- except ImportError as e:
40
+ except ImportError as _import_error:
41
41
  raise ImportError(
42
42
  'Please install `groq` to use the Groq model, '
43
43
  "you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
44
- ) from e
44
+ ) from _import_error
45
45
 
46
46
  GroqModelName = Literal[
47
47
  'llama-3.1-70b-versatile',
@@ -109,13 +109,14 @@ class GroqModel(Model):
109
109
 
110
110
  async def agent_model(
111
111
  self,
112
- function_tools: Mapping[str, AbstractToolDefinition],
112
+ *,
113
+ function_tools: list[ToolDefinition],
113
114
  allow_text_result: bool,
114
- result_tools: Sequence[AbstractToolDefinition] | None,
115
+ result_tools: list[ToolDefinition],
115
116
  ) -> AgentModel:
116
117
  check_allow_model_requests()
117
- tools = [self._map_tool_definition(r) for r in function_tools.values()]
118
- if result_tools is not None:
118
+ tools = [self._map_tool_definition(r) for r in function_tools]
119
+ if result_tools:
119
120
  tools += [self._map_tool_definition(r) for r in result_tools]
120
121
  return GroqAgentModel(
121
122
  self.client,
@@ -128,13 +129,13 @@ class GroqModel(Model):
128
129
  return f'groq:{self.model_name}'
129
130
 
130
131
  @staticmethod
131
- def _map_tool_definition(f: AbstractToolDefinition) -> chat.ChatCompletionToolParam:
132
+ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
132
133
  return {
133
134
  'type': 'function',
134
135
  'function': {
135
136
  'name': f.name,
136
137
  'description': f.description,
137
- 'parameters': f.json_schema,
138
+ 'parameters': f.parameters_json_schema,
138
139
  },
139
140
  }
140
141
 
@@ -208,33 +209,29 @@ class GroqAgentModel(AgentModel):
208
209
  @staticmethod
209
210
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
210
211
  """Process a streamed response, and prepare a streaming response to return."""
211
- try:
212
- first_chunk = await response.__anext__()
213
- except StopAsyncIteration as e: # pragma: no cover
214
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
215
- timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
216
- delta = first_chunk.choices[0].delta
217
- start_cost = _map_cost(first_chunk)
218
-
219
- # the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
220
- while delta.tool_calls is None and delta.content is None:
212
+ timestamp: datetime | None = None
213
+ start_cost = Cost()
214
+ # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
215
+ while True:
221
216
  try:
222
- next_chunk = await response.__anext__()
217
+ chunk = await response.__anext__()
223
218
  except StopAsyncIteration as e:
224
219
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
225
- delta = next_chunk.choices[0].delta
226
- start_cost += _map_cost(next_chunk)
227
-
228
- if delta.content is not None:
229
- return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
230
- else:
231
- assert delta.tool_calls is not None, f'Expected delta with tool_calls, got {delta}'
232
- return GroqStreamStructuredResponse(
233
- response,
234
- {c.index: c for c in delta.tool_calls},
235
- timestamp,
236
- start_cost,
237
- )
220
+ timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
221
+ start_cost += _map_cost(chunk)
222
+
223
+ if chunk.choices:
224
+ delta = chunk.choices[0].delta
225
+
226
+ if delta.content is not None:
227
+ return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
228
+ elif delta.tool_calls is not None:
229
+ return GroqStreamStructuredResponse(
230
+ response,
231
+ {c.index: c for c in delta.tool_calls},
232
+ timestamp,
233
+ start_cost,
234
+ )
238
235
 
239
236
  @staticmethod
240
237
  def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Union
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+
8
+ from ..tools import ToolDefinition
9
+ from . import (
10
+ AgentModel,
11
+ Model,
12
+ cached_async_http_client,
13
+ )
14
+
15
+ try:
16
+ from openai import AsyncOpenAI
17
+ except ImportError as e:
18
+ raise ImportError(
19
+ 'Please install `openai` to use the OpenAI model, '
20
+ "you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
21
+ ) from e
22
+
23
+
24
+ from .openai import OpenAIModel
25
+
26
+ CommonOllamaModelNames = Literal[
27
+ 'codellama',
28
+ 'gemma',
29
+ 'gemma2',
30
+ 'llama3',
31
+ 'llama3.1',
32
+ 'llama3.2',
33
+ 'llama3.2-vision',
34
+ 'llama3.3',
35
+ 'mistral',
36
+ 'mistral-nemo',
37
+ 'mixtral',
38
+ 'phi3',
39
+ 'qwq',
40
+ 'qwen',
41
+ 'qwen2',
42
+ 'qwen2.5',
43
+ 'starcoder2',
44
+ ]
45
+ """This contains just the most common ollama models.
46
+
47
+ For a full list see [ollama.com/library](https://ollama.com/library).
48
+ """
49
+ OllamaModelName = Union[CommonOllamaModelNames, str]
50
+ """Possible ollama models.
51
+
52
+ Since Ollama supports hundreds of models, we explicitly list the most models but
53
+ allow any name in the type hints.
54
+ """
55
+
56
+
57
+ @dataclass(init=False)
58
+ class OllamaModel(Model):
59
+ """A model that implements Ollama using the OpenAI API.
60
+
61
+ Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the Ollama server.
62
+
63
+ Apart from `__init__`, all methods are private or match those of the base class.
64
+ """
65
+
66
+ model_name: OllamaModelName
67
+ openai_model: OpenAIModel
68
+
69
+ def __init__(
70
+ self,
71
+ model_name: OllamaModelName,
72
+ *,
73
+ base_url: str | None = 'http://localhost:11434/v1/',
74
+ openai_client: AsyncOpenAI | None = None,
75
+ http_client: AsyncHTTPClient | None = None,
76
+ ):
77
+ """Initialize an Ollama model.
78
+
79
+ Ollama has built-in compatability for the OpenAI chat completions API ([source](https://ollama.com/blog/openai-compatibility)), so we reuse the
80
+ [`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] here.
81
+
82
+ Args:
83
+ model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
84
+ You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
85
+ base_url: The base url for the ollama requests. The default value is the ollama default
86
+ openai_client: An existing
87
+ [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
88
+ client to use, if provided, `base_url` and `http_client` must be `None`.
89
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
90
+ """
91
+ self.model_name = model_name
92
+ if openai_client is not None:
93
+ assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
94
+ assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
95
+ self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client)
96
+ else:
97
+ # API key is not required for ollama but a value is required to create the client
98
+ http_client_ = http_client or cached_async_http_client()
99
+ oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client_)
100
+ self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
101
+
102
+ async def agent_model(
103
+ self,
104
+ *,
105
+ function_tools: list[ToolDefinition],
106
+ allow_text_result: bool,
107
+ result_tools: list[ToolDefinition],
108
+ ) -> AgentModel:
109
+ return await self.openai_model.agent_model(
110
+ function_tools=function_tools,
111
+ allow_text_result=allow_text_result,
112
+ result_tools=result_tools,
113
+ )
114
+
115
+ def name(self) -> str:
116
+ return f'ollama:{self.model_name}'
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
3
+ from collections.abc import AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
- from typing import Literal, overload
7
+ from typing import Literal, Union, overload
8
8
 
9
9
  from httpx import AsyncClient as AsyncHTTPClient
10
10
  from typing_extensions import assert_never
@@ -21,8 +21,8 @@ from ..messages import (
21
21
  ToolReturn,
22
22
  )
23
23
  from ..result import Cost
24
+ from ..tools import ToolDefinition
24
25
  from . import (
25
- AbstractToolDefinition,
26
26
  AgentModel,
27
27
  EitherStreamedResponse,
28
28
  Model,
@@ -37,11 +37,17 @@ try:
37
37
  from openai.types import ChatModel, chat
38
38
  from openai.types.chat import ChatCompletionChunk
39
39
  from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
40
- except ImportError as e:
40
+ except ImportError as _import_error:
41
41
  raise ImportError(
42
42
  'Please install `openai` to use the OpenAI model, '
43
43
  "you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
44
- ) from e
44
+ ) from _import_error
45
+
46
+ OpenAIModelName = Union[ChatModel, str]
47
+ """
48
+ Using this more broad type for the model name instead of the ChatModel definition
49
+ allows this model to be used more easily with other model types (ie, Ollama)
50
+ """
45
51
 
46
52
 
47
53
  @dataclass(init=False)
@@ -53,12 +59,12 @@ class OpenAIModel(Model):
53
59
  Apart from `__init__`, all methods are private or match those of the base class.
54
60
  """
55
61
 
56
- model_name: ChatModel
62
+ model_name: OpenAIModelName
57
63
  client: AsyncOpenAI = field(repr=False)
58
64
 
59
65
  def __init__(
60
66
  self,
61
- model_name: ChatModel,
67
+ model_name: OpenAIModelName,
62
68
  *,
63
69
  api_key: str | None = None,
64
70
  openai_client: AsyncOpenAI | None = None,
@@ -77,7 +83,7 @@ class OpenAIModel(Model):
77
83
  client to use, if provided, `api_key` and `http_client` must be `None`.
78
84
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
79
85
  """
80
- self.model_name: ChatModel = model_name
86
+ self.model_name: OpenAIModelName = model_name
81
87
  if openai_client is not None:
82
88
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
83
89
  assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
@@ -89,13 +95,14 @@ class OpenAIModel(Model):
89
95
 
90
96
  async def agent_model(
91
97
  self,
92
- function_tools: Mapping[str, AbstractToolDefinition],
98
+ *,
99
+ function_tools: list[ToolDefinition],
93
100
  allow_text_result: bool,
94
- result_tools: Sequence[AbstractToolDefinition] | None,
101
+ result_tools: list[ToolDefinition],
95
102
  ) -> AgentModel:
96
103
  check_allow_model_requests()
97
- tools = [self._map_tool_definition(r) for r in function_tools.values()]
98
- if result_tools is not None:
104
+ tools = [self._map_tool_definition(r) for r in function_tools]
105
+ if result_tools:
99
106
  tools += [self._map_tool_definition(r) for r in result_tools]
100
107
  return OpenAIAgentModel(
101
108
  self.client,
@@ -108,13 +115,13 @@ class OpenAIModel(Model):
108
115
  return f'openai:{self.model_name}'
109
116
 
110
117
  @staticmethod
111
- def _map_tool_definition(f: AbstractToolDefinition) -> chat.ChatCompletionToolParam:
118
+ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
112
119
  return {
113
120
  'type': 'function',
114
121
  'function': {
115
122
  'name': f.name,
116
123
  'description': f.description,
117
- 'parameters': f.json_schema,
124
+ 'parameters': f.parameters_json_schema,
118
125
  },
119
126
  }
120
127
 
@@ -124,7 +131,7 @@ class OpenAIAgentModel(AgentModel):
124
131
  """Implementation of `AgentModel` for OpenAI models."""
125
132
 
126
133
  client: AsyncOpenAI
127
- model_name: ChatModel
134
+ model_name: OpenAIModelName
128
135
  allow_text_result: bool
129
136
  tools: list[chat.ChatCompletionToolParam]
130
137
 
@@ -188,33 +195,31 @@ class OpenAIAgentModel(AgentModel):
188
195
  @staticmethod
189
196
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
190
197
  """Process a streamed response, and prepare a streaming response to return."""
191
- try:
192
- first_chunk = await response.__anext__()
193
- except StopAsyncIteration as e: # pragma: no cover
194
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
195
- timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
196
- delta = first_chunk.choices[0].delta
197
- start_cost = _map_cost(first_chunk)
198
-
199
- # the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
200
- while delta.tool_calls is None and delta.content is None:
198
+ timestamp: datetime | None = None
199
+ start_cost = Cost()
200
+ # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
201
+ while True:
201
202
  try:
202
- next_chunk = await response.__anext__()
203
+ chunk = await response.__anext__()
203
204
  except StopAsyncIteration as e:
204
205
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
205
- delta = next_chunk.choices[0].delta
206
- start_cost += _map_cost(next_chunk)
207
206
 
208
- if delta.content is not None:
209
- return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
210
- else:
211
- assert delta.tool_calls is not None, f'Expected delta with tool_calls, got {delta}'
212
- return OpenAIStreamStructuredResponse(
213
- response,
214
- {c.index: c for c in delta.tool_calls},
215
- timestamp,
216
- start_cost,
217
- )
207
+ timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
208
+ start_cost += _map_cost(chunk)
209
+
210
+ if chunk.choices:
211
+ delta = chunk.choices[0].delta
212
+
213
+ if delta.content is not None:
214
+ return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
215
+ elif delta.tool_calls is not None:
216
+ return OpenAIStreamStructuredResponse(
217
+ response,
218
+ {c.index: c for c in delta.tool_calls},
219
+ timestamp,
220
+ start_cost,
221
+ )
222
+ # else continue until we get either delta.content or delta.tool_calls
218
223
 
219
224
  @staticmethod
220
225
  def _map_message(message: Message) -> chat.ChatCompletionMessageParam: