pydantic-ai-slim 0.0.44__tar.gz → 0.0.46__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/PKG-INFO +2 -2
  2. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_parts_manager.py +7 -1
  3. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_utils.py +12 -6
  4. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/agent.py +2 -2
  5. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/exceptions.py +2 -2
  6. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/mcp.py +25 -1
  7. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/messages.py +15 -27
  8. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/__init__.py +15 -6
  9. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/anthropic.py +7 -46
  10. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/bedrock.py +7 -11
  11. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/cohere.py +10 -50
  12. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/gemini.py +18 -73
  13. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/groq.py +9 -53
  14. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/mistral.py +12 -51
  15. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/openai.py +15 -67
  16. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/anthropic.py +6 -6
  17. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/azure.py +9 -10
  18. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/bedrock.py +2 -1
  19. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/cohere.py +6 -8
  20. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/deepseek.py +6 -5
  21. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/google_gla.py +4 -3
  22. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/google_vertex.py +3 -4
  23. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/groq.py +6 -8
  24. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/mistral.py +6 -6
  25. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/openai.py +6 -8
  26. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pyproject.toml +2 -2
  27. pydantic_ai_slim-0.0.44/pydantic_ai/models/vertexai.py +0 -260
  28. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/.gitignore +0 -0
  29. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/README.md +0 -0
  30. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/__init__.py +0 -0
  31. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_agent_graph.py +0 -0
  32. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_cli.py +0 -0
  33. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_griffe.py +0 -0
  34. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_pydantic.py +0 -0
  35. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_result.py +0 -0
  36. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_system_prompt.py +0 -0
  37. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/common_tools/__init__.py +0 -0
  38. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  39. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/common_tools/tavily.py +0 -0
  40. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/format_as_xml.py +0 -0
  41. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/fallback.py +0 -0
  42. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/function.py +0 -0
  43. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/instrumented.py +0 -0
  44. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/test.py +0 -0
  45. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/wrapper.py +0 -0
  46. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/__init__.py +0 -0
  47. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/py.typed +0 -0
  48. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/result.py +0 -0
  49. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/settings.py +0 -0
  50. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/tools.py +0 -0
  51. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.44
3
+ Version: 0.0.46
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.44
32
+ Requires-Dist: pydantic-graph==0.0.46
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -29,6 +29,8 @@ from pydantic_ai.messages import (
29
29
  ToolCallPartDelta,
30
30
  )
31
31
 
32
+ from ._utils import generate_tool_call_id as _generate_tool_call_id
33
+
32
34
  VendorId = Hashable
33
35
  """
34
36
  Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
@@ -221,7 +223,11 @@ class ModelResponsePartsManager:
221
223
  ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
222
224
  has been added to the manager, or replaced an existing part.
223
225
  """
224
- new_part = ToolCallPart(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
226
+ new_part = ToolCallPart(
227
+ tool_name=tool_name,
228
+ args=args,
229
+ tool_call_id=tool_call_id or _generate_tool_call_id(),
230
+ )
225
231
  if vendor_part_id is None:
226
232
  # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
227
233
  new_part_index = len(self._parts)
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import time
5
+ import uuid
5
6
  from collections.abc import AsyncIterable, AsyncIterator, Iterator
6
7
  from contextlib import asynccontextmanager, suppress
7
8
  from dataclasses import dataclass, is_dataclass
@@ -195,12 +196,17 @@ def now_utc() -> datetime:
195
196
  return datetime.now(tz=timezone.utc)
196
197
 
197
198
 
198
- def guard_tool_call_id(
199
- t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart, model_source: str
200
- ) -> str:
201
- """Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
202
- assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
203
- return t.tool_call_id
199
+ def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
200
+ """Type guard that either returns the tool call id or generates a new one if it's None."""
201
+ return t.tool_call_id or generate_tool_call_id()
202
+
203
+
204
+ def generate_tool_call_id() -> str:
205
+ """Generate a tool call id.
206
+
207
+ Ensure that the tool call id is unique.
208
+ """
209
+ return f'pyd_ai_{uuid.uuid4().hex}'
204
210
 
205
211
 
206
212
  class PeekableAsyncStream(Generic[T]):
@@ -13,7 +13,7 @@ from pydantic.json_schema import GenerateJsonSchema
13
13
  from typing_extensions import TypeGuard, TypeVar, deprecated
14
14
 
15
15
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
16
- from pydantic_graph._utils import run_until_complete
16
+ from pydantic_graph._utils import get_event_loop
17
17
 
18
18
  from . import (
19
19
  _agent_graph,
@@ -567,7 +567,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
567
567
  """
568
568
  if infer_name and self.name is None:
569
569
  self._infer_name(inspect.currentframe())
570
- return run_until_complete(
570
+ return get_event_loop().run_until_complete(
571
571
  self.run(
572
572
  user_prompt,
573
573
  result_type=result_type,
@@ -3,9 +3,9 @@ from __future__ import annotations as _annotations
3
3
  import json
4
4
  import sys
5
5
 
6
- if sys.version_info < (3, 11):
6
+ if sys.version_info < (3, 11): # pragma: no cover
7
7
  from exceptiongroup import ExceptionGroup
8
- else:
8
+ else: # pragma: no cover
9
9
  ExceptionGroup = ExceptionGroup
10
10
 
11
11
  __all__ = (
@@ -188,11 +188,35 @@ class MCPServerHTTP(MCPServer):
188
188
  For example for a server running locally, this might be `http://localhost:3001/sse`.
189
189
  """
190
190
 
191
+ headers: dict[str, Any] | None = None
192
+ """Optional HTTP headers to be sent with each request to the SSE endpoint.
193
+
194
+ These headers will be passed directly to the underlying `httpx.AsyncClient`.
195
+ Useful for authentication, custom headers, or other HTTP-specific configurations.
196
+ """
197
+
198
+ timeout: float = 5
199
+ """Initial connection timeout in seconds for establishing the SSE connection.
200
+
201
+ This timeout applies to the initial connection setup and handshake.
202
+ If the connection cannot be established within this time, the operation will fail.
203
+ """
204
+
205
+ sse_read_timeout: float = 60 * 5
206
+ """Maximum time in seconds to wait for new SSE messages before timing out.
207
+
208
+ This timeout applies to the long-lived SSE connection after it's established.
209
+ If no new messages are received within this time, the connection will be considered stale
210
+ and may be closed. Defaults to 5 minutes (300 seconds).
211
+ """
212
+
191
213
  @asynccontextmanager
192
214
  async def client_streams(
193
215
  self,
194
216
  ) -> AsyncIterator[
195
217
  tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
196
218
  ]: # pragma: no cover
197
- async with sse_client(url=self.url) as (read_stream, write_stream):
219
+ async with sse_client(
220
+ url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout
221
+ ) as (read_stream, write_stream):
198
222
  yield read_stream, write_stream
@@ -12,7 +12,7 @@ import pydantic_core
12
12
  from opentelemetry._events import Event
13
13
  from typing_extensions import TypeAlias
14
14
 
15
- from ._utils import now_utc as _now_utc
15
+ from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
16
16
  from .exceptions import UnexpectedModelBehavior
17
17
 
18
18
 
@@ -268,8 +268,8 @@ class ToolReturnPart:
268
268
  content: Any
269
269
  """The return value."""
270
270
 
271
- tool_call_id: str | None = None
272
- """Optional tool call identifier, this is used by some models including OpenAI."""
271
+ tool_call_id: str
272
+ """The tool call identifier, this is used by some models including OpenAI."""
273
273
 
274
274
  timestamp: datetime = field(default_factory=_now_utc)
275
275
  """The timestamp, when the tool returned."""
@@ -328,8 +328,11 @@ class RetryPromptPart:
328
328
  tool_name: str | None = None
329
329
  """The name of the tool that was called, if any."""
330
330
 
331
- tool_call_id: str | None = None
332
- """Optional tool call identifier, this is used by some models including OpenAI."""
331
+ tool_call_id: str = field(default_factory=_generate_tool_call_id)
332
+ """The tool call identifier, this is used by some models including OpenAI.
333
+
334
+ In case the tool call id is not provided by the model, PydanticAI will generate a random one.
335
+ """
333
336
 
334
337
  timestamp: datetime = field(default_factory=_now_utc)
335
338
  """The timestamp, when the retry was triggered."""
@@ -406,8 +409,11 @@ class ToolCallPart:
406
409
  This is stored either as a JSON string or a Python dictionary depending on how data was received.
407
410
  """
408
411
 
409
- tool_call_id: str | None = None
410
- """Optional tool call identifier, this is used by some models including OpenAI."""
412
+ tool_call_id: str = field(default_factory=_generate_tool_call_id)
413
+ """The tool call identifier, this is used by some models including OpenAI.
414
+
415
+ In case the tool call id is not provided by the model, PydanticAI will generate a random one.
416
+ """
411
417
 
412
418
  part_kind: Literal['tool-call'] = 'tool-call'
413
419
  """Part type identifier, this is available on all parts as a discriminator."""
@@ -564,11 +570,7 @@ class ToolCallPartDelta:
564
570
  if self.tool_name_delta is None or self.args_delta is None:
565
571
  return None
566
572
 
567
- return ToolCallPart(
568
- self.tool_name_delta,
569
- self.args_delta,
570
- self.tool_call_id,
571
- )
573
+ return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id())
572
574
 
573
575
  @overload
574
576
  def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
@@ -620,20 +622,11 @@ class ToolCallPartDelta:
620
622
  delta = replace(delta, args_delta=updated_args_delta)
621
623
 
622
624
  if self.tool_call_id:
623
- # Set the tool_call_id if it wasn't present, otherwise error if it has changed
624
- if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
625
- raise UnexpectedModelBehavior(
626
- f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
627
- )
628
625
  delta = replace(delta, tool_call_id=self.tool_call_id)
629
626
 
630
627
  # If we now have enough data to create a full ToolCallPart, do so
631
628
  if delta.tool_name_delta is not None and delta.args_delta is not None:
632
- return ToolCallPart(
633
- delta.tool_name_delta,
634
- delta.args_delta,
635
- delta.tool_call_id,
636
- )
629
+ return ToolCallPart(delta.tool_name_delta, delta.args_delta, delta.tool_call_id or _generate_tool_call_id())
637
630
 
638
631
  return delta
639
632
 
@@ -656,11 +649,6 @@ class ToolCallPartDelta:
656
649
  part = replace(part, args=updated_dict)
657
650
 
658
651
  if self.tool_call_id:
659
- # Replace the tool_call_id entirely if given
660
- if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
661
- raise UnexpectedModelBehavior(
662
- f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
663
- )
664
652
  part = replace(part, tool_call_id=self.tool_call_id)
665
653
  return part
666
654
 
@@ -431,32 +431,41 @@ def infer_model(model: Model | KnownModelName) -> Model:
431
431
  raise UserError(f'Unknown model: {model}')
432
432
 
433
433
 
434
- def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
435
- """Cached HTTPX async client so multiple agents and calls can share the same client.
434
+ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
435
+ """Cached HTTPX async client that creates a separate client for each provider.
436
+
437
+ The client is cached based on the provider parameter. If provider is None, it's used for non-provider specific
438
+ requests (like downloading images). Multiple agents and calls can share the same client when they use the same provider.
436
439
 
437
440
  There are good reasons why in production you should use a `httpx.AsyncClient` as an async context manager as
438
441
  described in [encode/httpx#2026](https://github.com/encode/httpx/pull/2026), but when experimenting or showing
439
- examples, it's very useful not to, this allows multiple Agents to use a single client.
442
+ examples, it's very useful not to.
440
443
 
441
444
  The default timeouts match those of OpenAI,
442
445
  see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
443
446
  """
444
- client = _cached_async_http_client(timeout=timeout, connect=connect)
447
+ client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect)
445
448
  if client.is_closed:
446
449
  # This happens if the context manager is used, so we need to create a new client.
447
450
  _cached_async_http_client.cache_clear()
448
- client = _cached_async_http_client(timeout=timeout, connect=connect)
451
+ client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect)
449
452
  return client
450
453
 
451
454
 
452
455
  @cache
453
- def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
456
+ def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
454
457
  return httpx.AsyncClient(
458
+ transport=_cached_async_http_transport(),
455
459
  timeout=httpx.Timeout(timeout=timeout, connect=connect),
456
460
  headers={'User-Agent': get_user_agent()},
457
461
  )
458
462
 
459
463
 
464
+ @cache
465
+ def _cached_async_http_transport() -> httpx.AsyncHTTPTransport:
466
+ return httpx.AsyncHTTPTransport()
467
+
468
+
460
469
  @cache
461
470
  def get_user_agent() -> str:
462
471
  """Get the user agent string for the HTTP client."""
@@ -10,8 +10,7 @@ from json import JSONDecodeError, loads as json_loads
10
10
  from typing import Any, Literal, Union, cast, overload
11
11
 
12
12
  from anthropic.types import DocumentBlockParam
13
- from httpx import AsyncClient as AsyncHTTPClient
14
- from typing_extensions import assert_never, deprecated
13
+ from typing_extensions import assert_never
15
14
 
16
15
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
17
16
  from .._utils import guard_tool_call_id as _guard_tool_call_id
@@ -112,34 +111,11 @@ class AnthropicModel(Model):
112
111
  _model_name: AnthropicModelName = field(repr=False)
113
112
  _system: str = field(default='anthropic', repr=False)
114
113
 
115
- @overload
116
114
  def __init__(
117
115
  self,
118
116
  model_name: AnthropicModelName,
119
117
  *,
120
118
  provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
121
- ) -> None: ...
122
-
123
- @deprecated('Use the `provider` parameter instead of `api_key`, `anthropic_client`, and `http_client`.')
124
- @overload
125
- def __init__(
126
- self,
127
- model_name: AnthropicModelName,
128
- *,
129
- provider: None = None,
130
- api_key: str | None = None,
131
- anthropic_client: AsyncAnthropic | None = None,
132
- http_client: AsyncHTTPClient | None = None,
133
- ) -> None: ...
134
-
135
- def __init__(
136
- self,
137
- model_name: AnthropicModelName,
138
- *,
139
- provider: Literal['anthropic'] | Provider[AsyncAnthropic] | None = None,
140
- api_key: str | None = None,
141
- anthropic_client: AsyncAnthropic | None = None,
142
- http_client: AsyncHTTPClient | None = None,
143
119
  ):
144
120
  """Initialize an Anthropic model.
145
121
 
@@ -148,27 +124,12 @@ class AnthropicModel(Model):
148
124
  [here](https://docs.anthropic.com/en/docs/about-claude/models).
149
125
  provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
150
126
  instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
151
- api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
152
- will be used if available.
153
- anthropic_client: An existing
154
- [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
155
- client to use, if provided, `api_key` and `http_client` must be `None`.
156
- http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
157
127
  """
158
128
  self._model_name = model_name
159
129
 
160
- if provider is not None:
161
- if isinstance(provider, str):
162
- provider = infer_provider(provider)
163
- self.client = provider.client
164
- elif anthropic_client is not None:
165
- assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
166
- assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
167
- self.client = anthropic_client
168
- elif http_client is not None:
169
- self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
170
- else:
171
- self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
130
+ if isinstance(provider, str):
131
+ provider = infer_provider(provider)
132
+ self.client = provider.client
172
133
 
173
134
  @property
174
135
  def base_url(self) -> str:
@@ -326,7 +287,7 @@ class AnthropicModel(Model):
326
287
  user_content_params.append(content)
327
288
  elif isinstance(request_part, ToolReturnPart):
328
289
  tool_result_block_param = ToolResultBlockParam(
329
- tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
290
+ tool_use_id=_guard_tool_call_id(t=request_part),
330
291
  type='tool_result',
331
292
  content=request_part.model_response_str(),
332
293
  is_error=False,
@@ -337,7 +298,7 @@ class AnthropicModel(Model):
337
298
  retry_param = TextBlockParam(type='text', text=request_part.model_response())
338
299
  else:
339
300
  retry_param = ToolResultBlockParam(
340
- tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
301
+ tool_use_id=_guard_tool_call_id(t=request_part),
341
302
  type='tool_result',
342
303
  content=request_part.model_response(),
343
304
  is_error=True,
@@ -351,7 +312,7 @@ class AnthropicModel(Model):
351
312
  assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
352
313
  else:
353
314
  tool_use_block_param = ToolUseBlockParam(
354
- id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
315
+ id=_guard_tool_call_id(t=response_part),
355
316
  type='tool_use',
356
317
  name=response_part.tool_name,
357
318
  input=response_part.args_as_dict(),
@@ -143,14 +143,15 @@ class BedrockConverseModel(Model):
143
143
  model_name: The name of the model to use.
144
144
  model_name: The name of the Bedrock model to use. List of model names available
145
145
  [here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html).
146
- provider: The provider to use. Defaults to `'bedrock'`.
146
+ provider: The provider to use for authentication and API access. Can be either the string
147
+ 'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
148
+ created using the other parameters.
147
149
  """
148
150
  self._model_name = model_name
149
151
 
150
152
  if isinstance(provider, str):
151
- self.client = infer_provider(provider).client
152
- else:
153
- self.client = cast('BedrockRuntimeClient', provider.client)
153
+ provider = infer_provider(provider)
154
+ self.client = cast('BedrockRuntimeClient', provider.client)
154
155
 
155
156
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
156
157
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
@@ -345,7 +346,7 @@ class BedrockConverseModel(Model):
345
346
  content.append({'text': item.content})
346
347
  else:
347
348
  assert isinstance(item, ToolCallPart)
348
- content.append(self._map_tool_call(item)) # FIXME: MISSING key
349
+ content.append(self._map_tool_call(item))
349
350
  bedrock_messages.append({'role': 'assistant', 'content': content})
350
351
  else:
351
352
  assert_never(m)
@@ -394,13 +395,8 @@ class BedrockConverseModel(Model):
394
395
 
395
396
  @staticmethod
396
397
  def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
397
- assert t.tool_call_id is not None
398
398
  return {
399
- 'toolUse': {
400
- 'toolUseId': t.tool_call_id,
401
- 'name': t.tool_name,
402
- 'input': t.args_as_dict(),
403
- }
399
+ 'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
404
400
  }
405
401
 
406
402
 
@@ -3,14 +3,13 @@ from __future__ import annotations as _annotations
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass, field
5
5
  from itertools import chain
6
- from typing import Literal, Union, cast, overload
6
+ from typing import Literal, Union, cast
7
7
 
8
8
  from cohere import TextAssistantMessageContentItem
9
- from httpx import AsyncClient as AsyncHTTPClient
10
- from typing_extensions import assert_never, deprecated
9
+ from typing_extensions import assert_never
11
10
 
12
11
  from .. import ModelHTTPError, result
13
- from .._utils import guard_tool_call_id as _guard_tool_call_id
12
+ from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
14
13
  from ..messages import (
15
14
  ModelMessage,
16
15
  ModelRequest,
@@ -29,7 +28,6 @@ from ..tools import ToolDefinition
29
28
  from . import (
30
29
  Model,
31
30
  ModelRequestParameters,
32
- cached_async_http_client,
33
31
  check_allow_model_requests,
34
32
  )
35
33
 
@@ -102,37 +100,11 @@ class CohereModel(Model):
102
100
  _model_name: CohereModelName = field(repr=False)
103
101
  _system: str = field(default='cohere', repr=False)
104
102
 
105
- @overload
106
103
  def __init__(
107
104
  self,
108
105
  model_name: CohereModelName,
109
106
  *,
110
107
  provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
111
- api_key: None = None,
112
- cohere_client: None = None,
113
- http_client: None = None,
114
- ) -> None: ...
115
-
116
- @deprecated('Use the `provider` parameter instead of `api_key`, `cohere_client`, and `http_client`.')
117
- @overload
118
- def __init__(
119
- self,
120
- model_name: CohereModelName,
121
- *,
122
- provider: None = None,
123
- api_key: str | None = None,
124
- cohere_client: AsyncClientV2 | None = None,
125
- http_client: AsyncHTTPClient | None = None,
126
- ) -> None: ...
127
-
128
- def __init__(
129
- self,
130
- model_name: CohereModelName,
131
- *,
132
- provider: Literal['cohere'] | Provider[AsyncClientV2] | None = None,
133
- api_key: str | None = None,
134
- cohere_client: AsyncClientV2 | None = None,
135
- http_client: AsyncHTTPClient | None = None,
136
108
  ):
137
109
  """Initialize an Cohere model.
138
110
 
@@ -142,24 +114,12 @@ class CohereModel(Model):
142
114
  provider: The provider to use for authentication and API access. Can be either the string
143
115
  'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
144
116
  created using the other parameters.
145
- api_key: The API key to use for authentication, if not provided, the
146
- `CO_API_KEY` environment variable will be used if available.
147
- cohere_client: An existing Cohere async client to use. If provided,
148
- `api_key` and `http_client` must be `None`.
149
- http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
150
117
  """
151
118
  self._model_name: CohereModelName = model_name
152
119
 
153
- if provider is not None:
154
- if isinstance(provider, str):
155
- provider = infer_provider(provider)
156
- self.client = provider.client
157
- elif cohere_client is not None:
158
- assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
159
- assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
160
- self.client = cohere_client
161
- else:
162
- self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client or cached_async_http_client())
120
+ if isinstance(provider, str):
121
+ provider = infer_provider(provider)
122
+ self.client = provider.client
163
123
 
164
124
  @property
165
125
  def base_url(self) -> str:
@@ -225,7 +185,7 @@ class CohereModel(Model):
225
185
  ToolCallPart(
226
186
  tool_name=c.function.name,
227
187
  args=c.function.arguments,
228
- tool_call_id=c.id,
188
+ tool_call_id=c.id or _generate_tool_call_id(),
229
189
  )
230
190
  )
231
191
  return ModelResponse(parts=parts, model_name=self._model_name)
@@ -262,7 +222,7 @@ class CohereModel(Model):
262
222
  @staticmethod
263
223
  def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
264
224
  return ToolCallV2(
265
- id=_guard_tool_call_id(t=t, model_source='Cohere'),
225
+ id=_guard_tool_call_id(t=t),
266
226
  type='function',
267
227
  function=ToolCallV2Function(
268
228
  name=t.tool_name,
@@ -294,7 +254,7 @@ class CohereModel(Model):
294
254
  elif isinstance(part, ToolReturnPart):
295
255
  yield ToolChatMessageV2(
296
256
  role='tool',
297
- tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
257
+ tool_call_id=_guard_tool_call_id(t=part),
298
258
  content=part.model_response_str(),
299
259
  )
300
260
  elif isinstance(part, RetryPromptPart):
@@ -303,7 +263,7 @@ class CohereModel(Model):
303
263
  else:
304
264
  yield ToolChatMessageV2(
305
265
  role='tool',
306
- tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
266
+ tool_call_id=_guard_tool_call_id(t=part),
307
267
  content=part.model_response(),
308
268
  )
309
269
  else: