langchain-core 1.0.0a6__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. langchain_core/__init__.py +1 -1
  2. langchain_core/_api/__init__.py +3 -4
  3. langchain_core/_api/beta_decorator.py +23 -26
  4. langchain_core/_api/deprecation.py +51 -64
  5. langchain_core/_api/path.py +3 -6
  6. langchain_core/_import_utils.py +3 -4
  7. langchain_core/agents.py +55 -48
  8. langchain_core/caches.py +65 -66
  9. langchain_core/callbacks/__init__.py +1 -8
  10. langchain_core/callbacks/base.py +321 -336
  11. langchain_core/callbacks/file.py +44 -44
  12. langchain_core/callbacks/manager.py +454 -514
  13. langchain_core/callbacks/stdout.py +29 -30
  14. langchain_core/callbacks/streaming_stdout.py +32 -32
  15. langchain_core/callbacks/usage.py +60 -57
  16. langchain_core/chat_history.py +53 -68
  17. langchain_core/document_loaders/base.py +27 -25
  18. langchain_core/document_loaders/blob_loaders.py +1 -1
  19. langchain_core/document_loaders/langsmith.py +44 -48
  20. langchain_core/documents/__init__.py +23 -3
  21. langchain_core/documents/base.py +102 -94
  22. langchain_core/documents/compressor.py +10 -10
  23. langchain_core/documents/transformers.py +34 -35
  24. langchain_core/embeddings/fake.py +50 -54
  25. langchain_core/example_selectors/length_based.py +2 -2
  26. langchain_core/example_selectors/semantic_similarity.py +28 -32
  27. langchain_core/exceptions.py +21 -20
  28. langchain_core/globals.py +3 -151
  29. langchain_core/indexing/__init__.py +1 -1
  30. langchain_core/indexing/api.py +121 -126
  31. langchain_core/indexing/base.py +73 -75
  32. langchain_core/indexing/in_memory.py +4 -6
  33. langchain_core/language_models/__init__.py +14 -29
  34. langchain_core/language_models/_utils.py +58 -61
  35. langchain_core/language_models/base.py +82 -172
  36. langchain_core/language_models/chat_models.py +329 -402
  37. langchain_core/language_models/fake.py +11 -11
  38. langchain_core/language_models/fake_chat_models.py +42 -36
  39. langchain_core/language_models/llms.py +189 -269
  40. langchain_core/load/dump.py +9 -12
  41. langchain_core/load/load.py +18 -28
  42. langchain_core/load/mapping.py +2 -4
  43. langchain_core/load/serializable.py +42 -40
  44. langchain_core/messages/__init__.py +10 -16
  45. langchain_core/messages/ai.py +148 -148
  46. langchain_core/messages/base.py +53 -51
  47. langchain_core/messages/block_translators/__init__.py +19 -22
  48. langchain_core/messages/block_translators/anthropic.py +6 -6
  49. langchain_core/messages/block_translators/bedrock_converse.py +5 -5
  50. langchain_core/messages/block_translators/google_genai.py +10 -7
  51. langchain_core/messages/block_translators/google_vertexai.py +4 -32
  52. langchain_core/messages/block_translators/groq.py +117 -21
  53. langchain_core/messages/block_translators/langchain_v0.py +5 -5
  54. langchain_core/messages/block_translators/openai.py +11 -11
  55. langchain_core/messages/chat.py +2 -6
  56. langchain_core/messages/content.py +339 -330
  57. langchain_core/messages/function.py +6 -10
  58. langchain_core/messages/human.py +24 -31
  59. langchain_core/messages/modifier.py +2 -2
  60. langchain_core/messages/system.py +19 -29
  61. langchain_core/messages/tool.py +74 -90
  62. langchain_core/messages/utils.py +484 -510
  63. langchain_core/output_parsers/__init__.py +13 -10
  64. langchain_core/output_parsers/base.py +61 -61
  65. langchain_core/output_parsers/format_instructions.py +9 -4
  66. langchain_core/output_parsers/json.py +12 -10
  67. langchain_core/output_parsers/list.py +21 -23
  68. langchain_core/output_parsers/openai_functions.py +49 -47
  69. langchain_core/output_parsers/openai_tools.py +30 -23
  70. langchain_core/output_parsers/pydantic.py +13 -14
  71. langchain_core/output_parsers/string.py +5 -5
  72. langchain_core/output_parsers/transform.py +15 -17
  73. langchain_core/output_parsers/xml.py +35 -34
  74. langchain_core/outputs/__init__.py +1 -1
  75. langchain_core/outputs/chat_generation.py +18 -18
  76. langchain_core/outputs/chat_result.py +1 -3
  77. langchain_core/outputs/generation.py +16 -16
  78. langchain_core/outputs/llm_result.py +10 -10
  79. langchain_core/prompt_values.py +13 -19
  80. langchain_core/prompts/__init__.py +3 -27
  81. langchain_core/prompts/base.py +81 -86
  82. langchain_core/prompts/chat.py +308 -351
  83. langchain_core/prompts/dict.py +6 -6
  84. langchain_core/prompts/few_shot.py +81 -88
  85. langchain_core/prompts/few_shot_with_templates.py +11 -13
  86. langchain_core/prompts/image.py +12 -14
  87. langchain_core/prompts/loading.py +4 -6
  88. langchain_core/prompts/message.py +7 -7
  89. langchain_core/prompts/prompt.py +24 -39
  90. langchain_core/prompts/string.py +26 -10
  91. langchain_core/prompts/structured.py +49 -53
  92. langchain_core/rate_limiters.py +51 -60
  93. langchain_core/retrievers.py +61 -198
  94. langchain_core/runnables/base.py +1551 -1656
  95. langchain_core/runnables/branch.py +68 -70
  96. langchain_core/runnables/config.py +72 -89
  97. langchain_core/runnables/configurable.py +145 -161
  98. langchain_core/runnables/fallbacks.py +102 -96
  99. langchain_core/runnables/graph.py +91 -97
  100. langchain_core/runnables/graph_ascii.py +27 -28
  101. langchain_core/runnables/graph_mermaid.py +42 -51
  102. langchain_core/runnables/graph_png.py +43 -16
  103. langchain_core/runnables/history.py +175 -177
  104. langchain_core/runnables/passthrough.py +151 -167
  105. langchain_core/runnables/retry.py +46 -51
  106. langchain_core/runnables/router.py +30 -35
  107. langchain_core/runnables/schema.py +75 -80
  108. langchain_core/runnables/utils.py +60 -67
  109. langchain_core/stores.py +85 -121
  110. langchain_core/structured_query.py +8 -8
  111. langchain_core/sys_info.py +29 -29
  112. langchain_core/tools/__init__.py +1 -14
  113. langchain_core/tools/base.py +306 -245
  114. langchain_core/tools/convert.py +160 -155
  115. langchain_core/tools/render.py +10 -10
  116. langchain_core/tools/retriever.py +12 -11
  117. langchain_core/tools/simple.py +19 -24
  118. langchain_core/tools/structured.py +32 -39
  119. langchain_core/tracers/__init__.py +1 -9
  120. langchain_core/tracers/base.py +97 -99
  121. langchain_core/tracers/context.py +29 -52
  122. langchain_core/tracers/core.py +49 -53
  123. langchain_core/tracers/evaluation.py +11 -11
  124. langchain_core/tracers/event_stream.py +65 -64
  125. langchain_core/tracers/langchain.py +21 -21
  126. langchain_core/tracers/log_stream.py +45 -45
  127. langchain_core/tracers/memory_stream.py +3 -3
  128. langchain_core/tracers/root_listeners.py +16 -16
  129. langchain_core/tracers/run_collector.py +2 -4
  130. langchain_core/tracers/schemas.py +0 -129
  131. langchain_core/tracers/stdout.py +3 -3
  132. langchain_core/utils/__init__.py +1 -4
  133. langchain_core/utils/_merge.py +2 -2
  134. langchain_core/utils/aiter.py +57 -61
  135. langchain_core/utils/env.py +9 -9
  136. langchain_core/utils/function_calling.py +94 -188
  137. langchain_core/utils/html.py +7 -8
  138. langchain_core/utils/input.py +9 -6
  139. langchain_core/utils/interactive_env.py +1 -1
  140. langchain_core/utils/iter.py +36 -40
  141. langchain_core/utils/json.py +4 -3
  142. langchain_core/utils/json_schema.py +9 -9
  143. langchain_core/utils/mustache.py +8 -10
  144. langchain_core/utils/pydantic.py +35 -37
  145. langchain_core/utils/strings.py +6 -9
  146. langchain_core/utils/usage.py +1 -1
  147. langchain_core/utils/utils.py +66 -62
  148. langchain_core/vectorstores/base.py +182 -216
  149. langchain_core/vectorstores/in_memory.py +101 -176
  150. langchain_core/vectorstores/utils.py +5 -5
  151. langchain_core/version.py +1 -1
  152. langchain_core-1.0.4.dist-info/METADATA +69 -0
  153. langchain_core-1.0.4.dist-info/RECORD +172 -0
  154. {langchain_core-1.0.0a6.dist-info → langchain_core-1.0.4.dist-info}/WHEEL +1 -1
  155. langchain_core/memory.py +0 -120
  156. langchain_core/messages/block_translators/ollama.py +0 -47
  157. langchain_core/prompts/pipeline.py +0 -138
  158. langchain_core/pydantic_v1/__init__.py +0 -30
  159. langchain_core/pydantic_v1/dataclasses.py +0 -23
  160. langchain_core/pydantic_v1/main.py +0 -23
  161. langchain_core/tracers/langchain_v1.py +0 -31
  162. langchain_core/utils/loading.py +0 -35
  163. langchain_core-1.0.0a6.dist-info/METADATA +0 -67
  164. langchain_core-1.0.0a6.dist-info/RECORD +0 -181
  165. langchain_core-1.0.0a6.dist-info/entry_points.txt +0 -4
@@ -8,16 +8,14 @@ import json
8
8
  import typing
9
9
  import warnings
10
10
  from abc import ABC, abstractmethod
11
+ from collections.abc import Callable
11
12
  from inspect import signature
12
13
  from typing import (
13
14
  TYPE_CHECKING,
14
15
  Annotated,
15
16
  Any,
16
- Callable,
17
17
  Literal,
18
- Optional,
19
18
  TypeVar,
20
- Union,
21
19
  cast,
22
20
  get_args,
23
21
  get_origin,
@@ -31,7 +29,6 @@ from pydantic import (
31
29
  PydanticDeprecationWarning,
32
30
  SkipValidation,
33
31
  ValidationError,
34
- model_validator,
35
32
  validate_arguments,
36
33
  )
37
34
  from pydantic.v1 import BaseModel as BaseModelV1
@@ -39,10 +36,8 @@ from pydantic.v1 import ValidationError as ValidationErrorV1
39
36
  from pydantic.v1 import validate_arguments as validate_arguments_v1
40
37
  from typing_extensions import override
41
38
 
42
- from langchain_core._api import deprecated
43
39
  from langchain_core.callbacks import (
44
40
  AsyncCallbackManager,
45
- BaseCallbackManager,
46
41
  CallbackManager,
47
42
  Callbacks,
48
43
  )
@@ -97,7 +92,7 @@ def _is_annotated_type(typ: type[Any]) -> bool:
97
92
  typ: The type to check.
98
93
 
99
94
  Returns:
100
- True if the type is an Annotated type, False otherwise.
95
+ `True` if the type is an Annotated type, `False` otherwise.
101
96
  """
102
97
  return get_origin(typ) is typing.Annotated
103
98
 
@@ -231,7 +226,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
231
226
  pydantic_version: The Pydantic version to check against ("v1" or "v2").
232
227
 
233
228
  Returns:
234
- True if the annotation is a Pydantic model, False otherwise.
229
+ `True` if the annotation is a Pydantic model, `False` otherwise.
235
230
  """
236
231
  base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
237
232
  try:
@@ -250,7 +245,7 @@ def _function_annotations_are_pydantic_v1(
250
245
  func: The function being checked.
251
246
 
252
247
  Returns:
253
- True if all Pydantic annotations are from V1, False otherwise.
248
+ True if all Pydantic annotations are from V1, `False` otherwise.
254
249
 
255
250
  Raises:
256
251
  NotImplementedError: If the function contains mixed V1 and V2 annotations.
@@ -285,29 +280,28 @@ def create_schema_from_function(
285
280
  model_name: str,
286
281
  func: Callable,
287
282
  *,
288
- filter_args: Optional[Sequence[str]] = None,
283
+ filter_args: Sequence[str] | None = None,
289
284
  parse_docstring: bool = False,
290
285
  error_on_invalid_docstring: bool = False,
291
286
  include_injected: bool = True,
292
287
  ) -> type[BaseModel]:
293
- """Create a pydantic schema from a function's signature.
288
+ """Create a Pydantic schema from a function's signature.
294
289
 
295
290
  Args:
296
- model_name: Name to assign to the generated pydantic schema.
291
+ model_name: Name to assign to the generated Pydantic schema.
297
292
  func: Function to generate the schema from.
298
293
  filter_args: Optional list of arguments to exclude from the schema.
299
- Defaults to FILTERED_ARGS.
294
+ Defaults to `FILTERED_ARGS`.
300
295
  parse_docstring: Whether to parse the function's docstring for descriptions
301
- for each argument. Defaults to False.
302
- error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
303
- whether to raise ValueError on invalid Google Style docstrings.
304
- Defaults to False.
296
+ for each argument.
297
+ error_on_invalid_docstring: if `parse_docstring` is provided, configure
298
+ whether to raise `ValueError` on invalid Google Style docstrings.
305
299
  include_injected: Whether to include injected arguments in the schema.
306
- Defaults to True, since we want to include them in the schema
300
+ Defaults to `True`, since we want to include them in the schema
307
301
  when *validating* tool inputs.
308
302
 
309
303
  Returns:
310
- A pydantic model with the same arguments as the function.
304
+ A Pydantic model with the same arguments as the function.
311
305
  """
312
306
  sig = inspect.signature(func)
313
307
 
@@ -317,7 +311,7 @@ def create_schema_from_function(
317
311
  # https://docs.pydantic.dev/latest/usage/validation_decorator/
318
312
  with warnings.catch_warnings():
319
313
  # We are using deprecated functionality here.
320
- # This code should be re-written to simply construct a pydantic model
314
+ # This code should be re-written to simply construct a Pydantic model
321
315
  # using inspect.signature and create_model.
322
316
  warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
323
317
  validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
@@ -390,13 +384,14 @@ class ToolException(Exception): # noqa: N818
390
384
  """
391
385
 
392
386
 
393
- ArgsSchema = Union[TypeBaseModel, dict[str, Any]]
387
+ ArgsSchema = TypeBaseModel | dict[str, Any]
394
388
 
395
389
 
396
- class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
390
+ class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
397
391
  """Base class for all LangChain tools.
398
392
 
399
393
  This abstract class defines the interface that all LangChain tools must implement.
394
+
400
395
  Tools are components that can be called by agents to perform specific actions.
401
396
  """
402
397
 
@@ -407,7 +402,7 @@ class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
407
402
  **kwargs: Additional keyword arguments passed to the parent class.
408
403
 
409
404
  Raises:
410
- SchemaAnnotationError: If args_schema has incorrect type annotation.
405
+ SchemaAnnotationError: If `args_schema` has incorrect type annotation.
411
406
  """
412
407
  super().__init_subclass__(**kwargs)
413
408
 
@@ -441,22 +436,22 @@ class ChildTool(BaseTool):
441
436
  You can provide few-shot examples as a part of the description.
442
437
  """
443
438
 
444
- args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field(
439
+ args_schema: Annotated[ArgsSchema | None, SkipValidation()] = Field(
445
440
  default=None, description="The tool schema."
446
441
  )
447
442
  """Pydantic model class to validate and parse the tool's input arguments.
448
443
 
449
444
  Args schema should be either:
450
445
 
451
- - A subclass of pydantic.BaseModel.
452
- - A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
453
- - a JSON schema dict
446
+ - A subclass of `pydantic.BaseModel`.
447
+ - A subclass of `pydantic.v1.BaseModel` if accessing v1 namespace in pydantic 2
448
+ - A JSON schema dict
454
449
  """
455
450
  return_direct: bool = False
456
451
  """Whether to return the tool's output directly.
457
452
 
458
- Setting this to True means
459
- that after the tool is called, the AgentExecutor will stop looping.
453
+ Setting this to `True` means that after the tool is called, the `AgentExecutor` will
454
+ stop looping.
460
455
  """
461
456
  verbose: bool = False
462
457
  """Whether to log the tool's progress."""
@@ -464,52 +459,47 @@ class ChildTool(BaseTool):
464
459
  callbacks: Callbacks = Field(default=None, exclude=True)
465
460
  """Callbacks to be called during tool execution."""
466
461
 
467
- callback_manager: Optional[BaseCallbackManager] = deprecated(
468
- name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
469
- )(
470
- Field(
471
- default=None,
472
- exclude=True,
473
- description="Callback manager to add to the run trace.",
474
- )
475
- )
476
- tags: Optional[list[str]] = None
477
- """Optional list of tags associated with the tool. Defaults to None.
462
+ tags: list[str] | None = None
463
+ """Optional list of tags associated with the tool.
464
+
478
465
  These tags will be associated with each call to this tool,
479
466
  and passed as arguments to the handlers defined in `callbacks`.
480
- You can use these to eg identify a specific instance of a tool with its use case.
467
+
468
+ You can use these to, e.g., identify a specific instance of a tool with its use
469
+ case.
481
470
  """
482
- metadata: Optional[dict[str, Any]] = None
483
- """Optional metadata associated with the tool. Defaults to None.
471
+ metadata: dict[str, Any] | None = None
472
+ """Optional metadata associated with the tool.
473
+
484
474
  This metadata will be associated with each call to this tool,
485
475
  and passed as arguments to the handlers defined in `callbacks`.
486
- You can use these to eg identify a specific instance of a tool with its use case.
476
+
477
+ You can use these to, e.g., identify a specific instance of a tool with its use
478
+ case.
487
479
  """
488
480
 
489
- handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = (
490
- False
491
- )
492
- """Handle the content of the ToolException thrown."""
481
+ handle_tool_error: bool | str | Callable[[ToolException], str] | None = False
482
+ """Handle the content of the `ToolException` thrown."""
493
483
 
494
- handle_validation_error: Optional[
495
- Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]]
496
- ] = False
497
- """Handle the content of the ValidationError thrown."""
484
+ handle_validation_error: (
485
+ bool | str | Callable[[ValidationError | ValidationErrorV1], str] | None
486
+ ) = False
487
+ """Handle the content of the `ValidationError` thrown."""
498
488
 
499
489
  response_format: Literal["content", "content_and_artifact"] = "content"
500
- """The tool response format. Defaults to 'content'.
490
+ """The tool response format.
501
491
 
502
- If "content" then the output of the tool is interpreted as the contents of a
503
- ToolMessage. If "content_and_artifact" then the output is expected to be a
504
- two-tuple corresponding to the (content, artifact) of a ToolMessage.
492
+ If `'content'` then the output of the tool is interpreted as the contents of a
493
+ `ToolMessage`. If `'content_and_artifact'` then the output is expected to be a
494
+ two-tuple corresponding to the `(content, artifact)` of a `ToolMessage`.
505
495
  """
506
496
 
507
497
  def __init__(self, **kwargs: Any) -> None:
508
498
  """Initialize the tool.
509
499
 
510
500
  Raises:
511
- TypeError: If ``args_schema`` is not a subclass of pydantic ``BaseModel`` or
512
- dict.
501
+ TypeError: If `args_schema` is not a subclass of pydantic `BaseModel` or
502
+ `dict`.
513
503
  """
514
504
  if (
515
505
  "args_schema" in kwargs
@@ -533,7 +523,7 @@ class ChildTool(BaseTool):
533
523
  """Check if the tool accepts only a single input argument.
534
524
 
535
525
  Returns:
536
- True if the tool has only one input argument, False otherwise.
526
+ `True` if the tool has only one input argument, `False` otherwise.
537
527
  """
538
528
  keys = {k for k in self.args if k != "kwargs"}
539
529
  return len(keys) == 1
@@ -543,7 +533,7 @@ class ChildTool(BaseTool):
543
533
  """Get the tool's input arguments schema.
544
534
 
545
535
  Returns:
546
- Dictionary containing the tool's argument properties.
536
+ `dict` containing the tool's argument properties.
547
537
  """
548
538
  if isinstance(self.args_schema, dict):
549
539
  json_schema = self.args_schema
@@ -582,9 +572,7 @@ class ChildTool(BaseTool):
582
572
  # --- Runnable ---
583
573
 
584
574
  @override
585
- def get_input_schema(
586
- self, config: Optional[RunnableConfig] = None
587
- ) -> type[BaseModel]:
575
+ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
588
576
  """The tool's input schema.
589
577
 
590
578
  Args:
@@ -602,8 +590,8 @@ class ChildTool(BaseTool):
602
590
  @override
603
591
  def invoke(
604
592
  self,
605
- input: Union[str, dict, ToolCall],
606
- config: Optional[RunnableConfig] = None,
593
+ input: str | dict | ToolCall,
594
+ config: RunnableConfig | None = None,
607
595
  **kwargs: Any,
608
596
  ) -> Any:
609
597
  tool_input, kwargs = _prep_run_args(input, config, **kwargs)
@@ -612,8 +600,8 @@ class ChildTool(BaseTool):
612
600
  @override
613
601
  async def ainvoke(
614
602
  self,
615
- input: Union[str, dict, ToolCall],
616
- config: Optional[RunnableConfig] = None,
603
+ input: str | dict | ToolCall,
604
+ config: RunnableConfig | None = None,
617
605
  **kwargs: Any,
618
606
  ) -> Any:
619
607
  tool_input, kwargs = _prep_run_args(input, config, **kwargs)
@@ -622,8 +610,8 @@ class ChildTool(BaseTool):
622
610
  # --- Tool ---
623
611
 
624
612
  def _parse_input(
625
- self, tool_input: Union[str, dict], tool_call_id: Optional[str]
626
- ) -> Union[str, dict[str, Any]]:
613
+ self, tool_input: str | dict, tool_call_id: str | None
614
+ ) -> str | dict[str, Any]:
627
615
  """Parse and validate tool input using the args schema.
628
616
 
629
617
  Args:
@@ -634,10 +622,10 @@ class ChildTool(BaseTool):
634
622
  The parsed and validated input.
635
623
 
636
624
  Raises:
637
- ValueError: If string input is provided with JSON schema ``args_schema``.
638
- ValueError: If InjectedToolCallId is required but ``tool_call_id`` is not
625
+ ValueError: If `string` input is provided with JSON schema `args_schema`.
626
+ ValueError: If `InjectedToolCallId` is required but `tool_call_id` is not
639
627
  provided.
640
- TypeError: If args_schema is not a Pydantic ``BaseModel`` or dict.
628
+ TypeError: If `args_schema` is not a Pydantic `BaseModel` or dict.
641
629
  """
642
630
  input_args = self.args_schema
643
631
  if isinstance(tool_input, str):
@@ -700,32 +688,12 @@ class ChildTool(BaseTool):
700
688
  }
701
689
  return tool_input
702
690
 
703
- @model_validator(mode="before")
704
- @classmethod
705
- def raise_deprecation(cls, values: dict) -> Any:
706
- """Raise deprecation warning if callback_manager is used.
707
-
708
- Args:
709
- values: The values to validate.
710
-
711
- Returns:
712
- The validated values.
713
- """
714
- if values.get("callback_manager") is not None:
715
- warnings.warn(
716
- "callback_manager is deprecated. Please use callbacks instead.",
717
- DeprecationWarning,
718
- stacklevel=6,
719
- )
720
- values["callbacks"] = values.pop("callback_manager", None)
721
- return values
722
-
723
691
  @abstractmethod
724
692
  def _run(self, *args: Any, **kwargs: Any) -> Any:
725
693
  """Use the tool.
726
694
 
727
- Add run_manager: Optional[CallbackManagerForToolRun] = None
728
- to child implementations to enable tracing.
695
+ Add `run_manager: CallbackManagerForToolRun | None = None` to child
696
+ implementations to enable tracing.
729
697
 
730
698
  Returns:
731
699
  The result of the tool execution.
@@ -734,8 +702,8 @@ class ChildTool(BaseTool):
734
702
  async def _arun(self, *args: Any, **kwargs: Any) -> Any:
735
703
  """Use the tool asynchronously.
736
704
 
737
- Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
738
- to child implementations to enable tracing.
705
+ Add `run_manager: AsyncCallbackManagerForToolRun | None = None` to child
706
+ implementations to enable tracing.
739
707
 
740
708
  Returns:
741
709
  The result of the tool execution.
@@ -746,8 +714,37 @@ class ChildTool(BaseTool):
746
714
  kwargs["run_manager"] = kwargs["run_manager"].get_sync()
747
715
  return await run_in_executor(None, self._run, *args, **kwargs)
748
716
 
717
+ def _filter_injected_args(self, tool_input: dict) -> dict:
718
+ """Filter out injected tool arguments from the input dictionary.
719
+
720
+ Injected arguments are those annotated with `InjectedToolArg` or its
721
+ subclasses, or arguments in `FILTERED_ARGS` like `run_manager` and callbacks.
722
+
723
+ Args:
724
+ tool_input: The tool input dictionary to filter.
725
+
726
+ Returns:
727
+ A filtered dictionary with injected arguments removed.
728
+ """
729
+ # Start with filtered args from the constant
730
+ filtered_keys = set[str](FILTERED_ARGS)
731
+
732
+ # If we have an args_schema, use it to identify injected args
733
+ if self.args_schema is not None:
734
+ try:
735
+ annotations = get_all_basemodel_annotations(self.args_schema)
736
+ for field_name, field_type in annotations.items():
737
+ if _is_injected_arg_type(field_type):
738
+ filtered_keys.add(field_name)
739
+ except Exception: # noqa: S110
740
+ # If we can't get annotations, just use FILTERED_ARGS
741
+ pass
742
+
743
+ # Filter out the injected keys from tool_input
744
+ return {k: v for k, v in tool_input.items() if k not in filtered_keys}
745
+
749
746
  def _to_args_and_kwargs(
750
- self, tool_input: Union[str, dict], tool_call_id: Optional[str]
747
+ self, tool_input: str | dict, tool_call_id: str | None
751
748
  ) -> tuple[tuple, dict]:
752
749
  """Convert tool input to positional and keyword arguments.
753
750
 
@@ -756,7 +753,7 @@ class ChildTool(BaseTool):
756
753
  tool_call_id: The ID of the tool call, if available.
757
754
 
758
755
  Returns:
759
- A tuple of (positional_args, keyword_args) for the tool.
756
+ A tuple of `(positional_args, keyword_args)` for the tool.
760
757
 
761
758
  Raises:
762
759
  TypeError: If the tool input type is invalid.
@@ -787,35 +784,35 @@ class ChildTool(BaseTool):
787
784
 
788
785
  def run(
789
786
  self,
790
- tool_input: Union[str, dict[str, Any]],
791
- verbose: Optional[bool] = None, # noqa: FBT001
792
- start_color: Optional[str] = "green",
793
- color: Optional[str] = "green",
787
+ tool_input: str | dict[str, Any],
788
+ verbose: bool | None = None, # noqa: FBT001
789
+ start_color: str | None = "green",
790
+ color: str | None = "green",
794
791
  callbacks: Callbacks = None,
795
792
  *,
796
- tags: Optional[list[str]] = None,
797
- metadata: Optional[dict[str, Any]] = None,
798
- run_name: Optional[str] = None,
799
- run_id: Optional[uuid.UUID] = None,
800
- config: Optional[RunnableConfig] = None,
801
- tool_call_id: Optional[str] = None,
793
+ tags: list[str] | None = None,
794
+ metadata: dict[str, Any] | None = None,
795
+ run_name: str | None = None,
796
+ run_id: uuid.UUID | None = None,
797
+ config: RunnableConfig | None = None,
798
+ tool_call_id: str | None = None,
802
799
  **kwargs: Any,
803
800
  ) -> Any:
804
801
  """Run the tool.
805
802
 
806
803
  Args:
807
804
  tool_input: The input to the tool.
808
- verbose: Whether to log the tool's progress. Defaults to None.
809
- start_color: The color to use when starting the tool. Defaults to 'green'.
810
- color: The color to use when ending the tool. Defaults to 'green'.
811
- callbacks: Callbacks to be called during tool execution. Defaults to None.
812
- tags: Optional list of tags associated with the tool. Defaults to None.
813
- metadata: Optional metadata associated with the tool. Defaults to None.
814
- run_name: The name of the run. Defaults to None.
815
- run_id: The id of the run. Defaults to None.
816
- config: The configuration for the tool. Defaults to None.
817
- tool_call_id: The id of the tool call. Defaults to None.
818
- kwargs: Keyword arguments to be passed to tool callbacks (event handler)
805
+ verbose: Whether to log the tool's progress.
806
+ start_color: The color to use when starting the tool.
807
+ color: The color to use when ending the tool.
808
+ callbacks: Callbacks to be called during tool execution.
809
+ tags: Optional list of tags associated with the tool.
810
+ metadata: Optional metadata associated with the tool.
811
+ run_name: The name of the run.
812
+ run_id: The id of the run.
813
+ config: The configuration for the tool.
814
+ tool_call_id: The id of the tool call.
815
+ **kwargs: Keyword arguments to be passed to tool callbacks (event handler)
819
816
 
820
817
  Returns:
821
818
  The output of the tool.
@@ -833,24 +830,36 @@ class ChildTool(BaseTool):
833
830
  self.metadata,
834
831
  )
835
832
 
833
+ # Filter out injected arguments from callback inputs
834
+ filtered_tool_input = (
835
+ self._filter_injected_args(tool_input)
836
+ if isinstance(tool_input, dict)
837
+ else None
838
+ )
839
+
840
+ # Use filtered inputs for the input_str parameter as well
841
+ tool_input_str = (
842
+ tool_input
843
+ if isinstance(tool_input, str)
844
+ else str(
845
+ filtered_tool_input if filtered_tool_input is not None else tool_input
846
+ )
847
+ )
848
+
836
849
  run_manager = callback_manager.on_tool_start(
837
850
  {"name": self.name, "description": self.description},
838
- tool_input if isinstance(tool_input, str) else str(tool_input),
851
+ tool_input_str,
839
852
  color=start_color,
840
853
  name=run_name,
841
854
  run_id=run_id,
842
- # Inputs by definition should always be dicts.
843
- # For now, it's unclear whether this assumption is ever violated,
844
- # but if it is we will send a `None` value to the callback instead
845
- # TODO: will need to address issue via a patch.
846
- inputs=tool_input if isinstance(tool_input, dict) else None,
855
+ inputs=filtered_tool_input,
847
856
  **kwargs,
848
857
  )
849
858
 
850
859
  content = None
851
860
  artifact = None
852
861
  status = "success"
853
- error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
862
+ error_to_raise: Exception | KeyboardInterrupt | None = None
854
863
  try:
855
864
  child_config = patch_config(config, callbacks=run_manager.get_child())
856
865
  with set_config_context(child_config) as context:
@@ -863,16 +872,19 @@ class ChildTool(BaseTool):
863
872
  tool_kwargs |= {config_param: config}
864
873
  response = context.run(self._run, *tool_args, **tool_kwargs)
865
874
  if self.response_format == "content_and_artifact":
866
- if not isinstance(response, tuple) or len(response) != 2:
867
- msg = (
868
- "Since response_format='content_and_artifact' "
869
- "a two-tuple of the message content and raw tool output is "
870
- f"expected. Instead generated response of type: "
871
- f"{type(response)}."
872
- )
875
+ msg = (
876
+ "Since response_format='content_and_artifact' "
877
+ "a two-tuple of the message content and raw tool output is "
878
+ f"expected. Instead, generated response is of type: "
879
+ f"{type(response)}."
880
+ )
881
+ if not isinstance(response, tuple):
873
882
  error_to_raise = ValueError(msg)
874
883
  else:
875
- content, artifact = response
884
+ try:
885
+ content, artifact = response
886
+ except ValueError:
887
+ error_to_raise = ValueError(msg)
876
888
  else:
877
889
  content = response
878
890
  except (ValidationError, ValidationErrorV1) as e:
@@ -899,35 +911,35 @@ class ChildTool(BaseTool):
899
911
 
900
912
  async def arun(
901
913
  self,
902
- tool_input: Union[str, dict],
903
- verbose: Optional[bool] = None, # noqa: FBT001
904
- start_color: Optional[str] = "green",
905
- color: Optional[str] = "green",
914
+ tool_input: str | dict,
915
+ verbose: bool | None = None, # noqa: FBT001
916
+ start_color: str | None = "green",
917
+ color: str | None = "green",
906
918
  callbacks: Callbacks = None,
907
919
  *,
908
- tags: Optional[list[str]] = None,
909
- metadata: Optional[dict[str, Any]] = None,
910
- run_name: Optional[str] = None,
911
- run_id: Optional[uuid.UUID] = None,
912
- config: Optional[RunnableConfig] = None,
913
- tool_call_id: Optional[str] = None,
920
+ tags: list[str] | None = None,
921
+ metadata: dict[str, Any] | None = None,
922
+ run_name: str | None = None,
923
+ run_id: uuid.UUID | None = None,
924
+ config: RunnableConfig | None = None,
925
+ tool_call_id: str | None = None,
914
926
  **kwargs: Any,
915
927
  ) -> Any:
916
928
  """Run the tool asynchronously.
917
929
 
918
930
  Args:
919
931
  tool_input: The input to the tool.
920
- verbose: Whether to log the tool's progress. Defaults to None.
921
- start_color: The color to use when starting the tool. Defaults to 'green'.
922
- color: The color to use when ending the tool. Defaults to 'green'.
923
- callbacks: Callbacks to be called during tool execution. Defaults to None.
924
- tags: Optional list of tags associated with the tool. Defaults to None.
925
- metadata: Optional metadata associated with the tool. Defaults to None.
926
- run_name: The name of the run. Defaults to None.
927
- run_id: The id of the run. Defaults to None.
928
- config: The configuration for the tool. Defaults to None.
929
- tool_call_id: The id of the tool call. Defaults to None.
930
- kwargs: Keyword arguments to be passed to tool callbacks
932
+ verbose: Whether to log the tool's progress.
933
+ start_color: The color to use when starting the tool.
934
+ color: The color to use when ending the tool.
935
+ callbacks: Callbacks to be called during tool execution.
936
+ tags: Optional list of tags associated with the tool.
937
+ metadata: Optional metadata associated with the tool.
938
+ run_name: The name of the run.
939
+ run_id: The id of the run.
940
+ config: The configuration for the tool.
941
+ tool_call_id: The id of the tool call.
942
+ **kwargs: Keyword arguments to be passed to tool callbacks
931
943
 
932
944
  Returns:
933
945
  The output of the tool.
@@ -944,23 +956,36 @@ class ChildTool(BaseTool):
944
956
  metadata,
945
957
  self.metadata,
946
958
  )
959
+
960
+ # Filter out injected arguments from callback inputs
961
+ filtered_tool_input = (
962
+ self._filter_injected_args(tool_input)
963
+ if isinstance(tool_input, dict)
964
+ else None
965
+ )
966
+
967
+ # Use filtered inputs for the input_str parameter as well
968
+ tool_input_str = (
969
+ tool_input
970
+ if isinstance(tool_input, str)
971
+ else str(
972
+ filtered_tool_input if filtered_tool_input is not None else tool_input
973
+ )
974
+ )
975
+
947
976
  run_manager = await callback_manager.on_tool_start(
948
977
  {"name": self.name, "description": self.description},
949
- tool_input if isinstance(tool_input, str) else str(tool_input),
978
+ tool_input_str,
950
979
  color=start_color,
951
980
  name=run_name,
952
981
  run_id=run_id,
953
- # Inputs by definition should always be dicts.
954
- # For now, it's unclear whether this assumption is ever violated,
955
- # but if it is we will send a `None` value to the callback instead
956
- # TODO: will need to address issue via a patch.
957
- inputs=tool_input if isinstance(tool_input, dict) else None,
982
+ inputs=filtered_tool_input,
958
983
  **kwargs,
959
984
  )
960
985
  content = None
961
986
  artifact = None
962
987
  status = "success"
963
- error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
988
+ error_to_raise: Exception | KeyboardInterrupt | None = None
964
989
  try:
965
990
  tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
966
991
  child_config = patch_config(config, callbacks=run_manager.get_child())
@@ -976,16 +1001,19 @@ class ChildTool(BaseTool):
976
1001
  coro = self._arun(*tool_args, **tool_kwargs)
977
1002
  response = await coro_with_context(coro, context)
978
1003
  if self.response_format == "content_and_artifact":
979
- if not isinstance(response, tuple) or len(response) != 2:
980
- msg = (
981
- "Since response_format='content_and_artifact' "
982
- "a two-tuple of the message content and raw tool output is "
983
- f"expected. Instead generated response of type: "
984
- f"{type(response)}."
985
- )
1004
+ msg = (
1005
+ "Since response_format='content_and_artifact' "
1006
+ "a two-tuple of the message content and raw tool output is "
1007
+ f"expected. Instead, generated response is of type: "
1008
+ f"{type(response)}."
1009
+ )
1010
+ if not isinstance(response, tuple):
986
1011
  error_to_raise = ValueError(msg)
987
1012
  else:
988
- content, artifact = response
1013
+ try:
1014
+ content, artifact = response
1015
+ except ValueError:
1016
+ error_to_raise = ValueError(msg)
989
1017
  else:
990
1018
  content = response
991
1019
  except ValidationError as e:
@@ -1011,19 +1039,6 @@ class ChildTool(BaseTool):
1011
1039
  await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
1012
1040
  return output
1013
1041
 
1014
- @deprecated("0.1.47", alternative="invoke", removal="1.0")
1015
- def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
1016
- """Make tool callable (deprecated).
1017
-
1018
- Args:
1019
- tool_input: The input to the tool.
1020
- callbacks: Callbacks to use during execution.
1021
-
1022
- Returns:
1023
- The tool's output.
1024
- """
1025
- return self.run(tool_input, callbacks=callbacks)
1026
-
1027
1042
 
1028
1043
  def _is_tool_call(x: Any) -> bool:
1029
1044
  """Check if the input is a tool call dictionary.
@@ -1032,23 +1047,21 @@ def _is_tool_call(x: Any) -> bool:
1032
1047
  x: The input to check.
1033
1048
 
1034
1049
  Returns:
1035
- True if the input is a tool call, False otherwise.
1050
+ `True` if the input is a tool call, `False` otherwise.
1036
1051
  """
1037
1052
  return isinstance(x, dict) and x.get("type") == "tool_call"
1038
1053
 
1039
1054
 
1040
1055
  def _handle_validation_error(
1041
- e: Union[ValidationError, ValidationErrorV1],
1056
+ e: ValidationError | ValidationErrorV1,
1042
1057
  *,
1043
- flag: Union[
1044
- Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
1045
- ],
1058
+ flag: Literal[True] | str | Callable[[ValidationError | ValidationErrorV1], str],
1046
1059
  ) -> str:
1047
1060
  """Handle validation errors based on the configured flag.
1048
1061
 
1049
1062
  Args:
1050
1063
  e: The validation error that occurred.
1051
- flag: How to handle the error (bool, string, or callable).
1064
+ flag: How to handle the error (`bool`, `str`, or `Callable`).
1052
1065
 
1053
1066
  Returns:
1054
1067
  The error message to return.
@@ -1074,13 +1087,13 @@ def _handle_validation_error(
1074
1087
  def _handle_tool_error(
1075
1088
  e: ToolException,
1076
1089
  *,
1077
- flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
1090
+ flag: Literal[True] | str | Callable[[ToolException], str] | None,
1078
1091
  ) -> str:
1079
1092
  """Handle tool execution errors based on the configured flag.
1080
1093
 
1081
1094
  Args:
1082
1095
  e: The tool exception that occurred.
1083
- flag: How to handle the error (bool, string, or callable).
1096
+ flag: How to handle the error (`bool`, `str`, or `Callable`).
1084
1097
 
1085
1098
  Returns:
1086
1099
  The error message to return.
@@ -1104,27 +1117,27 @@ def _handle_tool_error(
1104
1117
 
1105
1118
 
1106
1119
  def _prep_run_args(
1107
- value: Union[str, dict, ToolCall],
1108
- config: Optional[RunnableConfig],
1120
+ value: str | dict | ToolCall,
1121
+ config: RunnableConfig | None,
1109
1122
  **kwargs: Any,
1110
- ) -> tuple[Union[str, dict], dict]:
1123
+ ) -> tuple[str | dict, dict]:
1111
1124
  """Prepare arguments for tool execution.
1112
1125
 
1113
1126
  Args:
1114
- value: The input value (string, dict, or ToolCall).
1127
+ value: The input value (`str`, `dict`, or `ToolCall`).
1115
1128
  config: The runnable configuration.
1116
1129
  **kwargs: Additional keyword arguments.
1117
1130
 
1118
1131
  Returns:
1119
- A tuple of (tool_input, run_kwargs).
1132
+ A tuple of `(tool_input, run_kwargs)`.
1120
1133
  """
1121
1134
  config = ensure_config(config)
1122
1135
  if _is_tool_call(value):
1123
- tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
1124
- tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy()
1136
+ tool_call_id: str | None = cast("ToolCall", value)["id"]
1137
+ tool_input: str | dict = cast("ToolCall", value)["args"].copy()
1125
1138
  else:
1126
1139
  tool_call_id = None
1127
- tool_input = cast("Union[str, dict]", value)
1140
+ tool_input = cast("str | dict", value)
1128
1141
  return (
1129
1142
  tool_input,
1130
1143
  dict(
@@ -1143,11 +1156,11 @@ def _prep_run_args(
1143
1156
  def _format_output(
1144
1157
  content: Any,
1145
1158
  artifact: Any,
1146
- tool_call_id: Optional[str],
1159
+ tool_call_id: str | None,
1147
1160
  name: str,
1148
1161
  status: str,
1149
- ) -> Union[ToolOutputMixin, Any]:
1150
- """Format tool output as a ToolMessage if appropriate.
1162
+ ) -> ToolOutputMixin | Any:
1163
+ """Format tool output as a `ToolMessage` if appropriate.
1151
1164
 
1152
1165
  Args:
1153
1166
  content: The main content of the tool output.
@@ -1157,7 +1170,7 @@ def _format_output(
1157
1170
  status: The execution status.
1158
1171
 
1159
1172
  Returns:
1160
- The formatted output, either as a ToolMessage or the original content.
1173
+ The formatted output, either as a `ToolMessage` or the original content.
1161
1174
  """
1162
1175
  if isinstance(content, ToolOutputMixin) or tool_call_id is None:
1163
1176
  return content
@@ -1181,7 +1194,7 @@ def _is_message_content_type(obj: Any) -> bool:
1181
1194
  obj: The object to check.
1182
1195
 
1183
1196
  Returns:
1184
- True if the object is valid message content, False otherwise.
1197
+ `True` if the object is valid message content, `False` otherwise.
1185
1198
  """
1186
1199
  return isinstance(obj, str) or (
1187
1200
  isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
@@ -1197,7 +1210,7 @@ def _is_message_content_block(obj: Any) -> bool:
1197
1210
  obj: The object to check.
1198
1211
 
1199
1212
  Returns:
1200
- True if the object is a valid content block, False otherwise.
1213
+ `True` if the object is a valid content block, `False` otherwise.
1201
1214
  """
1202
1215
  if isinstance(obj, str):
1203
1216
  return True
@@ -1221,14 +1234,14 @@ def _stringify(content: Any) -> str:
1221
1234
  return str(content)
1222
1235
 
1223
1236
 
1224
- def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
1237
+ def _get_type_hints(func: Callable) -> dict[str, type] | None:
1225
1238
  """Get type hints from a function, handling partial functions.
1226
1239
 
1227
1240
  Args:
1228
1241
  func: The function to get type hints from.
1229
1242
 
1230
1243
  Returns:
1231
- Dictionary of type hints, or None if extraction fails.
1244
+ `dict` of type hints, or `None` if extraction fails.
1232
1245
  """
1233
1246
  if isinstance(func, functools.partial):
1234
1247
  func = func.func
@@ -1238,14 +1251,14 @@ def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
1238
1251
  return None
1239
1252
 
1240
1253
 
1241
- def _get_runnable_config_param(func: Callable) -> Optional[str]:
1242
- """Find the parameter name for RunnableConfig in a function.
1254
+ def _get_runnable_config_param(func: Callable) -> str | None:
1255
+ """Find the parameter name for `RunnableConfig` in a function.
1243
1256
 
1244
1257
  Args:
1245
1258
  func: The function to check.
1246
1259
 
1247
1260
  Returns:
1248
- The parameter name for RunnableConfig, or None if not found.
1261
+ The parameter name for `RunnableConfig`, or `None` if not found.
1249
1262
  """
1250
1263
  type_hints = _get_type_hints(func)
1251
1264
  if not type_hints:
@@ -1264,35 +1277,75 @@ class InjectedToolArg:
1264
1277
  """
1265
1278
 
1266
1279
 
1280
+ class _DirectlyInjectedToolArg:
1281
+ """Annotation for tool arguments that are injected at runtime.
1282
+
1283
+ Injected via direct type annotation, rather than annotated metadata.
1284
+
1285
+ For example, `ToolRuntime` is a directly injected argument.
1286
+
1287
+ Note the direct annotation rather than the verbose alternative:
1288
+ `Annotated[ToolRuntime, InjectedRuntime]`
1289
+
1290
+ ```python
1291
+ from langchain_core.tools import tool, ToolRuntime
1292
+
1293
+
1294
+ @tool
1295
+ def foo(x: int, runtime: ToolRuntime) -> str:
1296
+ # use runtime.state, runtime.context, runtime.store, etc.
1297
+ ...
1298
+ ```
1299
+ """
1300
+
1301
+
1267
1302
  class InjectedToolCallId(InjectedToolArg):
1268
1303
  """Annotation for injecting the tool call ID.
1269
1304
 
1270
1305
  This annotation is used to mark a tool parameter that should receive
1271
1306
  the tool call ID at runtime.
1272
1307
 
1273
- .. code-block:: python
1274
-
1275
- from typing import Annotated
1276
- from langchain_core.messages import ToolMessage
1277
- from langchain_core.tools import tool, InjectedToolCallId
1278
-
1279
- @tool
1280
- def foo(
1281
- x: int, tool_call_id: Annotated[str, InjectedToolCallId]
1282
- ) -> ToolMessage:
1283
- \"\"\"Return x.\"\"\"
1284
- return ToolMessage(
1285
- str(x),
1286
- artifact=x,
1287
- name="foo",
1288
- tool_call_id=tool_call_id
1289
- )
1308
+ ```python
1309
+ from typing import Annotated
1310
+ from langchain_core.messages import ToolMessage
1311
+ from langchain_core.tools import tool, InjectedToolCallId
1312
+
1313
+ @tool
1314
+ def foo(
1315
+ x: int, tool_call_id: Annotated[str, InjectedToolCallId]
1316
+ ) -> ToolMessage:
1317
+ \"\"\"Return x.\"\"\"
1318
+ return ToolMessage(
1319
+ str(x),
1320
+ artifact=x,
1321
+ name="foo",
1322
+ tool_call_id=tool_call_id
1323
+ )
1324
+
1325
+ ```
1326
+ """
1327
+
1290
1328
 
1329
+ def _is_directly_injected_arg_type(type_: Any) -> bool:
1330
+ """Check if a type annotation indicates a directly injected argument.
1331
+
1332
+ This is currently only used for `ToolRuntime`.
1333
+ Checks if either the annotation itself is a subclass of `_DirectlyInjectedToolArg`
1334
+ or the origin of the annotation is a subclass of `_DirectlyInjectedToolArg`.
1335
+
1336
+ Ex: `ToolRuntime` or `ToolRuntime[ContextT, StateT]` would both return `True`.
1291
1337
  """
1338
+ return (
1339
+ isinstance(type_, type) and issubclass(type_, _DirectlyInjectedToolArg)
1340
+ ) or (
1341
+ (origin := get_origin(type_)) is not None
1342
+ and isinstance(origin, type)
1343
+ and issubclass(origin, _DirectlyInjectedToolArg)
1344
+ )
1292
1345
 
1293
1346
 
1294
1347
  def _is_injected_arg_type(
1295
- type_: Union[type, TypeVar], injected_type: Optional[type[InjectedToolArg]] = None
1348
+ type_: type | TypeVar, injected_type: type[InjectedToolArg] | None = None
1296
1349
  ) -> bool:
1297
1350
  """Check if a type annotation indicates an injected argument.
1298
1351
 
@@ -1301,9 +1354,17 @@ def _is_injected_arg_type(
1301
1354
  injected_type: The specific injected type to check for.
1302
1355
 
1303
1356
  Returns:
1304
- True if the type is an injected argument, False otherwise.
1357
+ `True` if the type is an injected argument, `False` otherwise.
1305
1358
  """
1306
- injected_type = injected_type or InjectedToolArg
1359
+ if injected_type is None:
1360
+ # if no injected type is specified,
1361
+ # check if the type is a directly injected argument
1362
+ if _is_directly_injected_arg_type(type_):
1363
+ return True
1364
+ injected_type = InjectedToolArg
1365
+
1366
+ # if the type is an Annotated type, check if annotated metadata
1367
+ # is an intance or subclass of the injected type
1307
1368
  return any(
1308
1369
  isinstance(arg, injected_type)
1309
1370
  or (isinstance(arg, type) and issubclass(arg, injected_type))
@@ -1312,23 +1373,23 @@ def _is_injected_arg_type(
1312
1373
 
1313
1374
 
1314
1375
  def get_all_basemodel_annotations(
1315
- cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
1316
- ) -> dict[str, Union[type, TypeVar]]:
1317
- """Get all annotations from a Pydantic BaseModel and its parents.
1376
+ cls: TypeBaseModel | Any, *, default_to_bound: bool = True
1377
+ ) -> dict[str, type | TypeVar]:
1378
+ """Get all annotations from a Pydantic `BaseModel` and its parents.
1318
1379
 
1319
1380
  Args:
1320
- cls: The Pydantic BaseModel class.
1321
- default_to_bound: Whether to default to the bound of a TypeVar if it exists.
1381
+ cls: The Pydantic `BaseModel` class.
1382
+ default_to_bound: Whether to default to the bound of a `TypeVar` if it exists.
1322
1383
 
1323
1384
  Returns:
1324
- A dictionary of field names to their type annotations.
1385
+ `dict` of field names to their type annotations.
1325
1386
  """
1326
1387
  # cls has no subscript: cls = FooBar
1327
1388
  if isinstance(cls, type):
1328
1389
  fields = get_fields(cls)
1329
1390
  alias_map = {field.alias: name for name, field in fields.items() if field.alias}
1330
1391
 
1331
- annotations: dict[str, Union[type, TypeVar]] = {}
1392
+ annotations: dict[str, type | TypeVar] = {}
1332
1393
  for name, param in inspect.signature(cls).parameters.items():
1333
1394
  # Exclude hidden init args added by pydantic Config. For example if
1334
1395
  # BaseModel(extra="allow") then "extra_data" will part of init sig.
@@ -1369,7 +1430,7 @@ def get_all_basemodel_annotations(
1369
1430
  # generic_type_vars = (type vars in Baz)
1370
1431
  # generic_map = {type var in Baz: str}
1371
1432
  generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
1372
- generic_map = dict(zip(generic_type_vars, get_args(parent)))
1433
+ generic_map = dict(zip(generic_type_vars, get_args(parent), strict=False))
1373
1434
  for field in getattr(parent_origin, "__annotations__", {}):
1374
1435
  annotations[field] = _replace_type_vars(
1375
1436
  annotations[field], generic_map, default_to_bound=default_to_bound
@@ -1382,20 +1443,20 @@ def get_all_basemodel_annotations(
1382
1443
 
1383
1444
 
1384
1445
  def _replace_type_vars(
1385
- type_: Union[type, TypeVar],
1386
- generic_map: Optional[dict[TypeVar, type]] = None,
1446
+ type_: type | TypeVar,
1447
+ generic_map: dict[TypeVar, type] | None = None,
1387
1448
  *,
1388
1449
  default_to_bound: bool = True,
1389
- ) -> Union[type, TypeVar]:
1390
- """Replace TypeVars in a type annotation with concrete types.
1450
+ ) -> type | TypeVar:
1451
+ """Replace `TypeVar`s in a type annotation with concrete types.
1391
1452
 
1392
1453
  Args:
1393
1454
  type_: The type annotation to process.
1394
- generic_map: Mapping of TypeVars to concrete types.
1395
- default_to_bound: Whether to use TypeVar bounds as defaults.
1455
+ generic_map: Mapping of `TypeVar`s to concrete types.
1456
+ default_to_bound: Whether to use `TypeVar` bounds as defaults.
1396
1457
 
1397
1458
  Returns:
1398
- The type with TypeVars replaced.
1459
+ The type with `TypeVar`s replaced.
1399
1460
  """
1400
1461
  generic_map = generic_map or {}
1401
1462
  if isinstance(type_, TypeVar):