pydantic-ai-slim 0.0.44__tar.gz → 0.0.45__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/PKG-INFO +2 -2
  2. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_parts_manager.py +7 -1
  3. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_utils.py +12 -6
  4. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/exceptions.py +2 -2
  5. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/messages.py +15 -27
  6. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/anthropic.py +7 -46
  7. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/bedrock.py +7 -11
  8. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/cohere.py +10 -50
  9. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/gemini.py +18 -73
  10. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/groq.py +9 -53
  11. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/mistral.py +12 -51
  12. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/openai.py +15 -67
  13. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/anthropic.py +4 -5
  14. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/azure.py +8 -9
  15. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/bedrock.py +2 -1
  16. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/cohere.py +4 -5
  17. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/deepseek.py +4 -4
  18. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/google_gla.py +3 -2
  19. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/google_vertex.py +2 -3
  20. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/groq.py +4 -5
  21. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/mistral.py +4 -5
  22. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/openai.py +5 -8
  23. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pyproject.toml +2 -2
  24. pydantic_ai_slim-0.0.44/pydantic_ai/models/vertexai.py +0 -260
  25. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/.gitignore +0 -0
  26. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/README.md +0 -0
  27. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/__init__.py +0 -0
  28. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_agent_graph.py +0 -0
  29. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_cli.py +0 -0
  30. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_griffe.py +0 -0
  31. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_pydantic.py +0 -0
  32. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_result.py +0 -0
  33. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_system_prompt.py +0 -0
  34. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/agent.py +0 -0
  35. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/__init__.py +0 -0
  36. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  37. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/tavily.py +0 -0
  38. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/format_as_xml.py +0 -0
  39. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/mcp.py +0 -0
  40. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/__init__.py +0 -0
  41. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/fallback.py +0 -0
  42. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/function.py +0 -0
  43. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/instrumented.py +0 -0
  44. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/test.py +0 -0
  45. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/wrapper.py +0 -0
  46. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/__init__.py +0 -0
  47. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/py.typed +0 -0
  48. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/result.py +0 -0
  49. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/settings.py +0 -0
  50. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/tools.py +0 -0
  51. {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.44
3
+ Version: 0.0.45
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.44
32
+ Requires-Dist: pydantic-graph==0.0.45
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -29,6 +29,8 @@ from pydantic_ai.messages import (
29
29
  ToolCallPartDelta,
30
30
  )
31
31
 
32
+ from ._utils import generate_tool_call_id as _generate_tool_call_id
33
+
32
34
  VendorId = Hashable
33
35
  """
34
36
  Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
@@ -221,7 +223,11 @@ class ModelResponsePartsManager:
221
223
  ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
222
224
  has been added to the manager, or replaced an existing part.
223
225
  """
224
- new_part = ToolCallPart(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
226
+ new_part = ToolCallPart(
227
+ tool_name=tool_name,
228
+ args=args,
229
+ tool_call_id=tool_call_id or _generate_tool_call_id(),
230
+ )
225
231
  if vendor_part_id is None:
226
232
  # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
227
233
  new_part_index = len(self._parts)
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import time
5
+ import uuid
5
6
  from collections.abc import AsyncIterable, AsyncIterator, Iterator
6
7
  from contextlib import asynccontextmanager, suppress
7
8
  from dataclasses import dataclass, is_dataclass
@@ -195,12 +196,17 @@ def now_utc() -> datetime:
195
196
  return datetime.now(tz=timezone.utc)
196
197
 
197
198
 
198
- def guard_tool_call_id(
199
- t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart, model_source: str
200
- ) -> str:
201
- """Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
202
- assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
203
- return t.tool_call_id
199
+ def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
200
+ """Type guard that either returns the tool call id or generates a new one if it's None."""
201
+ return t.tool_call_id or generate_tool_call_id()
202
+
203
+
204
+ def generate_tool_call_id() -> str:
205
+ """Generate a tool call id.
206
+
207
+ Ensure that the tool call id is unique.
208
+ """
209
+ return f'pyd_ai_{uuid.uuid4().hex}'
204
210
 
205
211
 
206
212
  class PeekableAsyncStream(Generic[T]):
@@ -3,9 +3,9 @@ from __future__ import annotations as _annotations
3
3
  import json
4
4
  import sys
5
5
 
6
- if sys.version_info < (3, 11):
6
+ if sys.version_info < (3, 11): # pragma: no cover
7
7
  from exceptiongroup import ExceptionGroup
8
- else:
8
+ else: # pragma: no cover
9
9
  ExceptionGroup = ExceptionGroup
10
10
 
11
11
  __all__ = (
@@ -12,7 +12,7 @@ import pydantic_core
12
12
  from opentelemetry._events import Event
13
13
  from typing_extensions import TypeAlias
14
14
 
15
- from ._utils import now_utc as _now_utc
15
+ from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
16
16
  from .exceptions import UnexpectedModelBehavior
17
17
 
18
18
 
@@ -268,8 +268,8 @@ class ToolReturnPart:
268
268
  content: Any
269
269
  """The return value."""
270
270
 
271
- tool_call_id: str | None = None
272
- """Optional tool call identifier, this is used by some models including OpenAI."""
271
+ tool_call_id: str
272
+ """The tool call identifier, this is used by some models including OpenAI."""
273
273
 
274
274
  timestamp: datetime = field(default_factory=_now_utc)
275
275
  """The timestamp, when the tool returned."""
@@ -328,8 +328,11 @@ class RetryPromptPart:
328
328
  tool_name: str | None = None
329
329
  """The name of the tool that was called, if any."""
330
330
 
331
- tool_call_id: str | None = None
332
- """Optional tool call identifier, this is used by some models including OpenAI."""
331
+ tool_call_id: str = field(default_factory=_generate_tool_call_id)
332
+ """The tool call identifier, this is used by some models including OpenAI.
333
+
334
+ In case the tool call id is not provided by the model, PydanticAI will generate a random one.
335
+ """
333
336
 
334
337
  timestamp: datetime = field(default_factory=_now_utc)
335
338
  """The timestamp, when the retry was triggered."""
@@ -406,8 +409,11 @@ class ToolCallPart:
406
409
  This is stored either as a JSON string or a Python dictionary depending on how data was received.
407
410
  """
408
411
 
409
- tool_call_id: str | None = None
410
- """Optional tool call identifier, this is used by some models including OpenAI."""
412
+ tool_call_id: str = field(default_factory=_generate_tool_call_id)
413
+ """The tool call identifier, this is used by some models including OpenAI.
414
+
415
+ In case the tool call id is not provided by the model, PydanticAI will generate a random one.
416
+ """
411
417
 
412
418
  part_kind: Literal['tool-call'] = 'tool-call'
413
419
  """Part type identifier, this is available on all parts as a discriminator."""
@@ -564,11 +570,7 @@ class ToolCallPartDelta:
564
570
  if self.tool_name_delta is None or self.args_delta is None:
565
571
  return None
566
572
 
567
- return ToolCallPart(
568
- self.tool_name_delta,
569
- self.args_delta,
570
- self.tool_call_id,
571
- )
573
+ return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id())
572
574
 
573
575
  @overload
574
576
  def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
@@ -620,20 +622,11 @@ class ToolCallPartDelta:
620
622
  delta = replace(delta, args_delta=updated_args_delta)
621
623
 
622
624
  if self.tool_call_id:
623
- # Set the tool_call_id if it wasn't present, otherwise error if it has changed
624
- if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
625
- raise UnexpectedModelBehavior(
626
- f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
627
- )
628
625
  delta = replace(delta, tool_call_id=self.tool_call_id)
629
626
 
630
627
  # If we now have enough data to create a full ToolCallPart, do so
631
628
  if delta.tool_name_delta is not None and delta.args_delta is not None:
632
- return ToolCallPart(
633
- delta.tool_name_delta,
634
- delta.args_delta,
635
- delta.tool_call_id,
636
- )
629
+ return ToolCallPart(delta.tool_name_delta, delta.args_delta, delta.tool_call_id or _generate_tool_call_id())
637
630
 
638
631
  return delta
639
632
 
@@ -656,11 +649,6 @@ class ToolCallPartDelta:
656
649
  part = replace(part, args=updated_dict)
657
650
 
658
651
  if self.tool_call_id:
659
- # Replace the tool_call_id entirely if given
660
- if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
661
- raise UnexpectedModelBehavior(
662
- f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
663
- )
664
652
  part = replace(part, tool_call_id=self.tool_call_id)
665
653
  return part
666
654
 
@@ -10,8 +10,7 @@ from json import JSONDecodeError, loads as json_loads
10
10
  from typing import Any, Literal, Union, cast, overload
11
11
 
12
12
  from anthropic.types import DocumentBlockParam
13
- from httpx import AsyncClient as AsyncHTTPClient
14
- from typing_extensions import assert_never, deprecated
13
+ from typing_extensions import assert_never
15
14
 
16
15
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
17
16
  from .._utils import guard_tool_call_id as _guard_tool_call_id
@@ -112,34 +111,11 @@ class AnthropicModel(Model):
112
111
  _model_name: AnthropicModelName = field(repr=False)
113
112
  _system: str = field(default='anthropic', repr=False)
114
113
 
115
- @overload
116
114
  def __init__(
117
115
  self,
118
116
  model_name: AnthropicModelName,
119
117
  *,
120
118
  provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
121
- ) -> None: ...
122
-
123
- @deprecated('Use the `provider` parameter instead of `api_key`, `anthropic_client`, and `http_client`.')
124
- @overload
125
- def __init__(
126
- self,
127
- model_name: AnthropicModelName,
128
- *,
129
- provider: None = None,
130
- api_key: str | None = None,
131
- anthropic_client: AsyncAnthropic | None = None,
132
- http_client: AsyncHTTPClient | None = None,
133
- ) -> None: ...
134
-
135
- def __init__(
136
- self,
137
- model_name: AnthropicModelName,
138
- *,
139
- provider: Literal['anthropic'] | Provider[AsyncAnthropic] | None = None,
140
- api_key: str | None = None,
141
- anthropic_client: AsyncAnthropic | None = None,
142
- http_client: AsyncHTTPClient | None = None,
143
119
  ):
144
120
  """Initialize an Anthropic model.
145
121
 
@@ -148,27 +124,12 @@ class AnthropicModel(Model):
148
124
  [here](https://docs.anthropic.com/en/docs/about-claude/models).
149
125
  provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
150
126
  instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
151
- api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
152
- will be used if available.
153
- anthropic_client: An existing
154
- [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
155
- client to use, if provided, `api_key` and `http_client` must be `None`.
156
- http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
157
127
  """
158
128
  self._model_name = model_name
159
129
 
160
- if provider is not None:
161
- if isinstance(provider, str):
162
- provider = infer_provider(provider)
163
- self.client = provider.client
164
- elif anthropic_client is not None:
165
- assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
166
- assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
167
- self.client = anthropic_client
168
- elif http_client is not None:
169
- self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
170
- else:
171
- self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
130
+ if isinstance(provider, str):
131
+ provider = infer_provider(provider)
132
+ self.client = provider.client
172
133
 
173
134
  @property
174
135
  def base_url(self) -> str:
@@ -326,7 +287,7 @@ class AnthropicModel(Model):
326
287
  user_content_params.append(content)
327
288
  elif isinstance(request_part, ToolReturnPart):
328
289
  tool_result_block_param = ToolResultBlockParam(
329
- tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
290
+ tool_use_id=_guard_tool_call_id(t=request_part),
330
291
  type='tool_result',
331
292
  content=request_part.model_response_str(),
332
293
  is_error=False,
@@ -337,7 +298,7 @@ class AnthropicModel(Model):
337
298
  retry_param = TextBlockParam(type='text', text=request_part.model_response())
338
299
  else:
339
300
  retry_param = ToolResultBlockParam(
340
- tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
301
+ tool_use_id=_guard_tool_call_id(t=request_part),
341
302
  type='tool_result',
342
303
  content=request_part.model_response(),
343
304
  is_error=True,
@@ -351,7 +312,7 @@ class AnthropicModel(Model):
351
312
  assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
352
313
  else:
353
314
  tool_use_block_param = ToolUseBlockParam(
354
- id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
315
+ id=_guard_tool_call_id(t=response_part),
355
316
  type='tool_use',
356
317
  name=response_part.tool_name,
357
318
  input=response_part.args_as_dict(),
@@ -143,14 +143,15 @@ class BedrockConverseModel(Model):
143
143
  model_name: The name of the model to use.
144
144
  model_name: The name of the Bedrock model to use. List of model names available
145
145
  [here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html).
146
- provider: The provider to use. Defaults to `'bedrock'`.
146
+ provider: The provider to use for authentication and API access. Can be either the string
147
+ 'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
148
+ created using the other parameters.
147
149
  """
148
150
  self._model_name = model_name
149
151
 
150
152
  if isinstance(provider, str):
151
- self.client = infer_provider(provider).client
152
- else:
153
- self.client = cast('BedrockRuntimeClient', provider.client)
153
+ provider = infer_provider(provider)
154
+ self.client = cast('BedrockRuntimeClient', provider.client)
154
155
 
155
156
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
156
157
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
@@ -345,7 +346,7 @@ class BedrockConverseModel(Model):
345
346
  content.append({'text': item.content})
346
347
  else:
347
348
  assert isinstance(item, ToolCallPart)
348
- content.append(self._map_tool_call(item)) # FIXME: MISSING key
349
+ content.append(self._map_tool_call(item))
349
350
  bedrock_messages.append({'role': 'assistant', 'content': content})
350
351
  else:
351
352
  assert_never(m)
@@ -394,13 +395,8 @@ class BedrockConverseModel(Model):
394
395
 
395
396
  @staticmethod
396
397
  def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
397
- assert t.tool_call_id is not None
398
398
  return {
399
- 'toolUse': {
400
- 'toolUseId': t.tool_call_id,
401
- 'name': t.tool_name,
402
- 'input': t.args_as_dict(),
403
- }
399
+ 'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
404
400
  }
405
401
 
406
402
 
@@ -3,14 +3,13 @@ from __future__ import annotations as _annotations
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass, field
5
5
  from itertools import chain
6
- from typing import Literal, Union, cast, overload
6
+ from typing import Literal, Union, cast
7
7
 
8
8
  from cohere import TextAssistantMessageContentItem
9
- from httpx import AsyncClient as AsyncHTTPClient
10
- from typing_extensions import assert_never, deprecated
9
+ from typing_extensions import assert_never
11
10
 
12
11
  from .. import ModelHTTPError, result
13
- from .._utils import guard_tool_call_id as _guard_tool_call_id
12
+ from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
14
13
  from ..messages import (
15
14
  ModelMessage,
16
15
  ModelRequest,
@@ -29,7 +28,6 @@ from ..tools import ToolDefinition
29
28
  from . import (
30
29
  Model,
31
30
  ModelRequestParameters,
32
- cached_async_http_client,
33
31
  check_allow_model_requests,
34
32
  )
35
33
 
@@ -102,37 +100,11 @@ class CohereModel(Model):
102
100
  _model_name: CohereModelName = field(repr=False)
103
101
  _system: str = field(default='cohere', repr=False)
104
102
 
105
- @overload
106
103
  def __init__(
107
104
  self,
108
105
  model_name: CohereModelName,
109
106
  *,
110
107
  provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
111
- api_key: None = None,
112
- cohere_client: None = None,
113
- http_client: None = None,
114
- ) -> None: ...
115
-
116
- @deprecated('Use the `provider` parameter instead of `api_key`, `cohere_client`, and `http_client`.')
117
- @overload
118
- def __init__(
119
- self,
120
- model_name: CohereModelName,
121
- *,
122
- provider: None = None,
123
- api_key: str | None = None,
124
- cohere_client: AsyncClientV2 | None = None,
125
- http_client: AsyncHTTPClient | None = None,
126
- ) -> None: ...
127
-
128
- def __init__(
129
- self,
130
- model_name: CohereModelName,
131
- *,
132
- provider: Literal['cohere'] | Provider[AsyncClientV2] | None = None,
133
- api_key: str | None = None,
134
- cohere_client: AsyncClientV2 | None = None,
135
- http_client: AsyncHTTPClient | None = None,
136
108
  ):
137
109
  """Initialize an Cohere model.
138
110
 
@@ -142,24 +114,12 @@ class CohereModel(Model):
142
114
  provider: The provider to use for authentication and API access. Can be either the string
143
115
  'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
144
116
  created using the other parameters.
145
- api_key: The API key to use for authentication, if not provided, the
146
- `CO_API_KEY` environment variable will be used if available.
147
- cohere_client: An existing Cohere async client to use. If provided,
148
- `api_key` and `http_client` must be `None`.
149
- http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
150
117
  """
151
118
  self._model_name: CohereModelName = model_name
152
119
 
153
- if provider is not None:
154
- if isinstance(provider, str):
155
- provider = infer_provider(provider)
156
- self.client = provider.client
157
- elif cohere_client is not None:
158
- assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
159
- assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
160
- self.client = cohere_client
161
- else:
162
- self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client or cached_async_http_client())
120
+ if isinstance(provider, str):
121
+ provider = infer_provider(provider)
122
+ self.client = provider.client
163
123
 
164
124
  @property
165
125
  def base_url(self) -> str:
@@ -225,7 +185,7 @@ class CohereModel(Model):
225
185
  ToolCallPart(
226
186
  tool_name=c.function.name,
227
187
  args=c.function.arguments,
228
- tool_call_id=c.id,
188
+ tool_call_id=c.id or _generate_tool_call_id(),
229
189
  )
230
190
  )
231
191
  return ModelResponse(parts=parts, model_name=self._model_name)
@@ -262,7 +222,7 @@ class CohereModel(Model):
262
222
  @staticmethod
263
223
  def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
264
224
  return ToolCallV2(
265
- id=_guard_tool_call_id(t=t, model_source='Cohere'),
225
+ id=_guard_tool_call_id(t=t),
266
226
  type='function',
267
227
  function=ToolCallV2Function(
268
228
  name=t.tool_name,
@@ -294,7 +254,7 @@ class CohereModel(Model):
294
254
  elif isinstance(part, ToolReturnPart):
295
255
  yield ToolChatMessageV2(
296
256
  role='tool',
297
- tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
257
+ tool_call_id=_guard_tool_call_id(t=part),
298
258
  content=part.model_response_str(),
299
259
  )
300
260
  elif isinstance(part, RetryPromptPart):
@@ -303,7 +263,7 @@ class CohereModel(Model):
303
263
  else:
304
264
  yield ToolChatMessageV2(
305
265
  role='tool',
306
- tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
266
+ tool_call_id=_guard_tool_call_id(t=part),
307
267
  content=part.model_response(),
308
268
  )
309
269
  else:
@@ -1,19 +1,19 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import os
5
4
  import re
6
5
  from collections.abc import AsyncIterator, Sequence
7
6
  from contextlib import asynccontextmanager
8
7
  from copy import deepcopy
9
8
  from dataclasses import dataclass, field
10
9
  from datetime import datetime
11
- from typing import Annotated, Any, Literal, Protocol, Union, cast, overload
10
+ from typing import Annotated, Any, Literal, Protocol, Union, cast
12
11
  from uuid import uuid4
13
12
 
13
+ import httpx
14
14
  import pydantic
15
- from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
16
- from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
15
+ from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
16
+ from typing_extensions import NotRequired, TypedDict, assert_never
17
17
 
18
18
  from pydantic_ai.providers import Provider, infer_provider
19
19
 
@@ -85,78 +85,36 @@ class GeminiModel(Model):
85
85
  Apart from `__init__`, all methods are private or match those of the base class.
86
86
  """
87
87
 
88
- client: AsyncHTTPClient = field(repr=False)
88
+ client: httpx.AsyncClient = field(repr=False)
89
89
 
90
90
  _model_name: GeminiModelName = field(repr=False)
91
- _provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = field(repr=False)
91
+ _provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False)
92
92
  _auth: AuthProtocol | None = field(repr=False)
93
93
  _url: str | None = field(repr=False)
94
94
  _system: str = field(default='gemini', repr=False)
95
95
 
96
- @overload
97
96
  def __init__(
98
97
  self,
99
98
  model_name: GeminiModelName,
100
99
  *,
101
- provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] = 'google-gla',
102
- ) -> None: ...
103
-
104
- @deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
105
- @overload
106
- def __init__(
107
- self,
108
- model_name: GeminiModelName,
109
- *,
110
- provider: None = None,
111
- api_key: str | None = None,
112
- http_client: AsyncHTTPClient | None = None,
113
- url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
114
- ) -> None: ...
115
-
116
- def __init__(
117
- self,
118
- model_name: GeminiModelName,
119
- *,
120
- provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
121
- api_key: str | None = None,
122
- http_client: AsyncHTTPClient | None = None,
123
- url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
100
+ provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
124
101
  ):
125
102
  """Initialize a Gemini model.
126
103
 
127
104
  Args:
128
105
  model_name: The name of the model to use.
129
- provider: The provider to use for the model.
130
- api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
131
- will be used if available.
132
- http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
133
- url_template: The URL template to use for making requests, you shouldn't need to change this,
134
- docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
135
- `model` is substituted with the model name, and `function` is added to the end of the URL.
106
+ provider: The provider to use for authentication and API access. Can be either the string
107
+ 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
108
+ If not provided, a new provider will be created using the other parameters.
136
109
  """
137
110
  self._model_name = model_name
138
111
  self._provider = provider
139
112
 
140
- if provider is not None:
141
- if isinstance(provider, str):
142
- provider = infer_provider(provider)
143
- self._system = provider.name
144
- self.client = provider.client
145
- self._url = str(self.client.base_url)
146
- else:
147
- if api_key is None:
148
- if env_api_key := os.getenv('GEMINI_API_KEY'):
149
- api_key = env_api_key
150
- else:
151
- raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
152
- self.client = http_client or cached_async_http_client()
153
- self._auth = ApiKeyAuth(api_key)
154
- self._url = url_template.format(model=model_name)
155
-
156
- @property
157
- def auth(self) -> AuthProtocol:
158
- assert self._auth is not None, 'Auth not initialized'
159
- return self._auth
113
+ if isinstance(provider, str):
114
+ provider = infer_provider(provider)
115
+ self._system = provider.name
116
+ self.client = provider.client
117
+ self._url = str(self.client.base_url)
160
118
 
161
119
  @property
162
120
  def base_url(self) -> str:
@@ -252,18 +210,10 @@ class GeminiModel(Model):
252
210
  if generation_config:
253
211
  request_data['generation_config'] = generation_config
254
212
 
255
- headers = {
256
- 'Content-Type': 'application/json',
257
- 'User-Agent': get_user_agent(),
258
- }
259
- if self._provider is None: # pragma: no cover
260
- url = self.base_url + ('streamGenerateContent' if streamed else 'generateContent')
261
- headers.update(await self.auth.headers())
262
- else:
263
- url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
213
+ headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
214
+ url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
264
215
 
265
216
  request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
266
-
267
217
  async with self.client.stream(
268
218
  'POST',
269
219
  url,
@@ -603,12 +553,7 @@ def _process_response_from_parts(
603
553
  if 'text' in part:
604
554
  items.append(TextPart(content=part['text']))
605
555
  elif 'function_call' in part:
606
- items.append(
607
- ToolCallPart(
608
- tool_name=part['function_call']['name'],
609
- args=part['function_call']['args'],
610
- )
611
- )
556
+ items.append(ToolCallPart(tool_name=part['function_call']['name'], args=part['function_call']['args']))
612
557
  elif 'function_response' in part:
613
558
  raise UnexpectedModelBehavior(
614
559
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'