pydantic-ai-slim 0.1.7__tar.gz → 0.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/PKG-INFO +3 -3
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/_agent_graph.py +53 -3
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/agent.py +1 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/mcp.py +61 -6
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/messages.py +6 -1
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/anthropic.py +18 -8
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/mistral.py +14 -1
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/mistral.py +5 -2
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/.gitignore +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/README.md +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/_output.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/format_prompt.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/__init__.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/_json_schema.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/bedrock.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/gemini.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/openai.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pydantic_ai/usage.py +0 -0
- {pydantic_ai_slim-0.1.7 → pydantic_ai_slim-0.1.9}/pyproject.toml +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.9
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.1.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.1.9
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
47
|
Provides-Extra: evals
|
|
48
|
-
Requires-Dist: pydantic-evals==0.1.
|
|
48
|
+
Requires-Dist: pydantic-evals==0.1.9; extra == 'evals'
|
|
49
49
|
Provides-Extra: groq
|
|
50
50
|
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
51
51
|
Provides-Extra: logfire
|
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
|
+
import hashlib
|
|
5
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
6
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
7
8
|
from contextvars import ContextVar
|
|
@@ -92,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
|
92
93
|
|
|
93
94
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
94
95
|
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
96
|
+
default_retries: int
|
|
95
97
|
|
|
96
98
|
tracer: Tracer
|
|
97
99
|
|
|
@@ -546,7 +548,14 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
546
548
|
)
|
|
547
549
|
|
|
548
550
|
|
|
549
|
-
|
|
551
|
+
def multi_modal_content_identifier(identifier: str | bytes) -> str:
|
|
552
|
+
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
|
|
553
|
+
if isinstance(identifier, str):
|
|
554
|
+
identifier = identifier.encode('utf-8')
|
|
555
|
+
return hashlib.sha1(identifier).hexdigest()[:6]
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
async def process_function_tools( # noqa C901
|
|
550
559
|
tool_calls: list[_messages.ToolCallPart],
|
|
551
560
|
output_tool_name: str | None,
|
|
552
561
|
output_tool_call_id: str | None,
|
|
@@ -632,6 +641,8 @@ async def process_function_tools(
|
|
|
632
641
|
if not calls_to_run:
|
|
633
642
|
return
|
|
634
643
|
|
|
644
|
+
user_parts: list[_messages.UserPromptPart] = []
|
|
645
|
+
|
|
635
646
|
# Run all tool tasks in parallel
|
|
636
647
|
results_by_index: dict[int, _messages.ModelRequestPart] = {}
|
|
637
648
|
with ctx.deps.tracer.start_as_current_span(
|
|
@@ -645,6 +656,7 @@ async def process_function_tools(
|
|
|
645
656
|
asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer), name=call.tool_name)
|
|
646
657
|
for tool, call in calls_to_run
|
|
647
658
|
]
|
|
659
|
+
|
|
648
660
|
pending = tasks
|
|
649
661
|
while pending:
|
|
650
662
|
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
@@ -652,7 +664,43 @@ async def process_function_tools(
|
|
|
652
664
|
index = tasks.index(task)
|
|
653
665
|
result = task.result()
|
|
654
666
|
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
|
|
655
|
-
|
|
667
|
+
|
|
668
|
+
if isinstance(result, _messages.RetryPromptPart):
|
|
669
|
+
results_by_index[index] = result
|
|
670
|
+
elif isinstance(result, _messages.ToolReturnPart):
|
|
671
|
+
contents: list[Any]
|
|
672
|
+
single_content: bool
|
|
673
|
+
if isinstance(result.content, list):
|
|
674
|
+
contents = result.content # type: ignore
|
|
675
|
+
single_content = False
|
|
676
|
+
else:
|
|
677
|
+
contents = [result.content]
|
|
678
|
+
single_content = True
|
|
679
|
+
|
|
680
|
+
processed_contents: list[Any] = []
|
|
681
|
+
for content in contents:
|
|
682
|
+
if isinstance(content, _messages.MultiModalContentTypes):
|
|
683
|
+
if isinstance(content, _messages.BinaryContent):
|
|
684
|
+
identifier = multi_modal_content_identifier(content.data)
|
|
685
|
+
else:
|
|
686
|
+
identifier = multi_modal_content_identifier(content.url)
|
|
687
|
+
|
|
688
|
+
user_parts.append(
|
|
689
|
+
_messages.UserPromptPart(
|
|
690
|
+
content=[f'This is file {identifier}:', content],
|
|
691
|
+
timestamp=result.timestamp,
|
|
692
|
+
part_kind='user-prompt',
|
|
693
|
+
)
|
|
694
|
+
)
|
|
695
|
+
processed_contents.append(f'See file {identifier}')
|
|
696
|
+
else:
|
|
697
|
+
processed_contents.append(content)
|
|
698
|
+
|
|
699
|
+
if single_content:
|
|
700
|
+
result.content = processed_contents[0]
|
|
701
|
+
else:
|
|
702
|
+
result.content = processed_contents
|
|
703
|
+
|
|
656
704
|
results_by_index[index] = result
|
|
657
705
|
else:
|
|
658
706
|
assert_never(result)
|
|
@@ -662,6 +710,8 @@ async def process_function_tools(
|
|
|
662
710
|
for k in sorted(results_by_index):
|
|
663
711
|
output_parts.append(results_by_index[k])
|
|
664
712
|
|
|
713
|
+
output_parts.extend(user_parts)
|
|
714
|
+
|
|
665
715
|
|
|
666
716
|
async def _tool_from_mcp_server(
|
|
667
717
|
tool_name: str,
|
|
@@ -688,7 +738,7 @@ async def _tool_from_mcp_server(
|
|
|
688
738
|
for server in ctx.deps.mcp_servers:
|
|
689
739
|
tools = await server.list_tools()
|
|
690
740
|
if tool_name in {tool.name for tool in tools}:
|
|
691
|
-
return Tool(name=tool_name, function=run_tool, takes_ctx=True)
|
|
741
|
+
return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries)
|
|
692
742
|
return None
|
|
693
743
|
|
|
694
744
|
|
|
@@ -658,6 +658,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
658
658
|
output_validators=output_validators,
|
|
659
659
|
function_tools=self._function_tools,
|
|
660
660
|
mcp_servers=self._mcp_servers,
|
|
661
|
+
default_retries=self._default_retries,
|
|
661
662
|
tracer=tracer,
|
|
662
663
|
get_instructions=get_instructions,
|
|
663
664
|
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
3
5
|
from abc import ABC, abstractmethod
|
|
4
6
|
from collections.abc import AsyncIterator, Sequence
|
|
5
7
|
from contextlib import AsyncExitStack, asynccontextmanager
|
|
@@ -9,16 +11,25 @@ from types import TracebackType
|
|
|
9
11
|
from typing import Any
|
|
10
12
|
|
|
11
13
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
12
|
-
from mcp.types import
|
|
13
|
-
|
|
14
|
-
|
|
14
|
+
from mcp.types import (
|
|
15
|
+
BlobResourceContents,
|
|
16
|
+
EmbeddedResource,
|
|
17
|
+
ImageContent,
|
|
18
|
+
JSONRPCMessage,
|
|
19
|
+
LoggingLevel,
|
|
20
|
+
TextContent,
|
|
21
|
+
TextResourceContents,
|
|
22
|
+
)
|
|
23
|
+
from typing_extensions import Self, assert_never
|
|
24
|
+
|
|
25
|
+
from pydantic_ai.exceptions import ModelRetry
|
|
26
|
+
from pydantic_ai.messages import BinaryContent
|
|
15
27
|
from pydantic_ai.tools import ToolDefinition
|
|
16
28
|
|
|
17
29
|
try:
|
|
18
30
|
from mcp.client.session import ClientSession
|
|
19
31
|
from mcp.client.sse import sse_client
|
|
20
32
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
21
|
-
from mcp.types import CallToolResult
|
|
22
33
|
except ImportError as _import_error:
|
|
23
34
|
raise ImportError(
|
|
24
35
|
'Please install the `mcp` package to use the MCP server, '
|
|
@@ -74,7 +85,9 @@ class MCPServer(ABC):
|
|
|
74
85
|
for tool in tools.tools
|
|
75
86
|
]
|
|
76
87
|
|
|
77
|
-
async def call_tool(
|
|
88
|
+
async def call_tool(
|
|
89
|
+
self, tool_name: str, arguments: dict[str, Any]
|
|
90
|
+
) -> str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]:
|
|
78
91
|
"""Call a tool on the server.
|
|
79
92
|
|
|
80
93
|
Args:
|
|
@@ -83,8 +96,21 @@ class MCPServer(ABC):
|
|
|
83
96
|
|
|
84
97
|
Returns:
|
|
85
98
|
The result of the tool call.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ModelRetry: If the tool call fails.
|
|
86
102
|
"""
|
|
87
|
-
|
|
103
|
+
result = await self._client.call_tool(tool_name, arguments)
|
|
104
|
+
|
|
105
|
+
content = [self._map_tool_result_part(part) for part in result.content]
|
|
106
|
+
|
|
107
|
+
if result.isError:
|
|
108
|
+
text = '\n'.join(str(part) for part in content)
|
|
109
|
+
raise ModelRetry(text)
|
|
110
|
+
|
|
111
|
+
if len(content) == 1:
|
|
112
|
+
return content[0]
|
|
113
|
+
return content
|
|
88
114
|
|
|
89
115
|
async def __aenter__(self) -> Self:
|
|
90
116
|
self._exit_stack = AsyncExitStack()
|
|
@@ -105,6 +131,35 @@ class MCPServer(ABC):
|
|
|
105
131
|
await self._exit_stack.aclose()
|
|
106
132
|
self.is_running = False
|
|
107
133
|
|
|
134
|
+
def _map_tool_result_part(
|
|
135
|
+
self, part: TextContent | ImageContent | EmbeddedResource
|
|
136
|
+
) -> str | BinaryContent | dict[str, Any] | list[Any]:
|
|
137
|
+
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
|
|
138
|
+
|
|
139
|
+
if isinstance(part, TextContent):
|
|
140
|
+
text = part.text
|
|
141
|
+
if text.startswith(('[', '{')):
|
|
142
|
+
try:
|
|
143
|
+
return json.loads(text)
|
|
144
|
+
except ValueError:
|
|
145
|
+
pass
|
|
146
|
+
return text
|
|
147
|
+
elif isinstance(part, ImageContent):
|
|
148
|
+
return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
|
|
149
|
+
elif isinstance(part, EmbeddedResource):
|
|
150
|
+
resource = part.resource
|
|
151
|
+
if isinstance(resource, TextResourceContents):
|
|
152
|
+
return resource.text
|
|
153
|
+
elif isinstance(resource, BlobResourceContents):
|
|
154
|
+
return BinaryContent(
|
|
155
|
+
data=base64.b64decode(resource.blob),
|
|
156
|
+
media_type=resource.mimeType or 'application/octet-stream',
|
|
157
|
+
)
|
|
158
|
+
else:
|
|
159
|
+
assert_never(resource)
|
|
160
|
+
else:
|
|
161
|
+
assert_never(part)
|
|
162
|
+
|
|
108
163
|
|
|
109
164
|
@dataclass
|
|
110
165
|
class MCPServerStdio(MCPServer):
|
|
@@ -253,6 +253,9 @@ class BinaryContent:
|
|
|
253
253
|
|
|
254
254
|
UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent'
|
|
255
255
|
|
|
256
|
+
# Ideally this would be a Union of types, but Python 3.9 requires it to be a string, and strings don't work with `isinstance``.
|
|
257
|
+
MultiModalContentTypes = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)
|
|
258
|
+
|
|
256
259
|
|
|
257
260
|
def _document_format(media_type: str) -> DocumentFormat:
|
|
258
261
|
if media_type == 'application/pdf':
|
|
@@ -829,4 +832,6 @@ class FunctionToolResultEvent:
|
|
|
829
832
|
"""Event type identifier, used as a discriminator."""
|
|
830
833
|
|
|
831
834
|
|
|
832
|
-
HandleResponseEvent = Annotated[
|
|
835
|
+
HandleResponseEvent = Annotated[
|
|
836
|
+
Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('event_kind')
|
|
837
|
+
]
|
|
@@ -109,10 +109,6 @@ class AnthropicModel(Model):
|
|
|
109
109
|
Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
|
|
110
110
|
|
|
111
111
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
112
|
-
|
|
113
|
-
!!! note
|
|
114
|
-
The `AnthropicModel` class does not yet support streaming responses.
|
|
115
|
-
We anticipate adding support for streaming responses in a near-term future release.
|
|
116
112
|
"""
|
|
117
113
|
|
|
118
114
|
client: AsyncAnthropic = field(repr=False)
|
|
@@ -409,13 +405,27 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
|
|
|
409
405
|
if response_usage is None:
|
|
410
406
|
return usage.Usage()
|
|
411
407
|
|
|
412
|
-
|
|
408
|
+
# Store all integer-typed usage values in the details dict
|
|
409
|
+
response_usage_dict = response_usage.model_dump()
|
|
410
|
+
details: dict[str, int] = {}
|
|
411
|
+
for key, value in response_usage_dict.items():
|
|
412
|
+
if isinstance(value, int):
|
|
413
|
+
details[key] = value
|
|
414
|
+
|
|
415
|
+
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence the getattr call
|
|
416
|
+
# Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
|
|
417
|
+
# This approach maintains request_tokens as the count of all input tokens, with cached counts as details
|
|
418
|
+
request_tokens = (
|
|
419
|
+
getattr(response_usage, 'input_tokens', 0)
|
|
420
|
+
+ (getattr(response_usage, 'cache_creation_input_tokens', 0) or 0) # These can be missing, None, or int
|
|
421
|
+
+ (getattr(response_usage, 'cache_read_input_tokens', 0) or 0)
|
|
422
|
+
)
|
|
413
423
|
|
|
414
424
|
return usage.Usage(
|
|
415
|
-
|
|
416
|
-
request_tokens=request_tokens,
|
|
425
|
+
request_tokens=request_tokens or None,
|
|
417
426
|
response_tokens=response_usage.output_tokens,
|
|
418
|
-
total_tokens=
|
|
427
|
+
total_tokens=request_tokens + response_usage.output_tokens,
|
|
428
|
+
details=details or None,
|
|
419
429
|
)
|
|
420
430
|
|
|
421
431
|
|
|
@@ -483,7 +483,20 @@ class MistralModel(Model):
|
|
|
483
483
|
assert_never(message)
|
|
484
484
|
if instructions := self._get_instructions(messages):
|
|
485
485
|
mistral_messages.insert(0, MistralSystemMessage(content=instructions))
|
|
486
|
-
|
|
486
|
+
|
|
487
|
+
# Post-process messages to insert fake assistant message after tool message if followed by user message
|
|
488
|
+
# to work around `Unexpected role 'user' after role 'tool'` error.
|
|
489
|
+
processed_messages: list[MistralMessages] = []
|
|
490
|
+
for i, current_message in enumerate(mistral_messages):
|
|
491
|
+
processed_messages.append(current_message)
|
|
492
|
+
|
|
493
|
+
if isinstance(current_message, MistralToolMessage) and i + 1 < len(mistral_messages):
|
|
494
|
+
next_message = mistral_messages[i + 1]
|
|
495
|
+
if isinstance(next_message, MistralUserMessage):
|
|
496
|
+
# Insert a dummy assistant message
|
|
497
|
+
processed_messages.append(MistralAssistantMessage(content=[MistralTextChunk(text='OK')]))
|
|
498
|
+
|
|
499
|
+
return processed_messages
|
|
487
500
|
|
|
488
501
|
def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage:
|
|
489
502
|
content: str | list[MistralContentChunk]
|
|
@@ -44,6 +44,7 @@ class MistralProvider(Provider[Mistral]):
|
|
|
44
44
|
*,
|
|
45
45
|
api_key: str | None = None,
|
|
46
46
|
mistral_client: Mistral | None = None,
|
|
47
|
+
base_url: str | None = None,
|
|
47
48
|
http_client: AsyncHTTPClient | None = None,
|
|
48
49
|
) -> None:
|
|
49
50
|
"""Create a new Mistral provider.
|
|
@@ -52,11 +53,13 @@ class MistralProvider(Provider[Mistral]):
|
|
|
52
53
|
api_key: The API key to use for authentication, if not provided, the `MISTRAL_API_KEY` environment variable
|
|
53
54
|
will be used if available.
|
|
54
55
|
mistral_client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
56
|
+
base_url: The base url for the Mistral requests.
|
|
55
57
|
http_client: An existing async client to use for making HTTP requests.
|
|
56
58
|
"""
|
|
57
59
|
if mistral_client is not None:
|
|
58
60
|
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
|
|
59
61
|
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
|
|
62
|
+
assert base_url is None, 'Cannot provide both `mistral_client` and `base_url`'
|
|
60
63
|
self._client = mistral_client
|
|
61
64
|
else:
|
|
62
65
|
api_key = api_key or os.environ.get('MISTRAL_API_KEY')
|
|
@@ -67,7 +70,7 @@ class MistralProvider(Provider[Mistral]):
|
|
|
67
70
|
'to use the Mistral provider.'
|
|
68
71
|
)
|
|
69
72
|
elif http_client is not None:
|
|
70
|
-
self._client = Mistral(api_key=api_key, async_client=http_client)
|
|
73
|
+
self._client = Mistral(api_key=api_key, async_client=http_client, server_url=base_url)
|
|
71
74
|
else:
|
|
72
75
|
http_client = cached_async_http_client(provider='mistral')
|
|
73
|
-
self._client = Mistral(api_key=api_key, async_client=http_client)
|
|
76
|
+
self._client = Mistral(api_key=api_key, async_client=http_client, server_url=base_url)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|