pydantic-ai-slim 0.2.14__tar.gz → 0.2.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (76) hide show
  1. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/PKG-INFO +4 -4
  2. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_agent_graph.py +0 -4
  3. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_function_schema.py +4 -4
  4. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_output.py +1 -1
  5. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_utils.py +5 -1
  6. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/agent.py +5 -6
  7. pydantic_ai_slim-0.2.16/pydantic_ai/ext/langchain.py +61 -0
  8. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/__init__.py +6 -1
  9. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/gemini.py +3 -1
  10. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/google.py +5 -2
  11. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/groq.py +4 -4
  12. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/mistral.py +5 -5
  13. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/openai.py +7 -7
  14. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/__init__.py +5 -1
  15. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/google_vertex.py +1 -1
  16. pydantic_ai_slim-0.2.16/pydantic_ai/providers/heroku.py +82 -0
  17. pydantic_ai_slim-0.2.16/pydantic_ai/py.typed +0 -0
  18. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/settings.py +1 -0
  19. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/tools.py +53 -6
  20. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/.gitignore +0 -0
  21. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/LICENSE +0 -0
  22. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/README.md +0 -0
  23. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/__init__.py +0 -0
  24. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/__main__.py +0 -0
  25. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_a2a.py +0 -0
  26. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_cli.py +0 -0
  27. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_griffe.py +0 -0
  28. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_parts_manager.py +0 -0
  29. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/_system_prompt.py +0 -0
  30. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/common_tools/__init__.py +0 -0
  31. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  32. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/common_tools/tavily.py +0 -0
  33. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/direct.py +0 -0
  34. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/exceptions.py +0 -0
  35. /pydantic_ai_slim-0.2.14/pydantic_ai/py.typed → /pydantic_ai_slim-0.2.16/pydantic_ai/ext/__init__.py +0 -0
  36. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/format_as_xml.py +0 -0
  37. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/format_prompt.py +0 -0
  38. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/mcp.py +0 -0
  39. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/messages.py +0 -0
  40. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/anthropic.py +0 -0
  41. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/bedrock.py +0 -0
  42. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/cohere.py +0 -0
  43. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/fallback.py +0 -0
  44. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/function.py +0 -0
  45. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/instrumented.py +0 -0
  46. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/test.py +0 -0
  47. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/models/wrapper.py +0 -0
  48. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/__init__.py +0 -0
  49. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/_json_schema.py +0 -0
  50. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/amazon.py +0 -0
  51. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/anthropic.py +0 -0
  52. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/cohere.py +0 -0
  53. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/deepseek.py +0 -0
  54. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/google.py +0 -0
  55. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/grok.py +0 -0
  56. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/meta.py +0 -0
  57. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/mistral.py +0 -0
  58. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/openai.py +0 -0
  59. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/profiles/qwen.py +0 -0
  60. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/anthropic.py +0 -0
  61. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/azure.py +0 -0
  62. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/bedrock.py +0 -0
  63. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/cohere.py +0 -0
  64. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/deepseek.py +0 -0
  65. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/fireworks.py +0 -0
  66. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/google.py +0 -0
  67. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/google_gla.py +0 -0
  68. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/grok.py +0 -0
  69. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/groq.py +0 -0
  70. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/mistral.py +0 -0
  71. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/openai.py +0 -0
  72. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/openrouter.py +0 -0
  73. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/providers/together.py +0 -0
  74. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/result.py +0 -0
  75. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pydantic_ai/usage.py +0 -0
  76. {pydantic_ai_slim-0.2.14 → pydantic_ai_slim-0.2.16}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
6
6
  License-Expression: MIT
@@ -30,11 +30,11 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
30
30
  Requires-Dist: griffe>=1.3.2
31
31
  Requires-Dist: httpx>=0.27
32
32
  Requires-Dist: opentelemetry-api>=1.28.0
33
- Requires-Dist: pydantic-graph==0.2.14
33
+ Requires-Dist: pydantic-graph==0.2.16
34
34
  Requires-Dist: pydantic>=2.10
35
35
  Requires-Dist: typing-inspection>=0.4.0
36
36
  Provides-Extra: a2a
37
- Requires-Dist: fasta2a==0.2.14; extra == 'a2a'
37
+ Requires-Dist: fasta2a==0.2.16; extra == 'a2a'
38
38
  Provides-Extra: anthropic
39
39
  Requires-Dist: anthropic>=0.52.0; extra == 'anthropic'
40
40
  Provides-Extra: bedrock
@@ -48,7 +48,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
48
48
  Provides-Extra: duckduckgo
49
49
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
50
50
  Provides-Extra: evals
51
- Requires-Dist: pydantic-evals==0.2.14; extra == 'evals'
51
+ Requires-Dist: pydantic-evals==0.2.16; extra == 'evals'
52
52
  Provides-Extra: google
53
53
  Requires-Dist: google-genai>=1.15.0; extra == 'google'
54
54
  Provides-Extra: groq
@@ -151,10 +151,6 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
151
151
  ctx.state.message_history = history
152
152
  run_context.messages = history
153
153
 
154
- # TODO: We need to make it so that function_tools are not shared between runs
155
- # See comment on the current_retry field of `Tool` for more details.
156
- for tool in ctx.deps.function_tools.values():
157
- tool.current_retry = 0
158
154
  return next_message
159
155
 
160
156
  async def _prepare_messages(
@@ -7,7 +7,7 @@ from __future__ import annotations as _annotations
7
7
 
8
8
  import inspect
9
9
  from collections.abc import Awaitable
10
- from dataclasses import dataclass
10
+ from dataclasses import dataclass, field
11
11
  from inspect import Parameter, signature
12
12
  from typing import TYPE_CHECKING, Any, Callable, cast
13
13
 
@@ -43,9 +43,9 @@ class FunctionSchema:
43
43
  # if not None, the function takes a single by that name (besides potentially `info`)
44
44
  takes_ctx: bool
45
45
  is_async: bool
46
- single_arg_name: str | None
47
- positional_fields: list[str]
48
- var_positional_field: str | None
46
+ single_arg_name: str | None = None
47
+ positional_fields: list[str] = field(default_factory=list)
48
+ var_positional_field: str | None = None
49
49
 
50
50
  async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any:
51
51
  args, kwargs = self._call_args(args_dict, ctx)
@@ -138,7 +138,7 @@ class ToolOutput(Generic[OutputDataT]):
138
138
  T_co = TypeVar('T_co', covariant=True)
139
139
  # output_type=Type or output_type=function or output_type=object.method
140
140
  SimpleOutputType = TypeAliasType(
141
- 'SimpleOutputType', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,)
141
+ 'SimpleOutputType', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,)
142
142
  )
143
143
  # output_type=ToolOutput(<see above>) or <see above>
144
144
  SimpleOutputTypeOrMarker = TypeAliasType(
@@ -12,7 +12,7 @@ from types import GenericAlias
12
12
  from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
13
13
 
14
14
  from anyio.to_thread import run_sync
15
- from pydantic import BaseModel
15
+ from pydantic import BaseModel, TypeAdapter
16
16
  from pydantic.json_schema import JsonSchemaValue
17
17
  from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
18
18
 
@@ -298,3 +298,7 @@ def dataclasses_no_defaults_repr(self: Any) -> str:
298
298
  f'{f.name}={getattr(self, f.name)!r}' for f in fields(self) if f.repr and getattr(self, f.name) != f.default
299
299
  )
300
300
  return f'{self.__class__.__qualname__}({", ".join(kv_pairs)})'
301
+
302
+
303
+ def number_to_datetime(x: int | float) -> datetime:
304
+ return TypeAdapter(datetime).validate_python(x)
@@ -646,11 +646,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
646
646
  # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
647
647
  output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
648
648
 
649
- # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
650
- # runs. Requires some changes to `Tool` to make them copyable though.
651
- for v in self._function_tools.values():
652
- v.current_retry = 0
653
-
654
649
  model_settings = merge_model_settings(self.model_settings, model_settings)
655
650
  usage_limits = usage_limits or _usage.UsageLimits()
656
651
 
@@ -679,6 +674,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
679
674
  instructions += '\n' + await instructions_runner.run(run_context)
680
675
  return instructions.strip()
681
676
 
677
+ # Copy the function tools so that retry state is agent-run-specific
678
+ # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
679
+ run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()}
680
+
682
681
  graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
683
682
  user_deps=deps,
684
683
  prompt=user_prompt,
@@ -690,7 +689,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
690
689
  end_strategy=self.end_strategy,
691
690
  output_schema=output_schema,
692
691
  output_validators=output_validators,
693
- function_tools=self._function_tools,
692
+ function_tools=run_function_tools,
694
693
  mcp_servers=self._mcp_servers,
695
694
  default_retries=self._default_retries,
696
695
  tracer=tracer,
@@ -0,0 +1,61 @@
1
+ from typing import Any, Protocol
2
+
3
+ from pydantic.json_schema import JsonSchemaValue
4
+
5
+ from pydantic_ai.tools import Tool
6
+
7
+
8
+ class LangChainTool(Protocol):
9
+ # args are like
10
+ # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'},
11
+ # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}}
12
+ @property
13
+ def args(self) -> dict[str, JsonSchemaValue]: ...
14
+
15
+ def get_input_jsonschema(self) -> JsonSchemaValue: ...
16
+
17
+ @property
18
+ def name(self) -> str: ...
19
+
20
+ @property
21
+ def description(self) -> str: ...
22
+
23
+ def run(self, *args: Any, **kwargs: Any) -> str: ...
24
+
25
+
26
+ __all__ = ('tool_from_langchain',)
27
+
28
+
29
+ def tool_from_langchain(langchain_tool: LangChainTool) -> Tool:
30
+ """Creates a Pydantic tool proxy from a LangChain tool.
31
+
32
+ Args:
33
+ langchain_tool: The LangChain tool to wrap.
34
+
35
+ Returns:
36
+ A Pydantic tool that corresponds to the LangChain tool.
37
+ """
38
+ function_name = langchain_tool.name
39
+ function_description = langchain_tool.description
40
+ inputs = langchain_tool.args.copy()
41
+ required = sorted({name for name, detail in inputs.items() if 'default' not in detail})
42
+ schema: JsonSchemaValue = langchain_tool.get_input_jsonschema()
43
+ if 'additionalProperties' not in schema:
44
+ schema['additionalProperties'] = False
45
+ if required:
46
+ schema['required'] = required
47
+
48
+ defaults = {name: detail['default'] for name, detail in inputs.items() if 'default' in detail}
49
+
50
+ # restructures the arguments to match langchain tool run
51
+ def proxy(*args: Any, **kwargs: Any) -> str:
52
+ assert not args, 'This should always be called with kwargs'
53
+ kwargs = defaults | kwargs
54
+ return langchain_tool.run(kwargs)
55
+
56
+ return Tool.from_schema(
57
+ function=proxy,
58
+ name=function_name,
59
+ description=function_description,
60
+ json_schema=schema,
61
+ )
@@ -211,6 +211,11 @@ KnownModelName = TypeAliasType(
211
211
  'groq:llama-3.2-3b-preview',
212
212
  'groq:llama-3.2-11b-vision-preview',
213
213
  'groq:llama-3.2-90b-vision-preview',
214
+ 'heroku:claude-3-5-haiku',
215
+ 'heroku:claude-3-5-sonnet-latest',
216
+ 'heroku:claude-3-7-sonnet',
217
+ 'heroku:claude-4-sonnet',
218
+ 'heroku:claude-3-haiku',
214
219
  'mistral:codestral-latest',
215
220
  'mistral:mistral-large-latest',
216
221
  'mistral:mistral-moderation-latest',
@@ -543,7 +548,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
543
548
  from .cohere import CohereModel
544
549
 
545
550
  return CohereModel(model_name, provider=provider)
546
- elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'):
551
+ elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku'):
547
552
  from .openai import OpenAIModel
548
553
 
549
554
  return OpenAIModel(model_name, provider=provider)
@@ -228,7 +228,7 @@ class GeminiModel(Model):
228
228
 
229
229
  if gemini_labels := model_settings.get('gemini_labels'):
230
230
  if self._system == 'google-vertex':
231
- request_data['labels'] = gemini_labels
231
+ request_data['labels'] = gemini_labels # pragma: lax no cover
232
232
 
233
233
  headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
234
234
  url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
@@ -366,6 +366,8 @@ def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _Gemi
366
366
  config: _GeminiGenerationConfig = {}
367
367
  if (max_tokens := model_settings.get('max_tokens')) is not None:
368
368
  config['max_output_tokens'] = max_tokens
369
+ if (stop_sequences := model_settings.get('stop_sequences')) is not None:
370
+ config['stop_sequences'] = stop_sequences # pragma: no cover
369
371
  if (temperature := model_settings.get('temperature')) is not None:
370
372
  config['temperature'] = temperature
371
373
  if (top_p := model_settings.get('top_p')) is not None:
@@ -260,6 +260,7 @@ class GoogleModel(Model):
260
260
  temperature=model_settings.get('temperature'),
261
261
  top_p=model_settings.get('top_p'),
262
262
  max_output_tokens=model_settings.get('max_tokens'),
263
+ stop_sequences=model_settings.get('stop_sequences'),
263
264
  presence_penalty=model_settings.get('presence_penalty'),
264
265
  frequency_penalty=model_settings.get('frequency_penalty'),
265
266
  safety_settings=model_settings.get('google_safety_settings'),
@@ -346,8 +347,10 @@ class GoogleModel(Model):
346
347
  else:
347
348
  assert_never(part)
348
349
 
349
- if message_parts: # pragma: no branch
350
- contents.append({'role': 'user', 'parts': message_parts})
350
+ # Google GenAI requires at least one part in the message.
351
+ if not message_parts:
352
+ message_parts = [{'text': ''}]
353
+ contents.append({'role': 'user', 'parts': message_parts})
351
354
  elif isinstance(m, ModelResponse):
352
355
  contents.append(_content_model_response(m))
353
356
  else:
@@ -4,13 +4,13 @@ import base64
4
4
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
- from datetime import datetime, timezone
7
+ from datetime import datetime
8
8
  from typing import Literal, Union, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
12
12
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
- from .._utils import guard_tool_call_id as _guard_tool_call_id
13
+ from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
14
14
  from ..messages import (
15
15
  BinaryContent,
16
16
  DocumentUrl,
@@ -246,7 +246,7 @@ class GroqModel(Model):
246
246
 
247
247
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
248
248
  """Process a non-streamed response, and prepare a message to return."""
249
- timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
249
+ timestamp = number_to_datetime(response.created)
250
250
  choice = response.choices[0]
251
251
  items: list[ModelResponsePart] = []
252
252
  if choice.message.content is not None:
@@ -270,7 +270,7 @@ class GroqModel(Model):
270
270
  return GroqStreamedResponse(
271
271
  _response=peekable_response,
272
272
  _model_name=self._model_name,
273
- _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
273
+ _timestamp=number_to_datetime(first_chunk.created),
274
274
  )
275
275
 
276
276
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -4,7 +4,7 @@ import base64
4
4
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
- from datetime import datetime, timezone
7
+ from datetime import datetime
8
8
  from typing import Any, Literal, Union, cast
9
9
 
10
10
  import pydantic_core
@@ -12,7 +12,7 @@ from httpx import Timeout
12
12
  from typing_extensions import assert_never
13
13
 
14
14
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
15
- from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
15
+ from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
16
16
  from ..messages import (
17
17
  BinaryContent,
18
18
  DocumentUrl,
@@ -312,7 +312,7 @@ class MistralModel(Model):
312
312
  assert response.choices, 'Unexpected empty response choice.'
313
313
 
314
314
  if response.created:
315
- timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
315
+ timestamp = number_to_datetime(response.created)
316
316
  else:
317
317
  timestamp = _now_utc()
318
318
 
@@ -347,9 +347,9 @@ class MistralModel(Model):
347
347
  )
348
348
 
349
349
  if first_chunk.data.created:
350
- timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
350
+ timestamp = number_to_datetime(first_chunk.data.created)
351
351
  else:
352
- timestamp = datetime.now(tz=timezone.utc)
352
+ timestamp = _now_utc()
353
353
 
354
354
  return MistralStreamedResponse(
355
355
  _response=peekable_response,
@@ -5,7 +5,7 @@ import warnings
5
5
  from collections.abc import AsyncIterable, AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
- from datetime import datetime, timezone
8
+ from datetime import datetime
9
9
  from typing import Any, Literal, Union, cast, overload
10
10
 
11
11
  from typing_extensions import assert_never
@@ -14,7 +14,7 @@ from pydantic_ai.profiles.openai import OpenAIModelProfile
14
14
  from pydantic_ai.providers import Provider, infer_provider
15
15
 
16
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
17
- from .._utils import guard_tool_call_id as _guard_tool_call_id
17
+ from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
18
18
  from ..messages import (
19
19
  AudioUrl,
20
20
  BinaryContent,
@@ -170,7 +170,7 @@ class OpenAIModel(Model):
170
170
  self,
171
171
  model_name: OpenAIModelName,
172
172
  *,
173
- provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
173
+ provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku']
174
174
  | Provider[AsyncOpenAI] = 'openai',
175
175
  profile: ModelProfileSpec | None = None,
176
176
  system_prompt_role: OpenAISystemPromptRole | None = None,
@@ -308,7 +308,7 @@ class OpenAIModel(Model):
308
308
 
309
309
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
310
310
  """Process a non-streamed response, and prepare a message to return."""
311
- timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
311
+ timestamp = number_to_datetime(response.created)
312
312
  choice = response.choices[0]
313
313
  items: list[ModelResponsePart] = []
314
314
  vendor_details: dict[str, Any] | None = None
@@ -358,7 +358,7 @@ class OpenAIModel(Model):
358
358
  return OpenAIStreamedResponse(
359
359
  _model_name=self._model_name,
360
360
  _response=peekable_response,
361
- _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
361
+ _timestamp=number_to_datetime(first_chunk.created),
362
362
  )
363
363
 
364
364
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -593,7 +593,7 @@ class OpenAIResponsesModel(Model):
593
593
 
594
594
  def _process_response(self, response: responses.Response) -> ModelResponse:
595
595
  """Process a non-streamed response, and prepare a message to return."""
596
- timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
596
+ timestamp = number_to_datetime(response.created_at)
597
597
  items: list[ModelResponsePart] = []
598
598
  items.append(TextPart(response.output_text))
599
599
  for item in response.output:
@@ -614,7 +614,7 @@ class OpenAIResponsesModel(Model):
614
614
  return OpenAIResponsesStreamedResponse(
615
615
  _model_name=self._model_name,
616
616
  _response=peekable_response,
617
- _timestamp=datetime.fromtimestamp(first_chunk.response.created_at, tz=timezone.utc),
617
+ _timestamp=number_to_datetime(first_chunk.response.created_at),
618
618
  )
619
619
 
620
620
  @overload
@@ -48,7 +48,7 @@ class Provider(ABC, Generic[InterfaceClient]):
48
48
  return None # pragma: no cover
49
49
 
50
50
 
51
- def infer_provider(provider: str) -> Provider[Any]:
51
+ def infer_provider(provider: str) -> Provider[Any]: # noqa: C901
52
52
  """Infer the provider from the provider name."""
53
53
  if provider == 'openai':
54
54
  from .openai import OpenAIProvider
@@ -107,5 +107,9 @@ def infer_provider(provider: str) -> Provider[Any]:
107
107
  from .together import TogetherProvider
108
108
 
109
109
  return TogetherProvider()
110
+ elif provider == 'heroku':
111
+ from .heroku import HerokuProvider
112
+
113
+ return HerokuProvider()
110
114
  else: # pragma: no cover
111
115
  raise ValueError(f'Unknown provider: {provider}')
@@ -50,7 +50,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
50
50
  return self._client
51
51
 
52
52
  def model_profile(self, model_name: str) -> ModelProfile | None:
53
- return google_model_profile(model_name)
53
+ return google_model_profile(model_name) # pragma: lax no cover
54
54
 
55
55
  @overload
56
56
  def __init__(
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+ from openai import AsyncOpenAI
8
+
9
+ from pydantic_ai.exceptions import UserError
10
+ from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles import ModelProfile
12
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
13
+ from pydantic_ai.providers import Provider
14
+
15
+ try:
16
+ from openai import AsyncOpenAI
17
+ except ImportError as _import_error: # pragma: no cover
18
+ raise ImportError(
19
+ 'Please install the `openai` package to use the Heroku provider, '
20
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
21
+ ) from _import_error
22
+
23
+
24
+ class HerokuProvider(Provider[AsyncOpenAI]):
25
+ """Provider for Heroku API."""
26
+
27
+ @property
28
+ def name(self) -> str:
29
+ return 'heroku'
30
+
31
+ @property
32
+ def base_url(self) -> str:
33
+ return str(self.client.base_url)
34
+
35
+ @property
36
+ def client(self) -> AsyncOpenAI:
37
+ return self._client
38
+
39
+ def model_profile(self, model_name: str) -> ModelProfile | None:
40
+ # As the Heroku API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer.
41
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer)
42
+
43
+ @overload
44
+ def __init__(self) -> None: ...
45
+
46
+ @overload
47
+ def __init__(self, *, api_key: str) -> None: ...
48
+
49
+ @overload
50
+ def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
51
+
52
+ @overload
53
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
54
+
55
+ def __init__(
56
+ self,
57
+ *,
58
+ base_url: str | None = None,
59
+ api_key: str | None = None,
60
+ openai_client: AsyncOpenAI | None = None,
61
+ http_client: AsyncHTTPClient | None = None,
62
+ ) -> None:
63
+ if openai_client is not None:
64
+ assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
65
+ assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
66
+ self._client = openai_client
67
+ else:
68
+ api_key = api_key or os.environ.get('HEROKU_INFERENCE_KEY')
69
+ if not api_key:
70
+ raise UserError(
71
+ 'Set the `HEROKU_INFERENCE_KEY` environment variable or pass it via `HerokuProvider(api_key=...)`'
72
+ 'to use the Heroku provider.'
73
+ )
74
+
75
+ base_url = base_url or os.environ.get('HEROKU_INFERENCE_URL', 'https://us.inference.heroku.com')
76
+ base_url = base_url.rstrip('/') + '/v1'
77
+
78
+ if http_client is not None:
79
+ self._client = AsyncOpenAI(api_key=api_key, http_client=http_client, base_url=base_url)
80
+ else:
81
+ http_client = cached_async_http_client(provider='heroku')
82
+ self._client = AsyncOpenAI(api_key=api_key, http_client=http_client, base_url=base_url)
File without changes
@@ -139,6 +139,7 @@ class ModelSettings(TypedDict, total=False):
139
139
  * Mistral
140
140
  * Groq
141
141
  * Cohere
142
+ * Google
142
143
  """
143
144
 
144
145
  extra_headers: dict[str, str]
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import asyncio
3
4
  import dataclasses
4
5
  import json
5
6
  from collections.abc import Awaitable, Sequence
@@ -9,8 +10,8 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union
9
10
  from opentelemetry.trace import Tracer
10
11
  from pydantic import ValidationError
11
12
  from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
12
- from pydantic_core import core_schema
13
- from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
13
+ from pydantic_core import SchemaValidator, core_schema
14
+ from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar
14
15
 
15
16
  from . import _function_schema, _utils, messages as _messages
16
17
  from .exceptions import ModelRetry, UnexpectedModelBehavior
@@ -63,7 +64,9 @@ class RunContext(Generic[AgentDepsT]):
63
64
  """The current step in the run."""
64
65
 
65
66
  def replace_with(
66
- self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
67
+ self,
68
+ retry: int | None = None,
69
+ tool_name: str | None | _utils.Unset = _utils.UNSET,
67
70
  ) -> RunContext[AgentDepsT]:
68
71
  # Create a new `RunContext` a new `retry` value and `tool_name`.
69
72
  kwargs = {}
@@ -215,8 +218,10 @@ class Tool(Generic[AgentDepsT]):
215
218
  This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request.
216
219
  """
217
220
 
218
- # TODO: Move this state off the Tool class, which is otherwise stateless.
219
- # This should be tracked inside a specific agent run, not the tool.
221
+ # TODO: Consider moving this current_retry state to live on something other than the tool.
222
+ # We've worked around this for now by copying instances of the tool when creating new runs,
223
+ # but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things
224
+ # up, though is also likely a larger effort to refactor.
220
225
  current_retry: int = field(default=0, init=False)
221
226
 
222
227
  def __init__(
@@ -304,6 +309,45 @@ class Tool(Generic[AgentDepsT]):
304
309
  self.require_parameter_descriptions = require_parameter_descriptions
305
310
  self.strict = strict
306
311
 
312
+ @classmethod
313
+ def from_schema(
314
+ cls,
315
+ function: Callable[..., Any],
316
+ name: str,
317
+ description: str,
318
+ json_schema: JsonSchemaValue,
319
+ ) -> Self:
320
+ """Creates a Pydantic tool from a function and a JSON schema.
321
+
322
+ Args:
323
+ function: The function to call.
324
+ This will be called with keywords only, and no validation of
325
+ the arguments will be performed.
326
+ name: The unique name of the tool that clearly communicates its purpose
327
+ description: Used to tell the model how/when/why to use the tool.
328
+ You can provide few-shot examples as a part of the description.
329
+ json_schema: The schema for the function arguments
330
+
331
+ Returns:
332
+ A Pydantic tool that calls the function
333
+ """
334
+ function_schema = _function_schema.FunctionSchema(
335
+ function=function,
336
+ description=description,
337
+ validator=SchemaValidator(schema=core_schema.any_schema()),
338
+ json_schema=json_schema,
339
+ takes_ctx=False,
340
+ is_async=asyncio.iscoroutinefunction(function),
341
+ )
342
+
343
+ return cls(
344
+ function,
345
+ takes_ctx=False,
346
+ name=name,
347
+ description=description,
348
+ function_schema=function_schema,
349
+ )
350
+
307
351
  async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
308
352
  """Get the tool definition.
309
353
 
@@ -325,7 +369,10 @@ class Tool(Generic[AgentDepsT]):
325
369
  return tool_def
326
370
 
327
371
  async def run(
328
- self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer
372
+ self,
373
+ message: _messages.ToolCallPart,
374
+ run_context: RunContext[AgentDepsT],
375
+ tracer: Tracer,
329
376
  ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
330
377
  """Run the tool function asynchronously.
331
378