pydantic-ai-slim 0.0.55__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,116 +1,9 @@
1
- from __future__ import annotations as _annotations
1
+ from typing_extensions import deprecated
2
2
 
3
- from collections.abc import Iterable, Iterator, Mapping
4
- from dataclasses import asdict, dataclass, is_dataclass
5
- from datetime import date
6
- from typing import Any
7
- from xml.etree import ElementTree
3
+ from .format_prompt import format_as_xml as _format_as_xml
8
4
 
9
- from pydantic import BaseModel
10
5
 
11
- __all__ = ('format_as_xml',)
12
-
13
-
14
- def format_as_xml(
15
- obj: Any,
16
- root_tag: str = 'examples',
17
- item_tag: str = 'example',
18
- include_root_tag: bool = True,
19
- none_str: str = 'null',
20
- indent: str | None = ' ',
21
- ) -> str:
22
- """Format a Python object as XML.
23
-
24
- This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
25
- rather than JSON etc.
26
-
27
- Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`,
28
- `Iterable`, `dataclass`, and `BaseModel`.
29
-
30
- Args:
31
- obj: Python Object to serialize to XML.
32
- root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
33
- item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
34
- for dataclasses and Pydantic models.
35
- include_root_tag: Whether to include the root tag in the output
36
- (The root tag is always included if it includes a body - e.g. when the input is a simple value).
37
- none_str: String to use for `None` values.
38
- indent: Indentation string to use for pretty printing.
39
-
40
- Returns:
41
- XML representation of the object.
42
-
43
- Example:
44
- ```python {title="format_as_xml_example.py" lint="skip"}
45
- from pydantic_ai.format_as_xml import format_as_xml
46
-
47
- print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user'))
48
- '''
49
- <user>
50
- <name>John</name>
51
- <height>6</height>
52
- <weight>200</weight>
53
- </user>
54
- '''
55
- ```
56
- """
57
- el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
58
- if not include_root_tag and el.text is None:
59
- join = '' if indent is None else '\n'
60
- return join.join(_rootless_xml_elements(el, indent))
61
- else:
62
- if indent is not None:
63
- ElementTree.indent(el, space=indent)
64
- return ElementTree.tostring(el, encoding='unicode')
65
-
66
-
67
- @dataclass
68
- class _ToXml:
69
- item_tag: str
70
- none_str: str
71
-
72
- def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
73
- element = ElementTree.Element(self.item_tag if tag is None else tag)
74
- if value is None:
75
- element.text = self.none_str
76
- elif isinstance(value, str):
77
- element.text = value
78
- elif isinstance(value, (bytes, bytearray)):
79
- element.text = value.decode(errors='ignore')
80
- elif isinstance(value, (bool, int, float)):
81
- element.text = str(value)
82
- elif isinstance(value, date):
83
- element.text = value.isoformat()
84
- elif isinstance(value, Mapping):
85
- self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
86
- elif is_dataclass(value) and not isinstance(value, type):
87
- if tag is None:
88
- element = ElementTree.Element(value.__class__.__name__)
89
- dc_dict = asdict(value)
90
- self._mapping_to_xml(element, dc_dict)
91
- elif isinstance(value, BaseModel):
92
- if tag is None:
93
- element = ElementTree.Element(value.__class__.__name__)
94
- self._mapping_to_xml(element, value.model_dump(mode='python'))
95
- elif isinstance(value, Iterable):
96
- for item in value: # pyright: ignore[reportUnknownVariableType]
97
- item_el = self.to_xml(item, None)
98
- element.append(item_el)
99
- else:
100
- raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
101
- return element
102
-
103
- def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
104
- for key, value in mapping.items():
105
- if isinstance(key, int):
106
- key = str(key)
107
- elif not isinstance(key, str):
108
- raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
109
- element.append(self.to_xml(value, key))
110
-
111
-
112
- def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
113
- for sub_element in root:
114
- if indent is not None:
115
- ElementTree.indent(sub_element, space=indent)
116
- yield ElementTree.tostring(sub_element, encoding='unicode')
6
+ @deprecated('`format_as_xml` has moved, import it via `from pydantic_ai import format_as_xml`')
7
+ def format_as_xml(prompt: str) -> str:
8
+ """`format_as_xml` has moved, import it via `from pydantic_ai import format_as_xml` instead."""
9
+ return _format_as_xml(prompt)
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import Iterable, Iterator, Mapping
4
+ from dataclasses import asdict, dataclass, is_dataclass
5
+ from datetime import date
6
+ from typing import Any
7
+ from xml.etree import ElementTree
8
+
9
+ from pydantic import BaseModel
10
+
11
+ __all__ = ('format_as_xml',)
12
+
13
+
14
+ def format_as_xml(
15
+ obj: Any,
16
+ root_tag: str = 'examples',
17
+ item_tag: str = 'example',
18
+ include_root_tag: bool = True,
19
+ none_str: str = 'null',
20
+ indent: str | None = ' ',
21
+ ) -> str:
22
+ """Format a Python object as XML.
23
+
24
+ This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
25
+ rather than JSON etc.
26
+
27
+ Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`,
28
+ `Iterable`, `dataclass`, and `BaseModel`.
29
+
30
+ Args:
31
+ obj: Python Object to serialize to XML.
32
+ root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
33
+ item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
34
+ for dataclasses and Pydantic models.
35
+ include_root_tag: Whether to include the root tag in the output
36
+ (The root tag is always included if it includes a body - e.g. when the input is a simple value).
37
+ none_str: String to use for `None` values.
38
+ indent: Indentation string to use for pretty printing.
39
+
40
+ Returns:
41
+ XML representation of the object.
42
+
43
+ Example:
44
+ ```python {title="format_as_xml_example.py" lint="skip"}
45
+ from pydantic_ai import format_as_xml
46
+
47
+ print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user'))
48
+ '''
49
+ <user>
50
+ <name>John</name>
51
+ <height>6</height>
52
+ <weight>200</weight>
53
+ </user>
54
+ '''
55
+ ```
56
+ """
57
+ el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
58
+ if not include_root_tag and el.text is None:
59
+ join = '' if indent is None else '\n'
60
+ return join.join(_rootless_xml_elements(el, indent))
61
+ else:
62
+ if indent is not None:
63
+ ElementTree.indent(el, space=indent)
64
+ return ElementTree.tostring(el, encoding='unicode')
65
+
66
+
67
+ @dataclass
68
+ class _ToXml:
69
+ item_tag: str
70
+ none_str: str
71
+
72
+ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
73
+ element = ElementTree.Element(self.item_tag if tag is None else tag)
74
+ if value is None:
75
+ element.text = self.none_str
76
+ elif isinstance(value, str):
77
+ element.text = value
78
+ elif isinstance(value, (bytes, bytearray)):
79
+ element.text = value.decode(errors='ignore')
80
+ elif isinstance(value, (bool, int, float)):
81
+ element.text = str(value)
82
+ elif isinstance(value, date):
83
+ element.text = value.isoformat()
84
+ elif isinstance(value, Mapping):
85
+ self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
86
+ elif is_dataclass(value) and not isinstance(value, type):
87
+ if tag is None:
88
+ element = ElementTree.Element(value.__class__.__name__)
89
+ dc_dict = asdict(value)
90
+ self._mapping_to_xml(element, dc_dict)
91
+ elif isinstance(value, BaseModel):
92
+ if tag is None:
93
+ element = ElementTree.Element(value.__class__.__name__)
94
+ self._mapping_to_xml(element, value.model_dump(mode='python'))
95
+ elif isinstance(value, Iterable):
96
+ for item in value: # pyright: ignore[reportUnknownVariableType]
97
+ item_el = self.to_xml(item, None)
98
+ element.append(item_el)
99
+ else:
100
+ raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
101
+ return element
102
+
103
+ def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
104
+ for key, value in mapping.items():
105
+ if isinstance(key, int):
106
+ key = str(key)
107
+ elif not isinstance(key, str):
108
+ raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
109
+ element.append(self.to_xml(value, key))
110
+
111
+
112
+ def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
113
+ for sub_element in root:
114
+ if indent is not None:
115
+ ElementTree.indent(sub_element, space=indent)
116
+ yield ElementTree.tostring(sub_element, encoding='unicode')
pydantic_ai/messages.py CHANGED
@@ -15,6 +15,34 @@ from typing_extensions import TypeAlias
15
15
  from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
16
16
  from .exceptions import UnexpectedModelBehavior
17
17
 
18
+ AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
19
+ ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
20
+ DocumentMediaType: TypeAlias = Literal[
21
+ 'application/pdf',
22
+ 'text/plain',
23
+ 'text/csv',
24
+ 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
25
+ 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
26
+ 'text/html',
27
+ 'text/markdown',
28
+ 'application/vnd.ms-excel',
29
+ ]
30
+ VideoMediaType: TypeAlias = Literal[
31
+ 'video/x-matroska',
32
+ 'video/quicktime',
33
+ 'video/mp4',
34
+ 'video/webm',
35
+ 'video/x-flv',
36
+ 'video/mpeg',
37
+ 'video/x-ms-wmv',
38
+ 'video/3gpp',
39
+ ]
40
+
41
+ AudioFormat: TypeAlias = Literal['wav', 'mp3']
42
+ ImageFormat: TypeAlias = Literal['jpeg', 'png', 'gif', 'webp']
43
+ DocumentFormat: TypeAlias = Literal['csv', 'doc', 'docx', 'html', 'md', 'pdf', 'txt', 'xls', 'xlsx']
44
+ VideoFormat: TypeAlias = Literal['mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp']
45
+
18
46
 
19
47
  @dataclass
20
48
  class SystemPromptPart:
@@ -42,6 +70,47 @@ class SystemPromptPart:
42
70
  return Event('gen_ai.system.message', body={'content': self.content, 'role': 'system'})
43
71
 
44
72
 
73
+ @dataclass
74
+ class VideoUrl:
75
+ """A URL to an video."""
76
+
77
+ url: str
78
+ """The URL of the video."""
79
+
80
+ kind: Literal['video-url'] = 'video-url'
81
+ """Type identifier, this is available on all parts as a discriminator."""
82
+
83
+ @property
84
+ def media_type(self) -> VideoMediaType: # pragma: no cover
85
+ """Return the media type of the video, based on the url."""
86
+ if self.url.endswith('.mkv'):
87
+ return 'video/x-matroska'
88
+ elif self.url.endswith('.mov'):
89
+ return 'video/quicktime'
90
+ elif self.url.endswith('.mp4'):
91
+ return 'video/mp4'
92
+ elif self.url.endswith('.webm'):
93
+ return 'video/webm'
94
+ elif self.url.endswith('.flv'):
95
+ return 'video/x-flv'
96
+ elif self.url.endswith(('.mpeg', '.mpg')):
97
+ return 'video/mpeg'
98
+ elif self.url.endswith('.wmv'):
99
+ return 'video/x-ms-wmv'
100
+ elif self.url.endswith('.three_gp'):
101
+ return 'video/3gpp'
102
+ else:
103
+ raise ValueError(f'Unknown video file extension: {self.url}')
104
+
105
+ @property
106
+ def format(self) -> VideoFormat:
107
+ """The file format of the video.
108
+
109
+ The choice of supported formats were based on the Bedrock Converse API. Other APIs don't require to use a format.
110
+ """
111
+ return _video_format(self.media_type)
112
+
113
+
45
114
  @dataclass
46
115
  class AudioUrl:
47
116
  """A URL to an audio file."""
@@ -123,23 +192,6 @@ class DocumentUrl:
123
192
  return _document_format(self.media_type)
124
193
 
125
194
 
126
- AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
127
- ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
128
- DocumentMediaType: TypeAlias = Literal[
129
- 'application/pdf',
130
- 'text/plain',
131
- 'text/csv',
132
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
133
- 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
134
- 'text/html',
135
- 'text/markdown',
136
- 'application/vnd.ms-excel',
137
- ]
138
- AudioFormat: TypeAlias = Literal['wav', 'mp3']
139
- ImageFormat: TypeAlias = Literal['jpeg', 'png', 'gif', 'webp']
140
- DocumentFormat: TypeAlias = Literal['csv', 'doc', 'docx', 'html', 'md', 'pdf', 'txt', 'xls', 'xlsx']
141
-
142
-
143
195
  @dataclass
144
196
  class BinaryContent:
145
197
  """Binary content, e.g. an audio or image file."""
@@ -163,6 +215,11 @@ class BinaryContent:
163
215
  """Return `True` if the media type is an image type."""
164
216
  return self.media_type.startswith('image/')
165
217
 
218
+ @property
219
+ def is_video(self) -> bool:
220
+ """Return `True` if the media type is a video type."""
221
+ return self.media_type.startswith('video/')
222
+
166
223
  @property
167
224
  def is_document(self) -> bool:
168
225
  """Return `True` if the media type is a document type."""
@@ -189,10 +246,12 @@ class BinaryContent:
189
246
  return _image_format(self.media_type)
190
247
  elif self.is_document:
191
248
  return _document_format(self.media_type)
249
+ elif self.is_video:
250
+ return _video_format(self.media_type)
192
251
  raise ValueError(f'Unknown media type: {self.media_type}')
193
252
 
194
253
 
195
- UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | BinaryContent'
254
+ UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent'
196
255
 
197
256
 
198
257
  def _document_format(media_type: str) -> DocumentFormat:
@@ -229,6 +288,27 @@ def _image_format(media_type: str) -> ImageFormat:
229
288
  raise ValueError(f'Unknown image media type: {media_type}')
230
289
 
231
290
 
291
+ def _video_format(media_type: str) -> VideoFormat:
292
+ if media_type == 'video/x-matroska':
293
+ return 'mkv'
294
+ elif media_type == 'video/quicktime':
295
+ return 'mov'
296
+ elif media_type == 'video/mp4':
297
+ return 'mp4'
298
+ elif media_type == 'video/webm':
299
+ return 'webm'
300
+ elif media_type == 'video/x-flv':
301
+ return 'flv'
302
+ elif media_type == 'video/mpeg':
303
+ return 'mpeg'
304
+ elif media_type == 'video/x-ms-wmv':
305
+ return 'wmv'
306
+ elif media_type == 'video/3gpp':
307
+ return 'three_gp'
308
+ else: # pragma: no cover
309
+ raise ValueError(f'Unknown video media type: {media_type}')
310
+
311
+
232
312
  @dataclass
233
313
  class UserPromptPart:
234
314
  """A user prompt, generally written by the end user.
@@ -315,7 +395,7 @@ class RetryPromptPart:
315
395
  * the model returned plain text when a structured response was expected
316
396
  * Pydantic validation of a structured response failed, here content is derived from a Pydantic
317
397
  [`ValidationError`][pydantic_core.ValidationError]
318
- * a result validator raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception
398
+ * an output validator raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception
319
399
  """
320
400
 
321
401
  content: list[pydantic_core.ErrorDetails] | str
@@ -377,6 +457,9 @@ class ModelRequest:
377
457
  parts: list[ModelRequestPart]
378
458
  """The parts of the user message."""
379
459
 
460
+ instructions: str | None = None
461
+ """The instructions for the model."""
462
+
380
463
  kind: Literal['request'] = 'request'
381
464
  """Message type identifier, this is available on all parts as a discriminator."""
382
465
 
@@ -691,10 +774,10 @@ class PartDeltaEvent:
691
774
 
692
775
  @dataclass
693
776
  class FinalResultEvent:
694
- """An event indicating the response to the current model request matches the result schema."""
777
+ """An event indicating the response to the current model request matches the output schema and will produce a result."""
695
778
 
696
779
  tool_name: str | None
697
- """The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
780
+ """The name of the output tool that was called. `None` if the result is from text content and not from a tool."""
698
781
  tool_call_id: str | None
699
782
  """The tool call ID, if any, that this result is associated with."""
700
783
  event_kind: Literal['final_result'] = 'final_result'
@@ -19,7 +19,7 @@ from typing_extensions import Literal, TypeAliasType
19
19
 
20
20
  from .._parts_manager import ModelResponsePartsManager
21
21
  from ..exceptions import UserError
22
- from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
22
+ from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent
23
23
  from ..settings import ModelSettings
24
24
  from ..usage import Usage
25
25
 
@@ -107,6 +107,7 @@ KnownModelName = TypeAliasType(
107
107
  'google-gla:gemini-2.0-flash-lite-preview-02-05',
108
108
  'google-gla:gemini-2.0-pro-exp-02-05',
109
109
  'google-gla:gemini-2.5-pro-exp-03-25',
110
+ 'google-gla:gemini-2.5-pro-preview-03-25',
110
111
  'google-vertex:gemini-1.0-pro',
111
112
  'google-vertex:gemini-1.5-flash',
112
113
  'google-vertex:gemini-1.5-flash-8b',
@@ -118,6 +119,7 @@ KnownModelName = TypeAliasType(
118
119
  'google-vertex:gemini-2.0-flash-lite-preview-02-05',
119
120
  'google-vertex:gemini-2.0-pro-exp-02-05',
120
121
  'google-vertex:gemini-2.5-pro-exp-03-25',
122
+ 'google-vertex:gemini-2.5-pro-preview-03-25',
121
123
  'gpt-3.5-turbo',
122
124
  'gpt-3.5-turbo-0125',
123
125
  'gpt-3.5-turbo-0301',
@@ -137,6 +139,12 @@ KnownModelName = TypeAliasType(
137
139
  'gpt-4-turbo-2024-04-09',
138
140
  'gpt-4-turbo-preview',
139
141
  'gpt-4-vision-preview',
142
+ 'gpt-4.1',
143
+ 'gpt-4.1-2025-04-14',
144
+ 'gpt-4.1-mini',
145
+ 'gpt-4.1-mini-2025-04-14',
146
+ 'gpt-4.1-nano',
147
+ 'gpt-4.1-nano-2025-04-14',
140
148
  'gpt-4o',
141
149
  'gpt-4o-2024-05-13',
142
150
  'gpt-4o-2024-08-06',
@@ -206,6 +214,12 @@ KnownModelName = TypeAliasType(
206
214
  'openai:gpt-4-turbo-2024-04-09',
207
215
  'openai:gpt-4-turbo-preview',
208
216
  'openai:gpt-4-vision-preview',
217
+ 'openai:gpt-4.1',
218
+ 'openai:gpt-4.1-2025-04-14',
219
+ 'openai:gpt-4.1-mini',
220
+ 'openai:gpt-4.1-mini-2025-04-14',
221
+ 'openai:gpt-4.1-nano',
222
+ 'openai:gpt-4.1-nano-2025-04-14',
209
223
  'openai:gpt-4o',
210
224
  'openai:gpt-4o-2024-05-13',
211
225
  'openai:gpt-4o-2024-08-06',
@@ -240,11 +254,11 @@ KnownModelName = TypeAliasType(
240
254
 
241
255
  @dataclass
242
256
  class ModelRequestParameters:
243
- """Configuration for an agent's request to a model, specifically related to tools and result handling."""
257
+ """Configuration for an agent's request to a model, specifically related to tools and output handling."""
244
258
 
245
259
  function_tools: list[ToolDefinition]
246
- allow_text_result: bool
247
- result_tools: list[ToolDefinition]
260
+ allow_text_output: bool
261
+ output_tools: list[ToolDefinition]
248
262
 
249
263
 
250
264
  class Model(ABC):
@@ -306,6 +320,12 @@ class Model(ABC):
306
320
  """The base URL for the provider API, if available."""
307
321
  return None
308
322
 
323
+ def _get_instructions(self, messages: list[ModelMessage]) -> str | None:
324
+ """Get instructions from the first ModelRequest found when iterating messages in reverse."""
325
+ for message in reversed(messages):
326
+ if isinstance(message, ModelRequest):
327
+ return message.instructions
328
+
309
329
 
310
330
  @dataclass
311
331
  class StreamedResponse(ABC):
@@ -0,0 +1,156 @@
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+ from typing import Any, Literal
6
+
7
+ from pydantic_ai.exceptions import UserError
8
+
9
+ JsonSchema = dict[str, Any]
10
+
11
+
12
+ @dataclass(init=False)
13
+ class WalkJsonSchema(ABC):
14
+ """Walks a JSON schema, applying transformations to it at each level.
15
+
16
+ Note: We may eventually want to rework tools to build the JSON schema from the type directly, using a subclass of
17
+ pydantic.json_schema.GenerateJsonSchema, rather than making use of this machinery.
18
+ """
19
+
20
+ def __init__(
21
+ self, schema: JsonSchema, *, prefer_inlined_defs: bool = False, simplify_nullable_unions: bool = False
22
+ ):
23
+ self.schema = deepcopy(schema)
24
+ self.prefer_inlined_defs = prefer_inlined_defs
25
+ self.simplify_nullable_unions = simplify_nullable_unions
26
+
27
+ self.defs: dict[str, JsonSchema] = self.schema.pop('$defs', {})
28
+ self.refs_stack = tuple[str, ...]()
29
+ self.recursive_refs = set[str]()
30
+
31
+ @abstractmethod
32
+ def transform(self, schema: JsonSchema) -> JsonSchema:
33
+ """Make changes to the schema."""
34
+ return schema
35
+
36
+ def walk(self) -> JsonSchema:
37
+ handled = self._handle(deepcopy(self.schema))
38
+
39
+ if not self.prefer_inlined_defs and self.defs:
40
+ handled['$defs'] = {k: self._handle(v) for k, v in self.defs.items()}
41
+
42
+ elif self.recursive_refs: # pragma: no cover
43
+ # If we are preferring inlined defs and there are recursive refs, we _have_ to use a $defs+$ref structure
44
+ # We try to use whatever the original root key was, but if it is already in use,
45
+ # we modify it to avoid collisions.
46
+ defs = {key: self.defs[key] for key in self.recursive_refs}
47
+ root_ref = self.schema.get('$ref')
48
+ root_key = None if root_ref is None else re.sub(r'^#/\$defs/', '', root_ref)
49
+ if root_key is None:
50
+ root_key = self.schema.get('title', 'root')
51
+ while root_key in defs:
52
+ # Modify the root key until it is not already in use
53
+ root_key = f'{root_key}_root'
54
+
55
+ defs[root_key] = handled
56
+ return {'$defs': defs, '$ref': f'#/$defs/{root_key}'}
57
+
58
+ return handled
59
+
60
+ def _handle(self, schema: JsonSchema) -> JsonSchema:
61
+ if self.prefer_inlined_defs:
62
+ while ref := schema.get('$ref'):
63
+ key = re.sub(r'^#/\$defs/', '', ref)
64
+ if key in self.refs_stack:
65
+ self.recursive_refs.add(key)
66
+ break # recursive ref can't be unpacked
67
+ self.refs_stack += (key,)
68
+ def_schema = self.defs.get(key)
69
+ if def_schema is None: # pragma: no cover
70
+ raise UserError(f'Could not find $ref definition for {key}')
71
+ schema = def_schema
72
+
73
+ # Handle the schema based on its type / structure
74
+ type_ = schema.get('type')
75
+ if type_ == 'object':
76
+ schema = self._handle_object(schema)
77
+ elif type_ == 'array':
78
+ schema = self._handle_array(schema)
79
+ elif type_ is None:
80
+ schema = self._handle_union(schema, 'anyOf')
81
+ schema = self._handle_union(schema, 'oneOf')
82
+
83
+ # Apply the base transform
84
+ schema = self.transform(schema)
85
+
86
+ return schema
87
+
88
+ def _handle_object(self, schema: JsonSchema) -> JsonSchema:
89
+ if properties := schema.get('properties'):
90
+ handled_properties = {}
91
+ for key, value in properties.items():
92
+ handled_properties[key] = self._handle(value)
93
+ schema['properties'] = handled_properties
94
+
95
+ if (additional_properties := schema.get('additionalProperties')) is not None:
96
+ if isinstance(additional_properties, bool):
97
+ schema['additionalProperties'] = additional_properties
98
+ else: # pragma: no cover
99
+ schema['additionalProperties'] = self._handle(additional_properties)
100
+
101
+ if (pattern_properties := schema.get('patternProperties')) is not None:
102
+ handled_pattern_properties = {}
103
+ for key, value in pattern_properties.items():
104
+ handled_pattern_properties[key] = self._handle(value)
105
+ schema['patternProperties'] = handled_pattern_properties
106
+
107
+ return schema
108
+
109
+ def _handle_array(self, schema: JsonSchema) -> JsonSchema:
110
+ if prefix_items := schema.get('prefixItems'):
111
+ schema['prefixItems'] = [self._handle(item) for item in prefix_items]
112
+
113
+ if items := schema.get('items'):
114
+ schema['items'] = self._handle(items)
115
+
116
+ return schema
117
+
118
+ def _handle_union(self, schema: JsonSchema, union_kind: Literal['anyOf', 'oneOf']) -> JsonSchema:
119
+ members = schema.get(union_kind)
120
+ if not members:
121
+ return schema
122
+
123
+ handled = [self._handle(member) for member in members]
124
+
125
+ # convert nullable unions to nullable types
126
+ if self.simplify_nullable_unions:
127
+ handled = self._simplify_nullable_union(handled)
128
+
129
+ if len(handled) == 1:
130
+ # In this case, no need to retain the union
131
+ return handled[0]
132
+
133
+ # If we have keys besides the union kind (such as title or discriminator), keep them without modifications
134
+ schema = schema.copy()
135
+ schema[union_kind] = handled
136
+ return schema
137
+
138
+ @staticmethod
139
+ def _simplify_nullable_union(cases: list[JsonSchema]) -> list[JsonSchema]:
140
+ # TODO: Should we move this to relevant subclasses? Or is it worth keeping here to make reuse easier?
141
+ if len(cases) == 2 and {'type': 'null'} in cases:
142
+ # Find the non-null schema
143
+ non_null_schema = next(
144
+ (item for item in cases if item != {'type': 'null'}),
145
+ None,
146
+ )
147
+ if non_null_schema:
148
+ # Create a new schema based on the non-null part, mark as nullable
149
+ new_schema = deepcopy(non_null_schema)
150
+ new_schema['nullable'] = True
151
+ return [new_schema]
152
+ else: # pragma: no cover
153
+ # they are both null, so just return one of them
154
+ return [cases[0]]
155
+
156
+ return cases # pragma: no cover