pydantic-ai-slim 0.0.30__py3-none-any.whl → 0.0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from importlib.metadata import version
2
2
 
3
- from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
3
+ from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
4
4
  from .exceptions import (
5
5
  AgentRunError,
6
6
  FallbackExceptionGroup,
@@ -18,7 +18,7 @@ __all__ = (
18
18
  # agent
19
19
  'Agent',
20
20
  'EndStrategy',
21
- 'HandleResponseNode',
21
+ 'CallToolsNode',
22
22
  'ModelRequestNode',
23
23
  'UserPromptNode',
24
24
  'capture_run_messages',
@@ -23,6 +23,7 @@ from . import (
23
23
  result,
24
24
  usage as _usage,
25
25
  )
26
+ from .models.instrumented import InstrumentedModel
26
27
  from .result import ResultDataT
27
28
  from .settings import ModelSettings, merge_model_settings
28
29
  from .tools import (
@@ -36,7 +37,7 @@ __all__ = (
36
37
  'GraphAgentDeps',
37
38
  'UserPromptNode',
38
39
  'ModelRequestNode',
39
- 'HandleResponseNode',
40
+ 'CallToolsNode',
40
41
  'build_run_context',
41
42
  'capture_run_messages',
42
43
  )
@@ -243,12 +244,12 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
243
244
 
244
245
  request: _messages.ModelRequest
245
246
 
246
- _result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
247
+ _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
247
248
  _did_stream: bool = field(default=False, repr=False)
248
249
 
249
250
  async def run(
250
251
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
251
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
252
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
252
253
  if self._result is not None:
253
254
  return self._result
254
255
 
@@ -286,39 +287,33 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
286
287
  assert not self._did_stream, 'stream() should only be called once per node'
287
288
 
288
289
  model_settings, model_request_parameters = await self._prepare_request(ctx)
289
- with _logfire.span('model request', run_step=ctx.state.run_step) as span:
290
- async with ctx.deps.model.request_stream(
291
- ctx.state.message_history, model_settings, model_request_parameters
292
- ) as streamed_response:
293
- self._did_stream = True
294
- ctx.state.usage.incr(_usage.Usage(), requests=1)
295
- yield streamed_response
296
- # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
297
- # otherwise usage won't be properly counted:
298
- async for _ in streamed_response:
299
- pass
300
- model_response = streamed_response.get()
301
- request_usage = streamed_response.usage()
302
- span.set_attribute('response', model_response)
303
- span.set_attribute('usage', request_usage)
290
+ async with ctx.deps.model.request_stream(
291
+ ctx.state.message_history, model_settings, model_request_parameters
292
+ ) as streamed_response:
293
+ self._did_stream = True
294
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
295
+ yield streamed_response
296
+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
297
+ # otherwise usage won't be properly counted:
298
+ async for _ in streamed_response:
299
+ pass
300
+ model_response = streamed_response.get()
301
+ request_usage = streamed_response.usage()
304
302
 
305
303
  self._finish_handling(ctx, model_response, request_usage)
306
304
  assert self._result is not None # this should be set by the previous line
307
305
 
308
306
  async def _make_request(
309
307
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
310
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
308
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
311
309
  if self._result is not None:
312
310
  return self._result
313
311
 
314
312
  model_settings, model_request_parameters = await self._prepare_request(ctx)
315
- with _logfire.span('model request', run_step=ctx.state.run_step) as span:
316
- model_response, request_usage = await ctx.deps.model.request(
317
- ctx.state.message_history, model_settings, model_request_parameters
318
- )
319
- ctx.state.usage.incr(_usage.Usage(), requests=1)
320
- span.set_attribute('response', model_response)
321
- span.set_attribute('usage', request_usage)
313
+ model_response, request_usage = await ctx.deps.model.request(
314
+ ctx.state.message_history, model_settings, model_request_parameters
315
+ )
316
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
322
317
 
323
318
  return self._finish_handling(ctx, model_response, request_usage)
324
319
 
@@ -344,7 +339,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
344
339
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
345
340
  response: _messages.ModelResponse,
346
341
  usage: _usage.Usage,
347
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
342
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
348
343
  # Update usage
349
344
  ctx.state.usage.incr(usage, requests=0)
350
345
  if ctx.deps.usage_limits:
@@ -354,13 +349,13 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
354
349
  ctx.state.message_history.append(response)
355
350
 
356
351
  # Set the `_result` attribute since we can't use `return` in an async iterator
357
- self._result = HandleResponseNode(response)
352
+ self._result = CallToolsNode(response)
358
353
 
359
354
  return self._result
360
355
 
361
356
 
362
357
  @dataclasses.dataclass
363
- class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
358
+ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
364
359
  """Process a model response, and decide whether to end the run or make a new request."""
365
360
 
366
361
  model_response: _messages.ModelResponse
@@ -454,8 +449,7 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
454
449
  final_result: result.FinalResult[NodeRunEndT] | None = None
455
450
  parts: list[_messages.ModelRequestPart] = []
456
451
  if result_schema is not None:
457
- if match := result_schema.find_tool(tool_calls):
458
- call, result_tool = match
452
+ for call, result_tool in result_schema.find_tool(tool_calls):
459
453
  try:
460
454
  result_data = result_tool.validate(call)
461
455
  result_data = await _validate_result(result_data, ctx, call)
@@ -465,12 +459,17 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
465
459
  ctx.state.increment_retries(ctx.deps.max_result_retries)
466
460
  parts.append(e.tool_retry)
467
461
  else:
468
- final_result = result.FinalResult(result_data, call.tool_name)
462
+ final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
463
+ break
469
464
 
470
465
  # Then build the other request parts based on end strategy
471
466
  tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
472
467
  async for event in process_function_tools(
473
- tool_calls, final_result and final_result.tool_name, ctx, tool_responses
468
+ tool_calls,
469
+ final_result and final_result.tool_name,
470
+ final_result and final_result.tool_call_id,
471
+ ctx,
472
+ tool_responses,
474
473
  ):
475
474
  yield event
476
475
 
@@ -496,7 +495,10 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
496
495
  messages.append(_messages.ModelRequest(parts=tool_responses))
497
496
 
498
497
  run_span.set_attribute('usage', usage)
499
- run_span.set_attribute('all_messages', messages)
498
+ run_span.set_attribute(
499
+ 'all_messages_events',
500
+ [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)],
501
+ )
500
502
 
501
503
  # End the run with self.data
502
504
  return End(final_result)
@@ -518,7 +520,7 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
518
520
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
519
521
  else:
520
522
  # The following cast is safe because we know `str` is an allowed result type
521
- return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
523
+ return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
522
524
  else:
523
525
  ctx.state.increment_retries(ctx.deps.max_result_retries)
524
526
  return ModelRequestNode[DepsT, NodeRunEndT](
@@ -547,6 +549,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
547
549
  async def process_function_tools(
548
550
  tool_calls: list[_messages.ToolCallPart],
549
551
  result_tool_name: str | None,
552
+ result_tool_call_id: str | None,
550
553
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
551
554
  output_parts: list[_messages.ModelRequestPart],
552
555
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
@@ -566,7 +569,11 @@ async def process_function_tools(
566
569
  calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
567
570
  call_index_to_event_id: dict[int, str] = {}
568
571
  for call in tool_calls:
569
- if call.tool_name == result_tool_name and not found_used_result_tool:
572
+ if (
573
+ call.tool_name == result_tool_name
574
+ and call.tool_call_id == result_tool_call_id
575
+ and not found_used_result_tool
576
+ ):
570
577
  found_used_result_tool = True
571
578
  output_parts.append(
572
579
  _messages.ToolReturnPart(
@@ -593,9 +600,14 @@ async def process_function_tools(
593
600
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
594
601
  # validation, we don't add another part here
595
602
  if result_tool_name is not None:
603
+ if found_used_result_tool:
604
+ content = 'Result tool not used - a final result was already processed.'
605
+ else:
606
+ # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
607
+ content = 'Result tool not used - result failed validation.'
596
608
  part = _messages.ToolReturnPart(
597
609
  tool_name=call.tool_name,
598
- content='Result tool not used - a final result was already processed.',
610
+ content=content,
599
611
  tool_call_id=call.tool_call_id,
600
612
  )
601
613
  output_parts.append(part)
@@ -716,7 +728,7 @@ def build_agent_graph(
716
728
  nodes = (
717
729
  UserPromptNode[DepsT],
718
730
  ModelRequestNode[DepsT],
719
- HandleResponseNode[DepsT],
731
+ CallToolsNode[DepsT],
720
732
  )
721
733
  graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
722
734
  nodes=nodes,
pydantic_ai/_result.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import inspect
4
4
  import sys
5
5
  import types
6
- from collections.abc import Awaitable, Iterable
6
+ from collections.abc import Awaitable, Iterable, Iterator
7
7
  from dataclasses import dataclass, field
8
8
  from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
9
9
 
@@ -127,12 +127,12 @@ class ResultSchema(Generic[ResultDataT]):
127
127
  def find_tool(
128
128
  self,
129
129
  parts: Iterable[_messages.ModelResponsePart],
130
- ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
130
+ ) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
131
131
  """Find a tool that matches one of the calls."""
132
132
  for part in parts:
133
133
  if isinstance(part, _messages.ToolCallPart):
134
134
  if result := self.tools.get(part.tool_name):
135
- return part, result
135
+ yield part, result
136
136
 
137
137
  def tool_names(self) -> list[str]:
138
138
  """Return the names of the tools."""
pydantic_ai/_utils.py CHANGED
@@ -48,6 +48,8 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
48
48
 
49
49
  if schema.get('type') == 'object':
50
50
  return schema
51
+ elif schema.get('$ref') is not None:
52
+ return schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
51
53
  else:
52
54
  raise UserError('Schema must be an object')
53
55
 
pydantic_ai/agent.py CHANGED
@@ -25,6 +25,7 @@ from . import (
25
25
  result,
26
26
  usage as _usage,
27
27
  )
28
+ from .models.instrumented import InstrumentedModel
28
29
  from .result import FinalResult, ResultDataT, StreamedRunResult
29
30
  from .settings import ModelSettings, merge_model_settings
30
31
  from .tools import (
@@ -42,7 +43,7 @@ from .tools import (
42
43
  # Re-exporting like this improves auto-import behavior in PyCharm
43
44
  capture_run_messages = _agent_graph.capture_run_messages
44
45
  EndStrategy = _agent_graph.EndStrategy
45
- HandleResponseNode = _agent_graph.HandleResponseNode
46
+ CallToolsNode = _agent_graph.CallToolsNode
46
47
  ModelRequestNode = _agent_graph.ModelRequestNode
47
48
  UserPromptNode = _agent_graph.UserPromptNode
48
49
 
@@ -52,7 +53,7 @@ __all__ = (
52
53
  'AgentRunResult',
53
54
  'capture_run_messages',
54
55
  'EndStrategy',
55
- 'HandleResponseNode',
56
+ 'CallToolsNode',
56
57
  'ModelRequestNode',
57
58
  'UserPromptNode',
58
59
  )
@@ -294,7 +295,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
294
295
  """
295
296
  if infer_name and self.name is None:
296
297
  self._infer_name(inspect.currentframe())
297
- with self.iter(
298
+ async with self.iter(
298
299
  user_prompt=user_prompt,
299
300
  result_type=result_type,
300
301
  message_history=message_history,
@@ -310,8 +311,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
310
311
  assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly'
311
312
  return final_result
312
313
 
313
- @contextmanager
314
- def iter(
314
+ @asynccontextmanager
315
+ async def iter(
315
316
  self,
316
317
  user_prompt: str | Sequence[_messages.UserContent],
317
318
  *,
@@ -323,7 +324,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
323
324
  usage_limits: _usage.UsageLimits | None = None,
324
325
  usage: _usage.Usage | None = None,
325
326
  infer_name: bool = True,
326
- ) -> Iterator[AgentRun[AgentDepsT, Any]]:
327
+ ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
327
328
  """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
328
329
 
329
330
  This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an
@@ -344,7 +345,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
344
345
 
345
346
  async def main():
346
347
  nodes = []
347
- with agent.iter('What is the capital of France?') as agent_run:
348
+ async with agent.iter('What is the capital of France?') as agent_run:
348
349
  async for node in agent_run:
349
350
  nodes.append(node)
350
351
  print(nodes)
@@ -362,7 +363,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
362
363
  kind='request',
363
364
  )
364
365
  ),
365
- HandleResponseNode(
366
+ CallToolsNode(
366
367
  model_response=ModelResponse(
367
368
  parts=[TextPart(content='Paris', part_kind='text')],
368
369
  model_name='gpt-4o',
@@ -370,7 +371,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
370
371
  kind='response',
371
372
  )
372
373
  ),
373
- End(data=FinalResult(data='Paris', tool_name=None)),
374
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
374
375
  ]
375
376
  '''
376
377
  print(agent_run.result.data)
@@ -454,7 +455,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
454
455
  system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
455
456
  )
456
457
 
457
- with graph.iter(
458
+ async with graph.iter(
458
459
  start_node,
459
460
  state=state,
460
461
  deps=graph_deps,
@@ -633,7 +634,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
633
634
  self._infer_name(frame.f_back)
634
635
 
635
636
  yielded = False
636
- with self.iter(
637
+ async with self.iter(
637
638
  user_prompt,
638
639
  result_type=result_type,
639
640
  message_history=message_history,
@@ -661,11 +662,10 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
661
662
  new_part = maybe_part_event.part
662
663
  if isinstance(new_part, _messages.TextPart):
663
664
  if _agent_graph.allow_text_result(result_schema):
664
- return FinalResult(s, None)
665
- elif isinstance(new_part, _messages.ToolCallPart):
666
- if result_schema is not None and (match := result_schema.find_tool([new_part])):
667
- call, _ = match
668
- return FinalResult(s, call.tool_name)
665
+ return FinalResult(s, None, None)
666
+ elif isinstance(new_part, _messages.ToolCallPart) and result_schema:
667
+ for call, _ in result_schema.find_tool([new_part]):
668
+ return FinalResult(s, call.tool_name, call.tool_call_id)
669
669
  return None
670
670
 
671
671
  final_result_details = await stream_to_final(streamed_response)
@@ -692,6 +692,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
692
692
  async for _event in _agent_graph.process_function_tools(
693
693
  tool_calls,
694
694
  final_result_details.tool_name,
695
+ final_result_details.tool_call_id,
695
696
  graph_ctx,
696
697
  parts,
697
698
  ):
@@ -1115,6 +1116,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1115
1116
  else:
1116
1117
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
1117
1118
 
1119
+ if not isinstance(model_, InstrumentedModel):
1120
+ model_ = InstrumentedModel(model_)
1121
+
1118
1122
  return model_
1119
1123
 
1120
1124
  def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
@@ -1183,14 +1187,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1183
1187
  return isinstance(node, _agent_graph.ModelRequestNode)
1184
1188
 
1185
1189
  @staticmethod
1186
- def is_handle_response_node(
1190
+ def is_call_tools_node(
1187
1191
  node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1188
- ) -> TypeGuard[_agent_graph.HandleResponseNode[T, S]]:
1189
- """Check if the node is a `HandleResponseNode`, narrowing the type if it is.
1192
+ ) -> TypeGuard[_agent_graph.CallToolsNode[T, S]]:
1193
+ """Check if the node is a `CallToolsNode`, narrowing the type if it is.
1190
1194
 
1191
1195
  This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1192
1196
  """
1193
- return isinstance(node, _agent_graph.HandleResponseNode)
1197
+ return isinstance(node, _agent_graph.CallToolsNode)
1194
1198
 
1195
1199
  @staticmethod
1196
1200
  def is_user_prompt_node(
@@ -1217,7 +1221,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1217
1221
  class AgentRun(Generic[AgentDepsT, ResultDataT]):
1218
1222
  """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
1219
1223
 
1220
- You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`.
1224
+ You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
1221
1225
 
1222
1226
  Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an
1223
1227
  [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result]
@@ -1232,7 +1236,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1232
1236
  async def main():
1233
1237
  nodes = []
1234
1238
  # Iterate through the run, recording each node along the way:
1235
- with agent.iter('What is the capital of France?') as agent_run:
1239
+ async with agent.iter('What is the capital of France?') as agent_run:
1236
1240
  async for node in agent_run:
1237
1241
  nodes.append(node)
1238
1242
  print(nodes)
@@ -1250,7 +1254,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1250
1254
  kind='request',
1251
1255
  )
1252
1256
  ),
1253
- HandleResponseNode(
1257
+ CallToolsNode(
1254
1258
  model_response=ModelResponse(
1255
1259
  parts=[TextPart(content='Paris', part_kind='text')],
1256
1260
  model_name='gpt-4o',
@@ -1258,7 +1262,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1258
1262
  kind='response',
1259
1263
  )
1260
1264
  ),
1261
- End(data=FinalResult(data='Paris', tool_name=None)),
1265
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
1262
1266
  ]
1263
1267
  '''
1264
1268
  print(agent_run.result.data)
@@ -1346,7 +1350,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1346
1350
  agent = Agent('openai:gpt-4o')
1347
1351
 
1348
1352
  async def main():
1349
- with agent.iter('What is the capital of France?') as agent_run:
1353
+ async with agent.iter('What is the capital of France?') as agent_run:
1350
1354
  next_node = agent_run.next_node # start with the first node
1351
1355
  nodes = [next_node]
1352
1356
  while not isinstance(next_node, End):
@@ -1374,7 +1378,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1374
1378
  kind='request',
1375
1379
  )
1376
1380
  ),
1377
- HandleResponseNode(
1381
+ CallToolsNode(
1378
1382
  model_response=ModelResponse(
1379
1383
  parts=[TextPart(content='Paris', part_kind='text')],
1380
1384
  model_name='gpt-4o',
@@ -1382,7 +1386,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1382
1386
  kind='response',
1383
1387
  )
1384
1388
  ),
1385
- End(data=FinalResult(data='Paris', tool_name=None)),
1389
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
1386
1390
  ]
1387
1391
  '''
1388
1392
  print('Final result:', agent_run.result.data)
pydantic_ai/messages.py CHANGED
@@ -8,6 +8,7 @@ from typing import Annotated, Any, Literal, Union, cast, overload
8
8
 
9
9
  import pydantic
10
10
  import pydantic_core
11
+ from opentelemetry._events import Event
11
12
  from typing_extensions import TypeAlias
12
13
 
13
14
  from ._utils import now_utc as _now_utc
@@ -33,6 +34,9 @@ class SystemPromptPart:
33
34
  part_kind: Literal['system-prompt'] = 'system-prompt'
34
35
  """Part type identifier, this is available on all parts as a discriminator."""
35
36
 
37
+ def otel_event(self) -> Event:
38
+ return Event('gen_ai.system.message', body={'content': self.content, 'role': 'system'})
39
+
36
40
 
37
41
  @dataclass
38
42
  class AudioUrl:
@@ -138,6 +142,14 @@ class UserPromptPart:
138
142
  part_kind: Literal['user-prompt'] = 'user-prompt'
139
143
  """Part type identifier, this is available on all parts as a discriminator."""
140
144
 
145
+ def otel_event(self) -> Event:
146
+ if isinstance(self.content, str):
147
+ content = self.content
148
+ else:
149
+ # TODO figure out what to record for images and audio
150
+ content = [part if isinstance(part, str) else {'kind': part.kind} for part in self.content]
151
+ return Event('gen_ai.user.message', body={'content': content, 'role': 'user'})
152
+
141
153
 
142
154
  tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
143
155
 
@@ -176,6 +188,9 @@ class ToolReturnPart:
176
188
  else:
177
189
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
178
190
 
191
+ def otel_event(self) -> Event:
192
+ return Event('gen_ai.tool.message', body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id})
193
+
179
194
 
180
195
  error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
181
196
 
@@ -224,6 +239,14 @@ class RetryPromptPart:
224
239
  description = f'{len(self.content)} validation errors: {json_errors.decode()}'
225
240
  return f'{description}\n\nFix the errors and try again.'
226
241
 
242
+ def otel_event(self) -> Event:
243
+ if self.tool_name is None:
244
+ return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
245
+ else:
246
+ return Event(
247
+ 'gen_ai.tool.message', body={'content': self.model_response(), 'role': 'tool', 'id': self.tool_call_id}
248
+ )
249
+
227
250
 
228
251
  ModelRequestPart = Annotated[
229
252
  Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
@@ -329,6 +352,36 @@ class ModelResponse:
329
352
  kind: Literal['response'] = 'response'
330
353
  """Message type identifier, this is available on all parts as a discriminator."""
331
354
 
355
+ def otel_events(self) -> list[Event]:
356
+ """Return OpenTelemetry events for the response."""
357
+ result: list[Event] = []
358
+
359
+ def new_event_body():
360
+ new_body: dict[str, Any] = {'role': 'assistant'}
361
+ ev = Event('gen_ai.assistant.message', body=new_body)
362
+ result.append(ev)
363
+ return new_body
364
+
365
+ body = new_event_body()
366
+ for part in self.parts:
367
+ if isinstance(part, ToolCallPart):
368
+ body.setdefault('tool_calls', []).append(
369
+ {
370
+ 'id': part.tool_call_id,
371
+ 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
372
+ 'function': {
373
+ 'name': part.tool_name,
374
+ 'arguments': part.args,
375
+ },
376
+ }
377
+ )
378
+ elif isinstance(part, TextPart):
379
+ if body.get('content'):
380
+ body = new_event_body()
381
+ body['content'] = part.content
382
+
383
+ return result
384
+
332
385
 
333
386
  ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
334
387
  """Any message sent to or returned by a model."""
@@ -539,6 +592,8 @@ class FinalResultEvent:
539
592
 
540
593
  tool_name: str | None
541
594
  """The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
595
+ tool_call_id: str | None
596
+ """The tool call ID, if any, that this result is associated with."""
542
597
  event_kind: Literal['final_result'] = 'final_result'
543
598
  """Event type identifier, used as a discriminator."""
544
599
 
@@ -1,28 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
- from collections.abc import AsyncIterator, Iterator
4
+ from collections.abc import AsyncIterator, Iterator, Mapping
5
5
  from contextlib import asynccontextmanager, contextmanager
6
6
  from dataclasses import dataclass, field
7
- from functools import partial
8
7
  from typing import Any, Callable, Literal
9
8
 
10
9
  import logfire_api
11
10
  from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
12
- from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
11
+ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
13
12
  from opentelemetry.util.types import AttributeValue
13
+ from pydantic import TypeAdapter
14
14
 
15
15
  from ..messages import (
16
16
  ModelMessage,
17
17
  ModelRequest,
18
- ModelRequestPart,
19
18
  ModelResponse,
20
- RetryPromptPart,
21
- SystemPromptPart,
22
- TextPart,
23
- ToolCallPart,
24
- ToolReturnPart,
25
- UserPromptPart,
26
19
  )
27
20
  from ..settings import ModelSettings
28
21
  from ..usage import Usage
@@ -48,6 +41,8 @@ MODEL_SETTING_ATTRIBUTES: tuple[
48
41
  'frequency_penalty',
49
42
  )
50
43
 
44
+ ANY_ADAPTER = TypeAdapter[Any](Any)
45
+
51
46
 
52
47
  @dataclass
53
48
  class InstrumentedModel(WrapperModel):
@@ -115,7 +110,7 @@ class InstrumentedModel(WrapperModel):
115
110
  finish(response_stream.get(), response_stream.usage())
116
111
 
117
112
  @contextmanager
118
- def _instrument( # noqa: C901
113
+ def _instrument(
119
114
  self,
120
115
  messages: list[ModelMessage],
121
116
  model_settings: ModelSettings | None,
@@ -141,35 +136,24 @@ class InstrumentedModel(WrapperModel):
141
136
  if isinstance(value := model_settings.get(key), (float, int)):
142
137
  attributes[f'gen_ai.request.{key}'] = value
143
138
 
144
- events_list = []
145
- emit_event = partial(self._emit_event, system, events_list)
146
-
147
139
  with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
148
- if span.is_recording():
149
- for message in messages:
150
- if isinstance(message, ModelRequest):
151
- for part in message.parts:
152
- event_name, body = _request_part_body(part)
153
- if event_name:
154
- emit_event(event_name, body)
155
- elif isinstance(message, ModelResponse):
156
- for body in _response_bodies(message):
157
- emit_event('gen_ai.assistant.message', body)
158
140
 
159
141
  def finish(response: ModelResponse, usage: Usage):
160
142
  if not span.is_recording():
161
143
  return
162
144
 
163
- for response_body in _response_bodies(response):
164
- if response_body:
165
- emit_event(
145
+ events = self.messages_to_otel_events(messages)
146
+ for event in self.messages_to_otel_events([response]):
147
+ events.append(
148
+ Event(
166
149
  'gen_ai.choice',
167
- {
150
+ body={
168
151
  # TODO finish_reason
169
152
  'index': 0,
170
- 'message': response_body,
153
+ 'message': event.body,
171
154
  },
172
155
  )
156
+ )
173
157
  span.set_attributes(
174
158
  {
175
159
  # TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
@@ -178,67 +162,56 @@ class InstrumentedModel(WrapperModel):
178
162
  **usage.opentelemetry_attributes(),
179
163
  }
180
164
  )
181
- if events_list:
182
- attr_name = 'events'
183
- span.set_attributes(
184
- {
185
- attr_name: json.dumps(events_list),
186
- 'logfire.json_schema': json.dumps(
187
- {
188
- 'type': 'object',
189
- 'properties': {attr_name: {'type': 'array'}},
190
- }
191
- ),
192
- }
193
- )
165
+ self._emit_events(system, span, events)
194
166
 
195
167
  yield finish
196
168
 
197
- def _emit_event(
198
- self, system: str, events_list: list[dict[str, Any]], event_name: str, body: dict[str, Any]
199
- ) -> None:
200
- attributes = {'gen_ai.system': system}
169
+ def _emit_events(self, system: str, span: Span, events: list[Event]) -> None:
170
+ for event in events:
171
+ event.attributes = {'gen_ai.system': system, **(event.attributes or {})}
201
172
  if self.event_mode == 'logs':
202
- self.event_logger.emit(Event(event_name, body=body, attributes=attributes))
203
- else:
204
- events_list.append({'event.name': event_name, **body, **attributes})
205
-
206
-
207
- def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
208
- if isinstance(part, SystemPromptPart):
209
- return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
210
- elif isinstance(part, UserPromptPart):
211
- return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
212
- elif isinstance(part, ToolReturnPart):
213
- return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
214
- elif isinstance(part, RetryPromptPart):
215
- if part.tool_name is None:
216
- return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
173
+ for event in events:
174
+ self.event_logger.emit(event)
217
175
  else:
218
- return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
219
- else:
220
- return '', {}
221
-
222
-
223
- def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
224
- body: dict[str, Any] = {'role': 'assistant'}
225
- result = [body]
226
- for part in message.parts:
227
- if isinstance(part, ToolCallPart):
228
- body.setdefault('tool_calls', []).append(
176
+ attr_name = 'events'
177
+ span.set_attributes(
229
178
  {
230
- 'id': part.tool_call_id,
231
- 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
232
- 'function': {
233
- 'name': part.tool_name,
234
- 'arguments': part.args,
235
- },
179
+ attr_name: json.dumps([self.event_to_dict(event) for event in events]),
180
+ 'logfire.json_schema': json.dumps(
181
+ {
182
+ 'type': 'object',
183
+ 'properties': {attr_name: {'type': 'array'}},
184
+ }
185
+ ),
236
186
  }
237
187
  )
238
- elif isinstance(part, TextPart):
239
- if body.get('content'):
240
- body = {'role': 'assistant'}
241
- result.append(body)
242
- body['content'] = part.content
243
188
 
244
- return result
189
+ @staticmethod
190
+ def event_to_dict(event: Event) -> dict[str, Any]:
191
+ if not event.body:
192
+ body = {}
193
+ elif isinstance(event.body, Mapping):
194
+ body = event.body # type: ignore
195
+ else:
196
+ body = {'body': event.body}
197
+ return {**body, **(event.attributes or {})}
198
+
199
+ @staticmethod
200
+ def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
201
+ result: list[Event] = []
202
+ for message in messages:
203
+ if isinstance(message, ModelRequest):
204
+ for part in message.parts:
205
+ if hasattr(part, 'otel_event'):
206
+ result.append(part.otel_event())
207
+ elif isinstance(message, ModelResponse):
208
+ result.extend(message.otel_events())
209
+ for event in result:
210
+ try:
211
+ event.body = ANY_ADAPTER.dump_python(event.body, mode='json')
212
+ except Exception:
213
+ try:
214
+ event.body = str(event.body)
215
+ except Exception:
216
+ event.body = 'Unable to serialize event body'
217
+ return result
pydantic_ai/result.py CHANGED
@@ -145,12 +145,14 @@ class AgentStream(Generic[AgentDepsT, ResultDataT]):
145
145
  if isinstance(e, _messages.PartStartEvent):
146
146
  new_part = e.part
147
147
  if isinstance(new_part, _messages.ToolCallPart):
148
- if result_schema is not None and (match := result_schema.find_tool([new_part])):
149
- call, _ = match
150
- return _messages.FinalResultEvent(tool_name=call.tool_name)
148
+ if result_schema:
149
+ for call, _ in result_schema.find_tool([new_part]):
150
+ return _messages.FinalResultEvent(
151
+ tool_name=call.tool_name, tool_call_id=call.tool_call_id
152
+ )
151
153
  elif allow_text_result:
152
154
  assert_type(e, _messages.PartStartEvent)
153
- return _messages.FinalResultEvent(tool_name=None)
155
+ return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
154
156
 
155
157
  usage_checking_stream = _get_usage_checking_stream_response(
156
158
  self._raw_stream_response, self._usage_limits, self.usage
@@ -472,6 +474,8 @@ class FinalResult(Generic[ResultDataT]):
472
474
  """The final result data."""
473
475
  tool_name: str | None
474
476
  """Name of the final result tool; `None` if the result came from unstructured text content."""
477
+ tool_call_id: str | None
478
+ """ID of the tool call that produced the final result; `None` if the result came from unstructured text content."""
475
479
 
476
480
 
477
481
  def _get_usage_checking_stream_response(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.30
3
+ Version: 0.0.31
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,8 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: logfire-api>=1.2.0
32
- Requires-Dist: pydantic-graph==0.0.30
32
+ Requires-Dist: opentelemetry-api>=1.28.0
33
+ Requires-Dist: pydantic-graph==0.0.31
33
34
  Requires-Dist: pydantic>=2.10
34
35
  Provides-Extra: anthropic
35
36
  Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
@@ -1,17 +1,17 @@
1
- pydantic_ai/__init__.py,sha256=Rmpjmorf8YY1PtlkXRRNN-J3ZoQDSh7chaibVGyqY0k,937
2
- pydantic_ai/_agent_graph.py,sha256=gvJQ17A2glk8p2w2TCSfHwvWNp0vla1sQb0EZXOZbxU,30284
1
+ pydantic_ai/__init__.py,sha256=xrSDxkBwpUVInbPtTVhReEecStk-mWZMttAPUAQR0Ic,927
2
+ pydantic_ai/_agent_graph.py,sha256=vvhV051rjVcPPRZ_TeL4pWwX-DptEzWgBBJnhybmIWg,30510
3
3
  pydantic_ai/_griffe.py,sha256=RYRKiLbgG97QxnazbAwlnc74XxevGHLQet-FGfq9qls,3960
4
4
  pydantic_ai/_parts_manager.py,sha256=ARfDQY1_5AIY5rNl_M2fAYHEFCe03ZxdhgjHf9qeIKw,11872
5
5
  pydantic_ai/_pydantic.py,sha256=dROz3Hmfdi0C2exq88FhefDRVo_8S3rtkXnoUHzsz0c,8753
6
- pydantic_ai/_result.py,sha256=tN1pVulf_EM4bkBvpNUWPnUXezLY-sBrJEVCFdy2nLU,10264
6
+ pydantic_ai/_result.py,sha256=mqj3YrUzr5OT00h0KfGJglwQZ6_7nV7355Pvucd08ak,10276
7
7
  pydantic_ai/_system_prompt.py,sha256=602c2jyle2R_SesOrITBDETZqsLk4BZ8Cbo8yEhmx04,1120
8
- pydantic_ai/_utils.py,sha256=w9BYSfFZiaX757fRtMRclOL1uYzyQnxV_lxqbU2WTPs,9435
9
- pydantic_ai/agent.py,sha256=FeKELTSFKDkt6-UlmkezKnQTdnx1in6VckivqsfzfA4,65382
8
+ pydantic_ai/_utils.py,sha256=nx4Suswk2qjLvzphx8uQntKzFi-IzvhX_H1L7t_kJlQ,9579
9
+ pydantic_ai/agent.py,sha256=jHQ99M-kwUrUSWHPjBDmWG2AepbDS9H3YUE1NugaWGg,65625
10
10
  pydantic_ai/exceptions.py,sha256=1ujJeB3jDDQ-pH5ydBYrgStvR35-GlEW0bYGTGEr4ME,3127
11
11
  pydantic_ai/format_as_xml.py,sha256=QE7eMlg5-YUMw1_2kcI3h0uKYPZZyGkgXFDtfZTMeeI,4480
12
- pydantic_ai/messages.py,sha256=k8sX-V1cTeqXh1u6oJbqExZPYt3E7F3UCIudxvjKRO8,21486
12
+ pydantic_ai/messages.py,sha256=Yny2hIuExXfw9fvHDSPgbvfN91IOdcLaDEAmaCAoTBs,23751
13
13
  pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- pydantic_ai/result.py,sha256=Df_tPeqCQnLa0i0vVA-BGCJDx37ebD_3ojAmHnXE2yU,22767
14
+ pydantic_ai/result.py,sha256=Q--JTwDfPeJw1_Mk5EhI7R9V7GusG-oAx1m9pDH50zQ,23014
15
15
  pydantic_ai/settings.py,sha256=ntuWnke9UA18aByDxk9OIhN0tAgOaPdqCEkRf-wlp8Y,3059
16
16
  pydantic_ai/tools.py,sha256=IPZuZJCSQUppz1uyLVwpfFLGoMirB8YtKWXIDQGR444,13414
17
17
  pydantic_ai/usage.py,sha256=VmpU_o_RjFI65J81G1wfCwDIAYBclMjeWfLtslntFOw,5406
@@ -25,12 +25,12 @@ pydantic_ai/models/fallback.py,sha256=smHwNIpxu19JsgYYjY0nmzl3yox7yQRJ0Ir08zdhnk
25
25
  pydantic_ai/models/function.py,sha256=THIwVJ8qI3efYLNlYXlYze_J8hc7MHB-NMb3kpknq0g,11373
26
26
  pydantic_ai/models/gemini.py,sha256=2hDTMIMf899dp-MS0tLT7m1GkXsL9KIRMBklGM0VLB4,34223
27
27
  pydantic_ai/models/groq.py,sha256=Z4sZJDu5Yxa2tZiAPp9EjSVMz4uwLhS3fW7kFSc09gI,16406
28
- pydantic_ai/models/instrumented.py,sha256=xUZEn2VG8hP3hny0L5kZgXC5UnFdlUJ0DgXOxFmYhEo,9654
28
+ pydantic_ai/models/instrumented.py,sha256=npufEZJrR9m0_ZQB1inWFcuK3Nu5_2GdY1YtTYaIj3s,8366
29
29
  pydantic_ai/models/mistral.py,sha256=ZJ4xPcL9wJIQ5io34yP2fPyXy8GZrSvsW4itZiKPYFw,27448
30
30
  pydantic_ai/models/openai.py,sha256=koIcK_pDHmV-JFq_-VIzU-edAqGKOOzkSk5QSYWvfoc,20156
31
31
  pydantic_ai/models/test.py,sha256=Ux20cmuJFkhvI9L1N7ItHNFcd-j284TBEsrM53eWRag,16873
32
32
  pydantic_ai/models/vertexai.py,sha256=9Kp_1KMBlbP8_HRJTuFnrkkFmlJ7yFhADQYjxOgIh9Y,9523
33
33
  pydantic_ai/models/wrapper.py,sha256=Zr3fgiUBpt2N9gXds6iSwaMEtEsFKr9WwhpHjSoHa7o,1410
34
- pydantic_ai_slim-0.0.30.dist-info/METADATA,sha256=JDT77S9uw0w87WpAbXqK_c65849A7PeF1_dhJRGamiM,3062
35
- pydantic_ai_slim-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- pydantic_ai_slim-0.0.30.dist-info/RECORD,,
34
+ pydantic_ai_slim-0.0.31.dist-info/METADATA,sha256=dgkUKEU7r9OqgIkt3enzpISWt73KVAYL8gC2APlnpWg,3103
35
+ pydantic_ai_slim-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ pydantic_ai_slim-0.0.31.dist-info/RECORD,,