pydantic-ai-slim 0.0.6a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py ADDED
@@ -0,0 +1,795 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import asyncio
4
+ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
5
+ from contextlib import asynccontextmanager, contextmanager
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Callable, Generic, cast, final, overload
8
+
9
+ import logfire_api
10
+ from typing_extensions import assert_never
11
+
12
+ from . import (
13
+ _result,
14
+ _retriever as _r,
15
+ _system_prompt,
16
+ _utils,
17
+ exceptions,
18
+ messages as _messages,
19
+ models,
20
+ result,
21
+ )
22
+ from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
23
+ from .result import ResultData
24
+
25
+ __all__ = ('Agent',)
26
+
27
+ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
28
+
29
+ NoneType = type(None)
30
+
31
+
32
+ @final
33
+ @dataclass(init=False)
34
+ class Agent(Generic[AgentDeps, ResultData]):
35
+ """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
36
+
37
+ Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.dependencies.AgentDeps]
38
+ and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData].
39
+
40
+ By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
41
+
42
+ Minimal usage example:
43
+
44
+ ```py
45
+ from pydantic_ai import Agent
46
+
47
+ agent = Agent('openai:gpt-4o')
48
+ result = agent.run_sync('What is the capital of France?')
49
+ print(result.data)
50
+ #> Paris
51
+ ```
52
+ """
53
+
54
+ # dataclass fields mostly for my sanity — knowing what attributes are available
55
+ model: models.Model | models.KnownModelName | None
56
+ """The default model configured for this agent."""
57
+ _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
58
+ _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
59
+ _allow_text_result: bool = field(repr=False)
60
+ _system_prompts: tuple[str, ...] = field(repr=False)
61
+ _retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = field(repr=False)
62
+ _default_retries: int = field(repr=False)
63
+ _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
64
+ _deps_type: type[AgentDeps] = field(repr=False)
65
+ _max_result_retries: int = field(repr=False)
66
+ _current_result_retry: int = field(repr=False)
67
+ _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
68
+ _override_model: _utils.Option[models.Model] = field(default=None, repr=False)
69
+ last_run_messages: list[_messages.Message] | None = None
70
+ """The messages from the last run, useful when a run raised an exception.
71
+
72
+ Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ model: models.Model | models.KnownModelName | None = None,
78
+ result_type: type[ResultData] = str,
79
+ *,
80
+ system_prompt: str | Sequence[str] = (),
81
+ deps_type: type[AgentDeps] = NoneType,
82
+ retries: int = 1,
83
+ result_tool_name: str = 'final_result',
84
+ result_tool_description: str | None = None,
85
+ result_retries: int | None = None,
86
+ defer_model_check: bool = False,
87
+ ):
88
+ """Create an agent.
89
+
90
+ Args:
91
+ model: The default model to use for this agent, if not provide,
92
+ you must provide the model when calling the agent.
93
+ result_type: The type of the result data, used to validate the result data, defaults to `str`.
94
+ system_prompt: Static system prompts to use for this agent, you can also register system
95
+ prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
96
+ deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully
97
+ parameterize the agent, and therefore get the best out of static type checking.
98
+ If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
99
+ or add a type hint `: Agent[None, <return type>]`.
100
+ retries: The default number of retries to allow before raising an error.
101
+ result_tool_name: The name of the tool to use for the final result.
102
+ result_tool_description: The description of the final result tool.
103
+ result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
104
+ defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
105
+ it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
106
+ which checks for the necessary environment variables. Set this to `false`
107
+ to defer the evaluation until the first run. Useful if you want to
108
+ [override the model][pydantic_ai.Agent.override_model] for testing.
109
+ """
110
+ if model is None or defer_model_check:
111
+ self.model = model
112
+ else:
113
+ self.model = models.infer_model(model)
114
+
115
+ self._result_schema = _result.ResultSchema[result_type].build(
116
+ result_type, result_tool_name, result_tool_description
117
+ )
118
+ # if the result tool is None, or its schema allows `str`, we allow plain text results
119
+ self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
120
+
121
+ self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
122
+ self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {}
123
+ self._deps_type = deps_type
124
+ self._default_retries = retries
125
+ self._system_prompt_functions = []
126
+ self._max_result_retries = result_retries if result_retries is not None else retries
127
+ self._current_result_retry = 0
128
+ self._result_validators = []
129
+
130
+ async def run(
131
+ self,
132
+ user_prompt: str,
133
+ *,
134
+ message_history: list[_messages.Message] | None = None,
135
+ model: models.Model | models.KnownModelName | None = None,
136
+ deps: AgentDeps = None,
137
+ ) -> result.RunResult[ResultData]:
138
+ """Run the agent with a user prompt in async mode.
139
+
140
+ Args:
141
+ user_prompt: User input to start/continue the conversation.
142
+ message_history: History of the conversation so far.
143
+ model: Optional model to use for this run, required if `model` was not set when creating the agent.
144
+ deps: Optional dependencies to use for this run.
145
+
146
+ Returns:
147
+ The result of the run.
148
+ """
149
+ model_used, custom_model, agent_model = await self._get_agent_model(model)
150
+
151
+ deps = self._get_deps(deps)
152
+
153
+ new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
154
+ self.last_run_messages = messages
155
+
156
+ for retriever in self._retrievers.values():
157
+ retriever.reset()
158
+
159
+ cost = result.Cost()
160
+
161
+ with _logfire.span(
162
+ 'agent run {prompt=}',
163
+ prompt=user_prompt,
164
+ agent=self,
165
+ custom_model=custom_model,
166
+ model_name=model_used.name(),
167
+ ) as run_span:
168
+ run_step = 0
169
+ while True:
170
+ run_step += 1
171
+ with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
172
+ model_response, request_cost = await agent_model.request(messages)
173
+ model_req_span.set_attribute('response', model_response)
174
+ model_req_span.set_attribute('cost', request_cost)
175
+ model_req_span.message = f'model request -> {model_response.role}'
176
+
177
+ messages.append(model_response)
178
+ cost += request_cost
179
+
180
+ with _logfire.span('handle model response') as handle_span:
181
+ either = await self._handle_model_response(model_response, deps)
182
+
183
+ if isinstance(either, _MarkFinalResult):
184
+ # we have a final result, end the conversation
185
+ result_data = either.data
186
+ run_span.set_attribute('all_messages', messages)
187
+ run_span.set_attribute('cost', cost)
188
+ handle_span.set_attribute('result', result_data)
189
+ handle_span.message = 'handle model response -> final result'
190
+ return result.RunResult(messages, new_message_index, result_data, cost)
191
+ else:
192
+ # continue the conversation
193
+ tool_responses = either
194
+ handle_span.set_attribute('tool_responses', tool_responses)
195
+ response_msgs = ' '.join(m.role for m in tool_responses)
196
+ handle_span.message = f'handle model response -> {response_msgs}'
197
+ messages.extend(tool_responses)
198
+
199
+ def run_sync(
200
+ self,
201
+ user_prompt: str,
202
+ *,
203
+ message_history: list[_messages.Message] | None = None,
204
+ model: models.Model | models.KnownModelName | None = None,
205
+ deps: AgentDeps = None,
206
+ ) -> result.RunResult[ResultData]:
207
+ """Run the agent with a user prompt synchronously.
208
+
209
+ This is a convenience method that wraps `self.run` with `asyncio.run()`.
210
+
211
+ Args:
212
+ user_prompt: User input to start/continue the conversation.
213
+ message_history: History of the conversation so far.
214
+ model: Optional model to use for this run, required if `model` was not set when creating the agent.
215
+ deps: Optional dependencies to use for this run.
216
+
217
+ Returns:
218
+ The result of the run.
219
+ """
220
+ return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
221
+
222
+ @asynccontextmanager
223
+ async def run_stream(
224
+ self,
225
+ user_prompt: str,
226
+ *,
227
+ message_history: list[_messages.Message] | None = None,
228
+ model: models.Model | models.KnownModelName | None = None,
229
+ deps: AgentDeps = None,
230
+ ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
231
+ """Run the agent with a user prompt in async mode, returning a streamed response.
232
+
233
+ Args:
234
+ user_prompt: User input to start/continue the conversation.
235
+ message_history: History of the conversation so far.
236
+ model: Optional model to use for this run, required if `model` was not set when creating the agent.
237
+ deps: Optional dependencies to use for this run.
238
+
239
+ Returns:
240
+ The result of the run.
241
+ """
242
+ model_used, custom_model, agent_model = await self._get_agent_model(model)
243
+
244
+ deps = self._get_deps(deps)
245
+
246
+ new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
247
+ self.last_run_messages = messages
248
+
249
+ for retriever in self._retrievers.values():
250
+ retriever.reset()
251
+
252
+ cost = result.Cost()
253
+
254
+ with _logfire.span(
255
+ 'agent run stream {prompt=}',
256
+ prompt=user_prompt,
257
+ agent=self,
258
+ custom_model=custom_model,
259
+ model_name=model_used.name(),
260
+ ) as run_span:
261
+ run_step = 0
262
+ while True:
263
+ run_step += 1
264
+ with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
265
+ async with agent_model.request_stream(messages) as model_response:
266
+ model_req_span.set_attribute('response_type', model_response.__class__.__name__)
267
+ # We want to end the "model request" span here, but we can't exit the context manager
268
+ # in the traditional way
269
+ model_req_span.__exit__(None, None, None)
270
+
271
+ with _logfire.span('handle model response') as handle_span:
272
+ either = await self._handle_streamed_model_response(model_response, deps)
273
+
274
+ if isinstance(either, _MarkFinalResult):
275
+ result_stream = either.data
276
+ run_span.set_attribute('all_messages', messages)
277
+ handle_span.set_attribute('result_type', result_stream.__class__.__name__)
278
+ handle_span.message = 'handle model response -> final result'
279
+ yield result.StreamedRunResult(
280
+ messages,
281
+ new_message_index,
282
+ cost,
283
+ result_stream,
284
+ self._result_schema,
285
+ deps,
286
+ self._result_validators,
287
+ )
288
+ return
289
+ else:
290
+ tool_responses = either
291
+ handle_span.set_attribute('tool_responses', tool_responses)
292
+ response_msgs = ' '.join(m.role for m in tool_responses)
293
+ handle_span.message = f'handle model response -> {response_msgs}'
294
+ messages.extend(tool_responses)
295
+ # the model_response should have been fully streamed by now, we can add it's cost
296
+ cost += model_response.cost()
297
+
298
+ @contextmanager
299
+ def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]:
300
+ """Context manager to temporarily override agent dependencies, this is particularly useful when testing.
301
+
302
+ Args:
303
+ overriding_deps: The dependencies to use instead of the dependencies passed to the agent run.
304
+ """
305
+ override_deps_before = self._override_deps
306
+ self._override_deps = _utils.Some(overriding_deps)
307
+ try:
308
+ yield
309
+ finally:
310
+ self._override_deps = override_deps_before
311
+
312
+ @contextmanager
313
+ def override_model(self, overriding_model: models.Model | models.KnownModelName) -> Iterator[None]:
314
+ """Context manager to temporarily override the model used by the agent.
315
+
316
+ Args:
317
+ overriding_model: The model to use instead of the model passed to the agent run.
318
+ """
319
+ override_model_before = self._override_model
320
+ self._override_model = _utils.Some(models.infer_model(overriding_model))
321
+ try:
322
+ yield
323
+ finally:
324
+ self._override_model = override_model_before
325
+
326
+ @overload
327
+ def system_prompt(
328
+ self, func: Callable[[CallContext[AgentDeps]], str], /
329
+ ) -> Callable[[CallContext[AgentDeps]], str]: ...
330
+
331
+ @overload
332
+ def system_prompt(
333
+ self, func: Callable[[CallContext[AgentDeps]], Awaitable[str]], /
334
+ ) -> Callable[[CallContext[AgentDeps]], Awaitable[str]]: ...
335
+
336
+ @overload
337
+ def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
338
+
339
+ @overload
340
+ def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
341
+
342
+ def system_prompt(
343
+ self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
344
+ ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
345
+ """Decorator to register a system prompt function.
346
+
347
+ Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's only argument.
348
+ Can decorate a sync or async functions.
349
+
350
+ Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
351
+ the type of the function, see `tests/typed_agent.py` for tests.
352
+
353
+ Example:
354
+ ```py
355
+ from pydantic_ai import Agent, CallContext
356
+
357
+ agent = Agent('test', deps_type=str)
358
+
359
+ @agent.system_prompt
360
+ def simple_system_prompt() -> str:
361
+ return 'foobar'
362
+
363
+ @agent.system_prompt
364
+ async def async_system_prompt(ctx: CallContext[str]) -> str:
365
+ return f'{ctx.deps} is the best'
366
+
367
+ result = agent.run_sync('foobar', deps='spam')
368
+ print(result.data)
369
+ #> success (no retriever calls)
370
+ ```
371
+ """
372
+ self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
373
+ return func
374
+
375
+ @overload
376
+ def result_validator(
377
+ self, func: Callable[[CallContext[AgentDeps], ResultData], ResultData], /
378
+ ) -> Callable[[CallContext[AgentDeps], ResultData], ResultData]: ...
379
+
380
+ @overload
381
+ def result_validator(
382
+ self, func: Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]], /
383
+ ) -> Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
384
+
385
+ @overload
386
+ def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
387
+
388
+ @overload
389
+ def result_validator(
390
+ self, func: Callable[[ResultData], Awaitable[ResultData]], /
391
+ ) -> Callable[[ResultData], Awaitable[ResultData]]: ...
392
+
393
+ def result_validator(
394
+ self, func: _result.ResultValidatorFunc[AgentDeps, ResultData], /
395
+ ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
396
+ """Decorator to register a result validator function.
397
+
398
+ Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's first argument.
399
+ Can decorate a sync or async functions.
400
+
401
+ Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
402
+ the type of the function, see `tests/typed_agent.py` for tests.
403
+
404
+ Example:
405
+ ```py
406
+ from pydantic_ai import Agent, CallContext, ModelRetry
407
+
408
+ agent = Agent('test', deps_type=str)
409
+
410
+ @agent.result_validator
411
+ def result_validator_simple(data: str) -> str:
412
+ if 'wrong' in data:
413
+ raise ModelRetry('wrong response')
414
+ return data
415
+
416
+ @agent.result_validator
417
+ async def result_validator_deps(ctx: CallContext[str], data: str) -> str:
418
+ if ctx.deps in data:
419
+ raise ModelRetry('wrong response')
420
+ return data
421
+
422
+ result = agent.run_sync('foobar', deps='spam')
423
+ print(result.data)
424
+ #> success (no retriever calls)
425
+ ```
426
+ """
427
+ self._result_validators.append(_result.ResultValidator(func))
428
+ return func
429
+
430
+ @overload
431
+ def retriever(
432
+ self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
433
+ ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...
434
+
435
+ @overload
436
+ def retriever(
437
+ self, /, *, retries: int | None = None
438
+ ) -> Callable[
439
+ [RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
440
+ ]: ...
441
+
442
+ def retriever(
443
+ self,
444
+ func: RetrieverContextFunc[AgentDeps, RetrieverParams] | None = None,
445
+ /,
446
+ *,
447
+ retries: int | None = None,
448
+ ) -> Any:
449
+ """Decorator to register a retriever function which takes
450
+ [`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument.
451
+
452
+ Can decorate a sync or async functions.
453
+
454
+ The docstring is inspected to extract both the tool description and description of each parameter,
455
+ [learn more](../agents.md#retrievers-tools-and-schema).
456
+
457
+ We can't add overloads for every possible signature of retriever, since the return type is a recursive union
458
+ so the signature of functions decorated with `@agent.retriever` is obscured.
459
+
460
+ Example:
461
+ ```py
462
+ from pydantic_ai import Agent, CallContext
463
+
464
+ agent = Agent('test', deps_type=int)
465
+
466
+ @agent.retriever
467
+ def foobar(ctx: CallContext[int], x: int) -> int:
468
+ return ctx.deps + x
469
+
470
+ @agent.retriever(retries=2)
471
+ async def spam(ctx: CallContext[str], y: float) -> float:
472
+ return ctx.deps + y
473
+
474
+ result = agent.run_sync('foobar', deps=1)
475
+ print(result.data)
476
+ #> {"foobar":1,"spam":1.0}
477
+ ```
478
+
479
+ Args:
480
+ func: The retriever function to register.
481
+ retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
482
+ which defaults to 1.
483
+ """ # noqa: D205
484
+ if func is None:
485
+
486
+ def retriever_decorator(
487
+ func_: RetrieverContextFunc[AgentDeps, RetrieverParams],
488
+ ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]:
489
+ # noinspection PyTypeChecker
490
+ self._register_retriever(_utils.Either(left=func_), retries)
491
+ return func_
492
+
493
+ return retriever_decorator
494
+ else:
495
+ # noinspection PyTypeChecker
496
+ self._register_retriever(_utils.Either(left=func), retries)
497
+ return func
498
+
499
+ @overload
500
+ def retriever_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ...
501
+
502
+ @overload
503
+ def retriever_plain(
504
+ self, /, *, retries: int | None = None
505
+ ) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ...
506
+
507
+ def retriever_plain(
508
+ self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None
509
+ ) -> Any:
510
+ """Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
511
+
512
+ Can decorate a sync or async functions.
513
+
514
+ The docstring is inspected to extract both the tool description and description of each parameter,
515
+ [learn more](../agents.md#retrievers-tools-and-schema).
516
+
517
+ We can't add overloads for every possible signature of retriever, since the return type is a recursive union
518
+ so the signature of functions decorated with `@agent.retriever` is obscured.
519
+
520
+ Example:
521
+ ```py
522
+ from pydantic_ai import Agent, CallContext
523
+
524
+ agent = Agent('test')
525
+
526
+ @agent.retriever
527
+ def foobar(ctx: CallContext[int]) -> int:
528
+ return 123
529
+
530
+ @agent.retriever(retries=2)
531
+ async def spam(ctx: CallContext[str]) -> float:
532
+ return 3.14
533
+
534
+ result = agent.run_sync('foobar', deps=1)
535
+ print(result.data)
536
+ #> {"foobar":123,"spam":3.14}
537
+ ```
538
+
539
+ Args:
540
+ func: The retriever function to register.
541
+ retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
542
+ which defaults to 1.
543
+ """
544
+ if func is None:
545
+
546
+ def retriever_decorator(
547
+ func_: RetrieverPlainFunc[RetrieverParams],
548
+ ) -> RetrieverPlainFunc[RetrieverParams]:
549
+ # noinspection PyTypeChecker
550
+ self._register_retriever(_utils.Either(right=func_), retries)
551
+ return func_
552
+
553
+ return retriever_decorator
554
+ else:
555
+ self._register_retriever(_utils.Either(right=func), retries)
556
+ return func
557
+
558
+ def _register_retriever(
559
+ self, func: _r.RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int | None
560
+ ) -> None:
561
+ """Private utility to register a retriever function."""
562
+ retries_ = retries if retries is not None else self._default_retries
563
+ retriever = _r.Retriever[AgentDeps, RetrieverParams](func, retries_)
564
+
565
+ if self._result_schema and retriever.name in self._result_schema.tools:
566
+ raise ValueError(f'Retriever name conflicts with result schema name: {retriever.name!r}')
567
+
568
+ if retriever.name in self._retrievers:
569
+ raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}')
570
+
571
+ self._retrievers[retriever.name] = retriever
572
+
573
+ async def _get_agent_model(
574
+ self, model: models.Model | models.KnownModelName | None
575
+ ) -> tuple[models.Model, models.Model | None, models.AgentModel]:
576
+ """Create a model configured for this agent.
577
+
578
+ Args:
579
+ model: model to use for this run, required if `model` was not set when creating the agent.
580
+
581
+ Returns:
582
+ a tuple of `(model used, custom_model if any, agent_model)`
583
+ """
584
+ model_: models.Model
585
+ if some_model := self._override_model:
586
+ # we don't want `override_model()` to cover up errors from the model not being defined, hence this check
587
+ if model is None and self.model is None:
588
+ raise exceptions.UserError(
589
+ '`model` must be set either when creating the agent or when calling it. '
590
+ '(Even when `override_model()` is customizing the model that will actually be called)'
591
+ )
592
+ model_ = some_model.value
593
+ custom_model = None
594
+ elif model is not None:
595
+ custom_model = model_ = models.infer_model(model)
596
+ elif self.model is not None:
597
+ # noinspection PyTypeChecker
598
+ model_ = self.model = models.infer_model(self.model)
599
+ custom_model = None
600
+ else:
601
+ raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
602
+
603
+ result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
604
+ agent_model = await model_.agent_model(self._retrievers, self._allow_text_result, result_tools)
605
+ return model_, custom_model, agent_model
606
+
607
+ async def _prepare_messages(
608
+ self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
609
+ ) -> tuple[int, list[_messages.Message]]:
610
+ # if message history includes system prompts, we don't want to regenerate them
611
+ if message_history and any(m.role == 'system' for m in message_history):
612
+ # shallow copy messages
613
+ messages = message_history.copy()
614
+ else:
615
+ messages = await self._init_messages(deps)
616
+ if message_history:
617
+ messages += message_history
618
+
619
+ new_message_index = len(messages)
620
+ messages.append(_messages.UserPrompt(user_prompt))
621
+ return new_message_index, messages
622
+
623
+ async def _handle_model_response(
624
+ self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
625
+ ) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
626
+ """Process a non-streamed response from the model.
627
+
628
+ Returns:
629
+ Return `Either` — left: final result data, right: list of messages to send back to the model.
630
+ """
631
+ if model_response.role == 'model-text-response':
632
+ # plain string response
633
+ if self._allow_text_result:
634
+ result_data_input = cast(ResultData, model_response.content)
635
+ try:
636
+ result_data = await self._validate_result(result_data_input, deps, None)
637
+ except _result.ToolRetryError as e:
638
+ self._incr_result_retry()
639
+ return [e.tool_retry]
640
+ else:
641
+ return _MarkFinalResult(result_data)
642
+ else:
643
+ self._incr_result_retry()
644
+ response = _messages.RetryPrompt(
645
+ content='Plain text responses are not permitted, please call one of the functions instead.',
646
+ )
647
+ return [response]
648
+ elif model_response.role == 'model-structured-response':
649
+ if self._result_schema is not None:
650
+ # if there's a result schema, and any of the calls match one of its tools, return the result
651
+ # NOTE: this means we ignore any other tools called here
652
+ if match := self._result_schema.find_tool(model_response):
653
+ call, result_tool = match
654
+ try:
655
+ result_data = result_tool.validate(call)
656
+ result_data = await self._validate_result(result_data, deps, call)
657
+ except _result.ToolRetryError as e:
658
+ self._incr_result_retry()
659
+ return [e.tool_retry]
660
+ else:
661
+ return _MarkFinalResult(result_data)
662
+
663
+ if not model_response.calls:
664
+ raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
665
+
666
+ # otherwise we run all retriever functions in parallel
667
+ messages: list[_messages.Message] = []
668
+ tasks: list[asyncio.Task[_messages.Message]] = []
669
+ for call in model_response.calls:
670
+ if retriever := self._retrievers.get(call.tool_name):
671
+ tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
672
+ else:
673
+ messages.append(self._unknown_tool(call.tool_name))
674
+
675
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
676
+ messages += await asyncio.gather(*tasks)
677
+ return messages
678
+ else:
679
+ assert_never(model_response)
680
+
681
+ async def _handle_streamed_model_response(
682
+ self, model_response: models.EitherStreamedResponse, deps: AgentDeps
683
+ ) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
684
+ """Process a streamed response from the model.
685
+
686
+ TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
687
+ (with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
688
+
689
+ Returns:
690
+ Return `Either` — left: final result data, right: list of messages to send back to the model.
691
+ """
692
+ if isinstance(model_response, models.StreamTextResponse):
693
+ # plain string response
694
+ if self._allow_text_result:
695
+ return _MarkFinalResult(model_response)
696
+ else:
697
+ self._incr_result_retry()
698
+ response = _messages.RetryPrompt(
699
+ content='Plain text responses are not permitted, please call one of the functions instead.',
700
+ )
701
+ # stream the response, so cost is correct
702
+ async for _ in model_response:
703
+ pass
704
+
705
+ return [response]
706
+ else:
707
+ assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
708
+ if self._result_schema is not None:
709
+ # if there's a result schema, iterate over the stream until we find at least one tool
710
+ # NOTE: this means we ignore any other tools called here
711
+ structured_msg = model_response.get()
712
+ while not structured_msg.calls:
713
+ try:
714
+ await model_response.__anext__()
715
+ except StopAsyncIteration:
716
+ break
717
+ structured_msg = model_response.get()
718
+
719
+ if self._result_schema.find_tool(structured_msg):
720
+ return _MarkFinalResult(model_response)
721
+
722
+ # the model is calling a retriever function, consume the response to get the next message
723
+ async for _ in model_response:
724
+ pass
725
+ structured_msg = model_response.get()
726
+ if not structured_msg.calls:
727
+ raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
728
+ messages: list[_messages.Message] = [structured_msg]
729
+
730
+ # we now run all retriever functions in parallel
731
+ tasks: list[asyncio.Task[_messages.Message]] = []
732
+ for call in structured_msg.calls:
733
+ if retriever := self._retrievers.get(call.tool_name):
734
+ tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
735
+ else:
736
+ messages.append(self._unknown_tool(call.tool_name))
737
+
738
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
739
+ messages += await asyncio.gather(*tasks)
740
+ return messages
741
+
742
+ async def _validate_result(
743
+ self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
744
+ ) -> ResultData:
745
+ for validator in self._result_validators:
746
+ result_data = await validator.validate(result_data, deps, self._current_result_retry, tool_call)
747
+ return result_data
748
+
749
+ def _incr_result_retry(self) -> None:
750
+ self._current_result_retry += 1
751
+ if self._current_result_retry > self._max_result_retries:
752
+ raise exceptions.UnexpectedModelBehavior(
753
+ f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
754
+ )
755
+
756
+ async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
757
+ """Build the initial messages for the conversation."""
758
+ messages: list[_messages.Message] = [_messages.SystemPrompt(p) for p in self._system_prompts]
759
+ for sys_prompt_runner in self._system_prompt_functions:
760
+ prompt = await sys_prompt_runner.run(deps)
761
+ messages.append(_messages.SystemPrompt(prompt))
762
+ return messages
763
+
764
+ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
765
+ self._incr_result_retry()
766
+ names = list(self._retrievers.keys())
767
+ if self._result_schema:
768
+ names.extend(self._result_schema.tool_names())
769
+ if names:
770
+ msg = f'Available tools: {", ".join(names)}'
771
+ else:
772
+ msg = 'No tools available.'
773
+ return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')
774
+
775
+ def _get_deps(self, deps: AgentDeps) -> AgentDeps:
776
+ """Get deps for a run.
777
+
778
+ If we've overridden deps via `_override_deps_stack`, use that, otherwise use the deps passed to the call.
779
+
780
+ We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
781
+ """
782
+ if some_deps := self._override_deps:
783
+ return some_deps.value
784
+ else:
785
+ return deps
786
+
787
+
788
+ @dataclass
789
+ class _MarkFinalResult(Generic[ResultData]):
790
+ """Marker class to indicate that the result is the final result.
791
+
792
+ This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
793
+ """
794
+
795
+ data: ResultData