langchain-core 1.0.5__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. langchain_core/callbacks/manager.py +14 -14
  2. langchain_core/callbacks/usage.py +1 -1
  3. langchain_core/indexing/api.py +2 -0
  4. langchain_core/language_models/__init__.py +15 -5
  5. langchain_core/language_models/_utils.py +1 -0
  6. langchain_core/language_models/chat_models.py +74 -94
  7. langchain_core/language_models/llms.py +5 -3
  8. langchain_core/language_models/model_profile.py +84 -0
  9. langchain_core/load/load.py +14 -1
  10. langchain_core/messages/ai.py +12 -4
  11. langchain_core/messages/base.py +6 -6
  12. langchain_core/messages/block_translators/anthropic.py +27 -8
  13. langchain_core/messages/block_translators/bedrock_converse.py +18 -8
  14. langchain_core/messages/block_translators/google_genai.py +25 -10
  15. langchain_core/messages/content.py +1 -1
  16. langchain_core/messages/tool.py +28 -27
  17. langchain_core/messages/utils.py +45 -18
  18. langchain_core/output_parsers/openai_tools.py +9 -7
  19. langchain_core/output_parsers/pydantic.py +1 -1
  20. langchain_core/output_parsers/string.py +27 -1
  21. langchain_core/prompts/chat.py +22 -17
  22. langchain_core/prompts/string.py +29 -9
  23. langchain_core/prompts/structured.py +7 -1
  24. langchain_core/runnables/base.py +174 -160
  25. langchain_core/runnables/branch.py +1 -1
  26. langchain_core/runnables/config.py +25 -20
  27. langchain_core/runnables/fallbacks.py +1 -2
  28. langchain_core/runnables/graph.py +3 -2
  29. langchain_core/runnables/graph_mermaid.py +5 -1
  30. langchain_core/runnables/passthrough.py +2 -2
  31. langchain_core/tools/base.py +46 -2
  32. langchain_core/tools/convert.py +16 -0
  33. langchain_core/tools/retriever.py +29 -58
  34. langchain_core/tools/structured.py +14 -0
  35. langchain_core/tracers/event_stream.py +9 -4
  36. langchain_core/utils/aiter.py +3 -1
  37. langchain_core/utils/function_calling.py +7 -2
  38. langchain_core/utils/json_schema.py +29 -21
  39. langchain_core/utils/mustache.py +24 -9
  40. langchain_core/utils/pydantic.py +7 -7
  41. langchain_core/utils/uuid.py +54 -0
  42. langchain_core/vectorstores/base.py +26 -18
  43. langchain_core/version.py +1 -1
  44. {langchain_core-1.0.5.dist-info → langchain_core-1.2.1.dist-info}/METADATA +2 -1
  45. {langchain_core-1.0.5.dist-info → langchain_core-1.2.1.dist-info}/RECORD +46 -44
  46. {langchain_core-1.0.5.dist-info → langchain_core-1.2.1.dist-info}/WHEEL +1 -1
@@ -33,7 +33,7 @@ from langchain_core.runnables.utils import (
33
33
  AddableDict,
34
34
  ConfigurableFieldSpec,
35
35
  )
36
- from langchain_core.utils.aiter import atee, py_anext
36
+ from langchain_core.utils.aiter import atee
37
37
  from langchain_core.utils.iter import safetee
38
38
  from langchain_core.utils.pydantic import create_model_v2
39
39
 
@@ -614,7 +614,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
614
614
  )
615
615
  # start map output stream
616
616
  first_map_chunk_task: asyncio.Task = asyncio.create_task(
617
- py_anext(map_output, None), # type: ignore[arg-type]
617
+ anext(map_output, None),
618
618
  )
619
619
  # consume passthrough stream
620
620
  async for chunk in for_passthrough:
@@ -386,6 +386,8 @@ class ToolException(Exception): # noqa: N818
386
386
 
387
387
  ArgsSchema = TypeBaseModel | dict[str, Any]
388
388
 
389
+ _EMPTY_SET: frozenset[str] = frozenset()
390
+
389
391
 
390
392
  class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
391
393
  """Base class for all LangChain tools.
@@ -494,6 +496,24 @@ class ChildTool(BaseTool):
494
496
  two-tuple corresponding to the `(content, artifact)` of a `ToolMessage`.
495
497
  """
496
498
 
499
+ extras: dict[str, Any] | None = None
500
+ """Optional provider-specific extra fields for the tool.
501
+
502
+ This is used to pass provider-specific configuration that doesn't fit into
503
+ standard tool fields.
504
+
505
+ Example:
506
+ Anthropic-specific fields like [`cache_control`](https://docs.langchain.com/oss/python/integrations/chat/anthropic#prompt-caching),
507
+ [`defer_loading`](https://docs.langchain.com/oss/python/integrations/chat/anthropic#tool-search),
508
+ or `input_examples`.
509
+
510
+ ```python
511
+ @tool(extras={"defer_loading": True, "cache_control": {"type": "ephemeral"}})
512
+ def my_tool(x: str) -> str:
513
+ return x
514
+ ```
515
+ """
516
+
497
517
  def __init__(self, **kwargs: Any) -> None:
498
518
  """Initialize the tool.
499
519
 
@@ -569,6 +589,11 @@ class ChildTool(BaseTool):
569
589
  self.name, full_schema, fields, fn_description=self.description
570
590
  )
571
591
 
592
+ @functools.cached_property
593
+ def _injected_args_keys(self) -> frozenset[str]:
594
+ # base implementation doesn't manage injected args
595
+ return _EMPTY_SET
596
+
572
597
  # --- Runnable ---
573
598
 
574
599
  @override
@@ -649,6 +674,7 @@ class ChildTool(BaseTool):
649
674
  if isinstance(input_args, dict):
650
675
  return tool_input
651
676
  if issubclass(input_args, BaseModel):
677
+ # Check args_schema for InjectedToolCallId
652
678
  for k, v in get_all_basemodel_annotations(input_args).items():
653
679
  if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
654
680
  if tool_call_id is None:
@@ -664,6 +690,7 @@ class ChildTool(BaseTool):
664
690
  result = input_args.model_validate(tool_input)
665
691
  result_dict = result.model_dump()
666
692
  elif issubclass(input_args, BaseModelV1):
693
+ # Check args_schema for InjectedToolCallId
667
694
  for k, v in get_all_basemodel_annotations(input_args).items():
668
695
  if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
669
696
  if tool_call_id is None:
@@ -683,9 +710,24 @@ class ChildTool(BaseTool):
683
710
  f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
684
711
  )
685
712
  raise NotImplementedError(msg)
686
- return {
687
- k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
713
+ validated_input = {
714
+ k: getattr(result, k) for k in result_dict if k in tool_input
688
715
  }
716
+ for k in self._injected_args_keys:
717
+ if k in tool_input:
718
+ validated_input[k] = tool_input[k]
719
+ elif k == "tool_call_id":
720
+ if tool_call_id is None:
721
+ msg = (
722
+ "When tool includes an InjectedToolCallId "
723
+ "argument, tool must always be invoked with a full "
724
+ "model ToolCall of the form: {'args': {...}, "
725
+ "'name': '...', 'type': 'tool_call', "
726
+ "'tool_call_id': '...'}"
727
+ )
728
+ raise ValueError(msg)
729
+ validated_input[k] = tool_call_id
730
+ return validated_input
689
731
  return tool_input
690
732
 
691
733
  @abstractmethod
@@ -853,6 +895,7 @@ class ChildTool(BaseTool):
853
895
  name=run_name,
854
896
  run_id=run_id,
855
897
  inputs=filtered_tool_input,
898
+ tool_call_id=tool_call_id,
856
899
  **kwargs,
857
900
  )
858
901
 
@@ -980,6 +1023,7 @@ class ChildTool(BaseTool):
980
1023
  name=run_name,
981
1024
  run_id=run_id,
982
1025
  inputs=filtered_tool_input,
1026
+ tool_call_id=tool_call_id,
983
1027
  **kwargs,
984
1028
  )
985
1029
  content = None
@@ -23,6 +23,7 @@ def tool(
23
23
  response_format: Literal["content", "content_and_artifact"] = "content",
24
24
  parse_docstring: bool = False,
25
25
  error_on_invalid_docstring: bool = True,
26
+ extras: dict[str, Any] | None = None,
26
27
  ) -> Callable[[Callable | Runnable], BaseTool]: ...
27
28
 
28
29
 
@@ -38,6 +39,7 @@ def tool(
38
39
  response_format: Literal["content", "content_and_artifact"] = "content",
39
40
  parse_docstring: bool = False,
40
41
  error_on_invalid_docstring: bool = True,
42
+ extras: dict[str, Any] | None = None,
41
43
  ) -> BaseTool: ...
42
44
 
43
45
 
@@ -52,6 +54,7 @@ def tool(
52
54
  response_format: Literal["content", "content_and_artifact"] = "content",
53
55
  parse_docstring: bool = False,
54
56
  error_on_invalid_docstring: bool = True,
57
+ extras: dict[str, Any] | None = None,
55
58
  ) -> BaseTool: ...
56
59
 
57
60
 
@@ -66,6 +69,7 @@ def tool(
66
69
  response_format: Literal["content", "content_and_artifact"] = "content",
67
70
  parse_docstring: bool = False,
68
71
  error_on_invalid_docstring: bool = True,
72
+ extras: dict[str, Any] | None = None,
69
73
  ) -> Callable[[Callable | Runnable], BaseTool]: ...
70
74
 
71
75
 
@@ -80,6 +84,7 @@ def tool(
80
84
  response_format: Literal["content", "content_and_artifact"] = "content",
81
85
  parse_docstring: bool = False,
82
86
  error_on_invalid_docstring: bool = True,
87
+ extras: dict[str, Any] | None = None,
83
88
  ) -> BaseTool | Callable[[Callable | Runnable], BaseTool]:
84
89
  """Convert Python functions and `Runnables` to LangChain tools.
85
90
 
@@ -130,6 +135,15 @@ def tool(
130
135
  parse parameter descriptions from Google Style function docstrings.
131
136
  error_on_invalid_docstring: If `parse_docstring` is provided, configure
132
137
  whether to raise `ValueError` on invalid Google Style docstrings.
138
+ extras: Optional provider-specific extra fields for the tool.
139
+
140
+ Used to pass configuration that doesn't fit into standard tool fields.
141
+ Chat models should process known extras when constructing model payloads.
142
+
143
+ !!! example
144
+
145
+ For example, Anthropic-specific fields like `cache_control`,
146
+ `defer_loading`, or `input_examples`.
133
147
 
134
148
  Raises:
135
149
  ValueError: If too many positional arguments are provided (e.g. violating the
@@ -292,6 +306,7 @@ def tool(
292
306
  response_format=response_format,
293
307
  parse_docstring=parse_docstring,
294
308
  error_on_invalid_docstring=error_on_invalid_docstring,
309
+ extras=extras,
295
310
  )
296
311
  # If someone doesn't want a schema applied, we must treat it as
297
312
  # a simple string->string function
@@ -308,6 +323,7 @@ def tool(
308
323
  return_direct=return_direct,
309
324
  coroutine=coroutine,
310
325
  response_format=response_format,
326
+ extras=extras,
311
327
  )
312
328
 
313
329
  return _tool_factory
@@ -2,22 +2,21 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from functools import partial
6
5
  from typing import TYPE_CHECKING, Literal
7
6
 
8
7
  from pydantic import BaseModel, Field
9
8
 
9
+ from langchain_core.callbacks import Callbacks
10
+ from langchain_core.documents import Document
10
11
  from langchain_core.prompts import (
11
12
  BasePromptTemplate,
12
13
  PromptTemplate,
13
14
  aformat_document,
14
15
  format_document,
15
16
  )
16
- from langchain_core.tools.simple import Tool
17
+ from langchain_core.tools.structured import StructuredTool
17
18
 
18
19
  if TYPE_CHECKING:
19
- from langchain_core.callbacks import Callbacks
20
- from langchain_core.documents import Document
21
20
  from langchain_core.retrievers import BaseRetriever
22
21
 
23
22
 
@@ -27,43 +26,6 @@ class RetrieverInput(BaseModel):
27
26
  query: str = Field(description="query to look up in retriever")
28
27
 
29
28
 
30
- def _get_relevant_documents(
31
- query: str,
32
- retriever: BaseRetriever,
33
- document_prompt: BasePromptTemplate,
34
- document_separator: str,
35
- callbacks: Callbacks = None,
36
- response_format: Literal["content", "content_and_artifact"] = "content",
37
- ) -> str | tuple[str, list[Document]]:
38
- docs = retriever.invoke(query, config={"callbacks": callbacks})
39
- content = document_separator.join(
40
- format_document(doc, document_prompt) for doc in docs
41
- )
42
- if response_format == "content_and_artifact":
43
- return (content, docs)
44
-
45
- return content
46
-
47
-
48
- async def _aget_relevant_documents(
49
- query: str,
50
- retriever: BaseRetriever,
51
- document_prompt: BasePromptTemplate,
52
- document_separator: str,
53
- callbacks: Callbacks = None,
54
- response_format: Literal["content", "content_and_artifact"] = "content",
55
- ) -> str | tuple[str, list[Document]]:
56
- docs = await retriever.ainvoke(query, config={"callbacks": callbacks})
57
- content = document_separator.join(
58
- [await aformat_document(doc, document_prompt) for doc in docs]
59
- )
60
-
61
- if response_format == "content_and_artifact":
62
- return (content, docs)
63
-
64
- return content
65
-
66
-
67
29
  def create_retriever_tool(
68
30
  retriever: BaseRetriever,
69
31
  name: str,
@@ -72,7 +34,7 @@ def create_retriever_tool(
72
34
  document_prompt: BasePromptTemplate | None = None,
73
35
  document_separator: str = "\n\n",
74
36
  response_format: Literal["content", "content_and_artifact"] = "content",
75
- ) -> Tool:
37
+ ) -> StructuredTool:
76
38
  r"""Create a tool to do retrieval of documents.
77
39
 
78
40
  Args:
@@ -93,22 +55,31 @@ def create_retriever_tool(
93
55
  Returns:
94
56
  Tool class to pass to an agent.
95
57
  """
96
- document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
97
- func = partial(
98
- _get_relevant_documents,
99
- retriever=retriever,
100
- document_prompt=document_prompt,
101
- document_separator=document_separator,
102
- response_format=response_format,
103
- )
104
- afunc = partial(
105
- _aget_relevant_documents,
106
- retriever=retriever,
107
- document_prompt=document_prompt,
108
- document_separator=document_separator,
109
- response_format=response_format,
110
- )
111
- return Tool(
58
+ document_prompt_ = document_prompt or PromptTemplate.from_template("{page_content}")
59
+
60
+ def func(
61
+ query: str, callbacks: Callbacks = None
62
+ ) -> str | tuple[str, list[Document]]:
63
+ docs = retriever.invoke(query, config={"callbacks": callbacks})
64
+ content = document_separator.join(
65
+ format_document(doc, document_prompt_) for doc in docs
66
+ )
67
+ if response_format == "content_and_artifact":
68
+ return (content, docs)
69
+ return content
70
+
71
+ async def afunc(
72
+ query: str, callbacks: Callbacks = None
73
+ ) -> str | tuple[str, list[Document]]:
74
+ docs = await retriever.ainvoke(query, config={"callbacks": callbacks})
75
+ content = document_separator.join(
76
+ [await aformat_document(doc, document_prompt_) for doc in docs]
77
+ )
78
+ if response_format == "content_and_artifact":
79
+ return (content, docs)
80
+ return content
81
+
82
+ return StructuredTool(
112
83
  name=name,
113
84
  description=description,
114
85
  func=func,
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import functools
5
6
  import textwrap
6
7
  from collections.abc import Awaitable, Callable
7
8
  from inspect import signature
@@ -21,10 +22,12 @@ from langchain_core.callbacks import (
21
22
  )
22
23
  from langchain_core.runnables import RunnableConfig, run_in_executor
23
24
  from langchain_core.tools.base import (
25
+ _EMPTY_SET,
24
26
  FILTERED_ARGS,
25
27
  ArgsSchema,
26
28
  BaseTool,
27
29
  _get_runnable_config_param,
30
+ _is_injected_arg_type,
28
31
  create_schema_from_function,
29
32
  )
30
33
  from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -241,6 +244,17 @@ class StructuredTool(BaseTool):
241
244
  **kwargs,
242
245
  )
243
246
 
247
+ @functools.cached_property
248
+ def _injected_args_keys(self) -> frozenset[str]:
249
+ fn = self.func or self.coroutine
250
+ if fn is None:
251
+ return _EMPTY_SET
252
+ return frozenset(
253
+ k
254
+ for k, v in signature(fn).parameters.items()
255
+ if _is_injected_arg_type(v.annotation)
256
+ )
257
+
244
258
 
245
259
  def _filter_schema_args(func: Callable) -> list[str]:
246
260
  filter_args = list(FILTERED_ARGS)
@@ -12,7 +12,7 @@ from typing import (
12
12
  TypeVar,
13
13
  cast,
14
14
  )
15
- from uuid import UUID, uuid4
15
+ from uuid import UUID
16
16
 
17
17
  from typing_extensions import NotRequired, override
18
18
 
@@ -42,7 +42,8 @@ from langchain_core.tracers.log_stream import (
42
42
  _astream_log_implementation,
43
43
  )
44
44
  from langchain_core.tracers.memory_stream import _MemoryStream
45
- from langchain_core.utils.aiter import aclosing, py_anext
45
+ from langchain_core.utils.aiter import aclosing
46
+ from langchain_core.utils.uuid import uuid7
46
47
 
47
48
  if TYPE_CHECKING:
48
49
  from collections.abc import AsyncIterator, Iterator, Sequence
@@ -188,7 +189,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
188
189
  # atomic check and set
189
190
  tap = self.is_tapped.setdefault(run_id, sentinel)
190
191
  # wait for first chunk
191
- first = await py_anext(output, default=sentinel)
192
+ first = await anext(output, sentinel)
192
193
  if first is sentinel:
193
194
  return
194
195
  # get run info
@@ -1006,7 +1007,11 @@ async def _astream_events_implementation_v2(
1006
1007
 
1007
1008
  # Assign the stream handler to the config
1008
1009
  config = ensure_config(config)
1009
- run_id = cast("UUID", config.setdefault("run_id", uuid4()))
1010
+ if "run_id" in config:
1011
+ run_id = cast("UUID", config["run_id"])
1012
+ else:
1013
+ run_id = uuid7()
1014
+ config["run_id"] = run_id
1010
1015
  callbacks = config.get("callbacks")
1011
1016
  if callbacks is None:
1012
1017
  config["callbacks"] = [event_streamer]
@@ -26,13 +26,15 @@ from typing import (
26
26
 
27
27
  from typing_extensions import override
28
28
 
29
+ from langchain_core._api.deprecation import deprecated
30
+
29
31
  T = TypeVar("T")
30
32
 
31
33
  _no_default = object()
32
34
 
33
35
 
34
36
  # https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
35
- # before 3.10, the builtin anext() was not available
37
+ @deprecated(since="1.1.2", removal="2.0.0")
36
38
  def py_anext(
37
39
  iterator: AsyncIterator[T], default: T | Any = _no_default
38
40
  ) -> Awaitable[T | Any | None]:
@@ -8,6 +8,7 @@ import logging
8
8
  import types
9
9
  import typing
10
10
  import uuid
11
+ from collections.abc import Mapping
11
12
  from typing import (
12
13
  TYPE_CHECKING,
13
14
  Annotated,
@@ -327,7 +328,7 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
327
328
 
328
329
 
329
330
  def convert_to_openai_function(
330
- function: dict[str, Any] | type | Callable | BaseTool,
331
+ function: Mapping[str, Any] | type | Callable | BaseTool,
331
332
  *,
332
333
  strict: bool | None = None,
333
334
  ) -> dict[str, Any]:
@@ -353,6 +354,7 @@ def convert_to_openai_function(
353
354
  ValueError: If function is not in a supported format.
354
355
 
355
356
  !!! warning "Behavior changed in `langchain-core` 0.3.16"
357
+
356
358
  `description` and `parameters` keys are now optional. Only `name` is
357
359
  required and guaranteed to be part of the output.
358
360
  """
@@ -453,7 +455,7 @@ _WellKnownOpenAITools = (
453
455
 
454
456
 
455
457
  def convert_to_openai_tool(
456
- tool: dict[str, Any] | type[BaseModel] | Callable | BaseTool,
458
+ tool: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool,
457
459
  *,
458
460
  strict: bool | None = None,
459
461
  ) -> dict[str, Any]:
@@ -477,15 +479,18 @@ def convert_to_openai_tool(
477
479
  OpenAI tool-calling API.
478
480
 
479
481
  !!! warning "Behavior changed in `langchain-core` 0.3.16"
482
+
480
483
  `description` and `parameters` keys are now optional. Only `name` is
481
484
  required and guaranteed to be part of the output.
482
485
 
483
486
  !!! warning "Behavior changed in `langchain-core` 0.3.44"
487
+
484
488
  Return OpenAI Responses API-style tools unchanged. This includes
485
489
  any dict with `"type"` in `"file_search"`, `"function"`,
486
490
  `"computer_use_preview"`, `"web_search_preview"`.
487
491
 
488
492
  !!! warning "Behavior changed in `langchain-core` 0.3.63"
493
+
489
494
  Added support for OpenAI's image generation built-in tool.
490
495
  """
491
496
  # Import locally to prevent circular import
@@ -170,28 +170,33 @@ def dereference_refs(
170
170
  full_schema: dict | None = None,
171
171
  skip_keys: Sequence[str] | None = None,
172
172
  ) -> dict:
173
- """Resolve and inline JSON Schema $ref references in a schema object.
173
+ """Resolve and inline JSON Schema `$ref` references in a schema object.
174
174
 
175
- This function processes a JSON Schema and resolves all $ref references by replacing
176
- them with the actual referenced content. It handles both simple references and
177
- complex cases like circular references and mixed $ref objects that contain
178
- additional properties alongside the $ref.
175
+ This function processes a JSON Schema and resolves all `$ref` references by
176
+ replacing them with the actual referenced content.
177
+
178
+ Handles both simple references and complex cases like circular references and mixed
179
+ `$ref` objects that contain additional properties alongside the `$ref`.
179
180
 
180
181
  Args:
181
- schema_obj: The JSON Schema object or fragment to process. This can be a
182
- complete schema or just a portion of one.
183
- full_schema: The complete schema containing all definitions that $refs might
184
- point to. If not provided, defaults to schema_obj (useful when the
185
- schema is self-contained).
186
- skip_keys: Controls recursion behavior and reference resolution depth:
187
- - If `None` (Default): Only recurse under '$defs' and use shallow reference
188
- resolution (break cycles but don't deep-inline nested refs)
189
- - If provided (even as []): Recurse under all keys and use deep reference
190
- resolution (fully inline all nested references)
182
+ schema_obj: The JSON Schema object or fragment to process.
183
+
184
+ This can be a complete schema or just a portion of one.
185
+ full_schema: The complete schema containing all definitions that `$refs` might
186
+ point to.
187
+
188
+ If not provided, defaults to `schema_obj` (useful when the schema is
189
+ self-contained).
190
+ skip_keys: Controls recursion behavior and reference resolution depth.
191
+
192
+ - If `None` (Default): Only recurse under `'$defs'` and use shallow
193
+ reference resolution (break cycles but don't deep-inline nested refs)
194
+ - If provided (even as `[]`): Recurse under all keys and use deep reference
195
+ resolution (fully inline all nested references)
191
196
 
192
197
  Returns:
193
- A new dictionary with all $ref references resolved and inlined. The original
194
- schema_obj is not modified.
198
+ A new dictionary with all $ref references resolved and inlined.
199
+ The original `schema_obj` is not modified.
195
200
 
196
201
  Examples:
197
202
  Basic reference resolution:
@@ -203,7 +208,8 @@ def dereference_refs(
203
208
  >>> result = dereference_refs(schema)
204
209
  >>> result["properties"]["name"] # {"type": "string"}
205
210
 
206
- Mixed $ref with additional properties:
211
+ Mixed `$ref` with additional properties:
212
+
207
213
  >>> schema = {
208
214
  ... "properties": {
209
215
  ... "name": {"$ref": "#/$defs/base", "description": "User name"}
@@ -215,6 +221,7 @@ def dereference_refs(
215
221
  # {"type": "string", "minLength": 1, "description": "User name"}
216
222
 
217
223
  Handling circular references:
224
+
218
225
  >>> schema = {
219
226
  ... "properties": {"user": {"$ref": "#/$defs/User"}},
220
227
  ... "$defs": {
@@ -227,10 +234,11 @@ def dereference_refs(
227
234
  >>> result = dereference_refs(schema) # Won't cause infinite recursion
228
235
 
229
236
  !!! note
237
+
230
238
  - Circular references are handled gracefully by breaking cycles
231
- - Mixed $ref objects (with both $ref and other properties) are supported
232
- - Additional properties in mixed $refs override resolved properties
233
- - The $defs section is preserved in the output by default
239
+ - Mixed `$ref` objects (with both `$ref` and other properties) are supported
240
+ - Additional properties in mixed `$refs` override resolved properties
241
+ - The `$defs` section is preserved in the output by default
234
242
  """
235
243
  full = full_schema or schema_obj
236
244
  keys_to_skip = list(skip_keys) if skip_keys is not None else ["$defs"]
@@ -374,15 +374,29 @@ def _get_key(
374
374
  if resolved_scope in (0, False):
375
375
  return resolved_scope
376
376
  # Move into the scope
377
- try:
378
- # Try subscripting (Normal dictionaries)
379
- resolved_scope = cast("dict[str, Any]", resolved_scope)[child]
380
- except (TypeError, AttributeError):
377
+ if isinstance(resolved_scope, dict):
381
378
  try:
382
- resolved_scope = getattr(resolved_scope, child)
383
- except (TypeError, AttributeError):
384
- # Try as a list
385
- resolved_scope = resolved_scope[int(child)] # type: ignore[index]
379
+ resolved_scope = resolved_scope[child]
380
+ except (KeyError, TypeError):
381
+ # Key not found - will be caught by outer try-except
382
+ msg = f"Key {child!r} not found in dict"
383
+ raise KeyError(msg) from None
384
+ elif isinstance(resolved_scope, (list, tuple)):
385
+ try:
386
+ resolved_scope = resolved_scope[int(child)]
387
+ except (ValueError, IndexError, TypeError):
388
+ # Invalid index - will be caught by outer try-except
389
+ msg = f"Invalid index {child!r} for list/tuple"
390
+ raise IndexError(msg) from None
391
+ else:
392
+ # Reject everything else for security
393
+ # This prevents traversing into arbitrary Python objects
394
+ msg = (
395
+ f"Cannot traverse into {type(resolved_scope).__name__}. "
396
+ "Mustache templates only support dict, list, and tuple. "
397
+ f"Got: {type(resolved_scope)}"
398
+ )
399
+ raise TypeError(msg) # noqa: TRY301
386
400
 
387
401
  try:
388
402
  # This allows for custom falsy data types
@@ -393,8 +407,9 @@ def _get_key(
393
407
  if resolved_scope in (0, False):
394
408
  return resolved_scope
395
409
  return resolved_scope or ""
396
- except (AttributeError, KeyError, IndexError, ValueError):
410
+ except (AttributeError, KeyError, IndexError, ValueError, TypeError):
397
411
  # We couldn't find the key in the current scope
412
+ # TypeError: Attempted to traverse into non-dict/list type
398
413
  # We'll try again on the next pass
399
414
  pass
400
415
 
@@ -88,18 +88,18 @@ def is_pydantic_v2_subclass(cls: type) -> bool:
88
88
  """Check if the given class is Pydantic v2-like.
89
89
 
90
90
  Returns:
91
- `True` if the given class is a subclass of Pydantic BaseModel 2.x.
91
+ `True` if the given class is a subclass of Pydantic `BaseModel` 2.x.
92
92
  """
93
93
  return issubclass(cls, BaseModel)
94
94
 
95
95
 
96
96
  def is_basemodel_subclass(cls: type) -> bool:
97
- """Check if the given class is a subclass of Pydantic BaseModel.
97
+ """Check if the given class is a subclass of Pydantic `BaseModel`.
98
98
 
99
99
  Check if the given class is a subclass of any of the following:
100
100
 
101
- * pydantic.BaseModel in Pydantic 2.x
102
- * pydantic.v1.BaseModel in Pydantic 2.x
101
+ * `pydantic.BaseModel` in Pydantic 2.x
102
+ * `pydantic.v1.BaseModel` in Pydantic 2.x
103
103
 
104
104
  Returns:
105
105
  `True` if the given class is a subclass of Pydantic `BaseModel`.
@@ -112,12 +112,12 @@ def is_basemodel_subclass(cls: type) -> bool:
112
112
 
113
113
 
114
114
  def is_basemodel_instance(obj: Any) -> bool:
115
- """Check if the given class is an instance of Pydantic BaseModel.
115
+ """Check if the given class is an instance of Pydantic `BaseModel`.
116
116
 
117
117
  Check if the given class is an instance of any of the following:
118
118
 
119
- * pydantic.BaseModel in Pydantic 2.x
120
- * pydantic.v1.BaseModel in Pydantic 2.x
119
+ * `pydantic.BaseModel` in Pydantic 2.x
120
+ * `pydantic.v1.BaseModel` in Pydantic 2.x
121
121
 
122
122
  Returns:
123
123
  `True` if the given class is an instance of Pydantic `BaseModel`.