pydantic-ai-slim 0.0.30__py3-none-any.whl → 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from importlib.metadata import version
2
2
 
3
- from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
3
+ from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
4
4
  from .exceptions import (
5
5
  AgentRunError,
6
6
  FallbackExceptionGroup,
@@ -18,7 +18,7 @@ __all__ = (
18
18
  # agent
19
19
  'Agent',
20
20
  'EndStrategy',
21
- 'HandleResponseNode',
21
+ 'CallToolsNode',
22
22
  'ModelRequestNode',
23
23
  'UserPromptNode',
24
24
  'capture_run_messages',
@@ -2,13 +2,14 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import dataclasses
5
+ import json
5
6
  from collections.abc import AsyncIterator, Iterator, Sequence
6
7
  from contextlib import asynccontextmanager, contextmanager
7
8
  from contextvars import ContextVar
8
9
  from dataclasses import field
9
10
  from typing import Any, Generic, Literal, Union, cast
10
11
 
11
- import logfire_api
12
+ from opentelemetry.trace import Span, Tracer
12
13
  from typing_extensions import TypeGuard, TypeVar, assert_never
13
14
 
14
15
  from pydantic_graph import BaseNode, Graph, GraphRunContext
@@ -23,6 +24,7 @@ from . import (
23
24
  result,
24
25
  usage as _usage,
25
26
  )
27
+ from .models.instrumented import InstrumentedModel
26
28
  from .result import ResultDataT
27
29
  from .settings import ModelSettings, merge_model_settings
28
30
  from .tools import (
@@ -36,22 +38,11 @@ __all__ = (
36
38
  'GraphAgentDeps',
37
39
  'UserPromptNode',
38
40
  'ModelRequestNode',
39
- 'HandleResponseNode',
41
+ 'CallToolsNode',
40
42
  'build_run_context',
41
43
  'capture_run_messages',
42
44
  )
43
45
 
44
- _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
45
-
46
- # while waiting for https://github.com/pydantic/logfire/issues/745
47
- try:
48
- import logfire._internal.stack_info
49
- except ImportError:
50
- pass
51
- else:
52
- from pathlib import Path
53
-
54
- logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
55
46
 
56
47
  T = TypeVar('T')
57
48
  S = TypeVar('S')
@@ -104,7 +95,8 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
104
95
 
105
96
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
106
97
 
107
- run_span: logfire_api.LogfireSpan
98
+ run_span: Span
99
+ tracer: Tracer
108
100
 
109
101
 
110
102
  class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
@@ -243,12 +235,12 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
243
235
 
244
236
  request: _messages.ModelRequest
245
237
 
246
- _result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
238
+ _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
247
239
  _did_stream: bool = field(default=False, repr=False)
248
240
 
249
241
  async def run(
250
242
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
251
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
243
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
252
244
  if self._result is not None:
253
245
  return self._result
254
246
 
@@ -286,39 +278,33 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
286
278
  assert not self._did_stream, 'stream() should only be called once per node'
287
279
 
288
280
  model_settings, model_request_parameters = await self._prepare_request(ctx)
289
- with _logfire.span('model request', run_step=ctx.state.run_step) as span:
290
- async with ctx.deps.model.request_stream(
291
- ctx.state.message_history, model_settings, model_request_parameters
292
- ) as streamed_response:
293
- self._did_stream = True
294
- ctx.state.usage.incr(_usage.Usage(), requests=1)
295
- yield streamed_response
296
- # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
297
- # otherwise usage won't be properly counted:
298
- async for _ in streamed_response:
299
- pass
300
- model_response = streamed_response.get()
301
- request_usage = streamed_response.usage()
302
- span.set_attribute('response', model_response)
303
- span.set_attribute('usage', request_usage)
281
+ async with ctx.deps.model.request_stream(
282
+ ctx.state.message_history, model_settings, model_request_parameters
283
+ ) as streamed_response:
284
+ self._did_stream = True
285
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
286
+ yield streamed_response
287
+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
288
+ # otherwise usage won't be properly counted:
289
+ async for _ in streamed_response:
290
+ pass
291
+ model_response = streamed_response.get()
292
+ request_usage = streamed_response.usage()
304
293
 
305
294
  self._finish_handling(ctx, model_response, request_usage)
306
295
  assert self._result is not None # this should be set by the previous line
307
296
 
308
297
  async def _make_request(
309
298
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
310
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
299
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
311
300
  if self._result is not None:
312
301
  return self._result
313
302
 
314
303
  model_settings, model_request_parameters = await self._prepare_request(ctx)
315
- with _logfire.span('model request', run_step=ctx.state.run_step) as span:
316
- model_response, request_usage = await ctx.deps.model.request(
317
- ctx.state.message_history, model_settings, model_request_parameters
318
- )
319
- ctx.state.usage.incr(_usage.Usage(), requests=1)
320
- span.set_attribute('response', model_response)
321
- span.set_attribute('usage', request_usage)
304
+ model_response, request_usage = await ctx.deps.model.request(
305
+ ctx.state.message_history, model_settings, model_request_parameters
306
+ )
307
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
322
308
 
323
309
  return self._finish_handling(ctx, model_response, request_usage)
324
310
 
@@ -335,7 +321,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
335
321
  ctx.state.run_step += 1
336
322
 
337
323
  model_settings = merge_model_settings(ctx.deps.model_settings, None)
338
- with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
324
+ with ctx.deps.tracer.start_as_current_span(
325
+ 'preparing model request params', attributes=dict(run_step=ctx.state.run_step)
326
+ ):
339
327
  model_request_parameters = await _prepare_request_parameters(ctx)
340
328
  return model_settings, model_request_parameters
341
329
 
@@ -344,7 +332,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
344
332
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
345
333
  response: _messages.ModelResponse,
346
334
  usage: _usage.Usage,
347
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
335
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
348
336
  # Update usage
349
337
  ctx.state.usage.incr(usage, requests=0)
350
338
  if ctx.deps.usage_limits:
@@ -354,13 +342,13 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
354
342
  ctx.state.message_history.append(response)
355
343
 
356
344
  # Set the `_result` attribute since we can't use `return` in an async iterator
357
- self._result = HandleResponseNode(response)
345
+ self._result = CallToolsNode(response)
358
346
 
359
347
  return self._result
360
348
 
361
349
 
362
350
  @dataclasses.dataclass
363
- class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
351
+ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
364
352
  """Process a model response, and decide whether to end the run or make a new request."""
365
353
 
366
354
  model_response: _messages.ModelResponse
@@ -385,26 +373,12 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
385
373
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
386
374
  ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
387
375
  """Process the model response and yield events for the start and end of each function tool call."""
388
- with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
389
- stream = self._run_stream(ctx)
390
- yield stream
391
-
392
- # Run the stream to completion if it was not finished:
393
- async for _event in stream:
394
- pass
376
+ stream = self._run_stream(ctx)
377
+ yield stream
395
378
 
396
- # Set the next node based on the final state of the stream
397
- next_node = self._next_node
398
- if isinstance(next_node, End):
399
- handle_span.set_attribute('result', next_node.data)
400
- handle_span.message = 'handle model response -> final result'
401
- elif tool_responses := self._tool_responses:
402
- # TODO: We could drop `self._tool_responses` if we drop this set_attribute
403
- # I'm thinking it might be better to just create a span for the handling of each tool
404
- # than to set an attribute here.
405
- handle_span.set_attribute('tool_responses', tool_responses)
406
- tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
407
- handle_span.message = f'handle model response -> {tool_responses_str}'
379
+ # Run the stream to completion if it was not finished:
380
+ async for _event in stream:
381
+ pass
408
382
 
409
383
  async def _run_stream(
410
384
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -454,8 +428,7 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
454
428
  final_result: result.FinalResult[NodeRunEndT] | None = None
455
429
  parts: list[_messages.ModelRequestPart] = []
456
430
  if result_schema is not None:
457
- if match := result_schema.find_tool(tool_calls):
458
- call, result_tool = match
431
+ for call, result_tool in result_schema.find_tool(tool_calls):
459
432
  try:
460
433
  result_data = result_tool.validate(call)
461
434
  result_data = await _validate_result(result_data, ctx, call)
@@ -465,12 +438,17 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
465
438
  ctx.state.increment_retries(ctx.deps.max_result_retries)
466
439
  parts.append(e.tool_retry)
467
440
  else:
468
- final_result = result.FinalResult(result_data, call.tool_name)
441
+ final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
442
+ break
469
443
 
470
444
  # Then build the other request parts based on end strategy
471
445
  tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
472
446
  async for event in process_function_tools(
473
- tool_calls, final_result and final_result.tool_name, ctx, tool_responses
447
+ tool_calls,
448
+ final_result and final_result.tool_name,
449
+ final_result and final_result.tool_call_id,
450
+ ctx,
451
+ tool_responses,
474
452
  ):
475
453
  yield event
476
454
 
@@ -495,8 +473,30 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
495
473
  if tool_responses:
496
474
  messages.append(_messages.ModelRequest(parts=tool_responses))
497
475
 
498
- run_span.set_attribute('usage', usage)
499
- run_span.set_attribute('all_messages', messages)
476
+ run_span.set_attributes(
477
+ {
478
+ **usage.opentelemetry_attributes(),
479
+ 'all_messages_events': json.dumps(
480
+ [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
481
+ ),
482
+ 'final_result': final_result.data
483
+ if isinstance(final_result.data, str)
484
+ else json.dumps(InstrumentedModel.serialize_any(final_result.data)),
485
+ }
486
+ )
487
+ run_span.set_attributes(
488
+ {
489
+ 'logfire.json_schema': json.dumps(
490
+ {
491
+ 'type': 'object',
492
+ 'properties': {
493
+ 'all_messages_events': {'type': 'array'},
494
+ 'final_result': {'type': 'object'},
495
+ },
496
+ }
497
+ ),
498
+ }
499
+ )
500
500
 
501
501
  # End the run with self.data
502
502
  return End(final_result)
@@ -518,7 +518,7 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
518
518
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
519
519
  else:
520
520
  # The following cast is safe because we know `str` is an allowed result type
521
- return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
521
+ return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
522
522
  else:
523
523
  ctx.state.increment_retries(ctx.deps.max_result_retries)
524
524
  return ModelRequestNode[DepsT, NodeRunEndT](
@@ -547,6 +547,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
547
547
  async def process_function_tools(
548
548
  tool_calls: list[_messages.ToolCallPart],
549
549
  result_tool_name: str | None,
550
+ result_tool_call_id: str | None,
550
551
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
551
552
  output_parts: list[_messages.ModelRequestPart],
552
553
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
@@ -566,7 +567,11 @@ async def process_function_tools(
566
567
  calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
567
568
  call_index_to_event_id: dict[int, str] = {}
568
569
  for call in tool_calls:
569
- if call.tool_name == result_tool_name and not found_used_result_tool:
570
+ if (
571
+ call.tool_name == result_tool_name
572
+ and call.tool_call_id == result_tool_call_id
573
+ and not found_used_result_tool
574
+ ):
570
575
  found_used_result_tool = True
571
576
  output_parts.append(
572
577
  _messages.ToolReturnPart(
@@ -593,9 +598,14 @@ async def process_function_tools(
593
598
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
594
599
  # validation, we don't add another part here
595
600
  if result_tool_name is not None:
601
+ if found_used_result_tool:
602
+ content = 'Result tool not used - a final result was already processed.'
603
+ else:
604
+ # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
605
+ content = 'Result tool not used - result failed validation.'
596
606
  part = _messages.ToolReturnPart(
597
607
  tool_name=call.tool_name,
598
- content='Result tool not used - a final result was already processed.',
608
+ content=content,
599
609
  tool_call_id=call.tool_call_id,
600
610
  )
601
611
  output_parts.append(part)
@@ -607,7 +617,10 @@ async def process_function_tools(
607
617
 
608
618
  # Run all tool tasks in parallel
609
619
  results_by_index: dict[int, _messages.ModelRequestPart] = {}
610
- with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]):
620
+ tool_names = [call.tool_name for _, call in calls_to_run]
621
+ with ctx.deps.tracer.start_as_current_span(
622
+ 'running tools', attributes={'tools': tool_names, 'logfire.msg': f'running tools: {", ".join(tool_names)}'}
623
+ ):
611
624
  # TODO: Should we wrap each individual tool call in a dedicated span?
612
625
  tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
613
626
  pending = tasks
@@ -716,7 +729,7 @@ def build_agent_graph(
716
729
  nodes = (
717
730
  UserPromptNode[DepsT],
718
731
  ModelRequestNode[DepsT],
719
- HandleResponseNode[DepsT],
732
+ CallToolsNode[DepsT],
720
733
  )
721
734
  graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
722
735
  nodes=nodes,
pydantic_ai/_result.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import inspect
4
4
  import sys
5
5
  import types
6
- from collections.abc import Awaitable, Iterable
6
+ from collections.abc import Awaitable, Iterable, Iterator
7
7
  from dataclasses import dataclass, field
8
8
  from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
9
9
 
@@ -127,12 +127,12 @@ class ResultSchema(Generic[ResultDataT]):
127
127
  def find_tool(
128
128
  self,
129
129
  parts: Iterable[_messages.ModelResponsePart],
130
- ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
130
+ ) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
131
131
  """Find a tool that matches one of the calls."""
132
132
  for part in parts:
133
133
  if isinstance(part, _messages.ToolCallPart):
134
134
  if result := self.tools.get(part.tool_name):
135
- return part, result
135
+ yield part, result
136
136
 
137
137
  def tool_names(self) -> list[str]:
138
138
  """Return the names of the tools."""
pydantic_ai/_utils.py CHANGED
@@ -48,6 +48,8 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
48
48
 
49
49
  if schema.get('type') == 'object':
50
50
  return schema
51
+ elif schema.get('$ref') is not None:
52
+ return schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
51
53
  else:
52
54
  raise UserError('Schema must be an object')
53
55
 
pydantic_ai/agent.py CHANGED
@@ -8,7 +8,7 @@ from copy import deepcopy
8
8
  from types import FrameType
9
9
  from typing import Any, Callable, Generic, cast, final, overload
10
10
 
11
- import logfire_api
11
+ from opentelemetry.trace import NoOpTracer, use_span
12
12
  from typing_extensions import TypeGuard, TypeVar, deprecated
13
13
 
14
14
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
@@ -25,6 +25,7 @@ from . import (
25
25
  result,
26
26
  usage as _usage,
27
27
  )
28
+ from .models.instrumented import InstrumentedModel
28
29
  from .result import FinalResult, ResultDataT, StreamedRunResult
29
30
  from .settings import ModelSettings, merge_model_settings
30
31
  from .tools import (
@@ -42,7 +43,7 @@ from .tools import (
42
43
  # Re-exporting like this improves auto-import behavior in PyCharm
43
44
  capture_run_messages = _agent_graph.capture_run_messages
44
45
  EndStrategy = _agent_graph.EndStrategy
45
- HandleResponseNode = _agent_graph.HandleResponseNode
46
+ CallToolsNode = _agent_graph.CallToolsNode
46
47
  ModelRequestNode = _agent_graph.ModelRequestNode
47
48
  UserPromptNode = _agent_graph.UserPromptNode
48
49
 
@@ -52,22 +53,11 @@ __all__ = (
52
53
  'AgentRunResult',
53
54
  'capture_run_messages',
54
55
  'EndStrategy',
55
- 'HandleResponseNode',
56
+ 'CallToolsNode',
56
57
  'ModelRequestNode',
57
58
  'UserPromptNode',
58
59
  )
59
60
 
60
- _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
61
-
62
- # while waiting for https://github.com/pydantic/logfire/issues/745
63
- try:
64
- import logfire._internal.stack_info
65
- except ImportError:
66
- pass
67
- else:
68
- from pathlib import Path
69
-
70
- logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
71
61
 
72
62
  T = TypeVar('T')
73
63
  S = TypeVar('S')
@@ -122,6 +112,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
122
112
  The type of the result data, used to validate the result data, defaults to `str`.
123
113
  """
124
114
 
115
+ instrument: bool
116
+ """Automatically instrument with OpenTelemetry. Will use Logfire if it's configured."""
117
+
125
118
  _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
126
119
  _result_tool_name: str = dataclasses.field(repr=False)
127
120
  _result_tool_description: str | None = dataclasses.field(repr=False)
@@ -154,6 +147,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
154
147
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
155
148
  defer_model_check: bool = False,
156
149
  end_strategy: EndStrategy = 'early',
150
+ instrument: bool = False,
157
151
  ):
158
152
  """Create an agent.
159
153
 
@@ -183,6 +177,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
183
177
  [override the model][pydantic_ai.Agent.override] for testing.
184
178
  end_strategy: Strategy for handling tool calls that are requested alongside a final result.
185
179
  See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
180
+ instrument: Automatically instrument with OpenTelemetry. Will use Logfire if it's configured.
186
181
  """
187
182
  if model is None or defer_model_check:
188
183
  self.model = model
@@ -193,6 +188,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
193
188
  self.name = name
194
189
  self.model_settings = model_settings
195
190
  self.result_type = result_type
191
+ self.instrument = instrument
196
192
 
197
193
  self._deps_type = deps_type
198
194
 
@@ -294,7 +290,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
294
290
  """
295
291
  if infer_name and self.name is None:
296
292
  self._infer_name(inspect.currentframe())
297
- with self.iter(
293
+ async with self.iter(
298
294
  user_prompt=user_prompt,
299
295
  result_type=result_type,
300
296
  message_history=message_history,
@@ -310,8 +306,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
310
306
  assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly'
311
307
  return final_result
312
308
 
313
- @contextmanager
314
- def iter(
309
+ @asynccontextmanager
310
+ async def iter(
315
311
  self,
316
312
  user_prompt: str | Sequence[_messages.UserContent],
317
313
  *,
@@ -323,7 +319,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
323
319
  usage_limits: _usage.UsageLimits | None = None,
324
320
  usage: _usage.Usage | None = None,
325
321
  infer_name: bool = True,
326
- ) -> Iterator[AgentRun[AgentDepsT, Any]]:
322
+ ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
327
323
  """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
328
324
 
329
325
  This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an
@@ -344,7 +340,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
344
340
 
345
341
  async def main():
346
342
  nodes = []
347
- with agent.iter('What is the capital of France?') as agent_run:
343
+ async with agent.iter('What is the capital of France?') as agent_run:
348
344
  async for node in agent_run:
349
345
  nodes.append(node)
350
346
  print(nodes)
@@ -362,7 +358,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
362
358
  kind='request',
363
359
  )
364
360
  ),
365
- HandleResponseNode(
361
+ CallToolsNode(
366
362
  model_response=ModelResponse(
367
363
  parts=[TextPart(content='Paris', part_kind='text')],
368
364
  model_name='gpt-4o',
@@ -370,7 +366,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
370
366
  kind='response',
371
367
  )
372
368
  ),
373
- End(data=FinalResult(data='Paris', tool_name=None)),
369
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
374
370
  ]
375
371
  '''
376
372
  print(agent_run.result.data)
@@ -395,6 +391,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
395
391
  if infer_name and self.name is None:
396
392
  self._infer_name(inspect.currentframe())
397
393
  model_used = self._get_model(model)
394
+ del model
398
395
 
399
396
  deps = self._get_deps(deps)
400
397
  new_message_index = len(message_history) if message_history else 0
@@ -424,14 +421,20 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
424
421
  model_settings = merge_model_settings(self.model_settings, model_settings)
425
422
  usage_limits = usage_limits or _usage.UsageLimits()
426
423
 
427
- # Build the deps object for the graph
428
- run_span = _logfire.span(
429
- '{agent_name} run {prompt=}',
430
- prompt=user_prompt,
431
- agent=self,
432
- model_name=model_used.model_name if model_used else 'no-model',
433
- agent_name=self.name or 'agent',
424
+ if isinstance(model_used, InstrumentedModel):
425
+ tracer = model_used.tracer
426
+ else:
427
+ tracer = NoOpTracer()
428
+ agent_name = self.name or 'agent'
429
+ run_span = tracer.start_span(
430
+ 'agent run',
431
+ attributes={
432
+ 'model_name': model_used.model_name if model_used else 'no-model',
433
+ 'agent_name': agent_name,
434
+ 'logfire.msg': f'{agent_name} run',
435
+ },
434
436
  )
437
+
435
438
  graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
436
439
  user_deps=deps,
437
440
  prompt=user_prompt,
@@ -446,6 +449,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
446
449
  result_validators=result_validators,
447
450
  function_tools=self._function_tools,
448
451
  run_span=run_span,
452
+ tracer=tracer,
449
453
  )
450
454
  start_node = _agent_graph.UserPromptNode[AgentDepsT](
451
455
  user_prompt=user_prompt,
@@ -454,12 +458,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
454
458
  system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
455
459
  )
456
460
 
457
- with graph.iter(
461
+ async with graph.iter(
458
462
  start_node,
459
463
  state=state,
460
464
  deps=graph_deps,
461
465
  infer_name=False,
462
- span=run_span,
466
+ span=use_span(run_span, end_on_exit=True),
463
467
  ) as graph_run:
464
468
  yield AgentRun(graph_run)
465
469
 
@@ -633,7 +637,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
633
637
  self._infer_name(frame.f_back)
634
638
 
635
639
  yielded = False
636
- with self.iter(
640
+ async with self.iter(
637
641
  user_prompt,
638
642
  result_type=result_type,
639
643
  message_history=message_history,
@@ -661,11 +665,10 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
661
665
  new_part = maybe_part_event.part
662
666
  if isinstance(new_part, _messages.TextPart):
663
667
  if _agent_graph.allow_text_result(result_schema):
664
- return FinalResult(s, None)
665
- elif isinstance(new_part, _messages.ToolCallPart):
666
- if result_schema is not None and (match := result_schema.find_tool([new_part])):
667
- call, _ = match
668
- return FinalResult(s, call.tool_name)
668
+ return FinalResult(s, None, None)
669
+ elif isinstance(new_part, _messages.ToolCallPart) and result_schema:
670
+ for call, _ in result_schema.find_tool([new_part]):
671
+ return FinalResult(s, call.tool_name, call.tool_call_id)
669
672
  return None
670
673
 
671
674
  final_result_details = await stream_to_final(streamed_response)
@@ -692,6 +695,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
692
695
  async for _event in _agent_graph.process_function_tools(
693
696
  tool_calls,
694
697
  final_result_details.tool_name,
698
+ final_result_details.tool_call_id,
695
699
  graph_ctx,
696
700
  parts,
697
701
  ):
@@ -1115,6 +1119,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1115
1119
  else:
1116
1120
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
1117
1121
 
1122
+ if self.instrument and not isinstance(model_, InstrumentedModel):
1123
+ model_ = InstrumentedModel(model_)
1124
+
1118
1125
  return model_
1119
1126
 
1120
1127
  def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
@@ -1183,14 +1190,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1183
1190
  return isinstance(node, _agent_graph.ModelRequestNode)
1184
1191
 
1185
1192
  @staticmethod
1186
- def is_handle_response_node(
1193
+ def is_call_tools_node(
1187
1194
  node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1188
- ) -> TypeGuard[_agent_graph.HandleResponseNode[T, S]]:
1189
- """Check if the node is a `HandleResponseNode`, narrowing the type if it is.
1195
+ ) -> TypeGuard[_agent_graph.CallToolsNode[T, S]]:
1196
+ """Check if the node is a `CallToolsNode`, narrowing the type if it is.
1190
1197
 
1191
1198
  This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1192
1199
  """
1193
- return isinstance(node, _agent_graph.HandleResponseNode)
1200
+ return isinstance(node, _agent_graph.CallToolsNode)
1194
1201
 
1195
1202
  @staticmethod
1196
1203
  def is_user_prompt_node(
@@ -1217,7 +1224,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1217
1224
  class AgentRun(Generic[AgentDepsT, ResultDataT]):
1218
1225
  """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
1219
1226
 
1220
- You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`.
1227
+ You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
1221
1228
 
1222
1229
  Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an
1223
1230
  [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result]
@@ -1232,7 +1239,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1232
1239
  async def main():
1233
1240
  nodes = []
1234
1241
  # Iterate through the run, recording each node along the way:
1235
- with agent.iter('What is the capital of France?') as agent_run:
1242
+ async with agent.iter('What is the capital of France?') as agent_run:
1236
1243
  async for node in agent_run:
1237
1244
  nodes.append(node)
1238
1245
  print(nodes)
@@ -1250,7 +1257,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1250
1257
  kind='request',
1251
1258
  )
1252
1259
  ),
1253
- HandleResponseNode(
1260
+ CallToolsNode(
1254
1261
  model_response=ModelResponse(
1255
1262
  parts=[TextPart(content='Paris', part_kind='text')],
1256
1263
  model_name='gpt-4o',
@@ -1258,7 +1265,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1258
1265
  kind='response',
1259
1266
  )
1260
1267
  ),
1261
- End(data=FinalResult(data='Paris', tool_name=None)),
1268
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
1262
1269
  ]
1263
1270
  '''
1264
1271
  print(agent_run.result.data)
@@ -1346,7 +1353,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1346
1353
  agent = Agent('openai:gpt-4o')
1347
1354
 
1348
1355
  async def main():
1349
- with agent.iter('What is the capital of France?') as agent_run:
1356
+ async with agent.iter('What is the capital of France?') as agent_run:
1350
1357
  next_node = agent_run.next_node # start with the first node
1351
1358
  nodes = [next_node]
1352
1359
  while not isinstance(next_node, End):
@@ -1374,7 +1381,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1374
1381
  kind='request',
1375
1382
  )
1376
1383
  ),
1377
- HandleResponseNode(
1384
+ CallToolsNode(
1378
1385
  model_response=ModelResponse(
1379
1386
  parts=[TextPart(content='Paris', part_kind='text')],
1380
1387
  model_name='gpt-4o',
@@ -1382,7 +1389,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1382
1389
  kind='response',
1383
1390
  )
1384
1391
  ),
1385
- End(data=FinalResult(data='Paris', tool_name=None)),
1392
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
1386
1393
  ]
1387
1394
  '''
1388
1395
  print('Final result:', agent_run.result.data)
pydantic_ai/messages.py CHANGED
@@ -8,6 +8,7 @@ from typing import Annotated, Any, Literal, Union, cast, overload
8
8
 
9
9
  import pydantic
10
10
  import pydantic_core
11
+ from opentelemetry._events import Event
11
12
  from typing_extensions import TypeAlias
12
13
 
13
14
  from ._utils import now_utc as _now_utc
@@ -33,6 +34,9 @@ class SystemPromptPart:
33
34
  part_kind: Literal['system-prompt'] = 'system-prompt'
34
35
  """Part type identifier, this is available on all parts as a discriminator."""
35
36
 
37
+ def otel_event(self) -> Event:
38
+ return Event('gen_ai.system.message', body={'content': self.content, 'role': 'system'})
39
+
36
40
 
37
41
  @dataclass
38
42
  class AudioUrl:
@@ -138,6 +142,14 @@ class UserPromptPart:
138
142
  part_kind: Literal['user-prompt'] = 'user-prompt'
139
143
  """Part type identifier, this is available on all parts as a discriminator."""
140
144
 
145
+ def otel_event(self) -> Event:
146
+ if isinstance(self.content, str):
147
+ content = self.content
148
+ else:
149
+ # TODO figure out what to record for images and audio
150
+ content = [part if isinstance(part, str) else {'kind': part.kind} for part in self.content]
151
+ return Event('gen_ai.user.message', body={'content': content, 'role': 'user'})
152
+
141
153
 
142
154
  tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
143
155
 
@@ -176,6 +188,9 @@ class ToolReturnPart:
176
188
  else:
177
189
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
178
190
 
191
+ def otel_event(self) -> Event:
192
+ return Event('gen_ai.tool.message', body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id})
193
+
179
194
 
180
195
  error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
181
196
 
@@ -224,6 +239,14 @@ class RetryPromptPart:
224
239
  description = f'{len(self.content)} validation errors: {json_errors.decode()}'
225
240
  return f'{description}\n\nFix the errors and try again.'
226
241
 
242
+ def otel_event(self) -> Event:
243
+ if self.tool_name is None:
244
+ return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
245
+ else:
246
+ return Event(
247
+ 'gen_ai.tool.message', body={'content': self.model_response(), 'role': 'tool', 'id': self.tool_call_id}
248
+ )
249
+
227
250
 
228
251
  ModelRequestPart = Annotated[
229
252
  Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
@@ -329,6 +352,36 @@ class ModelResponse:
329
352
  kind: Literal['response'] = 'response'
330
353
  """Message type identifier, this is available on all parts as a discriminator."""
331
354
 
355
+ def otel_events(self) -> list[Event]:
356
+ """Return OpenTelemetry events for the response."""
357
+ result: list[Event] = []
358
+
359
+ def new_event_body():
360
+ new_body: dict[str, Any] = {'role': 'assistant'}
361
+ ev = Event('gen_ai.assistant.message', body=new_body)
362
+ result.append(ev)
363
+ return new_body
364
+
365
+ body = new_event_body()
366
+ for part in self.parts:
367
+ if isinstance(part, ToolCallPart):
368
+ body.setdefault('tool_calls', []).append(
369
+ {
370
+ 'id': part.tool_call_id,
371
+ 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
372
+ 'function': {
373
+ 'name': part.tool_name,
374
+ 'arguments': part.args,
375
+ },
376
+ }
377
+ )
378
+ elif isinstance(part, TextPart):
379
+ if body.get('content'):
380
+ body = new_event_body()
381
+ body['content'] = part.content
382
+
383
+ return result
384
+
332
385
 
333
386
  ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
334
387
  """Any message sent to or returned by a model."""
@@ -539,6 +592,8 @@ class FinalResultEvent:
539
592
 
540
593
  tool_name: str | None
541
594
  """The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
595
+ tool_call_id: str | None
596
+ """The tool call ID, if any, that this result is associated with."""
542
597
  event_kind: Literal['final_result'] = 'final_result'
543
598
  """Event type identifier, used as a discriminator."""
544
599
 
@@ -28,9 +28,11 @@ if TYPE_CHECKING:
28
28
 
29
29
 
30
30
  KnownModelName = Literal[
31
+ 'anthropic:claude-3-7-sonnet-latest',
31
32
  'anthropic:claude-3-5-haiku-latest',
32
33
  'anthropic:claude-3-5-sonnet-latest',
33
34
  'anthropic:claude-3-opus-latest',
35
+ 'claude-3-7-sonnet-latest',
34
36
  'claude-3-5-haiku-latest',
35
37
  'claude-3-5-sonnet-latest',
36
38
  'claude-3-opus-latest',
@@ -56,6 +58,7 @@ KnownModelName = Literal[
56
58
  'google-gla:gemini-exp-1206',
57
59
  'google-gla:gemini-2.0-flash',
58
60
  'google-gla:gemini-2.0-flash-lite-preview-02-05',
61
+ 'google-gla:gemini-2.0-pro-exp-02-05',
59
62
  'google-vertex:gemini-1.0-pro',
60
63
  'google-vertex:gemini-1.5-flash',
61
64
  'google-vertex:gemini-1.5-flash-8b',
@@ -65,6 +68,7 @@ KnownModelName = Literal[
65
68
  'google-vertex:gemini-exp-1206',
66
69
  'google-vertex:gemini-2.0-flash',
67
70
  'google-vertex:gemini-2.0-flash-lite-preview-02-05',
71
+ 'google-vertex:gemini-2.0-pro-exp-02-05',
68
72
  'gpt-3.5-turbo',
69
73
  'gpt-3.5-turbo-0125',
70
74
  'gpt-3.5-turbo-0301',
@@ -42,6 +42,7 @@ from . import (
42
42
  try:
43
43
  from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
44
44
  from anthropic.types import (
45
+ ContentBlock,
45
46
  ImageBlockParam,
46
47
  Message as AnthropicMessage,
47
48
  MessageParam,
@@ -69,6 +70,7 @@ except ImportError as _import_error:
69
70
  ) from _import_error
70
71
 
71
72
  LatestAnthropicModelNames = Literal[
73
+ 'claude-3-7-sonnet-latest',
72
74
  'claude-3-5-haiku-latest',
73
75
  'claude-3-5-sonnet-latest',
74
76
  'claude-3-opus-latest',
@@ -423,7 +425,7 @@ class AnthropicStreamedResponse(StreamedResponse):
423
425
  _timestamp: datetime
424
426
 
425
427
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
426
- current_block: TextBlock | ToolUseBlock | None = None
428
+ current_block: ContentBlock | None = None
427
429
  current_json: str = ''
428
430
 
429
431
  async for event in self._response:
@@ -53,6 +53,7 @@ LatestGeminiModelNames = Literal[
53
53
  'gemini-exp-1206',
54
54
  'gemini-2.0-flash',
55
55
  'gemini-2.0-flash-lite-preview-02-05',
56
+ 'gemini-2.0-pro-exp-02-05',
56
57
  ]
57
58
  """Latest Gemini models."""
58
59
 
@@ -1,28 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
- from collections.abc import AsyncIterator, Iterator
4
+ from collections.abc import AsyncIterator, Iterator, Mapping
5
5
  from contextlib import asynccontextmanager, contextmanager
6
6
  from dataclasses import dataclass, field
7
- from functools import partial
8
7
  from typing import Any, Callable, Literal
9
8
 
10
- import logfire_api
11
9
  from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
12
- from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
10
+ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
13
11
  from opentelemetry.util.types import AttributeValue
12
+ from pydantic import TypeAdapter
14
13
 
15
14
  from ..messages import (
16
15
  ModelMessage,
17
16
  ModelRequest,
18
- ModelRequestPart,
19
17
  ModelResponse,
20
- RetryPromptPart,
21
- SystemPromptPart,
22
- TextPart,
23
- ToolCallPart,
24
- ToolReturnPart,
25
- UserPromptPart,
26
18
  )
27
19
  from ..settings import ModelSettings
28
20
  from ..usage import Usage
@@ -48,6 +40,8 @@ MODEL_SETTING_ATTRIBUTES: tuple[
48
40
  'frequency_penalty',
49
41
  )
50
42
 
43
+ ANY_ADAPTER = TypeAdapter[Any](Any)
44
+
51
45
 
52
46
  @dataclass
53
47
  class InstrumentedModel(WrapperModel):
@@ -64,27 +58,15 @@ class InstrumentedModel(WrapperModel):
64
58
  event_logger_provider: EventLoggerProvider | None = None,
65
59
  event_mode: Literal['attributes', 'logs'] = 'attributes',
66
60
  ):
61
+ from pydantic_ai import __version__
62
+
67
63
  super().__init__(wrapped)
68
64
  tracer_provider = tracer_provider or get_tracer_provider()
69
65
  event_logger_provider = event_logger_provider or get_event_logger_provider()
70
- self.tracer = tracer_provider.get_tracer('pydantic-ai')
71
- self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
66
+ self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
67
+ self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
72
68
  self.event_mode = event_mode
73
69
 
74
- @classmethod
75
- def from_logfire(
76
- cls,
77
- wrapped: Model | KnownModelName,
78
- logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
79
- event_mode: Literal['attributes', 'logs'] = 'attributes',
80
- ) -> InstrumentedModel:
81
- if hasattr(logfire_instance.config, 'get_event_logger_provider'):
82
- event_provider = logfire_instance.config.get_event_logger_provider()
83
- else:
84
- event_provider = None
85
- tracer_provider = logfire_instance.config.get_tracer_provider()
86
- return cls(wrapped, tracer_provider, event_provider, event_mode)
87
-
88
70
  async def request(
89
71
  self,
90
72
  messages: list[ModelMessage],
@@ -115,7 +97,7 @@ class InstrumentedModel(WrapperModel):
115
97
  finish(response_stream.get(), response_stream.usage())
116
98
 
117
99
  @contextmanager
118
- def _instrument( # noqa: C901
100
+ def _instrument(
119
101
  self,
120
102
  messages: list[ModelMessage],
121
103
  model_settings: ModelSettings | None,
@@ -141,35 +123,24 @@ class InstrumentedModel(WrapperModel):
141
123
  if isinstance(value := model_settings.get(key), (float, int)):
142
124
  attributes[f'gen_ai.request.{key}'] = value
143
125
 
144
- events_list = []
145
- emit_event = partial(self._emit_event, system, events_list)
146
-
147
126
  with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
148
- if span.is_recording():
149
- for message in messages:
150
- if isinstance(message, ModelRequest):
151
- for part in message.parts:
152
- event_name, body = _request_part_body(part)
153
- if event_name:
154
- emit_event(event_name, body)
155
- elif isinstance(message, ModelResponse):
156
- for body in _response_bodies(message):
157
- emit_event('gen_ai.assistant.message', body)
158
127
 
159
128
  def finish(response: ModelResponse, usage: Usage):
160
129
  if not span.is_recording():
161
130
  return
162
131
 
163
- for response_body in _response_bodies(response):
164
- if response_body:
165
- emit_event(
132
+ events = self.messages_to_otel_events(messages)
133
+ for event in self.messages_to_otel_events([response]):
134
+ events.append(
135
+ Event(
166
136
  'gen_ai.choice',
167
- {
137
+ body={
168
138
  # TODO finish_reason
169
139
  'index': 0,
170
- 'message': response_body,
140
+ 'message': event.body,
171
141
  },
172
142
  )
143
+ )
173
144
  span.set_attributes(
174
145
  {
175
146
  # TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
@@ -178,67 +149,67 @@ class InstrumentedModel(WrapperModel):
178
149
  **usage.opentelemetry_attributes(),
179
150
  }
180
151
  )
181
- if events_list:
182
- attr_name = 'events'
183
- span.set_attributes(
184
- {
185
- attr_name: json.dumps(events_list),
186
- 'logfire.json_schema': json.dumps(
187
- {
188
- 'type': 'object',
189
- 'properties': {attr_name: {'type': 'array'}},
190
- }
191
- ),
192
- }
193
- )
152
+ self._emit_events(system, span, events)
194
153
 
195
154
  yield finish
196
155
 
197
- def _emit_event(
198
- self, system: str, events_list: list[dict[str, Any]], event_name: str, body: dict[str, Any]
199
- ) -> None:
200
- attributes = {'gen_ai.system': system}
156
+ def _emit_events(self, system: str, span: Span, events: list[Event]) -> None:
157
+ for event in events:
158
+ event.attributes = {'gen_ai.system': system, **(event.attributes or {})}
201
159
  if self.event_mode == 'logs':
202
- self.event_logger.emit(Event(event_name, body=body, attributes=attributes))
203
- else:
204
- events_list.append({'event.name': event_name, **body, **attributes})
205
-
206
-
207
- def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
208
- if isinstance(part, SystemPromptPart):
209
- return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
210
- elif isinstance(part, UserPromptPart):
211
- return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
212
- elif isinstance(part, ToolReturnPart):
213
- return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
214
- elif isinstance(part, RetryPromptPart):
215
- if part.tool_name is None:
216
- return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
160
+ for event in events:
161
+ self.event_logger.emit(event)
217
162
  else:
218
- return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
219
- else:
220
- return '', {}
221
-
222
-
223
- def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
224
- body: dict[str, Any] = {'role': 'assistant'}
225
- result = [body]
226
- for part in message.parts:
227
- if isinstance(part, ToolCallPart):
228
- body.setdefault('tool_calls', []).append(
163
+ attr_name = 'events'
164
+ span.set_attributes(
229
165
  {
230
- 'id': part.tool_call_id,
231
- 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
232
- 'function': {
233
- 'name': part.tool_name,
234
- 'arguments': part.args,
235
- },
166
+ attr_name: json.dumps([self.event_to_dict(event) for event in events]),
167
+ 'logfire.json_schema': json.dumps(
168
+ {
169
+ 'type': 'object',
170
+ 'properties': {attr_name: {'type': 'array'}},
171
+ }
172
+ ),
236
173
  }
237
174
  )
238
- elif isinstance(part, TextPart):
239
- if body.get('content'):
240
- body = {'role': 'assistant'}
241
- result.append(body)
242
- body['content'] = part.content
243
175
 
244
- return result
176
+ @staticmethod
177
+ def event_to_dict(event: Event) -> dict[str, Any]:
178
+ if not event.body:
179
+ body = {}
180
+ elif isinstance(event.body, Mapping):
181
+ body = event.body # type: ignore
182
+ else:
183
+ body = {'body': event.body}
184
+ return {**body, **(event.attributes or {})}
185
+
186
+ @staticmethod
187
+ def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
188
+ result: list[Event] = []
189
+ for message_index, message in enumerate(messages):
190
+ message_events: list[Event] = []
191
+ if isinstance(message, ModelRequest):
192
+ for part in message.parts:
193
+ if hasattr(part, 'otel_event'):
194
+ message_events.append(part.otel_event())
195
+ elif isinstance(message, ModelResponse):
196
+ message_events = message.otel_events()
197
+ for event in message_events:
198
+ event.attributes = {
199
+ 'gen_ai.message.index': message_index,
200
+ **(event.attributes or {}),
201
+ }
202
+ result.extend(message_events)
203
+ for event in result:
204
+ event.body = InstrumentedModel.serialize_any(event.body)
205
+ return result
206
+
207
+ @staticmethod
208
+ def serialize_any(value: Any) -> str:
209
+ try:
210
+ return ANY_ADAPTER.dump_python(value, mode='json')
211
+ except Exception:
212
+ try:
213
+ return str(value)
214
+ except Exception as e:
215
+ return f'Unable to serialize: {e}'
pydantic_ai/result.py CHANGED
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field
6
6
  from datetime import datetime
7
7
  from typing import Generic, Union, cast
8
8
 
9
- import logfire_api
10
9
  from typing_extensions import TypeVar, assert_type
11
10
 
12
11
  from . import _result, _utils, exceptions, messages as _messages, models
@@ -49,8 +48,6 @@ A function that always takes and returns the same type of data (which is the res
49
48
  Usage `ResultValidatorFunc[AgentDepsT, T]`.
50
49
  """
51
50
 
52
- _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
53
-
54
51
 
55
52
  @dataclass
56
53
  class AgentStream(Generic[AgentDepsT, ResultDataT]):
@@ -145,12 +142,14 @@ class AgentStream(Generic[AgentDepsT, ResultDataT]):
145
142
  if isinstance(e, _messages.PartStartEvent):
146
143
  new_part = e.part
147
144
  if isinstance(new_part, _messages.ToolCallPart):
148
- if result_schema is not None and (match := result_schema.find_tool([new_part])):
149
- call, _ = match
150
- return _messages.FinalResultEvent(tool_name=call.tool_name)
145
+ if result_schema:
146
+ for call, _ in result_schema.find_tool([new_part]):
147
+ return _messages.FinalResultEvent(
148
+ tool_name=call.tool_name, tool_call_id=call.tool_call_id
149
+ )
151
150
  elif allow_text_result:
152
151
  assert_type(e, _messages.PartStartEvent)
153
- return _messages.FinalResultEvent(tool_name=None)
152
+ return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
154
153
 
155
154
  usage_checking_stream = _get_usage_checking_stream_response(
156
155
  self._raw_stream_response, self._usage_limits, self.usage
@@ -300,17 +299,14 @@ class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
300
299
  if self._result_schema and not self._result_schema.allow_text_result:
301
300
  raise exceptions.UserError('stream_text() can only be used with text responses')
302
301
 
303
- with _logfire.span('response stream text') as lf_span:
304
- if delta:
305
- async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
306
- yield text
307
- else:
308
- combined_validated_text = ''
309
- async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
310
- combined_validated_text = await self._validate_text_result(text)
311
- yield combined_validated_text
312
- lf_span.set_attribute('combined_text', combined_validated_text)
313
- await self._marked_completed(self._stream_response.get())
302
+ if delta:
303
+ async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
304
+ yield text
305
+ else:
306
+ async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
307
+ combined_validated_text = await self._validate_text_result(text)
308
+ yield combined_validated_text
309
+ await self._marked_completed(self._stream_response.get())
314
310
 
315
311
  async def stream_structured(
316
312
  self, *, debounce_by: float | None = 0.1
@@ -325,22 +321,20 @@ class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
325
321
  Returns:
326
322
  An async iterable of the structured response message and whether that is the last message.
327
323
  """
328
- with _logfire.span('response stream structured') as lf_span:
329
- # if the message currently has any parts with content, yield before streaming
330
- msg = self._stream_response.get()
331
- for part in msg.parts:
332
- if part.has_content():
333
- yield msg, False
334
- break
335
-
336
- async for msg in self._stream_response_structured(debounce_by=debounce_by):
324
+ # if the message currently has any parts with content, yield before streaming
325
+ msg = self._stream_response.get()
326
+ for part in msg.parts:
327
+ if part.has_content():
337
328
  yield msg, False
329
+ break
338
330
 
339
- msg = self._stream_response.get()
340
- yield msg, True
331
+ async for msg in self._stream_response_structured(debounce_by=debounce_by):
332
+ yield msg, False
333
+
334
+ msg = self._stream_response.get()
335
+ yield msg, True
341
336
 
342
- lf_span.set_attribute('structured_response', msg)
343
- await self._marked_completed(msg)
337
+ await self._marked_completed(msg)
344
338
 
345
339
  async def get_data(self) -> ResultDataT:
346
340
  """Stream the whole response, validate and return it."""
@@ -472,6 +466,8 @@ class FinalResult(Generic[ResultDataT]):
472
466
  """The final result data."""
473
467
  tool_name: str | None
474
468
  """Name of the final result tool; `None` if the result came from unstructured text content."""
469
+ tool_call_id: str | None
470
+ """ID of the tool call that produced the final result; `None` if the result came from unstructured text content."""
475
471
 
476
472
 
477
473
  def _get_usage_checking_stream_response(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.30
3
+ Version: 0.0.32
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -28,11 +28,11 @@ Requires-Dist: eval-type-backport>=0.2.0
28
28
  Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
- Requires-Dist: logfire-api>=1.2.0
32
- Requires-Dist: pydantic-graph==0.0.30
31
+ Requires-Dist: opentelemetry-api>=1.28.0
32
+ Requires-Dist: pydantic-graph==0.0.32
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Provides-Extra: anthropic
35
- Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
35
+ Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
36
36
  Provides-Extra: cohere
37
37
  Requires-Dist: cohere>=5.13.11; extra == 'cohere'
38
38
  Provides-Extra: duckduckgo
@@ -1,36 +1,36 @@
1
- pydantic_ai/__init__.py,sha256=Rmpjmorf8YY1PtlkXRRNN-J3ZoQDSh7chaibVGyqY0k,937
2
- pydantic_ai/_agent_graph.py,sha256=gvJQ17A2glk8p2w2TCSfHwvWNp0vla1sQb0EZXOZbxU,30284
1
+ pydantic_ai/__init__.py,sha256=xrSDxkBwpUVInbPtTVhReEecStk-mWZMttAPUAQR0Ic,927
2
+ pydantic_ai/_agent_graph.py,sha256=wbhm3_5VNpx_Oy1_sQ_6b2hkaFjd9vd1v9g3Rw_8sJY,30127
3
3
  pydantic_ai/_griffe.py,sha256=RYRKiLbgG97QxnazbAwlnc74XxevGHLQet-FGfq9qls,3960
4
4
  pydantic_ai/_parts_manager.py,sha256=ARfDQY1_5AIY5rNl_M2fAYHEFCe03ZxdhgjHf9qeIKw,11872
5
5
  pydantic_ai/_pydantic.py,sha256=dROz3Hmfdi0C2exq88FhefDRVo_8S3rtkXnoUHzsz0c,8753
6
- pydantic_ai/_result.py,sha256=tN1pVulf_EM4bkBvpNUWPnUXezLY-sBrJEVCFdy2nLU,10264
6
+ pydantic_ai/_result.py,sha256=mqj3YrUzr5OT00h0KfGJglwQZ6_7nV7355Pvucd08ak,10276
7
7
  pydantic_ai/_system_prompt.py,sha256=602c2jyle2R_SesOrITBDETZqsLk4BZ8Cbo8yEhmx04,1120
8
- pydantic_ai/_utils.py,sha256=w9BYSfFZiaX757fRtMRclOL1uYzyQnxV_lxqbU2WTPs,9435
9
- pydantic_ai/agent.py,sha256=FeKELTSFKDkt6-UlmkezKnQTdnx1in6VckivqsfzfA4,65382
8
+ pydantic_ai/_utils.py,sha256=nx4Suswk2qjLvzphx8uQntKzFi-IzvhX_H1L7t_kJlQ,9579
9
+ pydantic_ai/agent.py,sha256=zqzFPvRvgb0iPCwRK2IhyTPJfbGjVMxCrQq9uDw9RM4,65873
10
10
  pydantic_ai/exceptions.py,sha256=1ujJeB3jDDQ-pH5ydBYrgStvR35-GlEW0bYGTGEr4ME,3127
11
11
  pydantic_ai/format_as_xml.py,sha256=QE7eMlg5-YUMw1_2kcI3h0uKYPZZyGkgXFDtfZTMeeI,4480
12
- pydantic_ai/messages.py,sha256=k8sX-V1cTeqXh1u6oJbqExZPYt3E7F3UCIudxvjKRO8,21486
12
+ pydantic_ai/messages.py,sha256=Yny2hIuExXfw9fvHDSPgbvfN91IOdcLaDEAmaCAoTBs,23751
13
13
  pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- pydantic_ai/result.py,sha256=Df_tPeqCQnLa0i0vVA-BGCJDx37ebD_3ojAmHnXE2yU,22767
14
+ pydantic_ai/result.py,sha256=LXKxRzy_rGMkdZ8xJ7yknPP3wGZtGNeZl-gh5opXbaQ,22542
15
15
  pydantic_ai/settings.py,sha256=ntuWnke9UA18aByDxk9OIhN0tAgOaPdqCEkRf-wlp8Y,3059
16
16
  pydantic_ai/tools.py,sha256=IPZuZJCSQUppz1uyLVwpfFLGoMirB8YtKWXIDQGR444,13414
17
17
  pydantic_ai/usage.py,sha256=VmpU_o_RjFI65J81G1wfCwDIAYBclMjeWfLtslntFOw,5406
18
18
  pydantic_ai/common_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  pydantic_ai/common_tools/duckduckgo.py,sha256=-kSa1gGn5-NIYvtxFWrFcX2XdmfEmGxI3_wAqrb6jLI,2230
20
20
  pydantic_ai/common_tools/tavily.py,sha256=Lz35037ggkdWKa_Stj0yXBkiN_hygDefEevoRDUclF0,2560
21
- pydantic_ai/models/__init__.py,sha256=2A3CpdMnvllnVVX8PmlUcBs0HMGcG4RurOXsRKl0BPc,13886
22
- pydantic_ai/models/anthropic.py,sha256=bFtE6hku9L4l4pKJg8XER37T2ST2htArho5lPjEohAk,20637
21
+ pydantic_ai/models/__init__.py,sha256=1TXjx6HIPy1keq9BGYzXHdR9YvRa3tYeB3I7JAJ5pWQ,14049
22
+ pydantic_ai/models/anthropic.py,sha256=DxoaSSo-HZYJSqbOAR2p7gsW6kUXY-SV6aA1j-8gy6c,20679
23
23
  pydantic_ai/models/cohere.py,sha256=6F6eWPGVT7mpMXlRugbVbR-a8Q1zmb1SKS_fWOoBL80,11514
24
24
  pydantic_ai/models/fallback.py,sha256=smHwNIpxu19JsgYYjY0nmzl3yox7yQRJ0Ir08zdhnk0,4207
25
25
  pydantic_ai/models/function.py,sha256=THIwVJ8qI3efYLNlYXlYze_J8hc7MHB-NMb3kpknq0g,11373
26
- pydantic_ai/models/gemini.py,sha256=2hDTMIMf899dp-MS0tLT7m1GkXsL9KIRMBklGM0VLB4,34223
26
+ pydantic_ai/models/gemini.py,sha256=IRLwvNcRiajiZzI5xDk13Fg_Q26uCbxc15ZrG8L9ufE,34255
27
27
  pydantic_ai/models/groq.py,sha256=Z4sZJDu5Yxa2tZiAPp9EjSVMz4uwLhS3fW7kFSc09gI,16406
28
- pydantic_ai/models/instrumented.py,sha256=xUZEn2VG8hP3hny0L5kZgXC5UnFdlUJ0DgXOxFmYhEo,9654
28
+ pydantic_ai/models/instrumented.py,sha256=7LXQgMtKyU3VQ1ReC7QdYFms01gAivJbPEJXij6HPYE,8196
29
29
  pydantic_ai/models/mistral.py,sha256=ZJ4xPcL9wJIQ5io34yP2fPyXy8GZrSvsW4itZiKPYFw,27448
30
30
  pydantic_ai/models/openai.py,sha256=koIcK_pDHmV-JFq_-VIzU-edAqGKOOzkSk5QSYWvfoc,20156
31
31
  pydantic_ai/models/test.py,sha256=Ux20cmuJFkhvI9L1N7ItHNFcd-j284TBEsrM53eWRag,16873
32
32
  pydantic_ai/models/vertexai.py,sha256=9Kp_1KMBlbP8_HRJTuFnrkkFmlJ7yFhADQYjxOgIh9Y,9523
33
33
  pydantic_ai/models/wrapper.py,sha256=Zr3fgiUBpt2N9gXds6iSwaMEtEsFKr9WwhpHjSoHa7o,1410
34
- pydantic_ai_slim-0.0.30.dist-info/METADATA,sha256=JDT77S9uw0w87WpAbXqK_c65849A7PeF1_dhJRGamiM,3062
35
- pydantic_ai_slim-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- pydantic_ai_slim-0.0.30.dist-info/RECORD,,
34
+ pydantic_ai_slim-0.0.32.dist-info/METADATA,sha256=MzmTjEJO4fYY3bQyY-wdxxKxoZYq-4iYlwv0trCixV8,3069
35
+ pydantic_ai_slim-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ pydantic_ai_slim-0.0.32.dist-info/RECORD,,