pydantic-ai-slim 0.1.6__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/PKG-INFO +3 -3
  2. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/_agent_graph.py +25 -33
  3. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/_pydantic.py +1 -1
  4. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/_utils.py +3 -5
  5. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/agent.py +72 -19
  6. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/messages.py +3 -0
  7. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/gemini.py +1 -3
  8. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/mistral.py +14 -1
  9. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/openai.py +4 -6
  10. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/.gitignore +0 -0
  11. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/README.md +0 -0
  12. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/__init__.py +0 -0
  13. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/__main__.py +0 -0
  14. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/_cli.py +0 -0
  15. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/_griffe.py +0 -0
  16. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/_output.py +0 -0
  17. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/_parts_manager.py +0 -0
  18. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/_system_prompt.py +0 -0
  19. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/common_tools/__init__.py +0 -0
  20. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  21. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/common_tools/tavily.py +0 -0
  22. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/exceptions.py +0 -0
  23. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/format_as_xml.py +0 -0
  24. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/format_prompt.py +0 -0
  25. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/mcp.py +0 -0
  26. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/__init__.py +0 -0
  27. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/_json_schema.py +0 -0
  28. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/anthropic.py +0 -0
  29. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/bedrock.py +0 -0
  30. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/cohere.py +0 -0
  31. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/fallback.py +0 -0
  32. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/function.py +0 -0
  33. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/groq.py +0 -0
  34. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/instrumented.py +0 -0
  35. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/test.py +0 -0
  36. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/models/wrapper.py +0 -0
  37. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/__init__.py +0 -0
  38. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/anthropic.py +0 -0
  39. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/azure.py +0 -0
  40. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/bedrock.py +0 -0
  41. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/cohere.py +0 -0
  42. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/deepseek.py +0 -0
  43. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/google_gla.py +0 -0
  44. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/google_vertex.py +0 -0
  45. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/groq.py +0 -0
  46. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/mistral.py +0 -0
  47. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/providers/openai.py +0 -0
  48. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/py.typed +0 -0
  49. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/result.py +0 -0
  50. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/settings.py +0 -0
  51. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/tools.py +0 -0
  52. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pydantic_ai/usage.py +0 -0
  53. {pydantic_ai_slim-0.1.6 → pydantic_ai_slim-0.1.8}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.1.6
32
+ Requires-Dist: pydantic-graph==0.1.8
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
45
45
  Provides-Extra: duckduckgo
46
46
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
47
47
  Provides-Extra: evals
48
- Requires-Dist: pydantic-evals==0.1.6; extra == 'evals'
48
+ Requires-Dist: pydantic-evals==0.1.8; extra == 'evals'
49
49
  Provides-Extra: groq
50
50
  Requires-Dist: groq>=0.15.0; extra == 'groq'
51
51
  Provides-Extra: logfire
@@ -2,14 +2,13 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import dataclasses
5
- import json
6
5
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
6
  from contextlib import asynccontextmanager, contextmanager
8
7
  from contextvars import ContextVar
9
8
  from dataclasses import field
10
9
  from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
11
10
 
12
- from opentelemetry.trace import Span, Tracer
11
+ from opentelemetry.trace import Tracer
13
12
  from typing_extensions import TypeGuard, TypeVar, assert_never
14
13
 
15
14
  from pydantic_graph import BaseNode, Graph, GraphRunContext
@@ -24,7 +23,6 @@ from . import (
24
23
  result,
25
24
  usage as _usage,
26
25
  )
27
- from .models.instrumented import InstrumentedModel
28
26
  from .result import OutputDataT, ToolOutput
29
27
  from .settings import ModelSettings, merge_model_settings
30
28
  from .tools import RunContext, Tool, ToolDefinition
@@ -95,7 +93,6 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
95
93
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
96
94
  mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
97
95
 
98
- run_span: Span
99
96
  tracer: Tracer
100
97
 
101
98
 
@@ -498,39 +495,12 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
498
495
  final_result: result.FinalResult[NodeRunEndT],
499
496
  tool_responses: list[_messages.ModelRequestPart],
500
497
  ) -> End[result.FinalResult[NodeRunEndT]]:
501
- run_span = ctx.deps.run_span
502
- usage = ctx.state.usage
503
498
  messages = ctx.state.message_history
504
499
 
505
500
  # For backwards compatibility, append a new ModelRequest using the tool returns and retries
506
501
  if tool_responses:
507
502
  messages.append(_messages.ModelRequest(parts=tool_responses))
508
503
 
509
- run_span.set_attributes(
510
- {
511
- **usage.opentelemetry_attributes(),
512
- 'all_messages_events': json.dumps(
513
- [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
514
- ),
515
- 'final_result': final_result.output
516
- if isinstance(final_result.output, str)
517
- else json.dumps(InstrumentedModel.serialize_any(final_result.output)),
518
- }
519
- )
520
- run_span.set_attributes(
521
- {
522
- 'logfire.json_schema': json.dumps(
523
- {
524
- 'type': 'object',
525
- 'properties': {
526
- 'all_messages_events': {'type': 'array'},
527
- 'final_result': {'type': 'object'},
528
- },
529
- }
530
- ),
531
- }
532
- )
533
-
534
504
  return End(final_result)
535
505
 
536
506
  async def _handle_text_response(
@@ -576,7 +546,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
576
546
  )
577
547
 
578
548
 
579
- async def process_function_tools(
549
+ async def process_function_tools( # noqa C901
580
550
  tool_calls: list[_messages.ToolCallPart],
581
551
  output_tool_name: str | None,
582
552
  output_tool_call_id: str | None,
@@ -662,6 +632,8 @@ async def process_function_tools(
662
632
  if not calls_to_run:
663
633
  return
664
634
 
635
+ user_parts: list[_messages.UserPromptPart] = []
636
+
665
637
  # Run all tool tasks in parallel
666
638
  results_by_index: dict[int, _messages.ModelRequestPart] = {}
667
639
  with ctx.deps.tracer.start_as_current_span(
@@ -675,6 +647,9 @@ async def process_function_tools(
675
647
  asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer), name=call.tool_name)
676
648
  for tool, call in calls_to_run
677
649
  ]
650
+
651
+ file_index = 1
652
+
678
653
  pending = tasks
679
654
  while pending:
680
655
  done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
@@ -682,7 +657,22 @@ async def process_function_tools(
682
657
  index = tasks.index(task)
683
658
  result = task.result()
684
659
  yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
685
- if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
660
+
661
+ if isinstance(result, _messages.RetryPromptPart):
662
+ results_by_index[index] = result
663
+ elif isinstance(result, _messages.ToolReturnPart):
664
+ if isinstance(result.content, _messages.MultiModalContentTypes):
665
+ user_parts.append(
666
+ _messages.UserPromptPart(
667
+ content=[f'This is file {file_index}:', result.content],
668
+ timestamp=result.timestamp,
669
+ part_kind='user-prompt',
670
+ )
671
+ )
672
+
673
+ result.content = f'See file {file_index}.'
674
+ file_index += 1
675
+
686
676
  results_by_index[index] = result
687
677
  else:
688
678
  assert_never(result)
@@ -692,6 +682,8 @@ async def process_function_tools(
692
682
  for k in sorted(results_by_index):
693
683
  output_parts.append(results_by_index[k])
694
684
 
685
+ output_parts.extend(user_parts)
686
+
695
687
 
696
688
  async def _tool_from_mcp_server(
697
689
  tool_name: str,
@@ -58,7 +58,7 @@ def function_schema( # noqa: C901
58
58
  Returns:
59
59
  A `FunctionSchema` instance.
60
60
  """
61
- config = ConfigDict(title=function.__name__)
61
+ config = ConfigDict(title=function.__name__, use_attribute_docstrings=True)
62
62
  config_wrapper = ConfigWrapper(config)
63
63
  gen_schema = _generate_schema.GenerateSchema(config_wrapper)
64
64
 
@@ -11,6 +11,7 @@ from functools import partial
11
11
  from types import GenericAlias
12
12
  from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
13
13
 
14
+ from anyio.to_thread import run_sync
14
15
  from pydantic import BaseModel
15
16
  from pydantic.json_schema import JsonSchemaValue
16
17
  from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
@@ -31,11 +32,8 @@ _R = TypeVar('_R')
31
32
 
32
33
 
33
34
  async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R:
34
- if kwargs:
35
- # noinspection PyTypeChecker
36
- return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs))
37
- else:
38
- return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore
35
+ wrapped_func = partial(func, *args, **kwargs)
36
+ return await run_sync(wrapped_func)
39
37
 
40
38
 
41
39
  def is_model_like(type_: Any) -> bool:
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import dataclasses
4
4
  import inspect
5
+ import json
5
6
  import warnings
6
7
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
8
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
@@ -152,7 +153,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
152
153
  model: models.Model | models.KnownModelName | str | None = None,
153
154
  *,
154
155
  output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,
155
- instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
156
+ instructions: str
157
+ | _system_prompt.SystemPromptFunc[AgentDepsT]
158
+ | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
159
+ | None = None,
156
160
  system_prompt: str | Sequence[str] = (),
157
161
  deps_type: type[AgentDepsT] = NoneType,
158
162
  name: str | None = None,
@@ -175,7 +179,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
175
179
  model: models.Model | models.KnownModelName | str | None = None,
176
180
  *,
177
181
  result_type: type[OutputDataT] = str,
178
- instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
182
+ instructions: str
183
+ | _system_prompt.SystemPromptFunc[AgentDepsT]
184
+ | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
185
+ | None = None,
179
186
  system_prompt: str | Sequence[str] = (),
180
187
  deps_type: type[AgentDepsT] = NoneType,
181
188
  name: str | None = None,
@@ -197,7 +204,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
197
204
  *,
198
205
  # TODO change this back to `output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,` when we remove the overloads
199
206
  output_type: Any = str,
200
- instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
207
+ instructions: str
208
+ | _system_prompt.SystemPromptFunc[AgentDepsT]
209
+ | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
210
+ | None = None,
201
211
  system_prompt: str | Sequence[str] = (),
202
212
  deps_type: type[AgentDepsT] = NoneType,
203
213
  name: str | None = None,
@@ -296,10 +306,16 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
296
306
  )
297
307
  self._output_validators = []
298
308
 
299
- self._instructions_functions = (
300
- [_system_prompt.SystemPromptRunner(instructions)] if callable(instructions) else []
301
- )
302
- self._instructions = instructions if isinstance(instructions, str) else None
309
+ self._instructions = ''
310
+ self._instructions_functions = []
311
+ if isinstance(instructions, (str, Callable)):
312
+ instructions = [instructions]
313
+ for instruction in instructions or []:
314
+ if isinstance(instruction, str):
315
+ self._instructions += instruction + '\n'
316
+ else:
317
+ self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction))
318
+ self._instructions = self._instructions.strip() or None
303
319
 
304
320
  self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
305
321
  self._system_prompt_functions = []
@@ -585,9 +601,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
585
601
  )
586
602
 
587
603
  # Build the initial state
604
+ usage = usage or _usage.Usage()
588
605
  state = _agent_graph.GraphAgentState(
589
606
  message_history=message_history[:] if message_history else [],
590
- usage=usage or _usage.Usage(),
607
+ usage=usage,
591
608
  retries=0,
592
609
  run_step=0,
593
610
  )
@@ -625,8 +642,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
625
642
 
626
643
  instructions = self._instructions or ''
627
644
  for instructions_runner in self._instructions_functions:
628
- instructions += await instructions_runner.run(run_context)
629
- return instructions
645
+ instructions += '\n' + await instructions_runner.run(run_context)
646
+ return instructions.strip()
630
647
 
631
648
  graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
632
649
  user_deps=deps,
@@ -641,7 +658,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
641
658
  output_validators=output_validators,
642
659
  function_tools=self._function_tools,
643
660
  mcp_servers=self._mcp_servers,
644
- run_span=run_span,
645
661
  tracer=tracer,
646
662
  get_instructions=get_instructions,
647
663
  )
@@ -654,14 +670,51 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
654
670
  system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
655
671
  )
656
672
 
657
- async with graph.iter(
658
- start_node,
659
- state=state,
660
- deps=graph_deps,
661
- span=use_span(run_span, end_on_exit=True) if run_span.is_recording() else None,
662
- infer_name=False,
663
- ) as graph_run:
664
- yield AgentRun(graph_run)
673
+ try:
674
+ async with graph.iter(
675
+ start_node,
676
+ state=state,
677
+ deps=graph_deps,
678
+ span=use_span(run_span) if run_span.is_recording() else None,
679
+ infer_name=False,
680
+ ) as graph_run:
681
+ agent_run = AgentRun(graph_run)
682
+ yield agent_run
683
+ if (final_result := agent_run.result) is not None and run_span.is_recording():
684
+ run_span.set_attribute(
685
+ 'final_result',
686
+ (
687
+ final_result.output
688
+ if isinstance(final_result.output, str)
689
+ else json.dumps(InstrumentedModel.serialize_any(final_result.output))
690
+ ),
691
+ )
692
+ finally:
693
+ try:
694
+ if run_span.is_recording():
695
+ run_span.set_attributes(self._run_span_end_attributes(state, usage))
696
+ finally:
697
+ run_span.end()
698
+
699
+ def _run_span_end_attributes(self, state: _agent_graph.GraphAgentState, usage: _usage.Usage):
700
+ return {
701
+ **usage.opentelemetry_attributes(),
702
+ 'all_messages_events': json.dumps(
703
+ [
704
+ InstrumentedModel.event_to_dict(e)
705
+ for e in InstrumentedModel.messages_to_otel_events(state.message_history)
706
+ ]
707
+ ),
708
+ 'logfire.json_schema': json.dumps(
709
+ {
710
+ 'type': 'object',
711
+ 'properties': {
712
+ 'all_messages_events': {'type': 'array'},
713
+ 'final_result': {'type': 'object'},
714
+ },
715
+ }
716
+ ),
717
+ }
665
718
 
666
719
  @overload
667
720
  def run_sync(
@@ -253,6 +253,9 @@ class BinaryContent:
253
253
 
254
254
  UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent'
255
255
 
256
+ # Ideally this would be a Union of types, but Python 3.9 requires it to be a string, and strings don't work with `isinstance``.
257
+ MultiModalContentTypes = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)
258
+
256
259
 
257
260
  def _document_format(media_type: str) -> DocumentFormat:
258
261
  if media_type == 'application/pdf':
@@ -328,7 +328,7 @@ class GeminiModel(Model):
328
328
  content.append(
329
329
  _GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
330
330
  )
331
- elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl)):
331
+ elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl, VideoUrl)):
332
332
  client = cached_async_http_client()
333
333
  response = await client.get(item.url, follow_redirects=True)
334
334
  response.raise_for_status()
@@ -337,8 +337,6 @@ class GeminiModel(Model):
337
337
  inline_data={'data': base64.b64encode(response.content).decode('utf-8'), 'mime_type': mime_type}
338
338
  )
339
339
  content.append(inline_data)
340
- elif isinstance(item, VideoUrl): # pragma: no cover
341
- raise NotImplementedError('VideoUrl is not supported for Gemini.')
342
340
  else:
343
341
  assert_never(item)
344
342
  return content
@@ -483,7 +483,20 @@ class MistralModel(Model):
483
483
  assert_never(message)
484
484
  if instructions := self._get_instructions(messages):
485
485
  mistral_messages.insert(0, MistralSystemMessage(content=instructions))
486
- return mistral_messages
486
+
487
+ # Post-process messages to insert fake assistant message after tool message if followed by user message
488
+ # to work around `Unexpected role 'user' after role 'tool'` error.
489
+ processed_messages: list[MistralMessages] = []
490
+ for i, current_message in enumerate(mistral_messages):
491
+ processed_messages.append(current_message)
492
+
493
+ if isinstance(current_message, MistralToolMessage) and i + 1 < len(mistral_messages):
494
+ next_message = mistral_messages[i + 1]
495
+ if isinstance(next_message, MistralUserMessage):
496
+ # Insert a dummy assistant message
497
+ processed_messages.append(MistralAssistantMessage(content=[MistralTextChunk(text='OK')]))
498
+
499
+ return processed_messages
487
500
 
488
501
  def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage:
489
502
  content: str | list[MistralContentChunk]
@@ -439,12 +439,13 @@ class OpenAIModel(Model):
439
439
  )
440
440
  else: # pragma: no cover
441
441
  raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
442
- elif isinstance(item, AudioUrl): # pragma: no cover
442
+ elif isinstance(item, AudioUrl):
443
443
  client = cached_async_http_client()
444
444
  response = await client.get(item.url)
445
445
  response.raise_for_status()
446
446
  base64_encoded = base64.b64encode(response.content).decode('utf-8')
447
- audio = InputAudio(data=base64_encoded, format=response.headers.get('content-type'))
447
+ audio_format: Any = response.headers['content-type'].removeprefix('audio/')
448
+ audio = InputAudio(data=base64_encoded, format=audio_format)
448
449
  content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
449
450
  elif isinstance(item, DocumentUrl):
450
451
  client = cached_async_http_client()
@@ -453,10 +454,7 @@ class OpenAIModel(Model):
453
454
  base64_encoded = base64.b64encode(response.content).decode('utf-8')
454
455
  media_type = response.headers.get('content-type').split(';')[0]
455
456
  file_data = f'data:{media_type};base64,{base64_encoded}'
456
- file = File(
457
- file=FileFile(file_data=file_data, filename=f'filename.{item.format}'),
458
- type='file',
459
- )
457
+ file = File(file=FileFile(file_data=file_data, filename=f'filename.{item.format}'), type='file')
460
458
  content.append(file)
461
459
  elif isinstance(item, VideoUrl): # pragma: no cover
462
460
  raise NotImplementedError('VideoUrl is not supported for OpenAI')