pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +6 -4
- pydantic_ai/_result.py +18 -22
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +11 -6
- pydantic_ai/agent.py +146 -67
- pydantic_ai/messages.py +5 -2
- pydantic_ai/models/__init__.py +30 -37
- pydantic_ai/models/function.py +8 -14
- pydantic_ai/models/gemini.py +11 -10
- pydantic_ai/models/groq.py +31 -34
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +43 -38
- pydantic_ai/models/test.py +70 -49
- pydantic_ai/models/vertexai.py +7 -6
- pydantic_ai/tools.py +119 -34
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.12.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.12.dist-info}/WHEEL +0 -0
pydantic_ai/messages.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
from dataclasses import dataclass, field
|
|
5
4
|
from datetime import datetime
|
|
6
5
|
from typing import Annotated, Any, Literal, Union
|
|
@@ -74,6 +73,9 @@ class ToolReturn:
|
|
|
74
73
|
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
75
74
|
|
|
76
75
|
|
|
76
|
+
ErrorDetailsTa = _pydantic.LazyTypeAdapter(list[pydantic_core.ErrorDetails])
|
|
77
|
+
|
|
78
|
+
|
|
77
79
|
@dataclass
|
|
78
80
|
class RetryPrompt:
|
|
79
81
|
"""A message back to a model asking it to try again.
|
|
@@ -109,7 +111,8 @@ class RetryPrompt:
|
|
|
109
111
|
if isinstance(self.content, str):
|
|
110
112
|
description = self.content
|
|
111
113
|
else:
|
|
112
|
-
|
|
114
|
+
json_errors = ErrorDetailsTa.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
|
|
115
|
+
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
|
|
113
116
|
return f'{description}\n\nFix the errors and try again.'
|
|
114
117
|
|
|
115
118
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -7,11 +7,11 @@ specific LLM being used.
|
|
|
7
7
|
from __future__ import annotations as _annotations
|
|
8
8
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
|
-
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
10
|
+
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
11
11
|
from contextlib import asynccontextmanager, contextmanager
|
|
12
12
|
from datetime import datetime
|
|
13
13
|
from functools import cache
|
|
14
|
-
from typing import TYPE_CHECKING, Literal,
|
|
14
|
+
from typing import TYPE_CHECKING, Literal, Union
|
|
15
15
|
|
|
16
16
|
import httpx
|
|
17
17
|
|
|
@@ -19,8 +19,8 @@ from ..exceptions import UserError
|
|
|
19
19
|
from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
|
-
from .._utils import ObjectJsonSchema
|
|
23
22
|
from ..result import Cost
|
|
23
|
+
from ..tools import ToolDefinition
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
KnownModelName = Literal[
|
|
@@ -49,6 +49,23 @@ KnownModelName = Literal[
|
|
|
49
49
|
'gemini-1.5-pro',
|
|
50
50
|
'vertexai:gemini-1.5-flash',
|
|
51
51
|
'vertexai:gemini-1.5-pro',
|
|
52
|
+
'ollama:codellama',
|
|
53
|
+
'ollama:gemma',
|
|
54
|
+
'ollama:gemma2',
|
|
55
|
+
'ollama:llama3',
|
|
56
|
+
'ollama:llama3.1',
|
|
57
|
+
'ollama:llama3.2',
|
|
58
|
+
'ollama:llama3.2-vision',
|
|
59
|
+
'ollama:llama3.3',
|
|
60
|
+
'ollama:mistral',
|
|
61
|
+
'ollama:mistral-nemo',
|
|
62
|
+
'ollama:mixtral',
|
|
63
|
+
'ollama:phi3',
|
|
64
|
+
'ollama:qwq',
|
|
65
|
+
'ollama:qwen',
|
|
66
|
+
'ollama:qwen2',
|
|
67
|
+
'ollama:qwen2.5',
|
|
68
|
+
'ollama:starcoder2',
|
|
52
69
|
'test',
|
|
53
70
|
]
|
|
54
71
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -63,11 +80,12 @@ class Model(ABC):
|
|
|
63
80
|
@abstractmethod
|
|
64
81
|
async def agent_model(
|
|
65
82
|
self,
|
|
66
|
-
|
|
83
|
+
*,
|
|
84
|
+
function_tools: list[ToolDefinition],
|
|
67
85
|
allow_text_result: bool,
|
|
68
|
-
result_tools:
|
|
86
|
+
result_tools: list[ToolDefinition],
|
|
69
87
|
) -> AgentModel:
|
|
70
|
-
"""Create an agent model.
|
|
88
|
+
"""Create an agent model, this is called for each step of an agent run.
|
|
71
89
|
|
|
72
90
|
This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
|
|
73
91
|
|
|
@@ -87,7 +105,7 @@ class Model(ABC):
|
|
|
87
105
|
|
|
88
106
|
|
|
89
107
|
class AgentModel(ABC):
|
|
90
|
-
"""Model configured for
|
|
108
|
+
"""Model configured for each step of an Agent run."""
|
|
91
109
|
|
|
92
110
|
@abstractmethod
|
|
93
111
|
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
|
|
@@ -238,7 +256,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
238
256
|
elif model.startswith('openai:'):
|
|
239
257
|
from .openai import OpenAIModel
|
|
240
258
|
|
|
241
|
-
return OpenAIModel(model[7:])
|
|
259
|
+
return OpenAIModel(model[7:])
|
|
242
260
|
elif model.startswith('gemini'):
|
|
243
261
|
from .gemini import GeminiModel
|
|
244
262
|
|
|
@@ -252,39 +270,14 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
252
270
|
from .vertexai import VertexAIModel
|
|
253
271
|
|
|
254
272
|
return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
|
|
273
|
+
elif model.startswith('ollama:'):
|
|
274
|
+
from .ollama import OllamaModel
|
|
275
|
+
|
|
276
|
+
return OllamaModel(model[7:])
|
|
255
277
|
else:
|
|
256
278
|
raise UserError(f'Unknown model: {model}')
|
|
257
279
|
|
|
258
280
|
|
|
259
|
-
class AbstractToolDefinition(Protocol):
|
|
260
|
-
"""Abstract definition of a function/tool.
|
|
261
|
-
|
|
262
|
-
This is used for both function tools and result tools.
|
|
263
|
-
"""
|
|
264
|
-
|
|
265
|
-
@property
|
|
266
|
-
def name(self) -> str:
|
|
267
|
-
"""The name of the tool."""
|
|
268
|
-
...
|
|
269
|
-
|
|
270
|
-
@property
|
|
271
|
-
def description(self) -> str:
|
|
272
|
-
"""The description of the tool."""
|
|
273
|
-
...
|
|
274
|
-
|
|
275
|
-
@property
|
|
276
|
-
def json_schema(self) -> ObjectJsonSchema:
|
|
277
|
-
"""The JSON schema for the tool's arguments."""
|
|
278
|
-
...
|
|
279
|
-
|
|
280
|
-
@property
|
|
281
|
-
def outer_typed_dict_key(self) -> str | None:
|
|
282
|
-
"""The key in the outer [TypedDict] that wraps a result tool.
|
|
283
|
-
|
|
284
|
-
This will only be set for result tools which don't have an `object` JSON schema.
|
|
285
|
-
"""
|
|
286
|
-
|
|
287
|
-
|
|
288
281
|
@cache
|
|
289
282
|
def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
290
283
|
"""Cached HTTPX async client so multiple agents and calls can share the same client.
|
pydantic_ai/models/function.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
@@ -14,14 +14,8 @@ from typing_extensions import TypeAlias, assert_never, overload
|
|
|
14
14
|
|
|
15
15
|
from .. import _utils, result
|
|
16
16
|
from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
|
|
17
|
-
from
|
|
18
|
-
|
|
19
|
-
AgentModel,
|
|
20
|
-
EitherStreamedResponse,
|
|
21
|
-
Model,
|
|
22
|
-
StreamStructuredResponse,
|
|
23
|
-
StreamTextResponse,
|
|
24
|
-
)
|
|
17
|
+
from ..tools import ToolDefinition
|
|
18
|
+
from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
|
|
25
19
|
|
|
26
20
|
|
|
27
21
|
@dataclass(init=False)
|
|
@@ -59,11 +53,11 @@ class FunctionModel(Model):
|
|
|
59
53
|
|
|
60
54
|
async def agent_model(
|
|
61
55
|
self,
|
|
62
|
-
|
|
56
|
+
*,
|
|
57
|
+
function_tools: list[ToolDefinition],
|
|
63
58
|
allow_text_result: bool,
|
|
64
|
-
result_tools:
|
|
59
|
+
result_tools: list[ToolDefinition],
|
|
65
60
|
) -> AgentModel:
|
|
66
|
-
result_tools = list(result_tools) if result_tools is not None else None
|
|
67
61
|
return FunctionAgentModel(
|
|
68
62
|
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
|
|
69
63
|
)
|
|
@@ -84,7 +78,7 @@ class AgentInfo:
|
|
|
84
78
|
This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel].
|
|
85
79
|
"""
|
|
86
80
|
|
|
87
|
-
function_tools:
|
|
81
|
+
function_tools: list[ToolDefinition]
|
|
88
82
|
"""The function tools available on this agent.
|
|
89
83
|
|
|
90
84
|
These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and
|
|
@@ -92,7 +86,7 @@ class AgentInfo:
|
|
|
92
86
|
"""
|
|
93
87
|
allow_text_result: bool
|
|
94
88
|
"""Whether a plain text result is allowed."""
|
|
95
|
-
result_tools: list[
|
|
89
|
+
result_tools: list[ToolDefinition]
|
|
96
90
|
"""The tools that can called as the final result of the run."""
|
|
97
91
|
|
|
98
92
|
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
@@ -25,8 +25,8 @@ from ..messages import (
|
|
|
25
25
|
ToolCall,
|
|
26
26
|
ToolReturn,
|
|
27
27
|
)
|
|
28
|
+
from ..tools import ToolDefinition
|
|
28
29
|
from . import (
|
|
29
|
-
AbstractToolDefinition,
|
|
30
30
|
AgentModel,
|
|
31
31
|
EitherStreamedResponse,
|
|
32
32
|
Model,
|
|
@@ -90,9 +90,10 @@ class GeminiModel(Model):
|
|
|
90
90
|
|
|
91
91
|
async def agent_model(
|
|
92
92
|
self,
|
|
93
|
-
|
|
93
|
+
*,
|
|
94
|
+
function_tools: list[ToolDefinition],
|
|
94
95
|
allow_text_result: bool,
|
|
95
|
-
result_tools:
|
|
96
|
+
result_tools: list[ToolDefinition],
|
|
96
97
|
) -> GeminiAgentModel:
|
|
97
98
|
return GeminiAgentModel(
|
|
98
99
|
http_client=self.http_client,
|
|
@@ -142,13 +143,13 @@ class GeminiAgentModel(AgentModel):
|
|
|
142
143
|
model_name: GeminiModelName,
|
|
143
144
|
auth: AuthProtocol,
|
|
144
145
|
url: str,
|
|
145
|
-
function_tools:
|
|
146
|
+
function_tools: list[ToolDefinition],
|
|
146
147
|
allow_text_result: bool,
|
|
147
|
-
result_tools:
|
|
148
|
+
result_tools: list[ToolDefinition],
|
|
148
149
|
):
|
|
149
150
|
check_allow_model_requests()
|
|
150
|
-
tools = [_function_from_abstract_tool(t) for t in function_tools
|
|
151
|
-
if result_tools
|
|
151
|
+
tools = [_function_from_abstract_tool(t) for t in function_tools]
|
|
152
|
+
if result_tools:
|
|
152
153
|
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
153
154
|
|
|
154
155
|
if allow_text_result:
|
|
@@ -504,8 +505,8 @@ class _GeminiFunction(TypedDict):
|
|
|
504
505
|
"""
|
|
505
506
|
|
|
506
507
|
|
|
507
|
-
def _function_from_abstract_tool(tool:
|
|
508
|
-
json_schema = _GeminiJsonSchema(tool.
|
|
508
|
+
def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
|
|
509
|
+
json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
|
|
509
510
|
f = _GeminiFunction(
|
|
510
511
|
name=tool.name,
|
|
511
512
|
description=tool.description,
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
@@ -21,8 +21,8 @@ from ..messages import (
|
|
|
21
21
|
ToolReturn,
|
|
22
22
|
)
|
|
23
23
|
from ..result import Cost
|
|
24
|
+
from ..tools import ToolDefinition
|
|
24
25
|
from . import (
|
|
25
|
-
AbstractToolDefinition,
|
|
26
26
|
AgentModel,
|
|
27
27
|
EitherStreamedResponse,
|
|
28
28
|
Model,
|
|
@@ -37,11 +37,11 @@ try:
|
|
|
37
37
|
from groq.types import chat
|
|
38
38
|
from groq.types.chat import ChatCompletion, ChatCompletionChunk
|
|
39
39
|
from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
40
|
-
except ImportError as
|
|
40
|
+
except ImportError as _import_error:
|
|
41
41
|
raise ImportError(
|
|
42
42
|
'Please install `groq` to use the Groq model, '
|
|
43
43
|
"you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
|
|
44
|
-
) from
|
|
44
|
+
) from _import_error
|
|
45
45
|
|
|
46
46
|
GroqModelName = Literal[
|
|
47
47
|
'llama-3.1-70b-versatile',
|
|
@@ -109,13 +109,14 @@ class GroqModel(Model):
|
|
|
109
109
|
|
|
110
110
|
async def agent_model(
|
|
111
111
|
self,
|
|
112
|
-
|
|
112
|
+
*,
|
|
113
|
+
function_tools: list[ToolDefinition],
|
|
113
114
|
allow_text_result: bool,
|
|
114
|
-
result_tools:
|
|
115
|
+
result_tools: list[ToolDefinition],
|
|
115
116
|
) -> AgentModel:
|
|
116
117
|
check_allow_model_requests()
|
|
117
|
-
tools = [self._map_tool_definition(r) for r in function_tools
|
|
118
|
-
if result_tools
|
|
118
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
119
|
+
if result_tools:
|
|
119
120
|
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
120
121
|
return GroqAgentModel(
|
|
121
122
|
self.client,
|
|
@@ -128,13 +129,13 @@ class GroqModel(Model):
|
|
|
128
129
|
return f'groq:{self.model_name}'
|
|
129
130
|
|
|
130
131
|
@staticmethod
|
|
131
|
-
def _map_tool_definition(f:
|
|
132
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
132
133
|
return {
|
|
133
134
|
'type': 'function',
|
|
134
135
|
'function': {
|
|
135
136
|
'name': f.name,
|
|
136
137
|
'description': f.description,
|
|
137
|
-
'parameters': f.
|
|
138
|
+
'parameters': f.parameters_json_schema,
|
|
138
139
|
},
|
|
139
140
|
}
|
|
140
141
|
|
|
@@ -208,33 +209,29 @@ class GroqAgentModel(AgentModel):
|
|
|
208
209
|
@staticmethod
|
|
209
210
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
210
211
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
|
|
216
|
-
delta = first_chunk.choices[0].delta
|
|
217
|
-
start_cost = _map_cost(first_chunk)
|
|
218
|
-
|
|
219
|
-
# the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
|
|
220
|
-
while delta.tool_calls is None and delta.content is None:
|
|
212
|
+
timestamp: datetime | None = None
|
|
213
|
+
start_cost = Cost()
|
|
214
|
+
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
215
|
+
while True:
|
|
221
216
|
try:
|
|
222
|
-
|
|
217
|
+
chunk = await response.__anext__()
|
|
223
218
|
except StopAsyncIteration as e:
|
|
224
219
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
225
|
-
|
|
226
|
-
start_cost += _map_cost(
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
220
|
+
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
221
|
+
start_cost += _map_cost(chunk)
|
|
222
|
+
|
|
223
|
+
if chunk.choices:
|
|
224
|
+
delta = chunk.choices[0].delta
|
|
225
|
+
|
|
226
|
+
if delta.content is not None:
|
|
227
|
+
return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
|
|
228
|
+
elif delta.tool_calls is not None:
|
|
229
|
+
return GroqStreamStructuredResponse(
|
|
230
|
+
response,
|
|
231
|
+
{c.index: c for c in delta.tool_calls},
|
|
232
|
+
timestamp,
|
|
233
|
+
start_cost,
|
|
234
|
+
)
|
|
238
235
|
|
|
239
236
|
@staticmethod
|
|
240
237
|
def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal, Union
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
|
|
8
|
+
from ..tools import ToolDefinition
|
|
9
|
+
from . import (
|
|
10
|
+
AgentModel,
|
|
11
|
+
Model,
|
|
12
|
+
cached_async_http_client,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from openai import AsyncOpenAI
|
|
17
|
+
except ImportError as e:
|
|
18
|
+
raise ImportError(
|
|
19
|
+
'Please install `openai` to use the OpenAI model, '
|
|
20
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
|
|
21
|
+
) from e
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from .openai import OpenAIModel
|
|
25
|
+
|
|
26
|
+
CommonOllamaModelNames = Literal[
|
|
27
|
+
'codellama',
|
|
28
|
+
'gemma',
|
|
29
|
+
'gemma2',
|
|
30
|
+
'llama3',
|
|
31
|
+
'llama3.1',
|
|
32
|
+
'llama3.2',
|
|
33
|
+
'llama3.2-vision',
|
|
34
|
+
'llama3.3',
|
|
35
|
+
'mistral',
|
|
36
|
+
'mistral-nemo',
|
|
37
|
+
'mixtral',
|
|
38
|
+
'phi3',
|
|
39
|
+
'qwq',
|
|
40
|
+
'qwen',
|
|
41
|
+
'qwen2',
|
|
42
|
+
'qwen2.5',
|
|
43
|
+
'starcoder2',
|
|
44
|
+
]
|
|
45
|
+
"""This contains just the most common ollama models.
|
|
46
|
+
|
|
47
|
+
For a full list see [ollama.com/library](https://ollama.com/library).
|
|
48
|
+
"""
|
|
49
|
+
OllamaModelName = Union[CommonOllamaModelNames, str]
|
|
50
|
+
"""Possible ollama models.
|
|
51
|
+
|
|
52
|
+
Since Ollama supports hundreds of models, we explicitly list the most models but
|
|
53
|
+
allow any name in the type hints.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(init=False)
|
|
58
|
+
class OllamaModel(Model):
|
|
59
|
+
"""A model that implements Ollama using the OpenAI API.
|
|
60
|
+
|
|
61
|
+
Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the Ollama server.
|
|
62
|
+
|
|
63
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
model_name: OllamaModelName
|
|
67
|
+
openai_model: OpenAIModel
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
model_name: OllamaModelName,
|
|
72
|
+
*,
|
|
73
|
+
base_url: str | None = 'http://localhost:11434/v1/',
|
|
74
|
+
openai_client: AsyncOpenAI | None = None,
|
|
75
|
+
http_client: AsyncHTTPClient | None = None,
|
|
76
|
+
):
|
|
77
|
+
"""Initialize an Ollama model.
|
|
78
|
+
|
|
79
|
+
Ollama has built-in compatability for the OpenAI chat completions API ([source](https://ollama.com/blog/openai-compatibility)), so we reuse the
|
|
80
|
+
[`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] here.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
|
|
84
|
+
You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
|
|
85
|
+
base_url: The base url for the ollama requests. The default value is the ollama default
|
|
86
|
+
openai_client: An existing
|
|
87
|
+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
88
|
+
client to use, if provided, `base_url` and `http_client` must be `None`.
|
|
89
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
90
|
+
"""
|
|
91
|
+
self.model_name = model_name
|
|
92
|
+
if openai_client is not None:
|
|
93
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
94
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
95
|
+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client)
|
|
96
|
+
else:
|
|
97
|
+
# API key is not required for ollama but a value is required to create the client
|
|
98
|
+
http_client_ = http_client or cached_async_http_client()
|
|
99
|
+
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client_)
|
|
100
|
+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
|
|
101
|
+
|
|
102
|
+
async def agent_model(
|
|
103
|
+
self,
|
|
104
|
+
*,
|
|
105
|
+
function_tools: list[ToolDefinition],
|
|
106
|
+
allow_text_result: bool,
|
|
107
|
+
result_tools: list[ToolDefinition],
|
|
108
|
+
) -> AgentModel:
|
|
109
|
+
return await self.openai_model.agent_model(
|
|
110
|
+
function_tools=function_tools,
|
|
111
|
+
allow_text_result=allow_text_result,
|
|
112
|
+
result_tools=result_tools,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def name(self) -> str:
|
|
116
|
+
return f'ollama:{self.model_name}'
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
|
-
from typing import Literal, overload
|
|
7
|
+
from typing import Literal, Union, overload
|
|
8
8
|
|
|
9
9
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
10
|
from typing_extensions import assert_never
|
|
@@ -21,8 +21,8 @@ from ..messages import (
|
|
|
21
21
|
ToolReturn,
|
|
22
22
|
)
|
|
23
23
|
from ..result import Cost
|
|
24
|
+
from ..tools import ToolDefinition
|
|
24
25
|
from . import (
|
|
25
|
-
AbstractToolDefinition,
|
|
26
26
|
AgentModel,
|
|
27
27
|
EitherStreamedResponse,
|
|
28
28
|
Model,
|
|
@@ -37,11 +37,17 @@ try:
|
|
|
37
37
|
from openai.types import ChatModel, chat
|
|
38
38
|
from openai.types.chat import ChatCompletionChunk
|
|
39
39
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
40
|
-
except ImportError as
|
|
40
|
+
except ImportError as _import_error:
|
|
41
41
|
raise ImportError(
|
|
42
42
|
'Please install `openai` to use the OpenAI model, '
|
|
43
43
|
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
|
|
44
|
-
) from
|
|
44
|
+
) from _import_error
|
|
45
|
+
|
|
46
|
+
OpenAIModelName = Union[ChatModel, str]
|
|
47
|
+
"""
|
|
48
|
+
Using this more broad type for the model name instead of the ChatModel definition
|
|
49
|
+
allows this model to be used more easily with other model types (ie, Ollama)
|
|
50
|
+
"""
|
|
45
51
|
|
|
46
52
|
|
|
47
53
|
@dataclass(init=False)
|
|
@@ -53,12 +59,12 @@ class OpenAIModel(Model):
|
|
|
53
59
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
54
60
|
"""
|
|
55
61
|
|
|
56
|
-
model_name:
|
|
62
|
+
model_name: OpenAIModelName
|
|
57
63
|
client: AsyncOpenAI = field(repr=False)
|
|
58
64
|
|
|
59
65
|
def __init__(
|
|
60
66
|
self,
|
|
61
|
-
model_name:
|
|
67
|
+
model_name: OpenAIModelName,
|
|
62
68
|
*,
|
|
63
69
|
api_key: str | None = None,
|
|
64
70
|
openai_client: AsyncOpenAI | None = None,
|
|
@@ -77,7 +83,7 @@ class OpenAIModel(Model):
|
|
|
77
83
|
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
78
84
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
79
85
|
"""
|
|
80
|
-
self.model_name:
|
|
86
|
+
self.model_name: OpenAIModelName = model_name
|
|
81
87
|
if openai_client is not None:
|
|
82
88
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
83
89
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
@@ -89,13 +95,14 @@ class OpenAIModel(Model):
|
|
|
89
95
|
|
|
90
96
|
async def agent_model(
|
|
91
97
|
self,
|
|
92
|
-
|
|
98
|
+
*,
|
|
99
|
+
function_tools: list[ToolDefinition],
|
|
93
100
|
allow_text_result: bool,
|
|
94
|
-
result_tools:
|
|
101
|
+
result_tools: list[ToolDefinition],
|
|
95
102
|
) -> AgentModel:
|
|
96
103
|
check_allow_model_requests()
|
|
97
|
-
tools = [self._map_tool_definition(r) for r in function_tools
|
|
98
|
-
if result_tools
|
|
104
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
105
|
+
if result_tools:
|
|
99
106
|
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
100
107
|
return OpenAIAgentModel(
|
|
101
108
|
self.client,
|
|
@@ -108,13 +115,13 @@ class OpenAIModel(Model):
|
|
|
108
115
|
return f'openai:{self.model_name}'
|
|
109
116
|
|
|
110
117
|
@staticmethod
|
|
111
|
-
def _map_tool_definition(f:
|
|
118
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
112
119
|
return {
|
|
113
120
|
'type': 'function',
|
|
114
121
|
'function': {
|
|
115
122
|
'name': f.name,
|
|
116
123
|
'description': f.description,
|
|
117
|
-
'parameters': f.
|
|
124
|
+
'parameters': f.parameters_json_schema,
|
|
118
125
|
},
|
|
119
126
|
}
|
|
120
127
|
|
|
@@ -124,7 +131,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
124
131
|
"""Implementation of `AgentModel` for OpenAI models."""
|
|
125
132
|
|
|
126
133
|
client: AsyncOpenAI
|
|
127
|
-
model_name:
|
|
134
|
+
model_name: OpenAIModelName
|
|
128
135
|
allow_text_result: bool
|
|
129
136
|
tools: list[chat.ChatCompletionToolParam]
|
|
130
137
|
|
|
@@ -188,33 +195,31 @@ class OpenAIAgentModel(AgentModel):
|
|
|
188
195
|
@staticmethod
|
|
189
196
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
190
197
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
|
|
196
|
-
delta = first_chunk.choices[0].delta
|
|
197
|
-
start_cost = _map_cost(first_chunk)
|
|
198
|
-
|
|
199
|
-
# the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
|
|
200
|
-
while delta.tool_calls is None and delta.content is None:
|
|
198
|
+
timestamp: datetime | None = None
|
|
199
|
+
start_cost = Cost()
|
|
200
|
+
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
201
|
+
while True:
|
|
201
202
|
try:
|
|
202
|
-
|
|
203
|
+
chunk = await response.__anext__()
|
|
203
204
|
except StopAsyncIteration as e:
|
|
204
205
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
205
|
-
delta = next_chunk.choices[0].delta
|
|
206
|
-
start_cost += _map_cost(next_chunk)
|
|
207
206
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
207
|
+
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
208
|
+
start_cost += _map_cost(chunk)
|
|
209
|
+
|
|
210
|
+
if chunk.choices:
|
|
211
|
+
delta = chunk.choices[0].delta
|
|
212
|
+
|
|
213
|
+
if delta.content is not None:
|
|
214
|
+
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
|
|
215
|
+
elif delta.tool_calls is not None:
|
|
216
|
+
return OpenAIStreamStructuredResponse(
|
|
217
|
+
response,
|
|
218
|
+
{c.index: c for c in delta.tool_calls},
|
|
219
|
+
timestamp,
|
|
220
|
+
start_cost,
|
|
221
|
+
)
|
|
222
|
+
# else continue until we get either delta.content or delta.tool_calls
|
|
218
223
|
|
|
219
224
|
@staticmethod
|
|
220
225
|
def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
|