pydantic-ai-slim 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/_griffe.py CHANGED
@@ -4,8 +4,7 @@ import re
4
4
  from inspect import Signature
5
5
  from typing import Any, Callable, Literal, cast
6
6
 
7
- from _griffe.enumerations import DocstringSectionKind
8
- from _griffe.models import Docstring, Object as GriffeObject
7
+ from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
9
8
 
10
9
  DocstringStyle = Literal['google', 'numpy', 'sphinx']
11
10
 
pydantic_ai/agent.py CHANGED
@@ -6,7 +6,6 @@ import inspect
6
6
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
8
  from contextvars import ContextVar
9
- from dataclasses import dataclass, field
10
9
  from types import FrameType
11
10
  from typing import Any, Callable, Generic, Literal, cast, final, overload
12
11
 
@@ -40,6 +39,16 @@ __all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
40
39
 
41
40
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
42
41
 
42
+ # while waiting for https://github.com/pydantic/logfire/issues/745
43
+ try:
44
+ import logfire._internal.stack_info
45
+ except ImportError:
46
+ pass
47
+ else:
48
+ from pathlib import Path
49
+
50
+ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
51
+
43
52
  NoneType = type(None)
44
53
  EndStrategy = Literal['early', 'exhaustive']
45
54
  """The strategy for handling multiple tool calls when a final result is found.
@@ -50,7 +59,7 @@ EndStrategy = Literal['early', 'exhaustive']
50
59
 
51
60
 
52
61
  @final
53
- @dataclass(init=False)
62
+ @dataclasses.dataclass(init=False)
54
63
  class Agent(Generic[AgentDeps, ResultData]):
55
64
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
56
65
 
@@ -90,17 +99,17 @@ class Agent(Generic[AgentDeps, ResultData]):
90
99
  be merged with this value, with the runtime argument taking priority.
91
100
  """
92
101
 
93
- _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
94
- _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
95
- _allow_text_result: bool = field(repr=False)
96
- _system_prompts: tuple[str, ...] = field(repr=False)
97
- _function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
98
- _default_retries: int = field(repr=False)
99
- _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
100
- _deps_type: type[AgentDeps] = field(repr=False)
101
- _max_result_retries: int = field(repr=False)
102
- _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
103
- _override_model: _utils.Option[models.Model] = field(default=None, repr=False)
102
+ _result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
103
+ _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
104
+ _allow_text_result: bool = dataclasses.field(repr=False)
105
+ _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
106
+ _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
107
+ _default_retries: int = dataclasses.field(repr=False)
108
+ _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
109
+ _deps_type: type[AgentDeps] = dataclasses.field(repr=False)
110
+ _max_result_retries: int = dataclasses.field(repr=False)
111
+ _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
112
+ _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
104
113
 
105
114
  def __init__(
106
115
  self,
@@ -184,6 +193,7 @@ class Agent(Generic[AgentDeps, ResultData]):
184
193
  deps: AgentDeps = None,
185
194
  model_settings: ModelSettings | None = None,
186
195
  usage_limits: UsageLimits | None = None,
196
+ usage: result.Usage | None = None,
187
197
  infer_name: bool = True,
188
198
  ) -> result.RunResult[ResultData]:
189
199
  """Run the agent with a user prompt in async mode.
@@ -206,6 +216,7 @@ class Agent(Generic[AgentDeps, ResultData]):
206
216
  deps: Optional dependencies to use for this run.
207
217
  model_settings: Optional settings to use for this model's request.
208
218
  usage_limits: Optional limits on model request count or token usage.
219
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
209
220
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
210
221
 
211
222
  Returns:
@@ -213,7 +224,7 @@ class Agent(Generic[AgentDeps, ResultData]):
213
224
  """
214
225
  if infer_name and self.name is None:
215
226
  self._infer_name(inspect.currentframe())
216
- model_used, mode_selection = await self._get_model(model)
227
+ model_used = await self._get_model(model)
217
228
 
218
229
  deps = self._get_deps(deps)
219
230
  new_message_index = len(message_history) if message_history else 0
@@ -222,40 +233,36 @@ class Agent(Generic[AgentDeps, ResultData]):
222
233
  '{agent_name} run {prompt=}',
223
234
  prompt=user_prompt,
224
235
  agent=self,
225
- mode_selection=mode_selection,
226
236
  model_name=model_used.name(),
227
237
  agent_name=self.name or 'agent',
228
238
  ) as run_span:
229
- run_context = RunContext(deps, 0, [], None, model_used)
239
+ run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
230
240
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
231
241
  run_context.messages = messages
232
242
 
233
243
  for tool in self._function_tools.values():
234
244
  tool.current_retry = 0
235
245
 
236
- usage = result.Usage(requests=0)
237
246
  model_settings = merge_model_settings(self.model_settings, model_settings)
238
247
  usage_limits = usage_limits or UsageLimits()
239
248
 
240
- run_step = 0
241
249
  while True:
242
- usage_limits.check_before_request(usage)
250
+ usage_limits.check_before_request(run_context.usage)
243
251
 
244
- run_step += 1
245
- with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
252
+ run_context.run_step += 1
253
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
246
254
  agent_model = await self._prepare_model(run_context)
247
255
 
248
- with _logfire.span('model request', run_step=run_step) as model_req_span:
256
+ with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
249
257
  model_response, request_usage = await agent_model.request(messages, model_settings)
250
258
  model_req_span.set_attribute('response', model_response)
251
259
  model_req_span.set_attribute('usage', request_usage)
252
260
 
253
261
  messages.append(model_response)
254
- usage += request_usage
255
- usage.requests += 1
256
- usage_limits.check_tokens(request_usage)
262
+ run_context.usage.incr(request_usage, requests=1)
263
+ usage_limits.check_tokens(run_context.usage)
257
264
 
258
- with _logfire.span('handle model response', run_step=run_step) as handle_span:
265
+ with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
259
266
  final_result, tool_responses = await self._handle_model_response(model_response, run_context)
260
267
 
261
268
  if tool_responses:
@@ -266,10 +273,10 @@ class Agent(Generic[AgentDeps, ResultData]):
266
273
  if final_result is not None:
267
274
  result_data = final_result.data
268
275
  run_span.set_attribute('all_messages', messages)
269
- run_span.set_attribute('usage', usage)
276
+ run_span.set_attribute('usage', run_context.usage)
270
277
  handle_span.set_attribute('result', result_data)
271
278
  handle_span.message = 'handle model response -> final result'
272
- return result.RunResult(messages, new_message_index, result_data, usage)
279
+ return result.RunResult(messages, new_message_index, result_data, run_context.usage)
273
280
  else:
274
281
  # continue the conversation
275
282
  handle_span.set_attribute('tool_responses', tool_responses)
@@ -285,6 +292,7 @@ class Agent(Generic[AgentDeps, ResultData]):
285
292
  deps: AgentDeps = None,
286
293
  model_settings: ModelSettings | None = None,
287
294
  usage_limits: UsageLimits | None = None,
295
+ usage: result.Usage | None = None,
288
296
  infer_name: bool = True,
289
297
  ) -> result.RunResult[ResultData]:
290
298
  """Run the agent with a user prompt synchronously.
@@ -311,6 +319,7 @@ class Agent(Generic[AgentDeps, ResultData]):
311
319
  deps: Optional dependencies to use for this run.
312
320
  model_settings: Optional settings to use for this model's request.
313
321
  usage_limits: Optional limits on model request count or token usage.
322
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
314
323
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
315
324
 
316
325
  Returns:
@@ -326,6 +335,7 @@ class Agent(Generic[AgentDeps, ResultData]):
326
335
  deps=deps,
327
336
  model_settings=model_settings,
328
337
  usage_limits=usage_limits,
338
+ usage=usage,
329
339
  infer_name=False,
330
340
  )
331
341
  )
@@ -340,6 +350,7 @@ class Agent(Generic[AgentDeps, ResultData]):
340
350
  deps: AgentDeps = None,
341
351
  model_settings: ModelSettings | None = None,
342
352
  usage_limits: UsageLimits | None = None,
353
+ usage: result.Usage | None = None,
343
354
  infer_name: bool = True,
344
355
  ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
345
356
  """Run the agent with a user prompt in async mode, returning a streamed response.
@@ -363,6 +374,7 @@ class Agent(Generic[AgentDeps, ResultData]):
363
374
  deps: Optional dependencies to use for this run.
364
375
  model_settings: Optional settings to use for this model's request.
365
376
  usage_limits: Optional limits on model request count or token usage.
377
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
366
378
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
367
379
 
368
380
  Returns:
@@ -372,7 +384,7 @@ class Agent(Generic[AgentDeps, ResultData]):
372
384
  # f_back because `asynccontextmanager` adds one frame
373
385
  if frame := inspect.currentframe(): # pragma: no branch
374
386
  self._infer_name(frame.f_back)
375
- model_used, mode_selection = await self._get_model(model)
387
+ model_used = await self._get_model(model)
376
388
 
377
389
  deps = self._get_deps(deps)
378
390
  new_message_index = len(message_history) if message_history else 0
@@ -381,32 +393,29 @@ class Agent(Generic[AgentDeps, ResultData]):
381
393
  '{agent_name} run stream {prompt=}',
382
394
  prompt=user_prompt,
383
395
  agent=self,
384
- mode_selection=mode_selection,
385
396
  model_name=model_used.name(),
386
397
  agent_name=self.name or 'agent',
387
398
  ) as run_span:
388
- run_context = RunContext(deps, 0, [], None, model_used)
399
+ run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
389
400
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
390
401
  run_context.messages = messages
391
402
 
392
403
  for tool in self._function_tools.values():
393
404
  tool.current_retry = 0
394
405
 
395
- usage = result.Usage()
396
406
  model_settings = merge_model_settings(self.model_settings, model_settings)
397
407
  usage_limits = usage_limits or UsageLimits()
398
408
 
399
- run_step = 0
400
409
  while True:
401
- run_step += 1
402
- usage_limits.check_before_request(usage)
410
+ run_context.run_step += 1
411
+ usage_limits.check_before_request(run_context.usage)
403
412
 
404
- with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
413
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
405
414
  agent_model = await self._prepare_model(run_context)
406
415
 
407
- with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
416
+ with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
408
417
  async with agent_model.request_stream(messages, model_settings) as model_response:
409
- usage.requests += 1
418
+ run_context.usage.requests += 1
410
419
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
411
420
  # We want to end the "model request" span here, but we can't exit the context manager
412
421
  # in the traditional way
@@ -442,7 +451,6 @@ class Agent(Generic[AgentDeps, ResultData]):
442
451
  yield result.StreamedRunResult(
443
452
  messages,
444
453
  new_message_index,
445
- usage,
446
454
  usage_limits,
447
455
  result_stream,
448
456
  self._result_schema,
@@ -466,8 +474,8 @@ class Agent(Generic[AgentDeps, ResultData]):
466
474
  handle_span.message = f'handle model response -> {tool_responses_str}'
467
475
  # the model_response should have been fully streamed by now, we can add its usage
468
476
  model_response_usage = model_response.usage()
469
- usage += model_response_usage
470
- usage_limits.check_tokens(usage)
477
+ run_context.usage.incr(model_response_usage)
478
+ usage_limits.check_tokens(run_context.usage)
471
479
 
472
480
  @contextmanager
473
481
  def override(
@@ -778,14 +786,14 @@ class Agent(Generic[AgentDeps, ResultData]):
778
786
 
779
787
  self._function_tools[tool.name] = tool
780
788
 
781
- async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
789
+ async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
782
790
  """Create a model configured for this agent.
783
791
 
784
792
  Args:
785
793
  model: model to use for this run, required if `model` was not set when creating the agent.
786
794
 
787
795
  Returns:
788
- a tuple of `(model used, how the model was selected)`
796
+ The model used
789
797
  """
790
798
  model_: models.Model
791
799
  if some_model := self._override_model:
@@ -796,18 +804,15 @@ class Agent(Generic[AgentDeps, ResultData]):
796
804
  '(Even when `override(model=...)` is customizing the model that will actually be called)'
797
805
  )
798
806
  model_ = some_model.value
799
- mode_selection = 'override-model'
800
807
  elif model is not None:
801
808
  model_ = models.infer_model(model)
802
- mode_selection = 'custom'
803
809
  elif self.model is not None:
804
810
  # noinspection PyTypeChecker
805
811
  model_ = self.model = models.infer_model(self.model)
806
- mode_selection = 'from-agent'
807
812
  else:
808
813
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
809
814
 
810
- return model_, mode_selection
815
+ return model_
811
816
 
812
817
  async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
813
818
  """Build tools and create an agent model."""
@@ -830,15 +835,15 @@ class Agent(Generic[AgentDeps, ResultData]):
830
835
  self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
831
836
  ) -> list[_messages.ModelMessage]:
832
837
  try:
833
- messages = _messages_ctx_var.get()
838
+ ctx_messages = _messages_ctx_var.get()
834
839
  except LookupError:
835
- messages = []
840
+ messages: list[_messages.ModelMessage] = []
836
841
  else:
837
- if messages:
838
- raise exceptions.UserError(
839
- 'The capture_run_messages() context manager may only be used to wrap '
840
- 'one call to run(), run_sync(), or run_stream().'
841
- )
842
+ if ctx_messages.used:
843
+ messages = []
844
+ else:
845
+ messages = ctx_messages.messages
846
+ ctx_messages.used = True
842
847
 
843
848
  if message_history:
844
849
  # shallow copy messages
@@ -1132,7 +1137,13 @@ class Agent(Generic[AgentDeps, ResultData]):
1132
1137
  raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1133
1138
 
1134
1139
 
1135
- _messages_ctx_var: ContextVar[list[_messages.ModelMessage]] = ContextVar('var')
1140
+ @dataclasses.dataclass
1141
+ class _RunMessages:
1142
+ messages: list[_messages.ModelMessage]
1143
+ used: bool = False
1144
+
1145
+
1146
+ _messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
1136
1147
 
1137
1148
 
1138
1149
  @contextmanager
@@ -1156,21 +1167,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1156
1167
  ```
1157
1168
 
1158
1169
  !!! note
1159
- You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context.
1160
- If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised.
1170
+ If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
1171
+ `messages` will represent the messages exchanged during the first call only.
1161
1172
  """
1162
1173
  try:
1163
- yield _messages_ctx_var.get()
1174
+ yield _messages_ctx_var.get().messages
1164
1175
  except LookupError:
1165
1176
  messages: list[_messages.ModelMessage] = []
1166
- token = _messages_ctx_var.set(messages)
1177
+ token = _messages_ctx_var.set(_RunMessages(messages))
1167
1178
  try:
1168
1179
  yield messages
1169
1180
  finally:
1170
1181
  _messages_ctx_var.reset(token)
1171
1182
 
1172
1183
 
1173
- @dataclass
1184
+ @dataclasses.dataclass
1174
1185
  class _MarkFinalResult(Generic[ResultData]):
1175
1186
  """Marker class to indicate that the result is the final result.
1176
1187
 
@@ -444,7 +444,8 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
444
444
  if isinstance(item, ToolCallPart):
445
445
  parts.append(_function_call_part_from_call(item))
446
446
  elif isinstance(item, TextPart):
447
- parts.append(_GeminiTextPart(text=item.content))
447
+ if item.content:
448
+ parts.append(_GeminiTextPart(text=item.content))
448
449
  else:
449
450
  assert_never(item)
450
451
  return _GeminiContent(role='model', parts=parts)
@@ -701,7 +702,7 @@ class _GeminiJsonSchema:
701
702
 
702
703
  def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
703
704
  schema.pop('title', None)
704
- schema.pop('default', None)
705
+ default = schema.pop('default', _utils.UNSET)
705
706
  if ref := schema.pop('$ref', None):
706
707
  # noinspection PyTypeChecker
707
708
  key = re.sub(r'^#/\$defs/', '', ref)
@@ -714,8 +715,14 @@ class _GeminiJsonSchema:
714
715
  return
715
716
 
716
717
  if any_of := schema.get('anyOf'):
717
- for schema in any_of:
718
- self._simplify(schema, refs_stack)
718
+ for item_schema in any_of:
719
+ self._simplify(item_schema, refs_stack)
720
+ if len(any_of) == 2 and {'type': 'null'} in any_of and default is None:
721
+ for item_schema in any_of:
722
+ if item_schema != {'type': 'null'}:
723
+ schema.clear()
724
+ schema.update(item_schema)
725
+ return
719
726
 
720
727
  type_ = schema.get('type')
721
728
 
@@ -71,6 +71,7 @@ class OllamaModel(Model):
71
71
  model_name: OllamaModelName,
72
72
  *,
73
73
  base_url: str | None = 'http://localhost:11434/v1/',
74
+ api_key: str = 'ollama',
74
75
  openai_client: AsyncOpenAI | None = None,
75
76
  http_client: AsyncHTTPClient | None = None,
76
77
  ):
@@ -83,6 +84,8 @@ class OllamaModel(Model):
83
84
  model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
84
85
  You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
85
86
  base_url: The base url for the ollama requests. The default value is the ollama default
87
+ api_key: The API key to use for authentication. Defaults to 'ollama' for local instances,
88
+ but can be customized for proxy setups that require authentication
86
89
  openai_client: An existing
87
90
  [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
88
91
  client to use, if provided, `base_url` and `http_client` must be `None`.
@@ -96,7 +99,7 @@ class OllamaModel(Model):
96
99
  else:
97
100
  # API key is not required for ollama but a value is required to create the client
98
101
  http_client_ = http_client or cached_async_http_client()
99
- oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client_)
102
+ oai_client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client_)
100
103
  self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
101
104
 
102
105
  async def agent_model(
@@ -16,6 +16,7 @@ from ..messages import (
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
+ ModelResponsePart,
19
20
  RetryPromptPart,
20
21
  TextPart,
21
22
  ToolCallPart,
@@ -177,13 +178,23 @@ class TestAgentModel(AgentModel):
177
178
  # check if there are any retry prompts, if so retry them
178
179
  new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
179
180
  if new_retry_names:
180
- return ModelResponse(
181
- parts=[
182
- ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
183
- for name, args in self.tool_calls
184
- if name in new_retry_names
185
- ]
186
- )
181
+ # Handle retries for both function tools and result tools
182
+ # Check function tools first
183
+ retry_parts: list[ModelResponsePart] = [
184
+ ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
185
+ for name, args in self.tool_calls
186
+ if name in new_retry_names
187
+ ]
188
+ # Check result tools
189
+ if self.result_tools:
190
+ retry_parts.extend(
191
+ [
192
+ ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
193
+ for tool in self.result_tools
194
+ if tool.name in new_retry_names
195
+ ]
196
+ )
197
+ return ModelResponse(parts=retry_parts)
187
198
 
188
199
  if response_text := self.result.left:
189
200
  if response_text.value is None:
pydantic_ai/result.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import AsyncIterator, Awaitable, Callable
5
+ from copy import copy
5
6
  from dataclasses import dataclass, field
6
7
  from datetime import datetime
7
8
  from typing import Generic, Union, cast
@@ -63,25 +64,33 @@ class Usage:
63
64
  details: dict[str, int] | None = None
64
65
  """Any extra details returned by the model."""
65
66
 
66
- def __add__(self, other: Usage) -> Usage:
67
- """Add two Usages together.
67
+ def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
68
+ """Increment the usage in place.
68
69
 
69
- This is provided so it's trivial to sum usage information from multiple requests and runs.
70
+ Args:
71
+ incr_usage: The usage to increment by.
72
+ requests: The number of requests to increment by in addition to `incr_usage.requests`.
70
73
  """
71
- counts: dict[str, int] = {}
74
+ self.requests += requests
72
75
  for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
73
76
  self_value = getattr(self, f)
74
- other_value = getattr(other, f)
77
+ other_value = getattr(incr_usage, f)
75
78
  if self_value is not None or other_value is not None:
76
- counts[f] = (self_value or 0) + (other_value or 0)
79
+ setattr(self, f, (self_value or 0) + (other_value or 0))
77
80
 
78
- details = self.details.copy() if self.details is not None else None
79
- if other.details is not None:
80
- details = details or {}
81
- for key, value in other.details.items():
82
- details[key] = details.get(key, 0) + value
81
+ if incr_usage.details:
82
+ self.details = self.details or {}
83
+ for key, value in incr_usage.details.items():
84
+ self.details[key] = self.details.get(key, 0) + value
83
85
 
84
- return Usage(**counts, details=details or None)
86
+ def __add__(self, other: Usage) -> Usage:
87
+ """Add two Usages together.
88
+
89
+ This is provided so it's trivial to sum usage information from multiple requests and runs.
90
+ """
91
+ new_usage = copy(self)
92
+ new_usage.incr(other)
93
+ return new_usage
85
94
 
86
95
 
87
96
  @dataclass
@@ -136,8 +145,6 @@ class RunResult(_BaseRunResult[ResultData]):
136
145
  class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
137
146
  """Result of a streamed run that returns structured data via a tool call."""
138
147
 
139
- usage_so_far: Usage
140
- """Usage of the run up until the last request."""
141
148
  _usage_limits: UsageLimits | None
142
149
  _stream_response: models.EitherStreamedResponse
143
150
  _result_schema: _result.ResultSchema[ResultData] | None
@@ -306,7 +313,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
306
313
  !!! note
307
314
  This won't return the full usage until the stream is finished.
308
315
  """
309
- return self.usage_so_far + self._stream_response.usage()
316
+ return self._run_ctx.usage + self._stream_response.usage()
310
317
 
311
318
  def timestamp(self) -> datetime:
312
319
  """Get the timestamp of the response."""
pydantic_ai/settings.py CHANGED
@@ -136,6 +136,6 @@ class UsageLimits:
136
136
  f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
137
137
  )
138
138
 
139
- total_tokens = request_tokens + response_tokens
139
+ total_tokens = usage.total_tokens or 0
140
140
  if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
141
141
  raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
pydantic_ai/tools.py CHANGED
@@ -4,7 +4,7 @@ import dataclasses
4
4
  import inspect
5
5
  from collections.abc import Awaitable
6
6
  from dataclasses import dataclass, field
7
- from typing import Any, Callable, Generic, TypeVar, Union, cast
7
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
8
8
 
9
9
  from pydantic import ValidationError
10
10
  from pydantic_core import SchemaValidator
@@ -13,6 +13,9 @@ from typing_extensions import Concatenate, ParamSpec, TypeAlias
13
13
  from . import _pydantic, _utils, messages as _messages, models
14
14
  from .exceptions import ModelRetry, UnexpectedModelBehavior
15
15
 
16
+ if TYPE_CHECKING:
17
+ from .result import Usage
18
+
16
19
  __all__ = (
17
20
  'AgentDeps',
18
21
  'RunContext',
@@ -37,14 +40,20 @@ class RunContext(Generic[AgentDeps]):
37
40
 
38
41
  deps: AgentDeps
39
42
  """Dependencies for the agent."""
40
- retry: int
41
- """Number of retries so far."""
42
- messages: list[_messages.ModelMessage]
43
- """Messages exchanged in the conversation so far."""
44
- tool_name: str | None
45
- """Name of the tool being called."""
46
43
  model: models.Model
47
44
  """The model used in this run."""
45
+ usage: Usage
46
+ """LLM usage associated with the run."""
47
+ prompt: str
48
+ """The original user prompt passed to the run."""
49
+ messages: list[_messages.ModelMessage] = field(default_factory=list)
50
+ """Messages exchanged in the conversation so far."""
51
+ tool_name: str | None = None
52
+ """Name of the tool being called."""
53
+ retry: int = 0
54
+ """Number of retries so far."""
55
+ run_step: int = 0
56
+ """The current step in the run."""
48
57
 
49
58
  def replace_with(
50
59
  self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.15
3
+ Version: 0.0.16
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -1,26 +1,26 @@
1
1
  pydantic_ai/__init__.py,sha256=FbYetEgT6OO25u2KF5ZnFxKpz5DtnSpfckRXP4mjl8E,489
2
- pydantic_ai/_griffe.py,sha256=pRjCJ6B1hhx6k46XJgl9zF6aRYxRmqEZKFok8unp4Iw,3449
2
+ pydantic_ai/_griffe.py,sha256=Wqk3AuyeWuPwE5s1GbMeCsERelx1B4QcU9uYZSoko8s,3409
3
3
  pydantic_ai/_pydantic.py,sha256=qXi5IsyiYOHeg_-qozCdxkfeqw2z0gBTjqgywBCiJWo,8125
4
4
  pydantic_ai/_result.py,sha256=cUSugZQV0n5Z4fFHiMqua-2xs_0S6m-rr-yd6QS3nFE,10317
5
5
  pydantic_ai/_system_prompt.py,sha256=MZJWksIoS5GM3Au5lznlcQnC-h7eqwtE7oI5WFgRcOg,1090
6
6
  pydantic_ai/_utils.py,sha256=skWNgm89US_x1EpxdRy5wCkghBrm1XgxFCiEh6wAkAo,8753
7
- pydantic_ai/agent.py,sha256=qa3Ox5pXEDzxcTJgwN0gebV37qQKizVc0PW-1q5MMn4,51662
7
+ pydantic_ai/agent.py,sha256=NJTcPSlqb4Fd-x9pDPuoXGCwFGF1GHcHevutoB0Busw,52333
8
8
  pydantic_ai/exceptions.py,sha256=eGDKX6bGhgVxXBzu81Sk3iiAkXr0GUtgT7bD5Rxlqpg,2028
9
9
  pydantic_ai/messages.py,sha256=ImbWY8Ft3mxInUQ08EmIWywf4nJBvTiJhmsECRYDkSQ,8968
10
10
  pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- pydantic_ai/result.py,sha256=n2cEFwm8WhFzHuT6KRhZ2itQVPShGUd7ECbOPmRIIoM,15335
12
- pydantic_ai/settings.py,sha256=R71rBg2u2SgjxKWcpUdSzm7icV5_apF3b0BlBqa2lpA,4927
13
- pydantic_ai/tools.py,sha256=wNYzfdp1XjIVw_8bqh5GP3x3k_12gHTC74HWNvHAAwI,11447
11
+ pydantic_ai/result.py,sha256=LbZVHZnJnQwgegSz5PtwS9r_ifrJnLRpsa9xjYnHg1g,15549
12
+ pydantic_ai/settings.py,sha256=W8krcFsujjhE03rwckrz39F4Dz_9RwdBSeEF3izK0-Y,4918
13
+ pydantic_ai/tools.py,sha256=mnh3Lvs0Ri0FkqpV1MUooExNN4epTCcBKw6DyZvNSQ8,11745
14
14
  pydantic_ai/models/__init__.py,sha256=XHt02IDQAircb-lEkIbIcuabSAIh5_UKnz2V1xN0Glw,10926
15
15
  pydantic_ai/models/anthropic.py,sha256=EUZgmvT0jhMDbooBp_jfW0z2cM5jTMuAhVws1XKgaNs,13451
16
16
  pydantic_ai/models/function.py,sha256=i7qkS_31aHrTbYVh6OzQ7Cwucz44F5PjT2EJK3GMphw,10573
17
- pydantic_ai/models/gemini.py,sha256=8vdcW4izL9NUGFj6lcD9yIPaakCtsmHauTvKwlTzD14,28207
17
+ pydantic_ai/models/gemini.py,sha256=Sr19D2hN8iEAcoLlzv5883pto90TgEr_xiGlV8hMOwA,28572
18
18
  pydantic_ai/models/groq.py,sha256=ZoPkuWJrf78JPnTRfZhi7v0ETgxJKNN5dH8BLWagGGk,15770
19
19
  pydantic_ai/models/mistral.py,sha256=xGVI6-b8-9vnFickPPI2cRaHEWLc0jKKUM_vMjipf-U,25894
20
- pydantic_ai/models/ollama.py,sha256=i3mMXkXu9xL6f4c52Eyx3j4aHKfYoloFondlGHPtkS4,3971
20
+ pydantic_ai/models/ollama.py,sha256=ELqxhcNcnvQBnadd3gukS01zprUp6v8N_h1P5K-uf6c,4188
21
21
  pydantic_ai/models/openai.py,sha256=qFFInL3NbgfGcsAWigxMP5mscp76hC-jJimHc9woU6Y,16518
22
- pydantic_ai/models/test.py,sha256=pty5qaudHsSDvdE89HqMj-kmd4UMV9VJI2YGtdfOX1o,15960
22
+ pydantic_ai/models/test.py,sha256=u2pdZd9OLXQ_jI6CaVt96udXuIcv0Hfnfqd3pFGmeJM,16514
23
23
  pydantic_ai/models/vertexai.py,sha256=DBCBfpvpIhZaMG7cKvRl5rugCZqJqqEFm74uBc45weo,9259
24
- pydantic_ai_slim-0.0.15.dist-info/METADATA,sha256=CM_cQ6RRb9PFJVXKVa0JIVsp_bCrdCKWGnEu0KBiD0c,2730
25
- pydantic_ai_slim-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
- pydantic_ai_slim-0.0.15.dist-info/RECORD,,
24
+ pydantic_ai_slim-0.0.16.dist-info/METADATA,sha256=4udd7j2erIuMC0ekYgmgQAqsKfhA5sLsKzTcD_QyOeo,2730
25
+ pydantic_ai_slim-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
+ pydantic_ai_slim-0.0.16.dist-info/RECORD,,