langchain-core 0.4.0.dev0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-core might be problematic. Click here for more details.

Files changed (172) hide show
  1. langchain_core/__init__.py +1 -1
  2. langchain_core/_api/__init__.py +3 -4
  3. langchain_core/_api/beta_decorator.py +45 -70
  4. langchain_core/_api/deprecation.py +80 -80
  5. langchain_core/_api/path.py +22 -8
  6. langchain_core/_import_utils.py +10 -4
  7. langchain_core/agents.py +25 -21
  8. langchain_core/caches.py +53 -63
  9. langchain_core/callbacks/__init__.py +1 -8
  10. langchain_core/callbacks/base.py +341 -348
  11. langchain_core/callbacks/file.py +55 -44
  12. langchain_core/callbacks/manager.py +546 -683
  13. langchain_core/callbacks/stdout.py +29 -30
  14. langchain_core/callbacks/streaming_stdout.py +35 -36
  15. langchain_core/callbacks/usage.py +65 -70
  16. langchain_core/chat_history.py +48 -55
  17. langchain_core/document_loaders/base.py +46 -21
  18. langchain_core/document_loaders/langsmith.py +39 -36
  19. langchain_core/documents/__init__.py +0 -1
  20. langchain_core/documents/base.py +96 -74
  21. langchain_core/documents/compressor.py +12 -9
  22. langchain_core/documents/transformers.py +29 -28
  23. langchain_core/embeddings/fake.py +56 -57
  24. langchain_core/env.py +2 -3
  25. langchain_core/example_selectors/base.py +12 -0
  26. langchain_core/example_selectors/length_based.py +1 -1
  27. langchain_core/example_selectors/semantic_similarity.py +21 -25
  28. langchain_core/exceptions.py +15 -9
  29. langchain_core/globals.py +4 -163
  30. langchain_core/indexing/api.py +132 -125
  31. langchain_core/indexing/base.py +64 -67
  32. langchain_core/indexing/in_memory.py +26 -6
  33. langchain_core/language_models/__init__.py +15 -27
  34. langchain_core/language_models/_utils.py +267 -117
  35. langchain_core/language_models/base.py +92 -177
  36. langchain_core/language_models/chat_models.py +547 -407
  37. langchain_core/language_models/fake.py +11 -11
  38. langchain_core/language_models/fake_chat_models.py +72 -118
  39. langchain_core/language_models/llms.py +168 -242
  40. langchain_core/load/dump.py +8 -11
  41. langchain_core/load/load.py +32 -28
  42. langchain_core/load/mapping.py +2 -4
  43. langchain_core/load/serializable.py +50 -56
  44. langchain_core/messages/__init__.py +36 -51
  45. langchain_core/messages/ai.py +377 -150
  46. langchain_core/messages/base.py +239 -47
  47. langchain_core/messages/block_translators/__init__.py +111 -0
  48. langchain_core/messages/block_translators/anthropic.py +470 -0
  49. langchain_core/messages/block_translators/bedrock.py +94 -0
  50. langchain_core/messages/block_translators/bedrock_converse.py +297 -0
  51. langchain_core/messages/block_translators/google_genai.py +530 -0
  52. langchain_core/messages/block_translators/google_vertexai.py +21 -0
  53. langchain_core/messages/block_translators/groq.py +143 -0
  54. langchain_core/messages/block_translators/langchain_v0.py +301 -0
  55. langchain_core/messages/block_translators/openai.py +1010 -0
  56. langchain_core/messages/chat.py +2 -3
  57. langchain_core/messages/content.py +1423 -0
  58. langchain_core/messages/function.py +7 -7
  59. langchain_core/messages/human.py +44 -38
  60. langchain_core/messages/modifier.py +3 -2
  61. langchain_core/messages/system.py +40 -27
  62. langchain_core/messages/tool.py +160 -58
  63. langchain_core/messages/utils.py +527 -638
  64. langchain_core/output_parsers/__init__.py +1 -14
  65. langchain_core/output_parsers/base.py +68 -104
  66. langchain_core/output_parsers/json.py +13 -17
  67. langchain_core/output_parsers/list.py +11 -33
  68. langchain_core/output_parsers/openai_functions.py +56 -74
  69. langchain_core/output_parsers/openai_tools.py +68 -109
  70. langchain_core/output_parsers/pydantic.py +15 -13
  71. langchain_core/output_parsers/string.py +6 -2
  72. langchain_core/output_parsers/transform.py +17 -60
  73. langchain_core/output_parsers/xml.py +34 -44
  74. langchain_core/outputs/__init__.py +1 -1
  75. langchain_core/outputs/chat_generation.py +26 -11
  76. langchain_core/outputs/chat_result.py +1 -3
  77. langchain_core/outputs/generation.py +17 -6
  78. langchain_core/outputs/llm_result.py +15 -8
  79. langchain_core/prompt_values.py +29 -123
  80. langchain_core/prompts/__init__.py +3 -27
  81. langchain_core/prompts/base.py +48 -63
  82. langchain_core/prompts/chat.py +259 -288
  83. langchain_core/prompts/dict.py +19 -11
  84. langchain_core/prompts/few_shot.py +84 -90
  85. langchain_core/prompts/few_shot_with_templates.py +14 -12
  86. langchain_core/prompts/image.py +19 -14
  87. langchain_core/prompts/loading.py +6 -8
  88. langchain_core/prompts/message.py +7 -8
  89. langchain_core/prompts/prompt.py +42 -43
  90. langchain_core/prompts/string.py +37 -16
  91. langchain_core/prompts/structured.py +43 -46
  92. langchain_core/rate_limiters.py +51 -60
  93. langchain_core/retrievers.py +52 -192
  94. langchain_core/runnables/base.py +1727 -1683
  95. langchain_core/runnables/branch.py +52 -73
  96. langchain_core/runnables/config.py +89 -103
  97. langchain_core/runnables/configurable.py +128 -130
  98. langchain_core/runnables/fallbacks.py +93 -82
  99. langchain_core/runnables/graph.py +127 -127
  100. langchain_core/runnables/graph_ascii.py +63 -41
  101. langchain_core/runnables/graph_mermaid.py +87 -70
  102. langchain_core/runnables/graph_png.py +31 -36
  103. langchain_core/runnables/history.py +145 -161
  104. langchain_core/runnables/passthrough.py +141 -144
  105. langchain_core/runnables/retry.py +84 -68
  106. langchain_core/runnables/router.py +33 -37
  107. langchain_core/runnables/schema.py +79 -72
  108. langchain_core/runnables/utils.py +95 -139
  109. langchain_core/stores.py +85 -131
  110. langchain_core/structured_query.py +11 -15
  111. langchain_core/sys_info.py +31 -32
  112. langchain_core/tools/__init__.py +1 -14
  113. langchain_core/tools/base.py +221 -247
  114. langchain_core/tools/convert.py +144 -161
  115. langchain_core/tools/render.py +10 -10
  116. langchain_core/tools/retriever.py +12 -19
  117. langchain_core/tools/simple.py +52 -29
  118. langchain_core/tools/structured.py +56 -60
  119. langchain_core/tracers/__init__.py +1 -9
  120. langchain_core/tracers/_streaming.py +6 -7
  121. langchain_core/tracers/base.py +103 -112
  122. langchain_core/tracers/context.py +29 -48
  123. langchain_core/tracers/core.py +142 -105
  124. langchain_core/tracers/evaluation.py +30 -34
  125. langchain_core/tracers/event_stream.py +162 -117
  126. langchain_core/tracers/langchain.py +34 -36
  127. langchain_core/tracers/log_stream.py +87 -49
  128. langchain_core/tracers/memory_stream.py +3 -3
  129. langchain_core/tracers/root_listeners.py +18 -34
  130. langchain_core/tracers/run_collector.py +8 -20
  131. langchain_core/tracers/schemas.py +0 -125
  132. langchain_core/tracers/stdout.py +3 -3
  133. langchain_core/utils/__init__.py +1 -4
  134. langchain_core/utils/_merge.py +47 -9
  135. langchain_core/utils/aiter.py +70 -66
  136. langchain_core/utils/env.py +12 -9
  137. langchain_core/utils/function_calling.py +139 -206
  138. langchain_core/utils/html.py +7 -8
  139. langchain_core/utils/input.py +6 -6
  140. langchain_core/utils/interactive_env.py +6 -2
  141. langchain_core/utils/iter.py +48 -45
  142. langchain_core/utils/json.py +14 -4
  143. langchain_core/utils/json_schema.py +159 -43
  144. langchain_core/utils/mustache.py +32 -25
  145. langchain_core/utils/pydantic.py +67 -40
  146. langchain_core/utils/strings.py +5 -5
  147. langchain_core/utils/usage.py +1 -1
  148. langchain_core/utils/utils.py +104 -62
  149. langchain_core/vectorstores/base.py +131 -179
  150. langchain_core/vectorstores/in_memory.py +113 -182
  151. langchain_core/vectorstores/utils.py +23 -17
  152. langchain_core/version.py +1 -1
  153. langchain_core-1.0.0.dist-info/METADATA +68 -0
  154. langchain_core-1.0.0.dist-info/RECORD +172 -0
  155. {langchain_core-0.4.0.dev0.dist-info → langchain_core-1.0.0.dist-info}/WHEEL +1 -1
  156. langchain_core/beta/__init__.py +0 -1
  157. langchain_core/beta/runnables/__init__.py +0 -1
  158. langchain_core/beta/runnables/context.py +0 -448
  159. langchain_core/memory.py +0 -116
  160. langchain_core/messages/content_blocks.py +0 -1435
  161. langchain_core/prompts/pipeline.py +0 -133
  162. langchain_core/pydantic_v1/__init__.py +0 -30
  163. langchain_core/pydantic_v1/dataclasses.py +0 -23
  164. langchain_core/pydantic_v1/main.py +0 -23
  165. langchain_core/tracers/langchain_v1.py +0 -23
  166. langchain_core/utils/loading.py +0 -31
  167. langchain_core/v1/__init__.py +0 -1
  168. langchain_core/v1/chat_models.py +0 -1047
  169. langchain_core/v1/messages.py +0 -755
  170. langchain_core-0.4.0.dev0.dist-info/METADATA +0 -108
  171. langchain_core-0.4.0.dev0.dist-info/RECORD +0 -177
  172. langchain_core-0.4.0.dev0.dist-info/entry_points.txt +0 -4
@@ -8,16 +8,14 @@ import json
8
8
  import typing
9
9
  import warnings
10
10
  from abc import ABC, abstractmethod
11
+ from collections.abc import Callable
11
12
  from inspect import signature
12
13
  from typing import (
13
14
  TYPE_CHECKING,
14
15
  Annotated,
15
16
  Any,
16
- Callable,
17
17
  Literal,
18
- Optional,
19
18
  TypeVar,
20
- Union,
21
19
  cast,
22
20
  get_args,
23
21
  get_origin,
@@ -31,7 +29,6 @@ from pydantic import (
31
29
  PydanticDeprecationWarning,
32
30
  SkipValidation,
33
31
  ValidationError,
34
- model_validator,
35
32
  validate_arguments,
36
33
  )
37
34
  from pydantic.v1 import BaseModel as BaseModelV1
@@ -39,10 +36,8 @@ from pydantic.v1 import ValidationError as ValidationErrorV1
39
36
  from pydantic.v1 import validate_arguments as validate_arguments_v1
40
37
  from typing_extensions import override
41
38
 
42
- from langchain_core._api import deprecated
43
39
  from langchain_core.callbacks import (
44
40
  AsyncCallbackManager,
45
- BaseCallbackManager,
46
41
  CallbackManager,
47
42
  Callbacks,
48
43
  )
@@ -68,14 +63,22 @@ from langchain_core.utils.pydantic import (
68
63
  is_pydantic_v1_subclass,
69
64
  is_pydantic_v2_subclass,
70
65
  )
71
- from langchain_core.v1.messages import ToolMessage as ToolMessageV1
72
66
 
73
67
  if TYPE_CHECKING:
74
68
  import uuid
75
69
  from collections.abc import Sequence
76
70
 
77
71
  FILTERED_ARGS = ("run_manager", "callbacks")
78
- TOOL_MESSAGE_BLOCK_TYPES = ("text", "image_url", "image", "json", "search_result")
72
+ TOOL_MESSAGE_BLOCK_TYPES = (
73
+ "text",
74
+ "image_url",
75
+ "image",
76
+ "json",
77
+ "search_result",
78
+ "custom_tool_call_output",
79
+ "document",
80
+ "file",
81
+ )
79
82
 
80
83
 
81
84
  class SchemaAnnotationError(TypeError):
@@ -89,7 +92,7 @@ def _is_annotated_type(typ: type[Any]) -> bool:
89
92
  typ: The type to check.
90
93
 
91
94
  Returns:
92
- True if the type is an Annotated type, False otherwise.
95
+ `True` if the type is an Annotated type, `False` otherwise.
93
96
  """
94
97
  return get_origin(typ) is typing.Annotated
95
98
 
@@ -223,7 +226,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
223
226
  pydantic_version: The Pydantic version to check against ("v1" or "v2").
224
227
 
225
228
  Returns:
226
- True if the annotation is a Pydantic model, False otherwise.
229
+ `True` if the annotation is a Pydantic model, `False` otherwise.
227
230
  """
228
231
  base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
229
232
  try:
@@ -242,7 +245,7 @@ def _function_annotations_are_pydantic_v1(
242
245
  func: The function being checked.
243
246
 
244
247
  Returns:
245
- True if all Pydantic annotations are from V1, False otherwise.
248
+ True if all Pydantic annotations are from V1, `False` otherwise.
246
249
 
247
250
  Raises:
248
251
  NotImplementedError: If the function contains mixed V1 and V2 annotations.
@@ -265,44 +268,40 @@ def _function_annotations_are_pydantic_v1(
265
268
 
266
269
 
267
270
  class _SchemaConfig:
268
- """Configuration for Pydantic models generated from function signatures.
269
-
270
- Attributes:
271
- extra: Whether to allow extra fields in the model.
272
- arbitrary_types_allowed: Whether to allow arbitrary types in the model.
273
- """
271
+ """Configuration for Pydantic models generated from function signatures."""
274
272
 
275
273
  extra: str = "forbid"
274
+ """Whether to allow extra fields in the model."""
276
275
  arbitrary_types_allowed: bool = True
276
+ """Whether to allow arbitrary types in the model."""
277
277
 
278
278
 
279
279
  def create_schema_from_function(
280
280
  model_name: str,
281
281
  func: Callable,
282
282
  *,
283
- filter_args: Optional[Sequence[str]] = None,
283
+ filter_args: Sequence[str] | None = None,
284
284
  parse_docstring: bool = False,
285
285
  error_on_invalid_docstring: bool = False,
286
286
  include_injected: bool = True,
287
287
  ) -> type[BaseModel]:
288
- """Create a pydantic schema from a function's signature.
288
+ """Create a Pydantic schema from a function's signature.
289
289
 
290
290
  Args:
291
- model_name: Name to assign to the generated pydantic schema.
291
+ model_name: Name to assign to the generated Pydantic schema.
292
292
  func: Function to generate the schema from.
293
293
  filter_args: Optional list of arguments to exclude from the schema.
294
- Defaults to FILTERED_ARGS.
294
+ Defaults to `FILTERED_ARGS`.
295
295
  parse_docstring: Whether to parse the function's docstring for descriptions
296
- for each argument. Defaults to False.
297
- error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
298
- whether to raise ValueError on invalid Google Style docstrings.
299
- Defaults to False.
296
+ for each argument.
297
+ error_on_invalid_docstring: if `parse_docstring` is provided, configure
298
+ whether to raise `ValueError` on invalid Google Style docstrings.
300
299
  include_injected: Whether to include injected arguments in the schema.
301
- Defaults to True, since we want to include them in the schema
300
+ Defaults to `True`, since we want to include them in the schema
302
301
  when *validating* tool inputs.
303
302
 
304
303
  Returns:
305
- A pydantic model with the same arguments as the function.
304
+ A Pydantic model with the same arguments as the function.
306
305
  """
307
306
  sig = inspect.signature(func)
308
307
 
@@ -312,7 +311,7 @@ def create_schema_from_function(
312
311
  # https://docs.pydantic.dev/latest/usage/validation_decorator/
313
312
  with warnings.catch_warnings():
314
313
  # We are using deprecated functionality here.
315
- # This code should be re-written to simply construct a pydantic model
314
+ # This code should be re-written to simply construct a Pydantic model
316
315
  # using inspect.signature and create_model.
317
316
  warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
318
317
  validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
@@ -385,10 +384,10 @@ class ToolException(Exception): # noqa: N818
385
384
  """
386
385
 
387
386
 
388
- ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
387
+ ArgsSchema = TypeBaseModel | dict[str, Any]
389
388
 
390
389
 
391
- class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
390
+ class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
392
391
  """Base class for all LangChain tools.
393
392
 
394
393
  This abstract class defines the interface that all LangChain tools must implement.
@@ -436,7 +435,7 @@ class ChildTool(BaseTool):
436
435
  You can provide few-shot examples as a part of the description.
437
436
  """
438
437
 
439
- args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field(
438
+ args_schema: Annotated[ArgsSchema | None, SkipValidation()] = Field(
440
439
  default=None, description="The tool schema."
441
440
  )
442
441
  """Pydantic model class to validate and parse the tool's input arguments.
@@ -459,56 +458,42 @@ class ChildTool(BaseTool):
459
458
  callbacks: Callbacks = Field(default=None, exclude=True)
460
459
  """Callbacks to be called during tool execution."""
461
460
 
462
- callback_manager: Optional[BaseCallbackManager] = deprecated(
463
- name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
464
- )(
465
- Field(
466
- default=None,
467
- exclude=True,
468
- description="Callback manager to add to the run trace.",
469
- )
470
- )
471
- tags: Optional[list[str]] = None
472
- """Optional list of tags associated with the tool. Defaults to None.
461
+ tags: list[str] | None = None
462
+ """Optional list of tags associated with the tool.
473
463
  These tags will be associated with each call to this tool,
474
464
  and passed as arguments to the handlers defined in `callbacks`.
475
465
  You can use these to eg identify a specific instance of a tool with its use case.
476
466
  """
477
- metadata: Optional[dict[str, Any]] = None
478
- """Optional metadata associated with the tool. Defaults to None.
467
+ metadata: dict[str, Any] | None = None
468
+ """Optional metadata associated with the tool.
479
469
  This metadata will be associated with each call to this tool,
480
470
  and passed as arguments to the handlers defined in `callbacks`.
481
471
  You can use these to eg identify a specific instance of a tool with its use case.
482
472
  """
483
473
 
484
- handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = (
485
- False
486
- )
474
+ handle_tool_error: bool | str | Callable[[ToolException], str] | None = False
487
475
  """Handle the content of the ToolException thrown."""
488
476
 
489
- handle_validation_error: Optional[
490
- Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]]
491
- ] = False
477
+ handle_validation_error: (
478
+ bool | str | Callable[[ValidationError | ValidationErrorV1], str] | None
479
+ ) = False
492
480
  """Handle the content of the ValidationError thrown."""
493
481
 
494
482
  response_format: Literal["content", "content_and_artifact"] = "content"
495
- """The tool response format. Defaults to 'content'.
483
+ """The tool response format.
496
484
 
497
- If "content" then the output of the tool is interpreted as the contents of a
498
- ToolMessage. If "content_and_artifact" then the output is expected to be a
499
- two-tuple corresponding to the (content, artifact) of a ToolMessage.
500
- """
501
-
502
- message_version: Literal["v0", "v1"] = "v0"
503
- """Version of ToolMessage to return given
504
- :class:`~langchain_core.messages.content_blocks.ToolCall` input.
505
-
506
- If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
507
- If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`.
485
+ If `"content"` then the output of the tool is interpreted as the contents of a
486
+ ToolMessage. If `"content_and_artifact"` then the output is expected to be a
487
+ two-tuple corresponding to the (content, artifact) of a `ToolMessage`.
508
488
  """
509
489
 
510
490
  def __init__(self, **kwargs: Any) -> None:
511
- """Initialize the tool."""
491
+ """Initialize the tool.
492
+
493
+ Raises:
494
+ TypeError: If `args_schema` is not a subclass of pydantic `BaseModel` or
495
+ dict.
496
+ """
512
497
  if (
513
498
  "args_schema" in kwargs
514
499
  and kwargs["args_schema"] is not None
@@ -531,7 +516,7 @@ class ChildTool(BaseTool):
531
516
  """Check if the tool accepts only a single input argument.
532
517
 
533
518
  Returns:
534
- True if the tool has only one input argument, False otherwise.
519
+ `True` if the tool has only one input argument, `False` otherwise.
535
520
  """
536
521
  keys = {k for k in self.args if k != "kwargs"}
537
522
  return len(keys) == 1
@@ -545,6 +530,8 @@ class ChildTool(BaseTool):
545
530
  """
546
531
  if isinstance(self.args_schema, dict):
547
532
  json_schema = self.args_schema
533
+ elif self.args_schema and issubclass(self.args_schema, BaseModelV1):
534
+ json_schema = self.args_schema.schema()
548
535
  else:
549
536
  input_schema = self.get_input_schema()
550
537
  json_schema = input_schema.model_json_schema()
@@ -578,9 +565,7 @@ class ChildTool(BaseTool):
578
565
  # --- Runnable ---
579
566
 
580
567
  @override
581
- def get_input_schema(
582
- self, config: Optional[RunnableConfig] = None
583
- ) -> type[BaseModel]:
568
+ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
584
569
  """The tool's input schema.
585
570
 
586
571
  Args:
@@ -598,8 +583,8 @@ class ChildTool(BaseTool):
598
583
  @override
599
584
  def invoke(
600
585
  self,
601
- input: Union[str, dict, ToolCall],
602
- config: Optional[RunnableConfig] = None,
586
+ input: str | dict | ToolCall,
587
+ config: RunnableConfig | None = None,
603
588
  **kwargs: Any,
604
589
  ) -> Any:
605
590
  tool_input, kwargs = _prep_run_args(input, config, **kwargs)
@@ -608,8 +593,8 @@ class ChildTool(BaseTool):
608
593
  @override
609
594
  async def ainvoke(
610
595
  self,
611
- input: Union[str, dict, ToolCall],
612
- config: Optional[RunnableConfig] = None,
596
+ input: str | dict | ToolCall,
597
+ config: RunnableConfig | None = None,
613
598
  **kwargs: Any,
614
599
  ) -> Any:
615
600
  tool_input, kwargs = _prep_run_args(input, config, **kwargs)
@@ -618,8 +603,8 @@ class ChildTool(BaseTool):
618
603
  # --- Tool ---
619
604
 
620
605
  def _parse_input(
621
- self, tool_input: Union[str, dict], tool_call_id: Optional[str]
622
- ) -> Union[str, dict[str, Any]]:
606
+ self, tool_input: str | dict, tool_call_id: str | None
607
+ ) -> str | dict[str, Any]:
623
608
  """Parse and validate tool input using the args schema.
624
609
 
625
610
  Args:
@@ -630,9 +615,10 @@ class ChildTool(BaseTool):
630
615
  The parsed and validated input.
631
616
 
632
617
  Raises:
633
- ValueError: If string input is provided with JSON schema or if
634
- InjectedToolCallId is required but not provided.
635
- NotImplementedError: If args_schema is not a supported type.
618
+ ValueError: If `string` input is provided with JSON schema `args_schema`.
619
+ ValueError: If InjectedToolCallId is required but `tool_call_id` is not
620
+ provided.
621
+ TypeError: If args_schema is not a Pydantic `BaseModel` or dict.
636
622
  """
637
623
  input_args = self.args_schema
638
624
  if isinstance(tool_input, str):
@@ -657,10 +643,7 @@ class ChildTool(BaseTool):
657
643
  return tool_input
658
644
  if issubclass(input_args, BaseModel):
659
645
  for k, v in get_all_basemodel_annotations(input_args).items():
660
- if (
661
- _is_injected_arg_type(v, injected_type=InjectedToolCallId)
662
- and k not in tool_input
663
- ):
646
+ if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
664
647
  if tool_call_id is None:
665
648
  msg = (
666
649
  "When tool includes an InjectedToolCallId "
@@ -675,10 +658,7 @@ class ChildTool(BaseTool):
675
658
  result_dict = result.model_dump()
676
659
  elif issubclass(input_args, BaseModelV1):
677
660
  for k, v in get_all_basemodel_annotations(input_args).items():
678
- if (
679
- _is_injected_arg_type(v, injected_type=InjectedToolCallId)
680
- and k not in tool_input
681
- ):
661
+ if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
682
662
  if tool_call_id is None:
683
663
  msg = (
684
664
  "When tool includes an InjectedToolCallId "
@@ -701,39 +681,25 @@ class ChildTool(BaseTool):
701
681
  }
702
682
  return tool_input
703
683
 
704
- @model_validator(mode="before")
705
- @classmethod
706
- def raise_deprecation(cls, values: dict) -> Any:
707
- """Raise deprecation warning if callback_manager is used.
708
-
709
- Args:
710
- values: The values to validate.
711
-
712
- Returns:
713
- The validated values.
714
- """
715
- if values.get("callback_manager") is not None:
716
- warnings.warn(
717
- "callback_manager is deprecated. Please use callbacks instead.",
718
- DeprecationWarning,
719
- stacklevel=6,
720
- )
721
- values["callbacks"] = values.pop("callback_manager", None)
722
- return values
723
-
724
684
  @abstractmethod
725
685
  def _run(self, *args: Any, **kwargs: Any) -> Any:
726
686
  """Use the tool.
727
687
 
728
- Add run_manager: Optional[CallbackManagerForToolRun] = None
729
- to child implementations to enable tracing.
688
+ Add `run_manager: CallbackManagerForToolRun | None = None` to child
689
+ implementations to enable tracing.
690
+
691
+ Returns:
692
+ The result of the tool execution.
730
693
  """
731
694
 
732
695
  async def _arun(self, *args: Any, **kwargs: Any) -> Any:
733
696
  """Use the tool asynchronously.
734
697
 
735
- Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
736
- to child implementations to enable tracing.
698
+ Add `run_manager: AsyncCallbackManagerForToolRun | None = None` to child
699
+ implementations to enable tracing.
700
+
701
+ Returns:
702
+ The result of the tool execution.
737
703
  """
738
704
  if kwargs.get("run_manager") and signature(self._run).parameters.get(
739
705
  "run_manager"
@@ -742,7 +708,7 @@ class ChildTool(BaseTool):
742
708
  return await run_in_executor(None, self._run, *args, **kwargs)
743
709
 
744
710
  def _to_args_and_kwargs(
745
- self, tool_input: Union[str, dict], tool_call_id: Optional[str]
711
+ self, tool_input: str | dict, tool_call_id: str | None
746
712
  ) -> tuple[tuple, dict]:
747
713
  """Convert tool input to positional and keyword arguments.
748
714
 
@@ -782,35 +748,35 @@ class ChildTool(BaseTool):
782
748
 
783
749
  def run(
784
750
  self,
785
- tool_input: Union[str, dict[str, Any]],
786
- verbose: Optional[bool] = None, # noqa: FBT001
787
- start_color: Optional[str] = "green",
788
- color: Optional[str] = "green",
751
+ tool_input: str | dict[str, Any],
752
+ verbose: bool | None = None, # noqa: FBT001
753
+ start_color: str | None = "green",
754
+ color: str | None = "green",
789
755
  callbacks: Callbacks = None,
790
756
  *,
791
- tags: Optional[list[str]] = None,
792
- metadata: Optional[dict[str, Any]] = None,
793
- run_name: Optional[str] = None,
794
- run_id: Optional[uuid.UUID] = None,
795
- config: Optional[RunnableConfig] = None,
796
- tool_call_id: Optional[str] = None,
757
+ tags: list[str] | None = None,
758
+ metadata: dict[str, Any] | None = None,
759
+ run_name: str | None = None,
760
+ run_id: uuid.UUID | None = None,
761
+ config: RunnableConfig | None = None,
762
+ tool_call_id: str | None = None,
797
763
  **kwargs: Any,
798
764
  ) -> Any:
799
765
  """Run the tool.
800
766
 
801
767
  Args:
802
768
  tool_input: The input to the tool.
803
- verbose: Whether to log the tool's progress. Defaults to None.
804
- start_color: The color to use when starting the tool. Defaults to 'green'.
805
- color: The color to use when ending the tool. Defaults to 'green'.
806
- callbacks: Callbacks to be called during tool execution. Defaults to None.
807
- tags: Optional list of tags associated with the tool. Defaults to None.
808
- metadata: Optional metadata associated with the tool. Defaults to None.
809
- run_name: The name of the run. Defaults to None.
810
- run_id: The id of the run. Defaults to None.
811
- config: The configuration for the tool. Defaults to None.
812
- tool_call_id: The id of the tool call. Defaults to None.
813
- kwargs: Keyword arguments to be passed to tool callbacks (event handler)
769
+ verbose: Whether to log the tool's progress.
770
+ start_color: The color to use when starting the tool.
771
+ color: The color to use when ending the tool.
772
+ callbacks: Callbacks to be called during tool execution.
773
+ tags: Optional list of tags associated with the tool.
774
+ metadata: Optional metadata associated with the tool.
775
+ run_name: The name of the run.
776
+ run_id: The id of the run.
777
+ config: The configuration for the tool.
778
+ tool_call_id: The id of the tool call.
779
+ **kwargs: Keyword arguments to be passed to tool callbacks (event handler)
814
780
 
815
781
  Returns:
816
782
  The output of the tool.
@@ -844,8 +810,8 @@ class ChildTool(BaseTool):
844
810
 
845
811
  content = None
846
812
  artifact = None
847
- status: Literal["success", "error"] = "success"
848
- error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
813
+ status = "success"
814
+ error_to_raise: Exception | KeyboardInterrupt | None = None
849
815
  try:
850
816
  child_config = patch_config(config, callbacks=run_manager.get_child())
851
817
  with set_config_context(child_config) as context:
@@ -888,48 +854,41 @@ class ChildTool(BaseTool):
888
854
  if error_to_raise:
889
855
  run_manager.on_tool_error(error_to_raise)
890
856
  raise error_to_raise
891
- output = _format_output(
892
- content,
893
- artifact,
894
- tool_call_id,
895
- self.name,
896
- status,
897
- message_version=self.message_version,
898
- )
857
+ output = _format_output(content, artifact, tool_call_id, self.name, status)
899
858
  run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
900
859
  return output
901
860
 
902
861
  async def arun(
903
862
  self,
904
- tool_input: Union[str, dict],
905
- verbose: Optional[bool] = None, # noqa: FBT001
906
- start_color: Optional[str] = "green",
907
- color: Optional[str] = "green",
863
+ tool_input: str | dict,
864
+ verbose: bool | None = None, # noqa: FBT001
865
+ start_color: str | None = "green",
866
+ color: str | None = "green",
908
867
  callbacks: Callbacks = None,
909
868
  *,
910
- tags: Optional[list[str]] = None,
911
- metadata: Optional[dict[str, Any]] = None,
912
- run_name: Optional[str] = None,
913
- run_id: Optional[uuid.UUID] = None,
914
- config: Optional[RunnableConfig] = None,
915
- tool_call_id: Optional[str] = None,
869
+ tags: list[str] | None = None,
870
+ metadata: dict[str, Any] | None = None,
871
+ run_name: str | None = None,
872
+ run_id: uuid.UUID | None = None,
873
+ config: RunnableConfig | None = None,
874
+ tool_call_id: str | None = None,
916
875
  **kwargs: Any,
917
876
  ) -> Any:
918
877
  """Run the tool asynchronously.
919
878
 
920
879
  Args:
921
880
  tool_input: The input to the tool.
922
- verbose: Whether to log the tool's progress. Defaults to None.
923
- start_color: The color to use when starting the tool. Defaults to 'green'.
924
- color: The color to use when ending the tool. Defaults to 'green'.
925
- callbacks: Callbacks to be called during tool execution. Defaults to None.
926
- tags: Optional list of tags associated with the tool. Defaults to None.
927
- metadata: Optional metadata associated with the tool. Defaults to None.
928
- run_name: The name of the run. Defaults to None.
929
- run_id: The id of the run. Defaults to None.
930
- config: The configuration for the tool. Defaults to None.
931
- tool_call_id: The id of the tool call. Defaults to None.
932
- kwargs: Keyword arguments to be passed to tool callbacks
881
+ verbose: Whether to log the tool's progress.
882
+ start_color: The color to use when starting the tool.
883
+ color: The color to use when ending the tool.
884
+ callbacks: Callbacks to be called during tool execution.
885
+ tags: Optional list of tags associated with the tool.
886
+ metadata: Optional metadata associated with the tool.
887
+ run_name: The name of the run.
888
+ run_id: The id of the run.
889
+ config: The configuration for the tool.
890
+ tool_call_id: The id of the tool call.
891
+ **kwargs: Keyword arguments to be passed to tool callbacks
933
892
 
934
893
  Returns:
935
894
  The output of the tool.
@@ -961,8 +920,8 @@ class ChildTool(BaseTool):
961
920
  )
962
921
  content = None
963
922
  artifact = None
964
- status: Literal["success", "error"] = "success"
965
- error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
923
+ status = "success"
924
+ error_to_raise: Exception | KeyboardInterrupt | None = None
966
925
  try:
967
926
  tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
968
927
  child_config = patch_config(config, callbacks=run_manager.get_child())
@@ -1009,30 +968,10 @@ class ChildTool(BaseTool):
1009
968
  await run_manager.on_tool_error(error_to_raise)
1010
969
  raise error_to_raise
1011
970
 
1012
- output = _format_output(
1013
- content,
1014
- artifact,
1015
- tool_call_id,
1016
- self.name,
1017
- status,
1018
- message_version=self.message_version,
1019
- )
971
+ output = _format_output(content, artifact, tool_call_id, self.name, status)
1020
972
  await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
1021
973
  return output
1022
974
 
1023
- @deprecated("0.1.47", alternative="invoke", removal="1.0")
1024
- def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
1025
- """Make tool callable (deprecated).
1026
-
1027
- Args:
1028
- tool_input: The input to the tool.
1029
- callbacks: Callbacks to use during execution.
1030
-
1031
- Returns:
1032
- The tool's output.
1033
- """
1034
- return self.run(tool_input, callbacks=callbacks)
1035
-
1036
975
 
1037
976
  def _is_tool_call(x: Any) -> bool:
1038
977
  """Check if the input is a tool call dictionary.
@@ -1041,17 +980,15 @@ def _is_tool_call(x: Any) -> bool:
1041
980
  x: The input to check.
1042
981
 
1043
982
  Returns:
1044
- True if the input is a tool call, False otherwise.
983
+ `True` if the input is a tool call, `False` otherwise.
1045
984
  """
1046
985
  return isinstance(x, dict) and x.get("type") == "tool_call"
1047
986
 
1048
987
 
1049
988
  def _handle_validation_error(
1050
- e: Union[ValidationError, ValidationErrorV1],
989
+ e: ValidationError | ValidationErrorV1,
1051
990
  *,
1052
- flag: Union[
1053
- Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
1054
- ],
991
+ flag: Literal[True] | str | Callable[[ValidationError | ValidationErrorV1], str],
1055
992
  ) -> str:
1056
993
  """Handle validation errors based on the configured flag.
1057
994
 
@@ -1083,7 +1020,7 @@ def _handle_validation_error(
1083
1020
  def _handle_tool_error(
1084
1021
  e: ToolException,
1085
1022
  *,
1086
- flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
1023
+ flag: Literal[True] | str | Callable[[ToolException], str] | None,
1087
1024
  ) -> str:
1088
1025
  """Handle tool execution errors based on the configured flag.
1089
1026
 
@@ -1113,10 +1050,10 @@ def _handle_tool_error(
1113
1050
 
1114
1051
 
1115
1052
  def _prep_run_args(
1116
- value: Union[str, dict, ToolCall],
1117
- config: Optional[RunnableConfig],
1053
+ value: str | dict | ToolCall,
1054
+ config: RunnableConfig | None,
1118
1055
  **kwargs: Any,
1119
- ) -> tuple[Union[str, dict], dict]:
1056
+ ) -> tuple[str | dict, dict]:
1120
1057
  """Prepare arguments for tool execution.
1121
1058
 
1122
1059
  Args:
@@ -1129,11 +1066,11 @@ def _prep_run_args(
1129
1066
  """
1130
1067
  config = ensure_config(config)
1131
1068
  if _is_tool_call(value):
1132
- tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
1133
- tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy()
1069
+ tool_call_id: str | None = cast("ToolCall", value)["id"]
1070
+ tool_input: str | dict = cast("ToolCall", value)["args"].copy()
1134
1071
  else:
1135
1072
  tool_call_id = None
1136
- tool_input = cast("Union[str, dict]", value)
1073
+ tool_input = cast("str | dict", value)
1137
1074
  return (
1138
1075
  tool_input,
1139
1076
  dict(
@@ -1152,12 +1089,10 @@ def _prep_run_args(
1152
1089
  def _format_output(
1153
1090
  content: Any,
1154
1091
  artifact: Any,
1155
- tool_call_id: Optional[str],
1092
+ tool_call_id: str | None,
1156
1093
  name: str,
1157
- status: Literal["success", "error"],
1158
- *,
1159
- message_version: Literal["v0", "v1"] = "v0",
1160
- ) -> Union[ToolOutputMixin, Any]:
1094
+ status: str,
1095
+ ) -> ToolOutputMixin | Any:
1161
1096
  """Format tool output as a ToolMessage if appropriate.
1162
1097
 
1163
1098
  Args:
@@ -1166,7 +1101,6 @@ def _format_output(
1166
1101
  tool_call_id: The ID of the tool call.
1167
1102
  name: The name of the tool.
1168
1103
  status: The execution status.
1169
- message_version: The version of the ToolMessage to return.
1170
1104
 
1171
1105
  Returns:
1172
1106
  The formatted output, either as a ToolMessage or the original content.
@@ -1175,15 +1109,7 @@ def _format_output(
1175
1109
  return content
1176
1110
  if not _is_message_content_type(content):
1177
1111
  content = _stringify(content)
1178
- if message_version == "v0":
1179
- return ToolMessage(
1180
- content,
1181
- artifact=artifact,
1182
- tool_call_id=tool_call_id,
1183
- name=name,
1184
- status=status,
1185
- )
1186
- return ToolMessageV1(
1112
+ return ToolMessage(
1187
1113
  content,
1188
1114
  artifact=artifact,
1189
1115
  tool_call_id=tool_call_id,
@@ -1201,7 +1127,7 @@ def _is_message_content_type(obj: Any) -> bool:
1201
1127
  obj: The object to check.
1202
1128
 
1203
1129
  Returns:
1204
- True if the object is valid message content, False otherwise.
1130
+ `True` if the object is valid message content, `False` otherwise.
1205
1131
  """
1206
1132
  return isinstance(obj, str) or (
1207
1133
  isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
@@ -1217,7 +1143,7 @@ def _is_message_content_block(obj: Any) -> bool:
1217
1143
  obj: The object to check.
1218
1144
 
1219
1145
  Returns:
1220
- True if the object is a valid content block, False otherwise.
1146
+ `True` if the object is a valid content block, `False` otherwise.
1221
1147
  """
1222
1148
  if isinstance(obj, str):
1223
1149
  return True
@@ -1241,7 +1167,7 @@ def _stringify(content: Any) -> str:
1241
1167
  return str(content)
1242
1168
 
1243
1169
 
1244
- def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
1170
+ def _get_type_hints(func: Callable) -> dict[str, type] | None:
1245
1171
  """Get type hints from a function, handling partial functions.
1246
1172
 
1247
1173
  Args:
@@ -1258,7 +1184,7 @@ def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
1258
1184
  return None
1259
1185
 
1260
1186
 
1261
- def _get_runnable_config_param(func: Callable) -> Optional[str]:
1187
+ def _get_runnable_config_param(func: Callable) -> str | None:
1262
1188
  """Find the parameter name for RunnableConfig in a function.
1263
1189
 
1264
1190
  Args:
@@ -1284,35 +1210,73 @@ class InjectedToolArg:
1284
1210
  """
1285
1211
 
1286
1212
 
1213
+ class _DirectlyInjectedToolArg:
1214
+ """Annotation for tool arguments that are injected at runtime.
1215
+
1216
+ Injected via direct type annotation, rather than annotated metadata.
1217
+
1218
+ For example, ToolRuntime is a directly injected argument.
1219
+ Note the direct annotation rather than the verbose alternative:
1220
+ Annotated[ToolRuntime, InjectedRuntime]
1221
+ ```python
1222
+ from langchain_core.tools import tool, ToolRuntime
1223
+
1224
+
1225
+ @tool
1226
+ def foo(x: int, runtime: ToolRuntime) -> str:
1227
+ # use runtime.state, runtime.context, runtime.store, etc.
1228
+ ...
1229
+ ```
1230
+ """
1231
+
1232
+
1287
1233
  class InjectedToolCallId(InjectedToolArg):
1288
1234
  """Annotation for injecting the tool call ID.
1289
1235
 
1290
1236
  This annotation is used to mark a tool parameter that should receive
1291
1237
  the tool call ID at runtime.
1292
1238
 
1293
- .. code-block:: python
1294
-
1295
- from typing_extensions import Annotated
1296
- from langchain_core.messages import ToolMessage
1297
- from langchain_core.tools import tool, InjectedToolCallId
1298
-
1299
- @tool
1300
- def foo(
1301
- x: int, tool_call_id: Annotated[str, InjectedToolCallId]
1302
- ) -> ToolMessage:
1303
- \"\"\"Return x.\"\"\"
1304
- return ToolMessage(
1305
- str(x),
1306
- artifact=x,
1307
- name="foo",
1308
- tool_call_id=tool_call_id
1309
- )
1239
+ ```python
1240
+ from typing import Annotated
1241
+ from langchain_core.messages import ToolMessage
1242
+ from langchain_core.tools import tool, InjectedToolCallId
1243
+
1244
+ @tool
1245
+ def foo(
1246
+ x: int, tool_call_id: Annotated[str, InjectedToolCallId]
1247
+ ) -> ToolMessage:
1248
+ \"\"\"Return x.\"\"\"
1249
+ return ToolMessage(
1250
+ str(x),
1251
+ artifact=x,
1252
+ name="foo",
1253
+ tool_call_id=tool_call_id
1254
+ )
1255
+
1256
+ ```
1257
+ """
1310
1258
 
1259
+
1260
+ def _is_directly_injected_arg_type(type_: Any) -> bool:
1261
+ """Check if a type annotation indicates a directly injected argument.
1262
+
1263
+ This is currently only used for ToolRuntime.
1264
+ Checks if either the annotation itself is a subclass of _DirectlyInjectedToolArg
1265
+ or the origin of the annotation is a subclass of _DirectlyInjectedToolArg.
1266
+
1267
+ Ex: ToolRuntime or ToolRuntime[ContextT, StateT] would both return True.
1311
1268
  """
1269
+ return (
1270
+ isinstance(type_, type) and issubclass(type_, _DirectlyInjectedToolArg)
1271
+ ) or (
1272
+ (origin := get_origin(type_)) is not None
1273
+ and isinstance(origin, type)
1274
+ and issubclass(origin, _DirectlyInjectedToolArg)
1275
+ )
1312
1276
 
1313
1277
 
1314
1278
  def _is_injected_arg_type(
1315
- type_: type, injected_type: Optional[type[InjectedToolArg]] = None
1279
+ type_: type | TypeVar, injected_type: type[InjectedToolArg] | None = None
1316
1280
  ) -> bool:
1317
1281
  """Check if a type annotation indicates an injected argument.
1318
1282
 
@@ -1321,9 +1285,17 @@ def _is_injected_arg_type(
1321
1285
  injected_type: The specific injected type to check for.
1322
1286
 
1323
1287
  Returns:
1324
- True if the type is an injected argument, False otherwise.
1288
+ `True` if the type is an injected argument, `False` otherwise.
1325
1289
  """
1326
- injected_type = injected_type or InjectedToolArg
1290
+ if injected_type is None:
1291
+ # if no injected type is specified,
1292
+ # check if the type is a directly injected argument
1293
+ if _is_directly_injected_arg_type(type_):
1294
+ return True
1295
+ injected_type = InjectedToolArg
1296
+
1297
+ # if the type is an Annotated type, check if annotated metadata
1298
+ # is an intance or subclass of the injected type
1327
1299
  return any(
1328
1300
  isinstance(arg, injected_type)
1329
1301
  or (isinstance(arg, type) and issubclass(arg, injected_type))
@@ -1332,21 +1304,23 @@ def _is_injected_arg_type(
1332
1304
 
1333
1305
 
1334
1306
  def get_all_basemodel_annotations(
1335
- cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
1336
- ) -> dict[str, type]:
1307
+ cls: TypeBaseModel | Any, *, default_to_bound: bool = True
1308
+ ) -> dict[str, type | TypeVar]:
1337
1309
  """Get all annotations from a Pydantic BaseModel and its parents.
1338
1310
 
1339
1311
  Args:
1340
1312
  cls: The Pydantic BaseModel class.
1341
1313
  default_to_bound: Whether to default to the bound of a TypeVar if it exists.
1314
+
1315
+ Returns:
1316
+ A dictionary of field names to their type annotations.
1342
1317
  """
1343
1318
  # cls has no subscript: cls = FooBar
1344
1319
  if isinstance(cls, type):
1345
- # Gather pydantic field objects (v2: model_fields / v1: __fields__)
1346
- fields = getattr(cls, "model_fields", {}) or getattr(cls, "__fields__", {})
1320
+ fields = get_fields(cls)
1347
1321
  alias_map = {field.alias: name for name, field in fields.items() if field.alias}
1348
1322
 
1349
- annotations: dict[str, type] = {}
1323
+ annotations: dict[str, type | TypeVar] = {}
1350
1324
  for name, param in inspect.signature(cls).parameters.items():
1351
1325
  # Exclude hidden init args added by pydantic Config. For example if
1352
1326
  # BaseModel(extra="allow") then "extra_data" will part of init sig.
@@ -1382,12 +1356,12 @@ def get_all_basemodel_annotations(
1382
1356
  continue
1383
1357
 
1384
1358
  # if class = FooBar inherits from Baz[str]:
1385
- # parent = Baz[str],
1386
- # parent_origin = Baz,
1359
+ # parent = class Baz[str],
1360
+ # parent_origin = class Baz,
1387
1361
  # generic_type_vars = (type vars in Baz)
1388
1362
  # generic_map = {type var in Baz: str}
1389
1363
  generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
1390
- generic_map = dict(zip(generic_type_vars, get_args(parent)))
1364
+ generic_map = dict(zip(generic_type_vars, get_args(parent), strict=False))
1391
1365
  for field in getattr(parent_origin, "__annotations__", {}):
1392
1366
  annotations[field] = _replace_type_vars(
1393
1367
  annotations[field], generic_map, default_to_bound=default_to_bound
@@ -1400,11 +1374,11 @@ def get_all_basemodel_annotations(
1400
1374
 
1401
1375
 
1402
1376
  def _replace_type_vars(
1403
- type_: type,
1404
- generic_map: Optional[dict[TypeVar, type]] = None,
1377
+ type_: type | TypeVar,
1378
+ generic_map: dict[TypeVar, type] | None = None,
1405
1379
  *,
1406
1380
  default_to_bound: bool = True,
1407
- ) -> type:
1381
+ ) -> type | TypeVar:
1408
1382
  """Replace TypeVars in a type annotation with concrete types.
1409
1383
 
1410
1384
  Args:
@@ -1420,7 +1394,7 @@ def _replace_type_vars(
1420
1394
  if type_ in generic_map:
1421
1395
  return generic_map[type_]
1422
1396
  if default_to_bound:
1423
- return type_.__bound__ or Any
1397
+ return type_.__bound__ if type_.__bound__ is not None else Any
1424
1398
  return type_
1425
1399
  if (origin := get_origin(type_)) and (args := get_args(type_)):
1426
1400
  new_args = tuple(