langchain-core 0.4.0.dev0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-core might be problematic. Click here for more details.
- langchain_core/__init__.py +1 -1
- langchain_core/_api/__init__.py +3 -4
- langchain_core/_api/beta_decorator.py +45 -70
- langchain_core/_api/deprecation.py +80 -80
- langchain_core/_api/path.py +22 -8
- langchain_core/_import_utils.py +10 -4
- langchain_core/agents.py +25 -21
- langchain_core/caches.py +53 -63
- langchain_core/callbacks/__init__.py +1 -8
- langchain_core/callbacks/base.py +341 -348
- langchain_core/callbacks/file.py +55 -44
- langchain_core/callbacks/manager.py +546 -683
- langchain_core/callbacks/stdout.py +29 -30
- langchain_core/callbacks/streaming_stdout.py +35 -36
- langchain_core/callbacks/usage.py +65 -70
- langchain_core/chat_history.py +48 -55
- langchain_core/document_loaders/base.py +46 -21
- langchain_core/document_loaders/langsmith.py +39 -36
- langchain_core/documents/__init__.py +0 -1
- langchain_core/documents/base.py +96 -74
- langchain_core/documents/compressor.py +12 -9
- langchain_core/documents/transformers.py +29 -28
- langchain_core/embeddings/fake.py +56 -57
- langchain_core/env.py +2 -3
- langchain_core/example_selectors/base.py +12 -0
- langchain_core/example_selectors/length_based.py +1 -1
- langchain_core/example_selectors/semantic_similarity.py +21 -25
- langchain_core/exceptions.py +15 -9
- langchain_core/globals.py +4 -163
- langchain_core/indexing/api.py +132 -125
- langchain_core/indexing/base.py +64 -67
- langchain_core/indexing/in_memory.py +26 -6
- langchain_core/language_models/__init__.py +15 -27
- langchain_core/language_models/_utils.py +267 -117
- langchain_core/language_models/base.py +92 -177
- langchain_core/language_models/chat_models.py +547 -407
- langchain_core/language_models/fake.py +11 -11
- langchain_core/language_models/fake_chat_models.py +72 -118
- langchain_core/language_models/llms.py +168 -242
- langchain_core/load/dump.py +8 -11
- langchain_core/load/load.py +32 -28
- langchain_core/load/mapping.py +2 -4
- langchain_core/load/serializable.py +50 -56
- langchain_core/messages/__init__.py +36 -51
- langchain_core/messages/ai.py +377 -150
- langchain_core/messages/base.py +239 -47
- langchain_core/messages/block_translators/__init__.py +111 -0
- langchain_core/messages/block_translators/anthropic.py +470 -0
- langchain_core/messages/block_translators/bedrock.py +94 -0
- langchain_core/messages/block_translators/bedrock_converse.py +297 -0
- langchain_core/messages/block_translators/google_genai.py +530 -0
- langchain_core/messages/block_translators/google_vertexai.py +21 -0
- langchain_core/messages/block_translators/groq.py +143 -0
- langchain_core/messages/block_translators/langchain_v0.py +301 -0
- langchain_core/messages/block_translators/openai.py +1010 -0
- langchain_core/messages/chat.py +2 -3
- langchain_core/messages/content.py +1423 -0
- langchain_core/messages/function.py +7 -7
- langchain_core/messages/human.py +44 -38
- langchain_core/messages/modifier.py +3 -2
- langchain_core/messages/system.py +40 -27
- langchain_core/messages/tool.py +160 -58
- langchain_core/messages/utils.py +527 -638
- langchain_core/output_parsers/__init__.py +1 -14
- langchain_core/output_parsers/base.py +68 -104
- langchain_core/output_parsers/json.py +13 -17
- langchain_core/output_parsers/list.py +11 -33
- langchain_core/output_parsers/openai_functions.py +56 -74
- langchain_core/output_parsers/openai_tools.py +68 -109
- langchain_core/output_parsers/pydantic.py +15 -13
- langchain_core/output_parsers/string.py +6 -2
- langchain_core/output_parsers/transform.py +17 -60
- langchain_core/output_parsers/xml.py +34 -44
- langchain_core/outputs/__init__.py +1 -1
- langchain_core/outputs/chat_generation.py +26 -11
- langchain_core/outputs/chat_result.py +1 -3
- langchain_core/outputs/generation.py +17 -6
- langchain_core/outputs/llm_result.py +15 -8
- langchain_core/prompt_values.py +29 -123
- langchain_core/prompts/__init__.py +3 -27
- langchain_core/prompts/base.py +48 -63
- langchain_core/prompts/chat.py +259 -288
- langchain_core/prompts/dict.py +19 -11
- langchain_core/prompts/few_shot.py +84 -90
- langchain_core/prompts/few_shot_with_templates.py +14 -12
- langchain_core/prompts/image.py +19 -14
- langchain_core/prompts/loading.py +6 -8
- langchain_core/prompts/message.py +7 -8
- langchain_core/prompts/prompt.py +42 -43
- langchain_core/prompts/string.py +37 -16
- langchain_core/prompts/structured.py +43 -46
- langchain_core/rate_limiters.py +51 -60
- langchain_core/retrievers.py +52 -192
- langchain_core/runnables/base.py +1727 -1683
- langchain_core/runnables/branch.py +52 -73
- langchain_core/runnables/config.py +89 -103
- langchain_core/runnables/configurable.py +128 -130
- langchain_core/runnables/fallbacks.py +93 -82
- langchain_core/runnables/graph.py +127 -127
- langchain_core/runnables/graph_ascii.py +63 -41
- langchain_core/runnables/graph_mermaid.py +87 -70
- langchain_core/runnables/graph_png.py +31 -36
- langchain_core/runnables/history.py +145 -161
- langchain_core/runnables/passthrough.py +141 -144
- langchain_core/runnables/retry.py +84 -68
- langchain_core/runnables/router.py +33 -37
- langchain_core/runnables/schema.py +79 -72
- langchain_core/runnables/utils.py +95 -139
- langchain_core/stores.py +85 -131
- langchain_core/structured_query.py +11 -15
- langchain_core/sys_info.py +31 -32
- langchain_core/tools/__init__.py +1 -14
- langchain_core/tools/base.py +221 -247
- langchain_core/tools/convert.py +144 -161
- langchain_core/tools/render.py +10 -10
- langchain_core/tools/retriever.py +12 -19
- langchain_core/tools/simple.py +52 -29
- langchain_core/tools/structured.py +56 -60
- langchain_core/tracers/__init__.py +1 -9
- langchain_core/tracers/_streaming.py +6 -7
- langchain_core/tracers/base.py +103 -112
- langchain_core/tracers/context.py +29 -48
- langchain_core/tracers/core.py +142 -105
- langchain_core/tracers/evaluation.py +30 -34
- langchain_core/tracers/event_stream.py +162 -117
- langchain_core/tracers/langchain.py +34 -36
- langchain_core/tracers/log_stream.py +87 -49
- langchain_core/tracers/memory_stream.py +3 -3
- langchain_core/tracers/root_listeners.py +18 -34
- langchain_core/tracers/run_collector.py +8 -20
- langchain_core/tracers/schemas.py +0 -125
- langchain_core/tracers/stdout.py +3 -3
- langchain_core/utils/__init__.py +1 -4
- langchain_core/utils/_merge.py +47 -9
- langchain_core/utils/aiter.py +70 -66
- langchain_core/utils/env.py +12 -9
- langchain_core/utils/function_calling.py +139 -206
- langchain_core/utils/html.py +7 -8
- langchain_core/utils/input.py +6 -6
- langchain_core/utils/interactive_env.py +6 -2
- langchain_core/utils/iter.py +48 -45
- langchain_core/utils/json.py +14 -4
- langchain_core/utils/json_schema.py +159 -43
- langchain_core/utils/mustache.py +32 -25
- langchain_core/utils/pydantic.py +67 -40
- langchain_core/utils/strings.py +5 -5
- langchain_core/utils/usage.py +1 -1
- langchain_core/utils/utils.py +104 -62
- langchain_core/vectorstores/base.py +131 -179
- langchain_core/vectorstores/in_memory.py +113 -182
- langchain_core/vectorstores/utils.py +23 -17
- langchain_core/version.py +1 -1
- langchain_core-1.0.0.dist-info/METADATA +68 -0
- langchain_core-1.0.0.dist-info/RECORD +172 -0
- {langchain_core-0.4.0.dev0.dist-info → langchain_core-1.0.0.dist-info}/WHEEL +1 -1
- langchain_core/beta/__init__.py +0 -1
- langchain_core/beta/runnables/__init__.py +0 -1
- langchain_core/beta/runnables/context.py +0 -448
- langchain_core/memory.py +0 -116
- langchain_core/messages/content_blocks.py +0 -1435
- langchain_core/prompts/pipeline.py +0 -133
- langchain_core/pydantic_v1/__init__.py +0 -30
- langchain_core/pydantic_v1/dataclasses.py +0 -23
- langchain_core/pydantic_v1/main.py +0 -23
- langchain_core/tracers/langchain_v1.py +0 -23
- langchain_core/utils/loading.py +0 -31
- langchain_core/v1/__init__.py +0 -1
- langchain_core/v1/chat_models.py +0 -1047
- langchain_core/v1/messages.py +0 -755
- langchain_core-0.4.0.dev0.dist-info/METADATA +0 -108
- langchain_core-0.4.0.dev0.dist-info/RECORD +0 -177
- langchain_core-0.4.0.dev0.dist-info/entry_points.txt +0 -4
langchain_core/tools/base.py
CHANGED
|
@@ -8,16 +8,14 @@ import json
|
|
|
8
8
|
import typing
|
|
9
9
|
import warnings
|
|
10
10
|
from abc import ABC, abstractmethod
|
|
11
|
+
from collections.abc import Callable
|
|
11
12
|
from inspect import signature
|
|
12
13
|
from typing import (
|
|
13
14
|
TYPE_CHECKING,
|
|
14
15
|
Annotated,
|
|
15
16
|
Any,
|
|
16
|
-
Callable,
|
|
17
17
|
Literal,
|
|
18
|
-
Optional,
|
|
19
18
|
TypeVar,
|
|
20
|
-
Union,
|
|
21
19
|
cast,
|
|
22
20
|
get_args,
|
|
23
21
|
get_origin,
|
|
@@ -31,7 +29,6 @@ from pydantic import (
|
|
|
31
29
|
PydanticDeprecationWarning,
|
|
32
30
|
SkipValidation,
|
|
33
31
|
ValidationError,
|
|
34
|
-
model_validator,
|
|
35
32
|
validate_arguments,
|
|
36
33
|
)
|
|
37
34
|
from pydantic.v1 import BaseModel as BaseModelV1
|
|
@@ -39,10 +36,8 @@ from pydantic.v1 import ValidationError as ValidationErrorV1
|
|
|
39
36
|
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
|
40
37
|
from typing_extensions import override
|
|
41
38
|
|
|
42
|
-
from langchain_core._api import deprecated
|
|
43
39
|
from langchain_core.callbacks import (
|
|
44
40
|
AsyncCallbackManager,
|
|
45
|
-
BaseCallbackManager,
|
|
46
41
|
CallbackManager,
|
|
47
42
|
Callbacks,
|
|
48
43
|
)
|
|
@@ -68,14 +63,22 @@ from langchain_core.utils.pydantic import (
|
|
|
68
63
|
is_pydantic_v1_subclass,
|
|
69
64
|
is_pydantic_v2_subclass,
|
|
70
65
|
)
|
|
71
|
-
from langchain_core.v1.messages import ToolMessage as ToolMessageV1
|
|
72
66
|
|
|
73
67
|
if TYPE_CHECKING:
|
|
74
68
|
import uuid
|
|
75
69
|
from collections.abc import Sequence
|
|
76
70
|
|
|
77
71
|
FILTERED_ARGS = ("run_manager", "callbacks")
|
|
78
|
-
TOOL_MESSAGE_BLOCK_TYPES = (
|
|
72
|
+
TOOL_MESSAGE_BLOCK_TYPES = (
|
|
73
|
+
"text",
|
|
74
|
+
"image_url",
|
|
75
|
+
"image",
|
|
76
|
+
"json",
|
|
77
|
+
"search_result",
|
|
78
|
+
"custom_tool_call_output",
|
|
79
|
+
"document",
|
|
80
|
+
"file",
|
|
81
|
+
)
|
|
79
82
|
|
|
80
83
|
|
|
81
84
|
class SchemaAnnotationError(TypeError):
|
|
@@ -89,7 +92,7 @@ def _is_annotated_type(typ: type[Any]) -> bool:
|
|
|
89
92
|
typ: The type to check.
|
|
90
93
|
|
|
91
94
|
Returns:
|
|
92
|
-
True if the type is an Annotated type, False otherwise.
|
|
95
|
+
`True` if the type is an Annotated type, `False` otherwise.
|
|
93
96
|
"""
|
|
94
97
|
return get_origin(typ) is typing.Annotated
|
|
95
98
|
|
|
@@ -223,7 +226,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
|
|
|
223
226
|
pydantic_version: The Pydantic version to check against ("v1" or "v2").
|
|
224
227
|
|
|
225
228
|
Returns:
|
|
226
|
-
True if the annotation is a Pydantic model, False otherwise.
|
|
229
|
+
`True` if the annotation is a Pydantic model, `False` otherwise.
|
|
227
230
|
"""
|
|
228
231
|
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
|
|
229
232
|
try:
|
|
@@ -242,7 +245,7 @@ def _function_annotations_are_pydantic_v1(
|
|
|
242
245
|
func: The function being checked.
|
|
243
246
|
|
|
244
247
|
Returns:
|
|
245
|
-
True if all Pydantic annotations are from V1, False otherwise.
|
|
248
|
+
True if all Pydantic annotations are from V1, `False` otherwise.
|
|
246
249
|
|
|
247
250
|
Raises:
|
|
248
251
|
NotImplementedError: If the function contains mixed V1 and V2 annotations.
|
|
@@ -265,44 +268,40 @@ def _function_annotations_are_pydantic_v1(
|
|
|
265
268
|
|
|
266
269
|
|
|
267
270
|
class _SchemaConfig:
|
|
268
|
-
"""Configuration for Pydantic models generated from function signatures.
|
|
269
|
-
|
|
270
|
-
Attributes:
|
|
271
|
-
extra: Whether to allow extra fields in the model.
|
|
272
|
-
arbitrary_types_allowed: Whether to allow arbitrary types in the model.
|
|
273
|
-
"""
|
|
271
|
+
"""Configuration for Pydantic models generated from function signatures."""
|
|
274
272
|
|
|
275
273
|
extra: str = "forbid"
|
|
274
|
+
"""Whether to allow extra fields in the model."""
|
|
276
275
|
arbitrary_types_allowed: bool = True
|
|
276
|
+
"""Whether to allow arbitrary types in the model."""
|
|
277
277
|
|
|
278
278
|
|
|
279
279
|
def create_schema_from_function(
|
|
280
280
|
model_name: str,
|
|
281
281
|
func: Callable,
|
|
282
282
|
*,
|
|
283
|
-
filter_args:
|
|
283
|
+
filter_args: Sequence[str] | None = None,
|
|
284
284
|
parse_docstring: bool = False,
|
|
285
285
|
error_on_invalid_docstring: bool = False,
|
|
286
286
|
include_injected: bool = True,
|
|
287
287
|
) -> type[BaseModel]:
|
|
288
|
-
"""Create a
|
|
288
|
+
"""Create a Pydantic schema from a function's signature.
|
|
289
289
|
|
|
290
290
|
Args:
|
|
291
|
-
model_name: Name to assign to the generated
|
|
291
|
+
model_name: Name to assign to the generated Pydantic schema.
|
|
292
292
|
func: Function to generate the schema from.
|
|
293
293
|
filter_args: Optional list of arguments to exclude from the schema.
|
|
294
|
-
Defaults to FILTERED_ARGS
|
|
294
|
+
Defaults to `FILTERED_ARGS`.
|
|
295
295
|
parse_docstring: Whether to parse the function's docstring for descriptions
|
|
296
|
-
for each argument.
|
|
297
|
-
error_on_invalid_docstring: if
|
|
298
|
-
whether to raise ValueError on invalid Google Style docstrings.
|
|
299
|
-
Defaults to False.
|
|
296
|
+
for each argument.
|
|
297
|
+
error_on_invalid_docstring: if `parse_docstring` is provided, configure
|
|
298
|
+
whether to raise `ValueError` on invalid Google Style docstrings.
|
|
300
299
|
include_injected: Whether to include injected arguments in the schema.
|
|
301
|
-
Defaults to True
|
|
300
|
+
Defaults to `True`, since we want to include them in the schema
|
|
302
301
|
when *validating* tool inputs.
|
|
303
302
|
|
|
304
303
|
Returns:
|
|
305
|
-
A
|
|
304
|
+
A Pydantic model with the same arguments as the function.
|
|
306
305
|
"""
|
|
307
306
|
sig = inspect.signature(func)
|
|
308
307
|
|
|
@@ -312,7 +311,7 @@ def create_schema_from_function(
|
|
|
312
311
|
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
|
313
312
|
with warnings.catch_warnings():
|
|
314
313
|
# We are using deprecated functionality here.
|
|
315
|
-
# This code should be re-written to simply construct a
|
|
314
|
+
# This code should be re-written to simply construct a Pydantic model
|
|
316
315
|
# using inspect.signature and create_model.
|
|
317
316
|
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
|
318
317
|
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
|
|
@@ -385,10 +384,10 @@ class ToolException(Exception): # noqa: N818
|
|
|
385
384
|
"""
|
|
386
385
|
|
|
387
386
|
|
|
388
|
-
ArgsSchema =
|
|
387
|
+
ArgsSchema = TypeBaseModel | dict[str, Any]
|
|
389
388
|
|
|
390
389
|
|
|
391
|
-
class BaseTool(RunnableSerializable[
|
|
390
|
+
class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
|
|
392
391
|
"""Base class for all LangChain tools.
|
|
393
392
|
|
|
394
393
|
This abstract class defines the interface that all LangChain tools must implement.
|
|
@@ -436,7 +435,7 @@ class ChildTool(BaseTool):
|
|
|
436
435
|
You can provide few-shot examples as a part of the description.
|
|
437
436
|
"""
|
|
438
437
|
|
|
439
|
-
args_schema: Annotated[
|
|
438
|
+
args_schema: Annotated[ArgsSchema | None, SkipValidation()] = Field(
|
|
440
439
|
default=None, description="The tool schema."
|
|
441
440
|
)
|
|
442
441
|
"""Pydantic model class to validate and parse the tool's input arguments.
|
|
@@ -459,56 +458,42 @@ class ChildTool(BaseTool):
|
|
|
459
458
|
callbacks: Callbacks = Field(default=None, exclude=True)
|
|
460
459
|
"""Callbacks to be called during tool execution."""
|
|
461
460
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
)(
|
|
465
|
-
Field(
|
|
466
|
-
default=None,
|
|
467
|
-
exclude=True,
|
|
468
|
-
description="Callback manager to add to the run trace.",
|
|
469
|
-
)
|
|
470
|
-
)
|
|
471
|
-
tags: Optional[list[str]] = None
|
|
472
|
-
"""Optional list of tags associated with the tool. Defaults to None.
|
|
461
|
+
tags: list[str] | None = None
|
|
462
|
+
"""Optional list of tags associated with the tool.
|
|
473
463
|
These tags will be associated with each call to this tool,
|
|
474
464
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
475
465
|
You can use these to eg identify a specific instance of a tool with its use case.
|
|
476
466
|
"""
|
|
477
|
-
metadata:
|
|
478
|
-
"""Optional metadata associated with the tool.
|
|
467
|
+
metadata: dict[str, Any] | None = None
|
|
468
|
+
"""Optional metadata associated with the tool.
|
|
479
469
|
This metadata will be associated with each call to this tool,
|
|
480
470
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
481
471
|
You can use these to eg identify a specific instance of a tool with its use case.
|
|
482
472
|
"""
|
|
483
473
|
|
|
484
|
-
handle_tool_error:
|
|
485
|
-
False
|
|
486
|
-
)
|
|
474
|
+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False
|
|
487
475
|
"""Handle the content of the ToolException thrown."""
|
|
488
476
|
|
|
489
|
-
handle_validation_error:
|
|
490
|
-
|
|
491
|
-
|
|
477
|
+
handle_validation_error: (
|
|
478
|
+
bool | str | Callable[[ValidationError | ValidationErrorV1], str] | None
|
|
479
|
+
) = False
|
|
492
480
|
"""Handle the content of the ValidationError thrown."""
|
|
493
481
|
|
|
494
482
|
response_format: Literal["content", "content_and_artifact"] = "content"
|
|
495
|
-
"""The tool response format.
|
|
483
|
+
"""The tool response format.
|
|
496
484
|
|
|
497
|
-
If "content" then the output of the tool is interpreted as the contents of a
|
|
498
|
-
ToolMessage. If "content_and_artifact" then the output is expected to be a
|
|
499
|
-
two-tuple corresponding to the (content, artifact) of a ToolMessage
|
|
500
|
-
"""
|
|
501
|
-
|
|
502
|
-
message_version: Literal["v0", "v1"] = "v0"
|
|
503
|
-
"""Version of ToolMessage to return given
|
|
504
|
-
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
|
|
505
|
-
|
|
506
|
-
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
|
|
507
|
-
If ``"v1"``, output will be a v1 :class:`~langchain_core.messages.v1.ToolMessage`.
|
|
485
|
+
If `"content"` then the output of the tool is interpreted as the contents of a
|
|
486
|
+
ToolMessage. If `"content_and_artifact"` then the output is expected to be a
|
|
487
|
+
two-tuple corresponding to the (content, artifact) of a `ToolMessage`.
|
|
508
488
|
"""
|
|
509
489
|
|
|
510
490
|
def __init__(self, **kwargs: Any) -> None:
|
|
511
|
-
"""Initialize the tool.
|
|
491
|
+
"""Initialize the tool.
|
|
492
|
+
|
|
493
|
+
Raises:
|
|
494
|
+
TypeError: If `args_schema` is not a subclass of pydantic `BaseModel` or
|
|
495
|
+
dict.
|
|
496
|
+
"""
|
|
512
497
|
if (
|
|
513
498
|
"args_schema" in kwargs
|
|
514
499
|
and kwargs["args_schema"] is not None
|
|
@@ -531,7 +516,7 @@ class ChildTool(BaseTool):
|
|
|
531
516
|
"""Check if the tool accepts only a single input argument.
|
|
532
517
|
|
|
533
518
|
Returns:
|
|
534
|
-
True if the tool has only one input argument, False otherwise.
|
|
519
|
+
`True` if the tool has only one input argument, `False` otherwise.
|
|
535
520
|
"""
|
|
536
521
|
keys = {k for k in self.args if k != "kwargs"}
|
|
537
522
|
return len(keys) == 1
|
|
@@ -545,6 +530,8 @@ class ChildTool(BaseTool):
|
|
|
545
530
|
"""
|
|
546
531
|
if isinstance(self.args_schema, dict):
|
|
547
532
|
json_schema = self.args_schema
|
|
533
|
+
elif self.args_schema and issubclass(self.args_schema, BaseModelV1):
|
|
534
|
+
json_schema = self.args_schema.schema()
|
|
548
535
|
else:
|
|
549
536
|
input_schema = self.get_input_schema()
|
|
550
537
|
json_schema = input_schema.model_json_schema()
|
|
@@ -578,9 +565,7 @@ class ChildTool(BaseTool):
|
|
|
578
565
|
# --- Runnable ---
|
|
579
566
|
|
|
580
567
|
@override
|
|
581
|
-
def get_input_schema(
|
|
582
|
-
self, config: Optional[RunnableConfig] = None
|
|
583
|
-
) -> type[BaseModel]:
|
|
568
|
+
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
|
584
569
|
"""The tool's input schema.
|
|
585
570
|
|
|
586
571
|
Args:
|
|
@@ -598,8 +583,8 @@ class ChildTool(BaseTool):
|
|
|
598
583
|
@override
|
|
599
584
|
def invoke(
|
|
600
585
|
self,
|
|
601
|
-
input:
|
|
602
|
-
config:
|
|
586
|
+
input: str | dict | ToolCall,
|
|
587
|
+
config: RunnableConfig | None = None,
|
|
603
588
|
**kwargs: Any,
|
|
604
589
|
) -> Any:
|
|
605
590
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
|
@@ -608,8 +593,8 @@ class ChildTool(BaseTool):
|
|
|
608
593
|
@override
|
|
609
594
|
async def ainvoke(
|
|
610
595
|
self,
|
|
611
|
-
input:
|
|
612
|
-
config:
|
|
596
|
+
input: str | dict | ToolCall,
|
|
597
|
+
config: RunnableConfig | None = None,
|
|
613
598
|
**kwargs: Any,
|
|
614
599
|
) -> Any:
|
|
615
600
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
|
@@ -618,8 +603,8 @@ class ChildTool(BaseTool):
|
|
|
618
603
|
# --- Tool ---
|
|
619
604
|
|
|
620
605
|
def _parse_input(
|
|
621
|
-
self, tool_input:
|
|
622
|
-
) ->
|
|
606
|
+
self, tool_input: str | dict, tool_call_id: str | None
|
|
607
|
+
) -> str | dict[str, Any]:
|
|
623
608
|
"""Parse and validate tool input using the args schema.
|
|
624
609
|
|
|
625
610
|
Args:
|
|
@@ -630,9 +615,10 @@ class ChildTool(BaseTool):
|
|
|
630
615
|
The parsed and validated input.
|
|
631
616
|
|
|
632
617
|
Raises:
|
|
633
|
-
ValueError: If string input is provided with JSON schema
|
|
634
|
-
|
|
635
|
-
|
|
618
|
+
ValueError: If `string` input is provided with JSON schema `args_schema`.
|
|
619
|
+
ValueError: If InjectedToolCallId is required but `tool_call_id` is not
|
|
620
|
+
provided.
|
|
621
|
+
TypeError: If args_schema is not a Pydantic `BaseModel` or dict.
|
|
636
622
|
"""
|
|
637
623
|
input_args = self.args_schema
|
|
638
624
|
if isinstance(tool_input, str):
|
|
@@ -657,10 +643,7 @@ class ChildTool(BaseTool):
|
|
|
657
643
|
return tool_input
|
|
658
644
|
if issubclass(input_args, BaseModel):
|
|
659
645
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
|
660
|
-
if (
|
|
661
|
-
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
|
662
|
-
and k not in tool_input
|
|
663
|
-
):
|
|
646
|
+
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
|
664
647
|
if tool_call_id is None:
|
|
665
648
|
msg = (
|
|
666
649
|
"When tool includes an InjectedToolCallId "
|
|
@@ -675,10 +658,7 @@ class ChildTool(BaseTool):
|
|
|
675
658
|
result_dict = result.model_dump()
|
|
676
659
|
elif issubclass(input_args, BaseModelV1):
|
|
677
660
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
|
678
|
-
if (
|
|
679
|
-
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
|
680
|
-
and k not in tool_input
|
|
681
|
-
):
|
|
661
|
+
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
|
682
662
|
if tool_call_id is None:
|
|
683
663
|
msg = (
|
|
684
664
|
"When tool includes an InjectedToolCallId "
|
|
@@ -701,39 +681,25 @@ class ChildTool(BaseTool):
|
|
|
701
681
|
}
|
|
702
682
|
return tool_input
|
|
703
683
|
|
|
704
|
-
@model_validator(mode="before")
|
|
705
|
-
@classmethod
|
|
706
|
-
def raise_deprecation(cls, values: dict) -> Any:
|
|
707
|
-
"""Raise deprecation warning if callback_manager is used.
|
|
708
|
-
|
|
709
|
-
Args:
|
|
710
|
-
values: The values to validate.
|
|
711
|
-
|
|
712
|
-
Returns:
|
|
713
|
-
The validated values.
|
|
714
|
-
"""
|
|
715
|
-
if values.get("callback_manager") is not None:
|
|
716
|
-
warnings.warn(
|
|
717
|
-
"callback_manager is deprecated. Please use callbacks instead.",
|
|
718
|
-
DeprecationWarning,
|
|
719
|
-
stacklevel=6,
|
|
720
|
-
)
|
|
721
|
-
values["callbacks"] = values.pop("callback_manager", None)
|
|
722
|
-
return values
|
|
723
|
-
|
|
724
684
|
@abstractmethod
|
|
725
685
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
|
726
686
|
"""Use the tool.
|
|
727
687
|
|
|
728
|
-
Add run_manager:
|
|
729
|
-
|
|
688
|
+
Add `run_manager: CallbackManagerForToolRun | None = None` to child
|
|
689
|
+
implementations to enable tracing.
|
|
690
|
+
|
|
691
|
+
Returns:
|
|
692
|
+
The result of the tool execution.
|
|
730
693
|
"""
|
|
731
694
|
|
|
732
695
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
|
733
696
|
"""Use the tool asynchronously.
|
|
734
697
|
|
|
735
|
-
Add run_manager:
|
|
736
|
-
|
|
698
|
+
Add `run_manager: AsyncCallbackManagerForToolRun | None = None` to child
|
|
699
|
+
implementations to enable tracing.
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
The result of the tool execution.
|
|
737
703
|
"""
|
|
738
704
|
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
|
739
705
|
"run_manager"
|
|
@@ -742,7 +708,7 @@ class ChildTool(BaseTool):
|
|
|
742
708
|
return await run_in_executor(None, self._run, *args, **kwargs)
|
|
743
709
|
|
|
744
710
|
def _to_args_and_kwargs(
|
|
745
|
-
self, tool_input:
|
|
711
|
+
self, tool_input: str | dict, tool_call_id: str | None
|
|
746
712
|
) -> tuple[tuple, dict]:
|
|
747
713
|
"""Convert tool input to positional and keyword arguments.
|
|
748
714
|
|
|
@@ -782,35 +748,35 @@ class ChildTool(BaseTool):
|
|
|
782
748
|
|
|
783
749
|
def run(
|
|
784
750
|
self,
|
|
785
|
-
tool_input:
|
|
786
|
-
verbose:
|
|
787
|
-
start_color:
|
|
788
|
-
color:
|
|
751
|
+
tool_input: str | dict[str, Any],
|
|
752
|
+
verbose: bool | None = None, # noqa: FBT001
|
|
753
|
+
start_color: str | None = "green",
|
|
754
|
+
color: str | None = "green",
|
|
789
755
|
callbacks: Callbacks = None,
|
|
790
756
|
*,
|
|
791
|
-
tags:
|
|
792
|
-
metadata:
|
|
793
|
-
run_name:
|
|
794
|
-
run_id:
|
|
795
|
-
config:
|
|
796
|
-
tool_call_id:
|
|
757
|
+
tags: list[str] | None = None,
|
|
758
|
+
metadata: dict[str, Any] | None = None,
|
|
759
|
+
run_name: str | None = None,
|
|
760
|
+
run_id: uuid.UUID | None = None,
|
|
761
|
+
config: RunnableConfig | None = None,
|
|
762
|
+
tool_call_id: str | None = None,
|
|
797
763
|
**kwargs: Any,
|
|
798
764
|
) -> Any:
|
|
799
765
|
"""Run the tool.
|
|
800
766
|
|
|
801
767
|
Args:
|
|
802
768
|
tool_input: The input to the tool.
|
|
803
|
-
verbose: Whether to log the tool's progress.
|
|
804
|
-
start_color: The color to use when starting the tool.
|
|
805
|
-
color: The color to use when ending the tool.
|
|
806
|
-
callbacks: Callbacks to be called during tool execution.
|
|
807
|
-
tags: Optional list of tags associated with the tool.
|
|
808
|
-
metadata: Optional metadata associated with the tool.
|
|
809
|
-
run_name: The name of the run.
|
|
810
|
-
run_id: The id of the run.
|
|
811
|
-
config: The configuration for the tool.
|
|
812
|
-
tool_call_id: The id of the tool call.
|
|
813
|
-
kwargs: Keyword arguments to be passed to tool callbacks (event handler)
|
|
769
|
+
verbose: Whether to log the tool's progress.
|
|
770
|
+
start_color: The color to use when starting the tool.
|
|
771
|
+
color: The color to use when ending the tool.
|
|
772
|
+
callbacks: Callbacks to be called during tool execution.
|
|
773
|
+
tags: Optional list of tags associated with the tool.
|
|
774
|
+
metadata: Optional metadata associated with the tool.
|
|
775
|
+
run_name: The name of the run.
|
|
776
|
+
run_id: The id of the run.
|
|
777
|
+
config: The configuration for the tool.
|
|
778
|
+
tool_call_id: The id of the tool call.
|
|
779
|
+
**kwargs: Keyword arguments to be passed to tool callbacks (event handler)
|
|
814
780
|
|
|
815
781
|
Returns:
|
|
816
782
|
The output of the tool.
|
|
@@ -844,8 +810,8 @@ class ChildTool(BaseTool):
|
|
|
844
810
|
|
|
845
811
|
content = None
|
|
846
812
|
artifact = None
|
|
847
|
-
status
|
|
848
|
-
error_to_raise:
|
|
813
|
+
status = "success"
|
|
814
|
+
error_to_raise: Exception | KeyboardInterrupt | None = None
|
|
849
815
|
try:
|
|
850
816
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
|
851
817
|
with set_config_context(child_config) as context:
|
|
@@ -888,48 +854,41 @@ class ChildTool(BaseTool):
|
|
|
888
854
|
if error_to_raise:
|
|
889
855
|
run_manager.on_tool_error(error_to_raise)
|
|
890
856
|
raise error_to_raise
|
|
891
|
-
output = _format_output(
|
|
892
|
-
content,
|
|
893
|
-
artifact,
|
|
894
|
-
tool_call_id,
|
|
895
|
-
self.name,
|
|
896
|
-
status,
|
|
897
|
-
message_version=self.message_version,
|
|
898
|
-
)
|
|
857
|
+
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
|
899
858
|
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
|
900
859
|
return output
|
|
901
860
|
|
|
902
861
|
async def arun(
|
|
903
862
|
self,
|
|
904
|
-
tool_input:
|
|
905
|
-
verbose:
|
|
906
|
-
start_color:
|
|
907
|
-
color:
|
|
863
|
+
tool_input: str | dict,
|
|
864
|
+
verbose: bool | None = None, # noqa: FBT001
|
|
865
|
+
start_color: str | None = "green",
|
|
866
|
+
color: str | None = "green",
|
|
908
867
|
callbacks: Callbacks = None,
|
|
909
868
|
*,
|
|
910
|
-
tags:
|
|
911
|
-
metadata:
|
|
912
|
-
run_name:
|
|
913
|
-
run_id:
|
|
914
|
-
config:
|
|
915
|
-
tool_call_id:
|
|
869
|
+
tags: list[str] | None = None,
|
|
870
|
+
metadata: dict[str, Any] | None = None,
|
|
871
|
+
run_name: str | None = None,
|
|
872
|
+
run_id: uuid.UUID | None = None,
|
|
873
|
+
config: RunnableConfig | None = None,
|
|
874
|
+
tool_call_id: str | None = None,
|
|
916
875
|
**kwargs: Any,
|
|
917
876
|
) -> Any:
|
|
918
877
|
"""Run the tool asynchronously.
|
|
919
878
|
|
|
920
879
|
Args:
|
|
921
880
|
tool_input: The input to the tool.
|
|
922
|
-
verbose: Whether to log the tool's progress.
|
|
923
|
-
start_color: The color to use when starting the tool.
|
|
924
|
-
color: The color to use when ending the tool.
|
|
925
|
-
callbacks: Callbacks to be called during tool execution.
|
|
926
|
-
tags: Optional list of tags associated with the tool.
|
|
927
|
-
metadata: Optional metadata associated with the tool.
|
|
928
|
-
run_name: The name of the run.
|
|
929
|
-
run_id: The id of the run.
|
|
930
|
-
config: The configuration for the tool.
|
|
931
|
-
tool_call_id: The id of the tool call.
|
|
932
|
-
kwargs: Keyword arguments to be passed to tool callbacks
|
|
881
|
+
verbose: Whether to log the tool's progress.
|
|
882
|
+
start_color: The color to use when starting the tool.
|
|
883
|
+
color: The color to use when ending the tool.
|
|
884
|
+
callbacks: Callbacks to be called during tool execution.
|
|
885
|
+
tags: Optional list of tags associated with the tool.
|
|
886
|
+
metadata: Optional metadata associated with the tool.
|
|
887
|
+
run_name: The name of the run.
|
|
888
|
+
run_id: The id of the run.
|
|
889
|
+
config: The configuration for the tool.
|
|
890
|
+
tool_call_id: The id of the tool call.
|
|
891
|
+
**kwargs: Keyword arguments to be passed to tool callbacks
|
|
933
892
|
|
|
934
893
|
Returns:
|
|
935
894
|
The output of the tool.
|
|
@@ -961,8 +920,8 @@ class ChildTool(BaseTool):
|
|
|
961
920
|
)
|
|
962
921
|
content = None
|
|
963
922
|
artifact = None
|
|
964
|
-
status
|
|
965
|
-
error_to_raise:
|
|
923
|
+
status = "success"
|
|
924
|
+
error_to_raise: Exception | KeyboardInterrupt | None = None
|
|
966
925
|
try:
|
|
967
926
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
|
968
927
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
|
@@ -1009,30 +968,10 @@ class ChildTool(BaseTool):
|
|
|
1009
968
|
await run_manager.on_tool_error(error_to_raise)
|
|
1010
969
|
raise error_to_raise
|
|
1011
970
|
|
|
1012
|
-
output = _format_output(
|
|
1013
|
-
content,
|
|
1014
|
-
artifact,
|
|
1015
|
-
tool_call_id,
|
|
1016
|
-
self.name,
|
|
1017
|
-
status,
|
|
1018
|
-
message_version=self.message_version,
|
|
1019
|
-
)
|
|
971
|
+
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
|
1020
972
|
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
|
1021
973
|
return output
|
|
1022
974
|
|
|
1023
|
-
@deprecated("0.1.47", alternative="invoke", removal="1.0")
|
|
1024
|
-
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
|
1025
|
-
"""Make tool callable (deprecated).
|
|
1026
|
-
|
|
1027
|
-
Args:
|
|
1028
|
-
tool_input: The input to the tool.
|
|
1029
|
-
callbacks: Callbacks to use during execution.
|
|
1030
|
-
|
|
1031
|
-
Returns:
|
|
1032
|
-
The tool's output.
|
|
1033
|
-
"""
|
|
1034
|
-
return self.run(tool_input, callbacks=callbacks)
|
|
1035
|
-
|
|
1036
975
|
|
|
1037
976
|
def _is_tool_call(x: Any) -> bool:
|
|
1038
977
|
"""Check if the input is a tool call dictionary.
|
|
@@ -1041,17 +980,15 @@ def _is_tool_call(x: Any) -> bool:
|
|
|
1041
980
|
x: The input to check.
|
|
1042
981
|
|
|
1043
982
|
Returns:
|
|
1044
|
-
True if the input is a tool call, False otherwise.
|
|
983
|
+
`True` if the input is a tool call, `False` otherwise.
|
|
1045
984
|
"""
|
|
1046
985
|
return isinstance(x, dict) and x.get("type") == "tool_call"
|
|
1047
986
|
|
|
1048
987
|
|
|
1049
988
|
def _handle_validation_error(
|
|
1050
|
-
e:
|
|
989
|
+
e: ValidationError | ValidationErrorV1,
|
|
1051
990
|
*,
|
|
1052
|
-
flag:
|
|
1053
|
-
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
|
1054
|
-
],
|
|
991
|
+
flag: Literal[True] | str | Callable[[ValidationError | ValidationErrorV1], str],
|
|
1055
992
|
) -> str:
|
|
1056
993
|
"""Handle validation errors based on the configured flag.
|
|
1057
994
|
|
|
@@ -1083,7 +1020,7 @@ def _handle_validation_error(
|
|
|
1083
1020
|
def _handle_tool_error(
|
|
1084
1021
|
e: ToolException,
|
|
1085
1022
|
*,
|
|
1086
|
-
flag:
|
|
1023
|
+
flag: Literal[True] | str | Callable[[ToolException], str] | None,
|
|
1087
1024
|
) -> str:
|
|
1088
1025
|
"""Handle tool execution errors based on the configured flag.
|
|
1089
1026
|
|
|
@@ -1113,10 +1050,10 @@ def _handle_tool_error(
|
|
|
1113
1050
|
|
|
1114
1051
|
|
|
1115
1052
|
def _prep_run_args(
|
|
1116
|
-
value:
|
|
1117
|
-
config:
|
|
1053
|
+
value: str | dict | ToolCall,
|
|
1054
|
+
config: RunnableConfig | None,
|
|
1118
1055
|
**kwargs: Any,
|
|
1119
|
-
) -> tuple[
|
|
1056
|
+
) -> tuple[str | dict, dict]:
|
|
1120
1057
|
"""Prepare arguments for tool execution.
|
|
1121
1058
|
|
|
1122
1059
|
Args:
|
|
@@ -1129,11 +1066,11 @@ def _prep_run_args(
|
|
|
1129
1066
|
"""
|
|
1130
1067
|
config = ensure_config(config)
|
|
1131
1068
|
if _is_tool_call(value):
|
|
1132
|
-
tool_call_id:
|
|
1133
|
-
tool_input:
|
|
1069
|
+
tool_call_id: str | None = cast("ToolCall", value)["id"]
|
|
1070
|
+
tool_input: str | dict = cast("ToolCall", value)["args"].copy()
|
|
1134
1071
|
else:
|
|
1135
1072
|
tool_call_id = None
|
|
1136
|
-
tool_input = cast("
|
|
1073
|
+
tool_input = cast("str | dict", value)
|
|
1137
1074
|
return (
|
|
1138
1075
|
tool_input,
|
|
1139
1076
|
dict(
|
|
@@ -1152,12 +1089,10 @@ def _prep_run_args(
|
|
|
1152
1089
|
def _format_output(
|
|
1153
1090
|
content: Any,
|
|
1154
1091
|
artifact: Any,
|
|
1155
|
-
tool_call_id:
|
|
1092
|
+
tool_call_id: str | None,
|
|
1156
1093
|
name: str,
|
|
1157
|
-
status:
|
|
1158
|
-
|
|
1159
|
-
message_version: Literal["v0", "v1"] = "v0",
|
|
1160
|
-
) -> Union[ToolOutputMixin, Any]:
|
|
1094
|
+
status: str,
|
|
1095
|
+
) -> ToolOutputMixin | Any:
|
|
1161
1096
|
"""Format tool output as a ToolMessage if appropriate.
|
|
1162
1097
|
|
|
1163
1098
|
Args:
|
|
@@ -1166,7 +1101,6 @@ def _format_output(
|
|
|
1166
1101
|
tool_call_id: The ID of the tool call.
|
|
1167
1102
|
name: The name of the tool.
|
|
1168
1103
|
status: The execution status.
|
|
1169
|
-
message_version: The version of the ToolMessage to return.
|
|
1170
1104
|
|
|
1171
1105
|
Returns:
|
|
1172
1106
|
The formatted output, either as a ToolMessage or the original content.
|
|
@@ -1175,15 +1109,7 @@ def _format_output(
|
|
|
1175
1109
|
return content
|
|
1176
1110
|
if not _is_message_content_type(content):
|
|
1177
1111
|
content = _stringify(content)
|
|
1178
|
-
|
|
1179
|
-
return ToolMessage(
|
|
1180
|
-
content,
|
|
1181
|
-
artifact=artifact,
|
|
1182
|
-
tool_call_id=tool_call_id,
|
|
1183
|
-
name=name,
|
|
1184
|
-
status=status,
|
|
1185
|
-
)
|
|
1186
|
-
return ToolMessageV1(
|
|
1112
|
+
return ToolMessage(
|
|
1187
1113
|
content,
|
|
1188
1114
|
artifact=artifact,
|
|
1189
1115
|
tool_call_id=tool_call_id,
|
|
@@ -1201,7 +1127,7 @@ def _is_message_content_type(obj: Any) -> bool:
|
|
|
1201
1127
|
obj: The object to check.
|
|
1202
1128
|
|
|
1203
1129
|
Returns:
|
|
1204
|
-
True if the object is valid message content, False otherwise.
|
|
1130
|
+
`True` if the object is valid message content, `False` otherwise.
|
|
1205
1131
|
"""
|
|
1206
1132
|
return isinstance(obj, str) or (
|
|
1207
1133
|
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
|
|
@@ -1217,7 +1143,7 @@ def _is_message_content_block(obj: Any) -> bool:
|
|
|
1217
1143
|
obj: The object to check.
|
|
1218
1144
|
|
|
1219
1145
|
Returns:
|
|
1220
|
-
True if the object is a valid content block, False otherwise.
|
|
1146
|
+
`True` if the object is a valid content block, `False` otherwise.
|
|
1221
1147
|
"""
|
|
1222
1148
|
if isinstance(obj, str):
|
|
1223
1149
|
return True
|
|
@@ -1241,7 +1167,7 @@ def _stringify(content: Any) -> str:
|
|
|
1241
1167
|
return str(content)
|
|
1242
1168
|
|
|
1243
1169
|
|
|
1244
|
-
def _get_type_hints(func: Callable) ->
|
|
1170
|
+
def _get_type_hints(func: Callable) -> dict[str, type] | None:
|
|
1245
1171
|
"""Get type hints from a function, handling partial functions.
|
|
1246
1172
|
|
|
1247
1173
|
Args:
|
|
@@ -1258,7 +1184,7 @@ def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
|
|
|
1258
1184
|
return None
|
|
1259
1185
|
|
|
1260
1186
|
|
|
1261
|
-
def _get_runnable_config_param(func: Callable) ->
|
|
1187
|
+
def _get_runnable_config_param(func: Callable) -> str | None:
|
|
1262
1188
|
"""Find the parameter name for RunnableConfig in a function.
|
|
1263
1189
|
|
|
1264
1190
|
Args:
|
|
@@ -1284,35 +1210,73 @@ class InjectedToolArg:
|
|
|
1284
1210
|
"""
|
|
1285
1211
|
|
|
1286
1212
|
|
|
1213
|
+
class _DirectlyInjectedToolArg:
|
|
1214
|
+
"""Annotation for tool arguments that are injected at runtime.
|
|
1215
|
+
|
|
1216
|
+
Injected via direct type annotation, rather than annotated metadata.
|
|
1217
|
+
|
|
1218
|
+
For example, ToolRuntime is a directly injected argument.
|
|
1219
|
+
Note the direct annotation rather than the verbose alternative:
|
|
1220
|
+
Annotated[ToolRuntime, InjectedRuntime]
|
|
1221
|
+
```python
|
|
1222
|
+
from langchain_core.tools import tool, ToolRuntime
|
|
1223
|
+
|
|
1224
|
+
|
|
1225
|
+
@tool
|
|
1226
|
+
def foo(x: int, runtime: ToolRuntime) -> str:
|
|
1227
|
+
# use runtime.state, runtime.context, runtime.store, etc.
|
|
1228
|
+
...
|
|
1229
|
+
```
|
|
1230
|
+
"""
|
|
1231
|
+
|
|
1232
|
+
|
|
1287
1233
|
class InjectedToolCallId(InjectedToolArg):
|
|
1288
1234
|
"""Annotation for injecting the tool call ID.
|
|
1289
1235
|
|
|
1290
1236
|
This annotation is used to mark a tool parameter that should receive
|
|
1291
1237
|
the tool call ID at runtime.
|
|
1292
1238
|
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1239
|
+
```python
|
|
1240
|
+
from typing import Annotated
|
|
1241
|
+
from langchain_core.messages import ToolMessage
|
|
1242
|
+
from langchain_core.tools import tool, InjectedToolCallId
|
|
1243
|
+
|
|
1244
|
+
@tool
|
|
1245
|
+
def foo(
|
|
1246
|
+
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
|
|
1247
|
+
) -> ToolMessage:
|
|
1248
|
+
\"\"\"Return x.\"\"\"
|
|
1249
|
+
return ToolMessage(
|
|
1250
|
+
str(x),
|
|
1251
|
+
artifact=x,
|
|
1252
|
+
name="foo",
|
|
1253
|
+
tool_call_id=tool_call_id
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
```
|
|
1257
|
+
"""
|
|
1310
1258
|
|
|
1259
|
+
|
|
1260
|
+
def _is_directly_injected_arg_type(type_: Any) -> bool:
|
|
1261
|
+
"""Check if a type annotation indicates a directly injected argument.
|
|
1262
|
+
|
|
1263
|
+
This is currently only used for ToolRuntime.
|
|
1264
|
+
Checks if either the annotation itself is a subclass of _DirectlyInjectedToolArg
|
|
1265
|
+
or the origin of the annotation is a subclass of _DirectlyInjectedToolArg.
|
|
1266
|
+
|
|
1267
|
+
Ex: ToolRuntime or ToolRuntime[ContextT, StateT] would both return True.
|
|
1311
1268
|
"""
|
|
1269
|
+
return (
|
|
1270
|
+
isinstance(type_, type) and issubclass(type_, _DirectlyInjectedToolArg)
|
|
1271
|
+
) or (
|
|
1272
|
+
(origin := get_origin(type_)) is not None
|
|
1273
|
+
and isinstance(origin, type)
|
|
1274
|
+
and issubclass(origin, _DirectlyInjectedToolArg)
|
|
1275
|
+
)
|
|
1312
1276
|
|
|
1313
1277
|
|
|
1314
1278
|
def _is_injected_arg_type(
|
|
1315
|
-
type_: type, injected_type:
|
|
1279
|
+
type_: type | TypeVar, injected_type: type[InjectedToolArg] | None = None
|
|
1316
1280
|
) -> bool:
|
|
1317
1281
|
"""Check if a type annotation indicates an injected argument.
|
|
1318
1282
|
|
|
@@ -1321,9 +1285,17 @@ def _is_injected_arg_type(
|
|
|
1321
1285
|
injected_type: The specific injected type to check for.
|
|
1322
1286
|
|
|
1323
1287
|
Returns:
|
|
1324
|
-
True if the type is an injected argument, False otherwise.
|
|
1288
|
+
`True` if the type is an injected argument, `False` otherwise.
|
|
1325
1289
|
"""
|
|
1326
|
-
|
|
1290
|
+
if injected_type is None:
|
|
1291
|
+
# if no injected type is specified,
|
|
1292
|
+
# check if the type is a directly injected argument
|
|
1293
|
+
if _is_directly_injected_arg_type(type_):
|
|
1294
|
+
return True
|
|
1295
|
+
injected_type = InjectedToolArg
|
|
1296
|
+
|
|
1297
|
+
# if the type is an Annotated type, check if annotated metadata
|
|
1298
|
+
# is an intance or subclass of the injected type
|
|
1327
1299
|
return any(
|
|
1328
1300
|
isinstance(arg, injected_type)
|
|
1329
1301
|
or (isinstance(arg, type) and issubclass(arg, injected_type))
|
|
@@ -1332,21 +1304,23 @@ def _is_injected_arg_type(
|
|
|
1332
1304
|
|
|
1333
1305
|
|
|
1334
1306
|
def get_all_basemodel_annotations(
|
|
1335
|
-
cls:
|
|
1336
|
-
) -> dict[str, type]:
|
|
1307
|
+
cls: TypeBaseModel | Any, *, default_to_bound: bool = True
|
|
1308
|
+
) -> dict[str, type | TypeVar]:
|
|
1337
1309
|
"""Get all annotations from a Pydantic BaseModel and its parents.
|
|
1338
1310
|
|
|
1339
1311
|
Args:
|
|
1340
1312
|
cls: The Pydantic BaseModel class.
|
|
1341
1313
|
default_to_bound: Whether to default to the bound of a TypeVar if it exists.
|
|
1314
|
+
|
|
1315
|
+
Returns:
|
|
1316
|
+
A dictionary of field names to their type annotations.
|
|
1342
1317
|
"""
|
|
1343
1318
|
# cls has no subscript: cls = FooBar
|
|
1344
1319
|
if isinstance(cls, type):
|
|
1345
|
-
|
|
1346
|
-
fields = getattr(cls, "model_fields", {}) or getattr(cls, "__fields__", {})
|
|
1320
|
+
fields = get_fields(cls)
|
|
1347
1321
|
alias_map = {field.alias: name for name, field in fields.items() if field.alias}
|
|
1348
1322
|
|
|
1349
|
-
annotations: dict[str, type] = {}
|
|
1323
|
+
annotations: dict[str, type | TypeVar] = {}
|
|
1350
1324
|
for name, param in inspect.signature(cls).parameters.items():
|
|
1351
1325
|
# Exclude hidden init args added by pydantic Config. For example if
|
|
1352
1326
|
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
|
@@ -1382,12 +1356,12 @@ def get_all_basemodel_annotations(
|
|
|
1382
1356
|
continue
|
|
1383
1357
|
|
|
1384
1358
|
# if class = FooBar inherits from Baz[str]:
|
|
1385
|
-
# parent = Baz[str],
|
|
1386
|
-
# parent_origin = Baz,
|
|
1359
|
+
# parent = class Baz[str],
|
|
1360
|
+
# parent_origin = class Baz,
|
|
1387
1361
|
# generic_type_vars = (type vars in Baz)
|
|
1388
1362
|
# generic_map = {type var in Baz: str}
|
|
1389
1363
|
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
|
|
1390
|
-
generic_map = dict(zip(generic_type_vars, get_args(parent)))
|
|
1364
|
+
generic_map = dict(zip(generic_type_vars, get_args(parent), strict=False))
|
|
1391
1365
|
for field in getattr(parent_origin, "__annotations__", {}):
|
|
1392
1366
|
annotations[field] = _replace_type_vars(
|
|
1393
1367
|
annotations[field], generic_map, default_to_bound=default_to_bound
|
|
@@ -1400,11 +1374,11 @@ def get_all_basemodel_annotations(
|
|
|
1400
1374
|
|
|
1401
1375
|
|
|
1402
1376
|
def _replace_type_vars(
|
|
1403
|
-
type_: type,
|
|
1404
|
-
generic_map:
|
|
1377
|
+
type_: type | TypeVar,
|
|
1378
|
+
generic_map: dict[TypeVar, type] | None = None,
|
|
1405
1379
|
*,
|
|
1406
1380
|
default_to_bound: bool = True,
|
|
1407
|
-
) -> type:
|
|
1381
|
+
) -> type | TypeVar:
|
|
1408
1382
|
"""Replace TypeVars in a type annotation with concrete types.
|
|
1409
1383
|
|
|
1410
1384
|
Args:
|
|
@@ -1420,7 +1394,7 @@ def _replace_type_vars(
|
|
|
1420
1394
|
if type_ in generic_map:
|
|
1421
1395
|
return generic_map[type_]
|
|
1422
1396
|
if default_to_bound:
|
|
1423
|
-
return type_.__bound__
|
|
1397
|
+
return type_.__bound__ if type_.__bound__ is not None else Any
|
|
1424
1398
|
return type_
|
|
1425
1399
|
if (origin := get_origin(type_)) and (args := get_args(type_)):
|
|
1426
1400
|
new_args = tuple(
|