pydantic-ai-slim 1.0.0b1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,8 +7,11 @@ from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
8
  from typing import Any, Literal, cast, overload
9
9
 
10
+ from pydantic import BaseModel, Json, ValidationError
10
11
  from typing_extensions import assert_never
11
12
 
13
+ from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
14
+
12
15
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
16
  from .._run_context import RunContext
14
17
  from .._thinking_part import split_content_into_text_and_thinking
@@ -48,7 +51,7 @@ from . import (
48
51
  )
49
52
 
50
53
  try:
51
- from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
54
+ from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
52
55
  from groq.types import chat
53
56
  from groq.types.chat.chat_completion_content_part_image_param import ImageURL
54
57
  except ImportError as _import_error:
@@ -169,9 +172,24 @@ class GroqModel(Model):
169
172
  model_request_parameters: ModelRequestParameters,
170
173
  ) -> ModelResponse:
171
174
  check_allow_model_requests()
172
- response = await self._completions_create(
173
- messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
174
- )
175
+ try:
176
+ response = await self._completions_create(
177
+ messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
178
+ )
179
+ except ModelHTTPError as e:
180
+ if isinstance(e.body, dict): # pragma: no branch
181
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
182
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call.
183
+ try:
184
+ error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
185
+ tool_call_part = ToolCallPart(
186
+ tool_name=error.error.failed_generation.name,
187
+ args=error.error.failed_generation.arguments,
188
+ )
189
+ return ModelResponse(parts=[tool_call_part])
190
+ except ValidationError:
191
+ pass
192
+ raise
175
193
  model_response = self._process_response(response)
176
194
  return model_response
177
195
 
@@ -228,6 +246,18 @@ class GroqModel(Model):
228
246
 
229
247
  groq_messages = self._map_messages(messages)
230
248
 
249
+ response_format: chat.completion_create_params.ResponseFormat | None = None
250
+ if model_request_parameters.output_mode == 'native':
251
+ output_object = model_request_parameters.output_object
252
+ assert output_object is not None
253
+ response_format = self._map_json_schema(output_object)
254
+ elif (
255
+ model_request_parameters.output_mode == 'prompted'
256
+ and not tools
257
+ and self.profile.supports_json_object_output
258
+ ): # pragma: no branch
259
+ response_format = {'type': 'json_object'}
260
+
231
261
  try:
232
262
  extra_headers = model_settings.get('extra_headers', {})
233
263
  extra_headers.setdefault('User-Agent', get_user_agent())
@@ -240,6 +270,7 @@ class GroqModel(Model):
240
270
  tool_choice=tool_choice or NOT_GIVEN,
241
271
  stop=model_settings.get('stop_sequences', NOT_GIVEN),
242
272
  stream=stream,
273
+ response_format=response_format or NOT_GIVEN,
243
274
  max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
244
275
  temperature=model_settings.get('temperature', NOT_GIVEN),
245
276
  top_p=model_settings.get('top_p', NOT_GIVEN),
@@ -385,6 +416,19 @@ class GroqModel(Model):
385
416
  },
386
417
  }
387
418
 
419
+ def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
420
+ response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
421
+ 'type': 'json_schema',
422
+ 'json_schema': {
423
+ 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
424
+ 'schema': o.json_schema,
425
+ 'strict': o.strict,
426
+ },
427
+ }
428
+ if o.description: # pragma: no branch
429
+ response_format_param['json_schema']['description'] = o.description
430
+ return response_format_param
431
+
388
432
  @classmethod
389
433
  def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
390
434
  for part in message.parts:
@@ -449,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
449
493
  _provider_name: str
450
494
 
451
495
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
452
- async for chunk in self._response:
453
- self._usage += _map_usage(chunk)
454
-
455
- try:
456
- choice = chunk.choices[0]
457
- except IndexError:
458
- continue
459
-
460
- # Handle the text part of the response
461
- content = choice.delta.content
462
- if content is not None:
463
- maybe_event = self._parts_manager.handle_text_delta(
464
- vendor_part_id='content',
465
- content=content,
466
- thinking_tags=self._model_profile.thinking_tags,
467
- ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
468
- )
469
- if maybe_event is not None: # pragma: no branch
470
- yield maybe_event
471
-
472
- # Handle the tool calls
473
- for dtc in choice.delta.tool_calls or []:
474
- maybe_event = self._parts_manager.handle_tool_call_delta(
475
- vendor_part_id=dtc.index,
476
- tool_name=dtc.function and dtc.function.name,
477
- args=dtc.function and dtc.function.arguments,
478
- tool_call_id=dtc.id,
479
- )
480
- if maybe_event is not None:
481
- yield maybe_event
496
+ try:
497
+ async for chunk in self._response:
498
+ self._usage += _map_usage(chunk)
499
+
500
+ try:
501
+ choice = chunk.choices[0]
502
+ except IndexError:
503
+ continue
504
+
505
+ # Handle the text part of the response
506
+ content = choice.delta.content
507
+ if content is not None:
508
+ maybe_event = self._parts_manager.handle_text_delta(
509
+ vendor_part_id='content',
510
+ content=content,
511
+ thinking_tags=self._model_profile.thinking_tags,
512
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
513
+ )
514
+ if maybe_event is not None: # pragma: no branch
515
+ yield maybe_event
516
+
517
+ # Handle the tool calls
518
+ for dtc in choice.delta.tool_calls or []:
519
+ maybe_event = self._parts_manager.handle_tool_call_delta(
520
+ vendor_part_id=dtc.index,
521
+ tool_name=dtc.function and dtc.function.name,
522
+ args=dtc.function and dtc.function.arguments,
523
+ tool_call_id=dtc.id,
524
+ )
525
+ if maybe_event is not None:
526
+ yield maybe_event
527
+ except APIError as e:
528
+ if isinstance(e.body, dict): # pragma: no branch
529
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
530
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call
531
+ try:
532
+ error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
533
+ yield self._parts_manager.handle_tool_call_part(
534
+ vendor_part_id='tool_use_failed',
535
+ tool_name=error.failed_generation.name,
536
+ args=error.failed_generation.arguments,
537
+ )
538
+ return
539
+ except ValidationError as e: # pragma: no cover
540
+ pass
541
+ raise # pragma: no cover
482
542
 
483
543
  @property
484
544
  def model_name(self) -> GroqModelName:
@@ -510,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
510
570
  input_tokens=response_usage.prompt_tokens,
511
571
  output_tokens=response_usage.completion_tokens,
512
572
  )
573
+
574
+
575
+ class _GroqToolUseFailedGeneration(BaseModel):
576
+ name: str
577
+ arguments: dict[str, Any]
578
+
579
+
580
+ class _GroqToolUseFailedInnerError(BaseModel):
581
+ message: str
582
+ type: Literal['invalid_request_error']
583
+ code: Literal['tool_use_failed']
584
+ failed_generation: Json[_GroqToolUseFailedGeneration]
585
+
586
+
587
+ class _GroqToolUseFailedError(BaseModel):
588
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
589
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call.
590
+ # Example payload from `exception.body`:
591
+ # {
592
+ # 'error': {
593
+ # 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
594
+ # 'type': 'invalid_request_error',
595
+ # 'code': 'tool_use_failed',
596
+ # 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
597
+ # }
598
+ # }
599
+
600
+ error: _GroqToolUseFailedInnerError
@@ -420,10 +420,15 @@ class InstrumentedModel(WrapperModel):
420
420
  return
421
421
 
422
422
  self.instrumentation_settings.handle_messages(messages, response, system, span)
423
+ try:
424
+ cost_attributes = {'operation.cost': float(response.cost().total_price)}
425
+ except LookupError:
426
+ cost_attributes = {}
423
427
  span.set_attributes(
424
428
  {
425
429
  **response.usage.opentelemetry_attributes(),
426
430
  'gen_ai.response.model': response_model,
431
+ **cost_attributes,
427
432
  }
428
433
  )
429
434
  span.update_name(f'{operation} {request_model}')
@@ -225,6 +225,7 @@ class OpenAIChatModel(Model):
225
225
  'openrouter',
226
226
  'together',
227
227
  'vercel',
228
+ 'litellm',
228
229
  ]
229
230
  | Provider[AsyncOpenAI] = 'openai',
230
231
  profile: ModelProfileSpec | None = None,
@@ -252,6 +253,7 @@ class OpenAIChatModel(Model):
252
253
  'openrouter',
253
254
  'together',
254
255
  'vercel',
256
+ 'litellm',
255
257
  ]
256
258
  | Provider[AsyncOpenAI] = 'openai',
257
259
  profile: ModelProfileSpec | None = None,
@@ -278,6 +280,7 @@ class OpenAIChatModel(Model):
278
280
  'openrouter',
279
281
  'together',
280
282
  'vercel',
283
+ 'litellm',
281
284
  ]
282
285
  | Provider[AsyncOpenAI] = 'openai',
283
286
  profile: ModelProfileSpec | None = None,
@@ -606,7 +609,7 @@ class OpenAIChatModel(Model):
606
609
  def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
607
610
  response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
608
611
  'type': 'json_schema',
609
- 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, 'strict': True},
612
+ 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
610
613
  }
611
614
  if o.description:
612
615
  response_format_param['json_schema']['description'] = o.description
@@ -1171,6 +1174,10 @@ class OpenAIStreamedResponse(StreamedResponse):
1171
1174
  except IndexError:
1172
1175
  continue
1173
1176
 
1177
+ # When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
1178
+ if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
1179
+ continue
1180
+
1174
1181
  # Handle the text part of the response
1175
1182
  content = choice.delta.content
1176
1183
  if content is not None:
@@ -1270,12 +1277,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1270
1277
  tool_call_id=chunk.item.call_id,
1271
1278
  )
1272
1279
  elif isinstance(chunk.item, responses.ResponseReasoningItem):
1273
- content = chunk.item.summary[0].text if chunk.item.summary else ''
1274
- yield self._parts_manager.handle_thinking_delta(
1275
- vendor_part_id=chunk.item.id,
1276
- content=content,
1277
- signature=chunk.item.id,
1278
- )
1280
+ pass
1279
1281
  elif isinstance(chunk.item, responses.ResponseOutputMessage):
1280
1282
  pass
1281
1283
  elif isinstance(chunk.item, responses.ResponseFunctionWebSearch):
@@ -1291,7 +1293,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1291
1293
  pass
1292
1294
 
1293
1295
  elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
1294
- pass # there's nothing we need to do here
1296
+ yield self._parts_manager.handle_thinking_delta(
1297
+ vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
1298
+ content=chunk.part.text,
1299
+ id=chunk.item_id,
1300
+ )
1295
1301
 
1296
1302
  elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
1297
1303
  pass # there's nothing we need to do here
@@ -1301,9 +1307,9 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1301
1307
 
1302
1308
  elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
1303
1309
  yield self._parts_manager.handle_thinking_delta(
1304
- vendor_part_id=chunk.item_id,
1310
+ vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
1305
1311
  content=chunk.delta,
1306
- signature=chunk.item_id,
1312
+ id=chunk.item_id,
1307
1313
  )
1308
1314
 
1309
1315
  # TODO(Marcelo): We should support annotations in the future.
@@ -1311,9 +1317,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1311
1317
  pass # there's nothing we need to do here
1312
1318
 
1313
1319
  elif isinstance(chunk, responses.ResponseTextDeltaEvent):
1314
- maybe_event = self._parts_manager.handle_text_delta(
1315
- vendor_part_id=chunk.content_index, content=chunk.delta
1316
- )
1320
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=chunk.item_id, content=chunk.delta)
1317
1321
  if maybe_event is not None: # pragma: no branch
1318
1322
  yield maybe_event
1319
1323
 
@@ -135,6 +135,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
135
135
  from .github import GitHubProvider
136
136
 
137
137
  return GitHubProvider
138
+ elif provider == 'litellm':
139
+ from .litellm import LiteLLMProvider
140
+
141
+ return LiteLLMProvider
138
142
  else: # pragma: no cover
139
143
  raise ValueError(f'Unknown provider: {provider}')
140
144
 
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import functools
4
+ from asyncio import Lock
4
5
  from collections.abc import AsyncGenerator, Mapping
5
6
  from pathlib import Path
6
7
  from typing import Literal, overload
@@ -118,7 +119,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
118
119
  class _VertexAIAuth(httpx.Auth):
119
120
  """Auth class for Vertex AI API."""
120
121
 
121
- _refresh_lock: anyio.Lock = anyio.Lock()
122
+ _refresh_lock: Lock = Lock()
122
123
 
123
124
  credentials: BaseCredentials | ServiceAccountCredentials | None
124
125
 
@@ -14,6 +14,7 @@ from pydantic_ai.profiles.groq import groq_model_profile
14
14
  from pydantic_ai.profiles.meta import meta_model_profile
15
15
  from pydantic_ai.profiles.mistral import mistral_model_profile
16
16
  from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
17
+ from pydantic_ai.profiles.openai import openai_model_profile
17
18
  from pydantic_ai.profiles.qwen import qwen_model_profile
18
19
  from pydantic_ai.providers import Provider
19
20
 
@@ -26,6 +27,23 @@ except ImportError as _import_error: # pragma: no cover
26
27
  ) from _import_error
27
28
 
28
29
 
30
+ def groq_moonshotai_model_profile(model_name: str) -> ModelProfile | None:
31
+ """Get the model profile for an MoonshotAI model used with the Groq provider."""
32
+ return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
33
+ moonshotai_model_profile(model_name)
34
+ )
35
+
36
+
37
+ def meta_groq_model_profile(model_name: str) -> ModelProfile | None:
38
+ """Get the model profile for a Meta model used with the Groq provider."""
39
+ if model_name in {'llama-4-maverick-17b-128e-instruct', 'llama-4-scout-17b-16e-instruct'}:
40
+ return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
41
+ meta_model_profile(model_name)
42
+ )
43
+ else:
44
+ return meta_model_profile(model_name)
45
+
46
+
29
47
  class GroqProvider(Provider[AsyncGroq]):
30
48
  """Provider for Groq API."""
31
49
 
@@ -44,13 +62,14 @@ class GroqProvider(Provider[AsyncGroq]):
44
62
  def model_profile(self, model_name: str) -> ModelProfile | None:
45
63
  prefix_to_profile = {
46
64
  'llama': meta_model_profile,
47
- 'meta-llama/': meta_model_profile,
65
+ 'meta-llama/': meta_groq_model_profile,
48
66
  'gemma': google_model_profile,
49
67
  'qwen': qwen_model_profile,
50
68
  'deepseek': deepseek_model_profile,
51
69
  'mistral': mistral_model_profile,
52
- 'moonshotai/': moonshotai_model_profile,
70
+ 'moonshotai/': groq_moonshotai_model_profile,
53
71
  'compound-': groq_model_profile,
72
+ 'openai/': openai_model_profile,
54
73
  }
55
74
 
56
75
  for prefix, profile_func in prefix_to_profile.items():
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from typing import overload
4
+
5
+ from httpx import AsyncClient as AsyncHTTPClient
6
+ from openai import AsyncOpenAI
7
+
8
+ from pydantic_ai.models import cached_async_http_client
9
+ from pydantic_ai.profiles import ModelProfile
10
+ from pydantic_ai.profiles.amazon import amazon_model_profile
11
+ from pydantic_ai.profiles.anthropic import anthropic_model_profile
12
+ from pydantic_ai.profiles.cohere import cohere_model_profile
13
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
14
+ from pydantic_ai.profiles.google import google_model_profile
15
+ from pydantic_ai.profiles.grok import grok_model_profile
16
+ from pydantic_ai.profiles.groq import groq_model_profile
17
+ from pydantic_ai.profiles.meta import meta_model_profile
18
+ from pydantic_ai.profiles.mistral import mistral_model_profile
19
+ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
20
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
21
+ from pydantic_ai.profiles.qwen import qwen_model_profile
22
+ from pydantic_ai.providers import Provider
23
+
24
+ try:
25
+ from openai import AsyncOpenAI
26
+ except ImportError as _import_error: # pragma: no cover
27
+ raise ImportError(
28
+ 'Please install the `openai` package to use the LiteLLM provider, '
29
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
30
+ ) from _import_error
31
+
32
+
33
+ class LiteLLMProvider(Provider[AsyncOpenAI]):
34
+ """Provider for LiteLLM API."""
35
+
36
+ @property
37
+ def name(self) -> str:
38
+ return 'litellm'
39
+
40
+ @property
41
+ def base_url(self) -> str:
42
+ return str(self.client.base_url)
43
+
44
+ @property
45
+ def client(self) -> AsyncOpenAI:
46
+ return self._client
47
+
48
+ def model_profile(self, model_name: str) -> ModelProfile | None:
49
+ # Map provider prefixes to their profile functions
50
+ provider_to_profile = {
51
+ 'anthropic': anthropic_model_profile,
52
+ 'openai': openai_model_profile,
53
+ 'google': google_model_profile,
54
+ 'mistralai': mistral_model_profile,
55
+ 'mistral': mistral_model_profile,
56
+ 'cohere': cohere_model_profile,
57
+ 'amazon': amazon_model_profile,
58
+ 'bedrock': amazon_model_profile,
59
+ 'meta-llama': meta_model_profile,
60
+ 'meta': meta_model_profile,
61
+ 'groq': groq_model_profile,
62
+ 'deepseek': deepseek_model_profile,
63
+ 'moonshotai': moonshotai_model_profile,
64
+ 'x-ai': grok_model_profile,
65
+ 'qwen': qwen_model_profile,
66
+ }
67
+
68
+ profile = None
69
+
70
+ # Check if model name contains a provider prefix (e.g., "anthropic/claude-3")
71
+ if '/' in model_name:
72
+ provider_prefix, model_suffix = model_name.split('/', 1)
73
+ if provider_prefix in provider_to_profile:
74
+ profile = provider_to_profile[provider_prefix](model_suffix)
75
+
76
+ # If no profile found, default to OpenAI profile
77
+ if profile is None:
78
+ profile = openai_model_profile(model_name)
79
+
80
+ # As LiteLLMProvider is used with OpenAIModel, which uses OpenAIJsonSchemaTransformer,
81
+ # we maintain that behavior
82
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
83
+
84
+ @overload
85
+ def __init__(
86
+ self,
87
+ *,
88
+ api_key: str | None = None,
89
+ api_base: str | None = None,
90
+ ) -> None: ...
91
+
92
+ @overload
93
+ def __init__(
94
+ self,
95
+ *,
96
+ api_key: str | None = None,
97
+ api_base: str | None = None,
98
+ http_client: AsyncHTTPClient,
99
+ ) -> None: ...
100
+
101
+ @overload
102
+ def __init__(self, *, openai_client: AsyncOpenAI) -> None: ...
103
+
104
+ def __init__(
105
+ self,
106
+ *,
107
+ api_key: str | None = None,
108
+ api_base: str | None = None,
109
+ openai_client: AsyncOpenAI | None = None,
110
+ http_client: AsyncHTTPClient | None = None,
111
+ ) -> None:
112
+ """Initialize a LiteLLM provider.
113
+
114
+ Args:
115
+ api_key: API key for the model provider. If None, LiteLLM will try to get it from environment variables.
116
+ api_base: Base URL for the model provider. Use this for custom endpoints or self-hosted models.
117
+ openai_client: Pre-configured OpenAI client. If provided, other parameters are ignored.
118
+ http_client: Custom HTTP client to use.
119
+ """
120
+ if openai_client is not None:
121
+ self._client = openai_client
122
+ return
123
+
124
+ # Create OpenAI client that will be used with LiteLLM's completion function
125
+ # The actual API calls will be intercepted and routed through LiteLLM
126
+ if http_client is not None:
127
+ self._client = AsyncOpenAI(
128
+ base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
129
+ )
130
+ else:
131
+ http_client = cached_async_http_client(provider='litellm')
132
+ self._client = AsyncOpenAI(
133
+ base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
134
+ )
pydantic_ai/retries.py CHANGED
@@ -13,6 +13,8 @@ The module includes:
13
13
 
14
14
  from __future__ import annotations
15
15
 
16
+ from types import TracebackType
17
+
16
18
  from httpx import (
17
19
  AsyncBaseTransport,
18
20
  AsyncHTTPTransport,
@@ -185,11 +187,30 @@ class TenacityTransport(BaseTransport):
185
187
  response.request = req
186
188
 
187
189
  if self.validate_response:
188
- self.validate_response(response)
190
+ try:
191
+ self.validate_response(response)
192
+ except Exception:
193
+ response.close()
194
+ raise
189
195
  return response
190
196
 
191
197
  return handle_request(request)
192
198
 
199
+ def __enter__(self) -> TenacityTransport:
200
+ self.wrapped.__enter__()
201
+ return self
202
+
203
+ def __exit__(
204
+ self,
205
+ exc_type: type[BaseException] | None = None,
206
+ exc_value: BaseException | None = None,
207
+ traceback: TracebackType | None = None,
208
+ ) -> None:
209
+ self.wrapped.__exit__(exc_type, exc_value, traceback)
210
+
211
+ def close(self) -> None:
212
+ self.wrapped.close() # pragma: no cover
213
+
193
214
 
194
215
  class AsyncTenacityTransport(AsyncBaseTransport):
195
216
  """Asynchronous HTTP transport with tenacity-based retry functionality.
@@ -263,11 +284,30 @@ class AsyncTenacityTransport(AsyncBaseTransport):
263
284
  response.request = req
264
285
 
265
286
  if self.validate_response:
266
- self.validate_response(response)
287
+ try:
288
+ self.validate_response(response)
289
+ except Exception:
290
+ await response.aclose()
291
+ raise
267
292
  return response
268
293
 
269
294
  return await handle_async_request(request)
270
295
 
296
+ async def __aenter__(self) -> AsyncTenacityTransport:
297
+ await self.wrapped.__aenter__()
298
+ return self
299
+
300
+ async def __aexit__(
301
+ self,
302
+ exc_type: type[BaseException] | None = None,
303
+ exc_value: BaseException | None = None,
304
+ traceback: TracebackType | None = None,
305
+ ) -> None:
306
+ await self.wrapped.__aexit__(exc_type, exc_value, traceback)
307
+
308
+ async def aclose(self) -> None:
309
+ await self.wrapped.aclose()
310
+
271
311
 
272
312
  def wait_retry_after(
273
313
  fallback_strategy: Callable[[RetryCallState], float] | None = None, max_wait: float = 300
pydantic_ai/tools.py CHANGED
@@ -70,7 +70,7 @@ Usage `ToolFuncEither[AgentDepsT, ToolParams]`.
70
70
  ToolPrepareFunc: TypeAlias = Callable[[RunContext[AgentDepsT], 'ToolDefinition'], Awaitable['ToolDefinition | None']]
71
71
  """Definition of a function that can prepare a tool definition at call time.
72
72
 
73
- See [tool docs](../tools.md#tool-prepare) for more information.
73
+ See [tool docs](../tools-advanced.md#tool-prepare) for more information.
74
74
 
75
75
  Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
76
76
 
@@ -140,7 +140,7 @@ class DeferredToolRequests:
140
140
 
141
141
  Results can be passed to the next agent run using a [`DeferredToolResults`][pydantic_ai.tools.DeferredToolResults] object with the same tool call IDs.
142
142
 
143
- See [deferred tools docs](../tools.md#deferred-tools) for more information.
143
+ See [deferred tools docs](../deferred-tools.md#deferred-tools) for more information.
144
144
  """
145
145
 
146
146
  calls: list[ToolCallPart] = field(default_factory=list)
@@ -204,7 +204,7 @@ class DeferredToolResults:
204
204
 
205
205
  The tool call IDs need to match those from the [`DeferredToolRequests`][pydantic_ai.output.DeferredToolRequests] output object from the previous run.
206
206
 
207
- See [deferred tools docs](../tools.md#deferred-tools) for more information.
207
+ See [deferred tools docs](../deferred-tools.md#deferred-tools) for more information.
208
208
  """
209
209
 
210
210
  calls: dict[str, DeferredToolCallResult | Any] = field(default_factory=dict)
@@ -328,7 +328,7 @@ class Tool(Generic[AgentDepsT]):
328
328
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
329
329
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
330
330
  requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
331
- See the [tools documentation](../tools.md#human-in-the-loop-tool-approval) for more info.
331
+ See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
332
332
  function_schema: The function schema to use for the tool. If not provided, it will be generated.
333
333
  """
334
334
  self.function = function
@@ -472,16 +472,16 @@ class ToolDefinition:
472
472
  - `'function'`: a tool that will be executed by Pydantic AI during an agent run and has its result returned to the model
473
473
  - `'output'`: a tool that passes through an output value that ends the run
474
474
  - `'external'`: a tool whose result will be produced outside of the Pydantic AI agent run in which it was called, because it depends on an upstream service (or user) or could take longer to generate than it's reasonable to keep the agent process running.
475
- See the [tools documentation](../tools.md#deferred-tools) for more info.
475
+ See the [tools documentation](../deferred-tools.md#deferred-tools) for more info.
476
476
  - `'unapproved'`: a tool that requires human-in-the-loop approval.
477
- See the [tools documentation](../tools.md#human-in-the-loop-tool-approval) for more info.
477
+ See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
478
478
  """
479
479
 
480
480
  @property
481
481
  def defer(self) -> bool:
482
482
  """Whether calls to this tool will be deferred.
483
483
 
484
- See the [tools documentation](../tools.md#deferred-tools) for more info.
484
+ See the [tools documentation](../deferred-tools.md#deferred-tools) for more info.
485
485
  """
486
486
  return self.kind in ('external', 'unapproved')
487
487