pydantic-ai-slim 0.0.30__tar.gz → 0.0.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (37) hide show
  1. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/PKG-INFO +4 -4
  2. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/__init__.py +2 -2
  3. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_agent_graph.py +86 -73
  4. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_result.py +3 -3
  5. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_utils.py +2 -0
  6. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/agent.py +54 -47
  7. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/messages.py +55 -0
  8. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/__init__.py +4 -0
  9. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/anthropic.py +3 -1
  10. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/gemini.py +1 -0
  11. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/instrumented.py +72 -101
  12. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/result.py +27 -31
  13. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pyproject.toml +4 -4
  14. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/.gitignore +0 -0
  15. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/README.md +0 -0
  16. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_griffe.py +0 -0
  17. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_parts_manager.py +0 -0
  18. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_pydantic.py +0 -0
  19. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/_system_prompt.py +0 -0
  20. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/common_tools/__init__.py +0 -0
  21. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  22. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/common_tools/tavily.py +0 -0
  23. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/exceptions.py +0 -0
  24. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/format_as_xml.py +0 -0
  25. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/cohere.py +0 -0
  26. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/fallback.py +0 -0
  27. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/function.py +0 -0
  28. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/groq.py +0 -0
  29. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/mistral.py +0 -0
  30. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/openai.py +0 -0
  31. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/test.py +0 -0
  32. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/vertexai.py +0 -0
  33. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/models/wrapper.py +0 -0
  34. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/py.typed +0 -0
  35. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/settings.py +0 -0
  36. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/tools.py +0 -0
  37. {pydantic_ai_slim-0.0.30 → pydantic_ai_slim-0.0.32}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.30
3
+ Version: 0.0.32
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -28,11 +28,11 @@ Requires-Dist: eval-type-backport>=0.2.0
28
28
  Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
- Requires-Dist: logfire-api>=1.2.0
32
- Requires-Dist: pydantic-graph==0.0.30
31
+ Requires-Dist: opentelemetry-api>=1.28.0
32
+ Requires-Dist: pydantic-graph==0.0.32
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Provides-Extra: anthropic
35
- Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
35
+ Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
36
36
  Provides-Extra: cohere
37
37
  Requires-Dist: cohere>=5.13.11; extra == 'cohere'
38
38
  Provides-Extra: duckduckgo
@@ -1,6 +1,6 @@
1
1
  from importlib.metadata import version
2
2
 
3
- from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
3
+ from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
4
4
  from .exceptions import (
5
5
  AgentRunError,
6
6
  FallbackExceptionGroup,
@@ -18,7 +18,7 @@ __all__ = (
18
18
  # agent
19
19
  'Agent',
20
20
  'EndStrategy',
21
- 'HandleResponseNode',
21
+ 'CallToolsNode',
22
22
  'ModelRequestNode',
23
23
  'UserPromptNode',
24
24
  'capture_run_messages',
@@ -2,13 +2,14 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import dataclasses
5
+ import json
5
6
  from collections.abc import AsyncIterator, Iterator, Sequence
6
7
  from contextlib import asynccontextmanager, contextmanager
7
8
  from contextvars import ContextVar
8
9
  from dataclasses import field
9
10
  from typing import Any, Generic, Literal, Union, cast
10
11
 
11
- import logfire_api
12
+ from opentelemetry.trace import Span, Tracer
12
13
  from typing_extensions import TypeGuard, TypeVar, assert_never
13
14
 
14
15
  from pydantic_graph import BaseNode, Graph, GraphRunContext
@@ -23,6 +24,7 @@ from . import (
23
24
  result,
24
25
  usage as _usage,
25
26
  )
27
+ from .models.instrumented import InstrumentedModel
26
28
  from .result import ResultDataT
27
29
  from .settings import ModelSettings, merge_model_settings
28
30
  from .tools import (
@@ -36,22 +38,11 @@ __all__ = (
36
38
  'GraphAgentDeps',
37
39
  'UserPromptNode',
38
40
  'ModelRequestNode',
39
- 'HandleResponseNode',
41
+ 'CallToolsNode',
40
42
  'build_run_context',
41
43
  'capture_run_messages',
42
44
  )
43
45
 
44
- _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
45
-
46
- # while waiting for https://github.com/pydantic/logfire/issues/745
47
- try:
48
- import logfire._internal.stack_info
49
- except ImportError:
50
- pass
51
- else:
52
- from pathlib import Path
53
-
54
- logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
55
46
 
56
47
  T = TypeVar('T')
57
48
  S = TypeVar('S')
@@ -104,7 +95,8 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
104
95
 
105
96
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
106
97
 
107
- run_span: logfire_api.LogfireSpan
98
+ run_span: Span
99
+ tracer: Tracer
108
100
 
109
101
 
110
102
  class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
@@ -243,12 +235,12 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
243
235
 
244
236
  request: _messages.ModelRequest
245
237
 
246
- _result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
238
+ _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
247
239
  _did_stream: bool = field(default=False, repr=False)
248
240
 
249
241
  async def run(
250
242
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
251
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
243
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
252
244
  if self._result is not None:
253
245
  return self._result
254
246
 
@@ -286,39 +278,33 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
286
278
  assert not self._did_stream, 'stream() should only be called once per node'
287
279
 
288
280
  model_settings, model_request_parameters = await self._prepare_request(ctx)
289
- with _logfire.span('model request', run_step=ctx.state.run_step) as span:
290
- async with ctx.deps.model.request_stream(
291
- ctx.state.message_history, model_settings, model_request_parameters
292
- ) as streamed_response:
293
- self._did_stream = True
294
- ctx.state.usage.incr(_usage.Usage(), requests=1)
295
- yield streamed_response
296
- # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
297
- # otherwise usage won't be properly counted:
298
- async for _ in streamed_response:
299
- pass
300
- model_response = streamed_response.get()
301
- request_usage = streamed_response.usage()
302
- span.set_attribute('response', model_response)
303
- span.set_attribute('usage', request_usage)
281
+ async with ctx.deps.model.request_stream(
282
+ ctx.state.message_history, model_settings, model_request_parameters
283
+ ) as streamed_response:
284
+ self._did_stream = True
285
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
286
+ yield streamed_response
287
+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
288
+ # otherwise usage won't be properly counted:
289
+ async for _ in streamed_response:
290
+ pass
291
+ model_response = streamed_response.get()
292
+ request_usage = streamed_response.usage()
304
293
 
305
294
  self._finish_handling(ctx, model_response, request_usage)
306
295
  assert self._result is not None # this should be set by the previous line
307
296
 
308
297
  async def _make_request(
309
298
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
310
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
299
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
311
300
  if self._result is not None:
312
301
  return self._result
313
302
 
314
303
  model_settings, model_request_parameters = await self._prepare_request(ctx)
315
- with _logfire.span('model request', run_step=ctx.state.run_step) as span:
316
- model_response, request_usage = await ctx.deps.model.request(
317
- ctx.state.message_history, model_settings, model_request_parameters
318
- )
319
- ctx.state.usage.incr(_usage.Usage(), requests=1)
320
- span.set_attribute('response', model_response)
321
- span.set_attribute('usage', request_usage)
304
+ model_response, request_usage = await ctx.deps.model.request(
305
+ ctx.state.message_history, model_settings, model_request_parameters
306
+ )
307
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
322
308
 
323
309
  return self._finish_handling(ctx, model_response, request_usage)
324
310
 
@@ -335,7 +321,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
335
321
  ctx.state.run_step += 1
336
322
 
337
323
  model_settings = merge_model_settings(ctx.deps.model_settings, None)
338
- with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
324
+ with ctx.deps.tracer.start_as_current_span(
325
+ 'preparing model request params', attributes=dict(run_step=ctx.state.run_step)
326
+ ):
339
327
  model_request_parameters = await _prepare_request_parameters(ctx)
340
328
  return model_settings, model_request_parameters
341
329
 
@@ -344,7 +332,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
344
332
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
345
333
  response: _messages.ModelResponse,
346
334
  usage: _usage.Usage,
347
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
335
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
348
336
  # Update usage
349
337
  ctx.state.usage.incr(usage, requests=0)
350
338
  if ctx.deps.usage_limits:
@@ -354,13 +342,13 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
354
342
  ctx.state.message_history.append(response)
355
343
 
356
344
  # Set the `_result` attribute since we can't use `return` in an async iterator
357
- self._result = HandleResponseNode(response)
345
+ self._result = CallToolsNode(response)
358
346
 
359
347
  return self._result
360
348
 
361
349
 
362
350
  @dataclasses.dataclass
363
- class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
351
+ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
364
352
  """Process a model response, and decide whether to end the run or make a new request."""
365
353
 
366
354
  model_response: _messages.ModelResponse
@@ -385,26 +373,12 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
385
373
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
386
374
  ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
387
375
  """Process the model response and yield events for the start and end of each function tool call."""
388
- with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
389
- stream = self._run_stream(ctx)
390
- yield stream
391
-
392
- # Run the stream to completion if it was not finished:
393
- async for _event in stream:
394
- pass
376
+ stream = self._run_stream(ctx)
377
+ yield stream
395
378
 
396
- # Set the next node based on the final state of the stream
397
- next_node = self._next_node
398
- if isinstance(next_node, End):
399
- handle_span.set_attribute('result', next_node.data)
400
- handle_span.message = 'handle model response -> final result'
401
- elif tool_responses := self._tool_responses:
402
- # TODO: We could drop `self._tool_responses` if we drop this set_attribute
403
- # I'm thinking it might be better to just create a span for the handling of each tool
404
- # than to set an attribute here.
405
- handle_span.set_attribute('tool_responses', tool_responses)
406
- tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
407
- handle_span.message = f'handle model response -> {tool_responses_str}'
379
+ # Run the stream to completion if it was not finished:
380
+ async for _event in stream:
381
+ pass
408
382
 
409
383
  async def _run_stream(
410
384
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -454,8 +428,7 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
454
428
  final_result: result.FinalResult[NodeRunEndT] | None = None
455
429
  parts: list[_messages.ModelRequestPart] = []
456
430
  if result_schema is not None:
457
- if match := result_schema.find_tool(tool_calls):
458
- call, result_tool = match
431
+ for call, result_tool in result_schema.find_tool(tool_calls):
459
432
  try:
460
433
  result_data = result_tool.validate(call)
461
434
  result_data = await _validate_result(result_data, ctx, call)
@@ -465,12 +438,17 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
465
438
  ctx.state.increment_retries(ctx.deps.max_result_retries)
466
439
  parts.append(e.tool_retry)
467
440
  else:
468
- final_result = result.FinalResult(result_data, call.tool_name)
441
+ final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
442
+ break
469
443
 
470
444
  # Then build the other request parts based on end strategy
471
445
  tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
472
446
  async for event in process_function_tools(
473
- tool_calls, final_result and final_result.tool_name, ctx, tool_responses
447
+ tool_calls,
448
+ final_result and final_result.tool_name,
449
+ final_result and final_result.tool_call_id,
450
+ ctx,
451
+ tool_responses,
474
452
  ):
475
453
  yield event
476
454
 
@@ -495,8 +473,30 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
495
473
  if tool_responses:
496
474
  messages.append(_messages.ModelRequest(parts=tool_responses))
497
475
 
498
- run_span.set_attribute('usage', usage)
499
- run_span.set_attribute('all_messages', messages)
476
+ run_span.set_attributes(
477
+ {
478
+ **usage.opentelemetry_attributes(),
479
+ 'all_messages_events': json.dumps(
480
+ [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
481
+ ),
482
+ 'final_result': final_result.data
483
+ if isinstance(final_result.data, str)
484
+ else json.dumps(InstrumentedModel.serialize_any(final_result.data)),
485
+ }
486
+ )
487
+ run_span.set_attributes(
488
+ {
489
+ 'logfire.json_schema': json.dumps(
490
+ {
491
+ 'type': 'object',
492
+ 'properties': {
493
+ 'all_messages_events': {'type': 'array'},
494
+ 'final_result': {'type': 'object'},
495
+ },
496
+ }
497
+ ),
498
+ }
499
+ )
500
500
 
501
501
  # End the run with self.data
502
502
  return End(final_result)
@@ -518,7 +518,7 @@ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
518
518
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
519
519
  else:
520
520
  # The following cast is safe because we know `str` is an allowed result type
521
- return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
521
+ return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
522
522
  else:
523
523
  ctx.state.increment_retries(ctx.deps.max_result_retries)
524
524
  return ModelRequestNode[DepsT, NodeRunEndT](
@@ -547,6 +547,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
547
547
  async def process_function_tools(
548
548
  tool_calls: list[_messages.ToolCallPart],
549
549
  result_tool_name: str | None,
550
+ result_tool_call_id: str | None,
550
551
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
551
552
  output_parts: list[_messages.ModelRequestPart],
552
553
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
@@ -566,7 +567,11 @@ async def process_function_tools(
566
567
  calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
567
568
  call_index_to_event_id: dict[int, str] = {}
568
569
  for call in tool_calls:
569
- if call.tool_name == result_tool_name and not found_used_result_tool:
570
+ if (
571
+ call.tool_name == result_tool_name
572
+ and call.tool_call_id == result_tool_call_id
573
+ and not found_used_result_tool
574
+ ):
570
575
  found_used_result_tool = True
571
576
  output_parts.append(
572
577
  _messages.ToolReturnPart(
@@ -593,9 +598,14 @@ async def process_function_tools(
593
598
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
594
599
  # validation, we don't add another part here
595
600
  if result_tool_name is not None:
601
+ if found_used_result_tool:
602
+ content = 'Result tool not used - a final result was already processed.'
603
+ else:
604
+ # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
605
+ content = 'Result tool not used - result failed validation.'
596
606
  part = _messages.ToolReturnPart(
597
607
  tool_name=call.tool_name,
598
- content='Result tool not used - a final result was already processed.',
608
+ content=content,
599
609
  tool_call_id=call.tool_call_id,
600
610
  )
601
611
  output_parts.append(part)
@@ -607,7 +617,10 @@ async def process_function_tools(
607
617
 
608
618
  # Run all tool tasks in parallel
609
619
  results_by_index: dict[int, _messages.ModelRequestPart] = {}
610
- with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]):
620
+ tool_names = [call.tool_name for _, call in calls_to_run]
621
+ with ctx.deps.tracer.start_as_current_span(
622
+ 'running tools', attributes={'tools': tool_names, 'logfire.msg': f'running tools: {", ".join(tool_names)}'}
623
+ ):
611
624
  # TODO: Should we wrap each individual tool call in a dedicated span?
612
625
  tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
613
626
  pending = tasks
@@ -716,7 +729,7 @@ def build_agent_graph(
716
729
  nodes = (
717
730
  UserPromptNode[DepsT],
718
731
  ModelRequestNode[DepsT],
719
- HandleResponseNode[DepsT],
732
+ CallToolsNode[DepsT],
720
733
  )
721
734
  graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
722
735
  nodes=nodes,
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import inspect
4
4
  import sys
5
5
  import types
6
- from collections.abc import Awaitable, Iterable
6
+ from collections.abc import Awaitable, Iterable, Iterator
7
7
  from dataclasses import dataclass, field
8
8
  from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
9
9
 
@@ -127,12 +127,12 @@ class ResultSchema(Generic[ResultDataT]):
127
127
  def find_tool(
128
128
  self,
129
129
  parts: Iterable[_messages.ModelResponsePart],
130
- ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
130
+ ) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
131
131
  """Find a tool that matches one of the calls."""
132
132
  for part in parts:
133
133
  if isinstance(part, _messages.ToolCallPart):
134
134
  if result := self.tools.get(part.tool_name):
135
- return part, result
135
+ yield part, result
136
136
 
137
137
  def tool_names(self) -> list[str]:
138
138
  """Return the names of the tools."""
@@ -48,6 +48,8 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
48
48
 
49
49
  if schema.get('type') == 'object':
50
50
  return schema
51
+ elif schema.get('$ref') is not None:
52
+ return schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
51
53
  else:
52
54
  raise UserError('Schema must be an object')
53
55