langchain-core 1.0.0a6__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_core/__init__.py +1 -1
- langchain_core/_api/__init__.py +3 -4
- langchain_core/_api/beta_decorator.py +23 -26
- langchain_core/_api/deprecation.py +51 -64
- langchain_core/_api/path.py +3 -6
- langchain_core/_import_utils.py +3 -4
- langchain_core/agents.py +20 -22
- langchain_core/caches.py +65 -66
- langchain_core/callbacks/__init__.py +1 -8
- langchain_core/callbacks/base.py +321 -336
- langchain_core/callbacks/file.py +44 -44
- langchain_core/callbacks/manager.py +436 -513
- langchain_core/callbacks/stdout.py +29 -30
- langchain_core/callbacks/streaming_stdout.py +32 -32
- langchain_core/callbacks/usage.py +60 -57
- langchain_core/chat_history.py +53 -68
- langchain_core/document_loaders/base.py +27 -25
- langchain_core/document_loaders/blob_loaders.py +1 -1
- langchain_core/document_loaders/langsmith.py +44 -48
- langchain_core/documents/__init__.py +23 -3
- langchain_core/documents/base.py +98 -90
- langchain_core/documents/compressor.py +10 -10
- langchain_core/documents/transformers.py +34 -35
- langchain_core/embeddings/fake.py +50 -54
- langchain_core/example_selectors/length_based.py +1 -1
- langchain_core/example_selectors/semantic_similarity.py +28 -32
- langchain_core/exceptions.py +21 -20
- langchain_core/globals.py +3 -151
- langchain_core/indexing/__init__.py +1 -1
- langchain_core/indexing/api.py +121 -126
- langchain_core/indexing/base.py +73 -75
- langchain_core/indexing/in_memory.py +4 -6
- langchain_core/language_models/__init__.py +14 -29
- langchain_core/language_models/_utils.py +58 -61
- langchain_core/language_models/base.py +53 -162
- langchain_core/language_models/chat_models.py +298 -387
- langchain_core/language_models/fake.py +11 -11
- langchain_core/language_models/fake_chat_models.py +42 -36
- langchain_core/language_models/llms.py +125 -235
- langchain_core/load/dump.py +9 -12
- langchain_core/load/load.py +18 -28
- langchain_core/load/mapping.py +2 -4
- langchain_core/load/serializable.py +42 -40
- langchain_core/messages/__init__.py +10 -16
- langchain_core/messages/ai.py +148 -148
- langchain_core/messages/base.py +53 -51
- langchain_core/messages/block_translators/__init__.py +19 -22
- langchain_core/messages/block_translators/anthropic.py +6 -6
- langchain_core/messages/block_translators/bedrock_converse.py +5 -5
- langchain_core/messages/block_translators/google_genai.py +10 -7
- langchain_core/messages/block_translators/google_vertexai.py +4 -32
- langchain_core/messages/block_translators/groq.py +117 -21
- langchain_core/messages/block_translators/langchain_v0.py +5 -5
- langchain_core/messages/block_translators/openai.py +11 -11
- langchain_core/messages/chat.py +2 -6
- langchain_core/messages/content.py +337 -328
- langchain_core/messages/function.py +6 -10
- langchain_core/messages/human.py +24 -31
- langchain_core/messages/modifier.py +2 -2
- langchain_core/messages/system.py +19 -29
- langchain_core/messages/tool.py +74 -90
- langchain_core/messages/utils.py +474 -504
- langchain_core/output_parsers/__init__.py +13 -10
- langchain_core/output_parsers/base.py +61 -61
- langchain_core/output_parsers/format_instructions.py +9 -4
- langchain_core/output_parsers/json.py +12 -10
- langchain_core/output_parsers/list.py +21 -23
- langchain_core/output_parsers/openai_functions.py +49 -47
- langchain_core/output_parsers/openai_tools.py +16 -21
- langchain_core/output_parsers/pydantic.py +13 -14
- langchain_core/output_parsers/string.py +5 -5
- langchain_core/output_parsers/transform.py +15 -17
- langchain_core/output_parsers/xml.py +35 -34
- langchain_core/outputs/__init__.py +1 -1
- langchain_core/outputs/chat_generation.py +18 -18
- langchain_core/outputs/chat_result.py +1 -3
- langchain_core/outputs/generation.py +10 -11
- langchain_core/outputs/llm_result.py +10 -10
- langchain_core/prompt_values.py +11 -17
- langchain_core/prompts/__init__.py +3 -27
- langchain_core/prompts/base.py +48 -56
- langchain_core/prompts/chat.py +275 -325
- langchain_core/prompts/dict.py +5 -5
- langchain_core/prompts/few_shot.py +81 -88
- langchain_core/prompts/few_shot_with_templates.py +11 -13
- langchain_core/prompts/image.py +12 -14
- langchain_core/prompts/loading.py +4 -6
- langchain_core/prompts/message.py +3 -3
- langchain_core/prompts/prompt.py +24 -39
- langchain_core/prompts/string.py +26 -10
- langchain_core/prompts/structured.py +49 -53
- langchain_core/rate_limiters.py +51 -60
- langchain_core/retrievers.py +61 -198
- langchain_core/runnables/base.py +1476 -1626
- langchain_core/runnables/branch.py +53 -57
- langchain_core/runnables/config.py +72 -89
- langchain_core/runnables/configurable.py +120 -137
- langchain_core/runnables/fallbacks.py +83 -79
- langchain_core/runnables/graph.py +91 -97
- langchain_core/runnables/graph_ascii.py +27 -28
- langchain_core/runnables/graph_mermaid.py +38 -50
- langchain_core/runnables/graph_png.py +15 -16
- langchain_core/runnables/history.py +135 -148
- langchain_core/runnables/passthrough.py +124 -150
- langchain_core/runnables/retry.py +46 -51
- langchain_core/runnables/router.py +25 -30
- langchain_core/runnables/schema.py +75 -80
- langchain_core/runnables/utils.py +60 -67
- langchain_core/stores.py +85 -121
- langchain_core/structured_query.py +8 -8
- langchain_core/sys_info.py +27 -29
- langchain_core/tools/__init__.py +1 -14
- langchain_core/tools/base.py +284 -229
- langchain_core/tools/convert.py +160 -155
- langchain_core/tools/render.py +10 -10
- langchain_core/tools/retriever.py +12 -11
- langchain_core/tools/simple.py +19 -24
- langchain_core/tools/structured.py +32 -39
- langchain_core/tracers/__init__.py +1 -9
- langchain_core/tracers/base.py +97 -99
- langchain_core/tracers/context.py +29 -52
- langchain_core/tracers/core.py +49 -53
- langchain_core/tracers/evaluation.py +11 -11
- langchain_core/tracers/event_stream.py +65 -64
- langchain_core/tracers/langchain.py +21 -21
- langchain_core/tracers/log_stream.py +45 -45
- langchain_core/tracers/memory_stream.py +3 -3
- langchain_core/tracers/root_listeners.py +16 -16
- langchain_core/tracers/run_collector.py +2 -4
- langchain_core/tracers/schemas.py +0 -129
- langchain_core/tracers/stdout.py +3 -3
- langchain_core/utils/__init__.py +1 -4
- langchain_core/utils/_merge.py +2 -2
- langchain_core/utils/aiter.py +57 -61
- langchain_core/utils/env.py +9 -9
- langchain_core/utils/function_calling.py +89 -186
- langchain_core/utils/html.py +7 -8
- langchain_core/utils/input.py +6 -6
- langchain_core/utils/interactive_env.py +1 -1
- langchain_core/utils/iter.py +36 -40
- langchain_core/utils/json.py +4 -3
- langchain_core/utils/json_schema.py +9 -9
- langchain_core/utils/mustache.py +8 -10
- langchain_core/utils/pydantic.py +33 -35
- langchain_core/utils/strings.py +6 -9
- langchain_core/utils/usage.py +1 -1
- langchain_core/utils/utils.py +66 -62
- langchain_core/vectorstores/base.py +182 -216
- langchain_core/vectorstores/in_memory.py +101 -176
- langchain_core/vectorstores/utils.py +5 -5
- langchain_core/version.py +1 -1
- langchain_core-1.0.3.dist-info/METADATA +69 -0
- langchain_core-1.0.3.dist-info/RECORD +172 -0
- {langchain_core-1.0.0a6.dist-info → langchain_core-1.0.3.dist-info}/WHEEL +1 -1
- langchain_core/memory.py +0 -120
- langchain_core/messages/block_translators/ollama.py +0 -47
- langchain_core/prompts/pipeline.py +0 -138
- langchain_core/pydantic_v1/__init__.py +0 -30
- langchain_core/pydantic_v1/dataclasses.py +0 -23
- langchain_core/pydantic_v1/main.py +0 -23
- langchain_core/tracers/langchain_v1.py +0 -31
- langchain_core/utils/loading.py +0 -35
- langchain_core-1.0.0a6.dist-info/METADATA +0 -67
- langchain_core-1.0.0a6.dist-info/RECORD +0 -181
- langchain_core-1.0.0a6.dist-info/entry_points.txt +0 -4
langchain_core/tools/base.py
CHANGED
|
@@ -8,16 +8,14 @@ import json
|
|
|
8
8
|
import typing
|
|
9
9
|
import warnings
|
|
10
10
|
from abc import ABC, abstractmethod
|
|
11
|
+
from collections.abc import Callable
|
|
11
12
|
from inspect import signature
|
|
12
13
|
from typing import (
|
|
13
14
|
TYPE_CHECKING,
|
|
14
15
|
Annotated,
|
|
15
16
|
Any,
|
|
16
|
-
Callable,
|
|
17
17
|
Literal,
|
|
18
|
-
Optional,
|
|
19
18
|
TypeVar,
|
|
20
|
-
Union,
|
|
21
19
|
cast,
|
|
22
20
|
get_args,
|
|
23
21
|
get_origin,
|
|
@@ -31,7 +29,6 @@ from pydantic import (
|
|
|
31
29
|
PydanticDeprecationWarning,
|
|
32
30
|
SkipValidation,
|
|
33
31
|
ValidationError,
|
|
34
|
-
model_validator,
|
|
35
32
|
validate_arguments,
|
|
36
33
|
)
|
|
37
34
|
from pydantic.v1 import BaseModel as BaseModelV1
|
|
@@ -39,10 +36,8 @@ from pydantic.v1 import ValidationError as ValidationErrorV1
|
|
|
39
36
|
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
|
40
37
|
from typing_extensions import override
|
|
41
38
|
|
|
42
|
-
from langchain_core._api import deprecated
|
|
43
39
|
from langchain_core.callbacks import (
|
|
44
40
|
AsyncCallbackManager,
|
|
45
|
-
BaseCallbackManager,
|
|
46
41
|
CallbackManager,
|
|
47
42
|
Callbacks,
|
|
48
43
|
)
|
|
@@ -97,7 +92,7 @@ def _is_annotated_type(typ: type[Any]) -> bool:
|
|
|
97
92
|
typ: The type to check.
|
|
98
93
|
|
|
99
94
|
Returns:
|
|
100
|
-
True if the type is an Annotated type, False otherwise.
|
|
95
|
+
`True` if the type is an Annotated type, `False` otherwise.
|
|
101
96
|
"""
|
|
102
97
|
return get_origin(typ) is typing.Annotated
|
|
103
98
|
|
|
@@ -231,7 +226,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
|
|
|
231
226
|
pydantic_version: The Pydantic version to check against ("v1" or "v2").
|
|
232
227
|
|
|
233
228
|
Returns:
|
|
234
|
-
True if the annotation is a Pydantic model, False otherwise.
|
|
229
|
+
`True` if the annotation is a Pydantic model, `False` otherwise.
|
|
235
230
|
"""
|
|
236
231
|
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
|
|
237
232
|
try:
|
|
@@ -250,7 +245,7 @@ def _function_annotations_are_pydantic_v1(
|
|
|
250
245
|
func: The function being checked.
|
|
251
246
|
|
|
252
247
|
Returns:
|
|
253
|
-
True if all Pydantic annotations are from V1, False otherwise.
|
|
248
|
+
True if all Pydantic annotations are from V1, `False` otherwise.
|
|
254
249
|
|
|
255
250
|
Raises:
|
|
256
251
|
NotImplementedError: If the function contains mixed V1 and V2 annotations.
|
|
@@ -285,29 +280,28 @@ def create_schema_from_function(
|
|
|
285
280
|
model_name: str,
|
|
286
281
|
func: Callable,
|
|
287
282
|
*,
|
|
288
|
-
filter_args:
|
|
283
|
+
filter_args: Sequence[str] | None = None,
|
|
289
284
|
parse_docstring: bool = False,
|
|
290
285
|
error_on_invalid_docstring: bool = False,
|
|
291
286
|
include_injected: bool = True,
|
|
292
287
|
) -> type[BaseModel]:
|
|
293
|
-
"""Create a
|
|
288
|
+
"""Create a Pydantic schema from a function's signature.
|
|
294
289
|
|
|
295
290
|
Args:
|
|
296
|
-
model_name: Name to assign to the generated
|
|
291
|
+
model_name: Name to assign to the generated Pydantic schema.
|
|
297
292
|
func: Function to generate the schema from.
|
|
298
293
|
filter_args: Optional list of arguments to exclude from the schema.
|
|
299
|
-
Defaults to FILTERED_ARGS
|
|
294
|
+
Defaults to `FILTERED_ARGS`.
|
|
300
295
|
parse_docstring: Whether to parse the function's docstring for descriptions
|
|
301
|
-
for each argument.
|
|
302
|
-
error_on_invalid_docstring: if
|
|
303
|
-
whether to raise ValueError on invalid Google Style docstrings.
|
|
304
|
-
Defaults to False.
|
|
296
|
+
for each argument.
|
|
297
|
+
error_on_invalid_docstring: if `parse_docstring` is provided, configure
|
|
298
|
+
whether to raise `ValueError` on invalid Google Style docstrings.
|
|
305
299
|
include_injected: Whether to include injected arguments in the schema.
|
|
306
|
-
Defaults to True
|
|
300
|
+
Defaults to `True`, since we want to include them in the schema
|
|
307
301
|
when *validating* tool inputs.
|
|
308
302
|
|
|
309
303
|
Returns:
|
|
310
|
-
A
|
|
304
|
+
A Pydantic model with the same arguments as the function.
|
|
311
305
|
"""
|
|
312
306
|
sig = inspect.signature(func)
|
|
313
307
|
|
|
@@ -317,7 +311,7 @@ def create_schema_from_function(
|
|
|
317
311
|
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
|
318
312
|
with warnings.catch_warnings():
|
|
319
313
|
# We are using deprecated functionality here.
|
|
320
|
-
# This code should be re-written to simply construct a
|
|
314
|
+
# This code should be re-written to simply construct a Pydantic model
|
|
321
315
|
# using inspect.signature and create_model.
|
|
322
316
|
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
|
323
317
|
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
|
|
@@ -390,13 +384,14 @@ class ToolException(Exception): # noqa: N818
|
|
|
390
384
|
"""
|
|
391
385
|
|
|
392
386
|
|
|
393
|
-
ArgsSchema =
|
|
387
|
+
ArgsSchema = TypeBaseModel | dict[str, Any]
|
|
394
388
|
|
|
395
389
|
|
|
396
|
-
class BaseTool(RunnableSerializable[
|
|
390
|
+
class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
|
|
397
391
|
"""Base class for all LangChain tools.
|
|
398
392
|
|
|
399
393
|
This abstract class defines the interface that all LangChain tools must implement.
|
|
394
|
+
|
|
400
395
|
Tools are components that can be called by agents to perform specific actions.
|
|
401
396
|
"""
|
|
402
397
|
|
|
@@ -407,7 +402,7 @@ class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
|
|
407
402
|
**kwargs: Additional keyword arguments passed to the parent class.
|
|
408
403
|
|
|
409
404
|
Raises:
|
|
410
|
-
SchemaAnnotationError: If args_schema has incorrect type annotation.
|
|
405
|
+
SchemaAnnotationError: If `args_schema` has incorrect type annotation.
|
|
411
406
|
"""
|
|
412
407
|
super().__init_subclass__(**kwargs)
|
|
413
408
|
|
|
@@ -441,22 +436,22 @@ class ChildTool(BaseTool):
|
|
|
441
436
|
You can provide few-shot examples as a part of the description.
|
|
442
437
|
"""
|
|
443
438
|
|
|
444
|
-
args_schema: Annotated[
|
|
439
|
+
args_schema: Annotated[ArgsSchema | None, SkipValidation()] = Field(
|
|
445
440
|
default=None, description="The tool schema."
|
|
446
441
|
)
|
|
447
442
|
"""Pydantic model class to validate and parse the tool's input arguments.
|
|
448
443
|
|
|
449
444
|
Args schema should be either:
|
|
450
445
|
|
|
451
|
-
- A subclass of pydantic.BaseModel
|
|
452
|
-
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
|
|
453
|
-
-
|
|
446
|
+
- A subclass of `pydantic.BaseModel`.
|
|
447
|
+
- A subclass of `pydantic.v1.BaseModel` if accessing v1 namespace in pydantic 2
|
|
448
|
+
- A JSON schema dict
|
|
454
449
|
"""
|
|
455
450
|
return_direct: bool = False
|
|
456
451
|
"""Whether to return the tool's output directly.
|
|
457
452
|
|
|
458
|
-
Setting this to True means
|
|
459
|
-
|
|
453
|
+
Setting this to `True` means that after the tool is called, the `AgentExecutor` will
|
|
454
|
+
stop looping.
|
|
460
455
|
"""
|
|
461
456
|
verbose: bool = False
|
|
462
457
|
"""Whether to log the tool's progress."""
|
|
@@ -464,52 +459,47 @@ class ChildTool(BaseTool):
|
|
|
464
459
|
callbacks: Callbacks = Field(default=None, exclude=True)
|
|
465
460
|
"""Callbacks to be called during tool execution."""
|
|
466
461
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
Field(
|
|
471
|
-
default=None,
|
|
472
|
-
exclude=True,
|
|
473
|
-
description="Callback manager to add to the run trace.",
|
|
474
|
-
)
|
|
475
|
-
)
|
|
476
|
-
tags: Optional[list[str]] = None
|
|
477
|
-
"""Optional list of tags associated with the tool. Defaults to None.
|
|
462
|
+
tags: list[str] | None = None
|
|
463
|
+
"""Optional list of tags associated with the tool.
|
|
464
|
+
|
|
478
465
|
These tags will be associated with each call to this tool,
|
|
479
466
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
480
|
-
|
|
467
|
+
|
|
468
|
+
You can use these to, e.g., identify a specific instance of a tool with its use
|
|
469
|
+
case.
|
|
481
470
|
"""
|
|
482
|
-
metadata:
|
|
483
|
-
"""Optional metadata associated with the tool.
|
|
471
|
+
metadata: dict[str, Any] | None = None
|
|
472
|
+
"""Optional metadata associated with the tool.
|
|
473
|
+
|
|
484
474
|
This metadata will be associated with each call to this tool,
|
|
485
475
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
486
|
-
|
|
476
|
+
|
|
477
|
+
You can use these to, e.g., identify a specific instance of a tool with its use
|
|
478
|
+
case.
|
|
487
479
|
"""
|
|
488
480
|
|
|
489
|
-
handle_tool_error:
|
|
490
|
-
|
|
491
|
-
)
|
|
492
|
-
"""Handle the content of the ToolException thrown."""
|
|
481
|
+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False
|
|
482
|
+
"""Handle the content of the `ToolException` thrown."""
|
|
493
483
|
|
|
494
|
-
handle_validation_error:
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
"""Handle the content of the ValidationError thrown."""
|
|
484
|
+
handle_validation_error: (
|
|
485
|
+
bool | str | Callable[[ValidationError | ValidationErrorV1], str] | None
|
|
486
|
+
) = False
|
|
487
|
+
"""Handle the content of the `ValidationError` thrown."""
|
|
498
488
|
|
|
499
489
|
response_format: Literal["content", "content_and_artifact"] = "content"
|
|
500
|
-
"""The tool response format.
|
|
490
|
+
"""The tool response format.
|
|
501
491
|
|
|
502
|
-
If
|
|
503
|
-
ToolMessage
|
|
504
|
-
two-tuple corresponding to the (content, artifact) of a ToolMessage
|
|
492
|
+
If `'content'` then the output of the tool is interpreted as the contents of a
|
|
493
|
+
`ToolMessage`. If `'content_and_artifact'` then the output is expected to be a
|
|
494
|
+
two-tuple corresponding to the `(content, artifact)` of a `ToolMessage`.
|
|
505
495
|
"""
|
|
506
496
|
|
|
507
497
|
def __init__(self, **kwargs: Any) -> None:
|
|
508
498
|
"""Initialize the tool.
|
|
509
499
|
|
|
510
500
|
Raises:
|
|
511
|
-
TypeError: If
|
|
512
|
-
dict
|
|
501
|
+
TypeError: If `args_schema` is not a subclass of pydantic `BaseModel` or
|
|
502
|
+
`dict`.
|
|
513
503
|
"""
|
|
514
504
|
if (
|
|
515
505
|
"args_schema" in kwargs
|
|
@@ -533,7 +523,7 @@ class ChildTool(BaseTool):
|
|
|
533
523
|
"""Check if the tool accepts only a single input argument.
|
|
534
524
|
|
|
535
525
|
Returns:
|
|
536
|
-
True if the tool has only one input argument, False otherwise.
|
|
526
|
+
`True` if the tool has only one input argument, `False` otherwise.
|
|
537
527
|
"""
|
|
538
528
|
keys = {k for k in self.args if k != "kwargs"}
|
|
539
529
|
return len(keys) == 1
|
|
@@ -543,7 +533,7 @@ class ChildTool(BaseTool):
|
|
|
543
533
|
"""Get the tool's input arguments schema.
|
|
544
534
|
|
|
545
535
|
Returns:
|
|
546
|
-
|
|
536
|
+
`dict` containing the tool's argument properties.
|
|
547
537
|
"""
|
|
548
538
|
if isinstance(self.args_schema, dict):
|
|
549
539
|
json_schema = self.args_schema
|
|
@@ -582,9 +572,7 @@ class ChildTool(BaseTool):
|
|
|
582
572
|
# --- Runnable ---
|
|
583
573
|
|
|
584
574
|
@override
|
|
585
|
-
def get_input_schema(
|
|
586
|
-
self, config: Optional[RunnableConfig] = None
|
|
587
|
-
) -> type[BaseModel]:
|
|
575
|
+
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
|
588
576
|
"""The tool's input schema.
|
|
589
577
|
|
|
590
578
|
Args:
|
|
@@ -602,8 +590,8 @@ class ChildTool(BaseTool):
|
|
|
602
590
|
@override
|
|
603
591
|
def invoke(
|
|
604
592
|
self,
|
|
605
|
-
input:
|
|
606
|
-
config:
|
|
593
|
+
input: str | dict | ToolCall,
|
|
594
|
+
config: RunnableConfig | None = None,
|
|
607
595
|
**kwargs: Any,
|
|
608
596
|
) -> Any:
|
|
609
597
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
|
@@ -612,8 +600,8 @@ class ChildTool(BaseTool):
|
|
|
612
600
|
@override
|
|
613
601
|
async def ainvoke(
|
|
614
602
|
self,
|
|
615
|
-
input:
|
|
616
|
-
config:
|
|
603
|
+
input: str | dict | ToolCall,
|
|
604
|
+
config: RunnableConfig | None = None,
|
|
617
605
|
**kwargs: Any,
|
|
618
606
|
) -> Any:
|
|
619
607
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
|
@@ -622,8 +610,8 @@ class ChildTool(BaseTool):
|
|
|
622
610
|
# --- Tool ---
|
|
623
611
|
|
|
624
612
|
def _parse_input(
|
|
625
|
-
self, tool_input:
|
|
626
|
-
) ->
|
|
613
|
+
self, tool_input: str | dict, tool_call_id: str | None
|
|
614
|
+
) -> str | dict[str, Any]:
|
|
627
615
|
"""Parse and validate tool input using the args schema.
|
|
628
616
|
|
|
629
617
|
Args:
|
|
@@ -634,10 +622,10 @@ class ChildTool(BaseTool):
|
|
|
634
622
|
The parsed and validated input.
|
|
635
623
|
|
|
636
624
|
Raises:
|
|
637
|
-
ValueError: If string input is provided with JSON schema
|
|
638
|
-
ValueError: If InjectedToolCallId is required but
|
|
625
|
+
ValueError: If `string` input is provided with JSON schema `args_schema`.
|
|
626
|
+
ValueError: If `InjectedToolCallId` is required but `tool_call_id` is not
|
|
639
627
|
provided.
|
|
640
|
-
TypeError: If args_schema is not a Pydantic
|
|
628
|
+
TypeError: If `args_schema` is not a Pydantic `BaseModel` or dict.
|
|
641
629
|
"""
|
|
642
630
|
input_args = self.args_schema
|
|
643
631
|
if isinstance(tool_input, str):
|
|
@@ -700,32 +688,12 @@ class ChildTool(BaseTool):
|
|
|
700
688
|
}
|
|
701
689
|
return tool_input
|
|
702
690
|
|
|
703
|
-
@model_validator(mode="before")
|
|
704
|
-
@classmethod
|
|
705
|
-
def raise_deprecation(cls, values: dict) -> Any:
|
|
706
|
-
"""Raise deprecation warning if callback_manager is used.
|
|
707
|
-
|
|
708
|
-
Args:
|
|
709
|
-
values: The values to validate.
|
|
710
|
-
|
|
711
|
-
Returns:
|
|
712
|
-
The validated values.
|
|
713
|
-
"""
|
|
714
|
-
if values.get("callback_manager") is not None:
|
|
715
|
-
warnings.warn(
|
|
716
|
-
"callback_manager is deprecated. Please use callbacks instead.",
|
|
717
|
-
DeprecationWarning,
|
|
718
|
-
stacklevel=6,
|
|
719
|
-
)
|
|
720
|
-
values["callbacks"] = values.pop("callback_manager", None)
|
|
721
|
-
return values
|
|
722
|
-
|
|
723
691
|
@abstractmethod
|
|
724
692
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
|
725
693
|
"""Use the tool.
|
|
726
694
|
|
|
727
|
-
Add run_manager:
|
|
728
|
-
|
|
695
|
+
Add `run_manager: CallbackManagerForToolRun | None = None` to child
|
|
696
|
+
implementations to enable tracing.
|
|
729
697
|
|
|
730
698
|
Returns:
|
|
731
699
|
The result of the tool execution.
|
|
@@ -734,8 +702,8 @@ class ChildTool(BaseTool):
|
|
|
734
702
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
|
735
703
|
"""Use the tool asynchronously.
|
|
736
704
|
|
|
737
|
-
Add run_manager:
|
|
738
|
-
|
|
705
|
+
Add `run_manager: AsyncCallbackManagerForToolRun | None = None` to child
|
|
706
|
+
implementations to enable tracing.
|
|
739
707
|
|
|
740
708
|
Returns:
|
|
741
709
|
The result of the tool execution.
|
|
@@ -746,8 +714,37 @@ class ChildTool(BaseTool):
|
|
|
746
714
|
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
|
747
715
|
return await run_in_executor(None, self._run, *args, **kwargs)
|
|
748
716
|
|
|
717
|
+
def _filter_injected_args(self, tool_input: dict) -> dict:
|
|
718
|
+
"""Filter out injected tool arguments from the input dictionary.
|
|
719
|
+
|
|
720
|
+
Injected arguments are those annotated with `InjectedToolArg` or its
|
|
721
|
+
subclasses, or arguments in `FILTERED_ARGS` like `run_manager` and callbacks.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
tool_input: The tool input dictionary to filter.
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
A filtered dictionary with injected arguments removed.
|
|
728
|
+
"""
|
|
729
|
+
# Start with filtered args from the constant
|
|
730
|
+
filtered_keys = set[str](FILTERED_ARGS)
|
|
731
|
+
|
|
732
|
+
# If we have an args_schema, use it to identify injected args
|
|
733
|
+
if self.args_schema is not None:
|
|
734
|
+
try:
|
|
735
|
+
annotations = get_all_basemodel_annotations(self.args_schema)
|
|
736
|
+
for field_name, field_type in annotations.items():
|
|
737
|
+
if _is_injected_arg_type(field_type):
|
|
738
|
+
filtered_keys.add(field_name)
|
|
739
|
+
except Exception: # noqa: S110
|
|
740
|
+
# If we can't get annotations, just use FILTERED_ARGS
|
|
741
|
+
pass
|
|
742
|
+
|
|
743
|
+
# Filter out the injected keys from tool_input
|
|
744
|
+
return {k: v for k, v in tool_input.items() if k not in filtered_keys}
|
|
745
|
+
|
|
749
746
|
def _to_args_and_kwargs(
|
|
750
|
-
self, tool_input:
|
|
747
|
+
self, tool_input: str | dict, tool_call_id: str | None
|
|
751
748
|
) -> tuple[tuple, dict]:
|
|
752
749
|
"""Convert tool input to positional and keyword arguments.
|
|
753
750
|
|
|
@@ -756,7 +753,7 @@ class ChildTool(BaseTool):
|
|
|
756
753
|
tool_call_id: The ID of the tool call, if available.
|
|
757
754
|
|
|
758
755
|
Returns:
|
|
759
|
-
A tuple of (positional_args, keyword_args) for the tool.
|
|
756
|
+
A tuple of `(positional_args, keyword_args)` for the tool.
|
|
760
757
|
|
|
761
758
|
Raises:
|
|
762
759
|
TypeError: If the tool input type is invalid.
|
|
@@ -787,35 +784,35 @@ class ChildTool(BaseTool):
|
|
|
787
784
|
|
|
788
785
|
def run(
|
|
789
786
|
self,
|
|
790
|
-
tool_input:
|
|
791
|
-
verbose:
|
|
792
|
-
start_color:
|
|
793
|
-
color:
|
|
787
|
+
tool_input: str | dict[str, Any],
|
|
788
|
+
verbose: bool | None = None, # noqa: FBT001
|
|
789
|
+
start_color: str | None = "green",
|
|
790
|
+
color: str | None = "green",
|
|
794
791
|
callbacks: Callbacks = None,
|
|
795
792
|
*,
|
|
796
|
-
tags:
|
|
797
|
-
metadata:
|
|
798
|
-
run_name:
|
|
799
|
-
run_id:
|
|
800
|
-
config:
|
|
801
|
-
tool_call_id:
|
|
793
|
+
tags: list[str] | None = None,
|
|
794
|
+
metadata: dict[str, Any] | None = None,
|
|
795
|
+
run_name: str | None = None,
|
|
796
|
+
run_id: uuid.UUID | None = None,
|
|
797
|
+
config: RunnableConfig | None = None,
|
|
798
|
+
tool_call_id: str | None = None,
|
|
802
799
|
**kwargs: Any,
|
|
803
800
|
) -> Any:
|
|
804
801
|
"""Run the tool.
|
|
805
802
|
|
|
806
803
|
Args:
|
|
807
804
|
tool_input: The input to the tool.
|
|
808
|
-
verbose: Whether to log the tool's progress.
|
|
809
|
-
start_color: The color to use when starting the tool.
|
|
810
|
-
color: The color to use when ending the tool.
|
|
811
|
-
callbacks: Callbacks to be called during tool execution.
|
|
812
|
-
tags: Optional list of tags associated with the tool.
|
|
813
|
-
metadata: Optional metadata associated with the tool.
|
|
814
|
-
run_name: The name of the run.
|
|
815
|
-
run_id: The id of the run.
|
|
816
|
-
config: The configuration for the tool.
|
|
817
|
-
tool_call_id: The id of the tool call.
|
|
818
|
-
kwargs: Keyword arguments to be passed to tool callbacks (event handler)
|
|
805
|
+
verbose: Whether to log the tool's progress.
|
|
806
|
+
start_color: The color to use when starting the tool.
|
|
807
|
+
color: The color to use when ending the tool.
|
|
808
|
+
callbacks: Callbacks to be called during tool execution.
|
|
809
|
+
tags: Optional list of tags associated with the tool.
|
|
810
|
+
metadata: Optional metadata associated with the tool.
|
|
811
|
+
run_name: The name of the run.
|
|
812
|
+
run_id: The id of the run.
|
|
813
|
+
config: The configuration for the tool.
|
|
814
|
+
tool_call_id: The id of the tool call.
|
|
815
|
+
**kwargs: Keyword arguments to be passed to tool callbacks (event handler)
|
|
819
816
|
|
|
820
817
|
Returns:
|
|
821
818
|
The output of the tool.
|
|
@@ -833,24 +830,36 @@ class ChildTool(BaseTool):
|
|
|
833
830
|
self.metadata,
|
|
834
831
|
)
|
|
835
832
|
|
|
833
|
+
# Filter out injected arguments from callback inputs
|
|
834
|
+
filtered_tool_input = (
|
|
835
|
+
self._filter_injected_args(tool_input)
|
|
836
|
+
if isinstance(tool_input, dict)
|
|
837
|
+
else None
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
# Use filtered inputs for the input_str parameter as well
|
|
841
|
+
tool_input_str = (
|
|
842
|
+
tool_input
|
|
843
|
+
if isinstance(tool_input, str)
|
|
844
|
+
else str(
|
|
845
|
+
filtered_tool_input if filtered_tool_input is not None else tool_input
|
|
846
|
+
)
|
|
847
|
+
)
|
|
848
|
+
|
|
836
849
|
run_manager = callback_manager.on_tool_start(
|
|
837
850
|
{"name": self.name, "description": self.description},
|
|
838
|
-
|
|
851
|
+
tool_input_str,
|
|
839
852
|
color=start_color,
|
|
840
853
|
name=run_name,
|
|
841
854
|
run_id=run_id,
|
|
842
|
-
|
|
843
|
-
# For now, it's unclear whether this assumption is ever violated,
|
|
844
|
-
# but if it is we will send a `None` value to the callback instead
|
|
845
|
-
# TODO: will need to address issue via a patch.
|
|
846
|
-
inputs=tool_input if isinstance(tool_input, dict) else None,
|
|
855
|
+
inputs=filtered_tool_input,
|
|
847
856
|
**kwargs,
|
|
848
857
|
)
|
|
849
858
|
|
|
850
859
|
content = None
|
|
851
860
|
artifact = None
|
|
852
861
|
status = "success"
|
|
853
|
-
error_to_raise:
|
|
862
|
+
error_to_raise: Exception | KeyboardInterrupt | None = None
|
|
854
863
|
try:
|
|
855
864
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
|
856
865
|
with set_config_context(child_config) as context:
|
|
@@ -899,35 +908,35 @@ class ChildTool(BaseTool):
|
|
|
899
908
|
|
|
900
909
|
async def arun(
|
|
901
910
|
self,
|
|
902
|
-
tool_input:
|
|
903
|
-
verbose:
|
|
904
|
-
start_color:
|
|
905
|
-
color:
|
|
911
|
+
tool_input: str | dict,
|
|
912
|
+
verbose: bool | None = None, # noqa: FBT001
|
|
913
|
+
start_color: str | None = "green",
|
|
914
|
+
color: str | None = "green",
|
|
906
915
|
callbacks: Callbacks = None,
|
|
907
916
|
*,
|
|
908
|
-
tags:
|
|
909
|
-
metadata:
|
|
910
|
-
run_name:
|
|
911
|
-
run_id:
|
|
912
|
-
config:
|
|
913
|
-
tool_call_id:
|
|
917
|
+
tags: list[str] | None = None,
|
|
918
|
+
metadata: dict[str, Any] | None = None,
|
|
919
|
+
run_name: str | None = None,
|
|
920
|
+
run_id: uuid.UUID | None = None,
|
|
921
|
+
config: RunnableConfig | None = None,
|
|
922
|
+
tool_call_id: str | None = None,
|
|
914
923
|
**kwargs: Any,
|
|
915
924
|
) -> Any:
|
|
916
925
|
"""Run the tool asynchronously.
|
|
917
926
|
|
|
918
927
|
Args:
|
|
919
928
|
tool_input: The input to the tool.
|
|
920
|
-
verbose: Whether to log the tool's progress.
|
|
921
|
-
start_color: The color to use when starting the tool.
|
|
922
|
-
color: The color to use when ending the tool.
|
|
923
|
-
callbacks: Callbacks to be called during tool execution.
|
|
924
|
-
tags: Optional list of tags associated with the tool.
|
|
925
|
-
metadata: Optional metadata associated with the tool.
|
|
926
|
-
run_name: The name of the run.
|
|
927
|
-
run_id: The id of the run.
|
|
928
|
-
config: The configuration for the tool.
|
|
929
|
-
tool_call_id: The id of the tool call.
|
|
930
|
-
kwargs: Keyword arguments to be passed to tool callbacks
|
|
929
|
+
verbose: Whether to log the tool's progress.
|
|
930
|
+
start_color: The color to use when starting the tool.
|
|
931
|
+
color: The color to use when ending the tool.
|
|
932
|
+
callbacks: Callbacks to be called during tool execution.
|
|
933
|
+
tags: Optional list of tags associated with the tool.
|
|
934
|
+
metadata: Optional metadata associated with the tool.
|
|
935
|
+
run_name: The name of the run.
|
|
936
|
+
run_id: The id of the run.
|
|
937
|
+
config: The configuration for the tool.
|
|
938
|
+
tool_call_id: The id of the tool call.
|
|
939
|
+
**kwargs: Keyword arguments to be passed to tool callbacks
|
|
931
940
|
|
|
932
941
|
Returns:
|
|
933
942
|
The output of the tool.
|
|
@@ -944,23 +953,36 @@ class ChildTool(BaseTool):
|
|
|
944
953
|
metadata,
|
|
945
954
|
self.metadata,
|
|
946
955
|
)
|
|
956
|
+
|
|
957
|
+
# Filter out injected arguments from callback inputs
|
|
958
|
+
filtered_tool_input = (
|
|
959
|
+
self._filter_injected_args(tool_input)
|
|
960
|
+
if isinstance(tool_input, dict)
|
|
961
|
+
else None
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
# Use filtered inputs for the input_str parameter as well
|
|
965
|
+
tool_input_str = (
|
|
966
|
+
tool_input
|
|
967
|
+
if isinstance(tool_input, str)
|
|
968
|
+
else str(
|
|
969
|
+
filtered_tool_input if filtered_tool_input is not None else tool_input
|
|
970
|
+
)
|
|
971
|
+
)
|
|
972
|
+
|
|
947
973
|
run_manager = await callback_manager.on_tool_start(
|
|
948
974
|
{"name": self.name, "description": self.description},
|
|
949
|
-
|
|
975
|
+
tool_input_str,
|
|
950
976
|
color=start_color,
|
|
951
977
|
name=run_name,
|
|
952
978
|
run_id=run_id,
|
|
953
|
-
|
|
954
|
-
# For now, it's unclear whether this assumption is ever violated,
|
|
955
|
-
# but if it is we will send a `None` value to the callback instead
|
|
956
|
-
# TODO: will need to address issue via a patch.
|
|
957
|
-
inputs=tool_input if isinstance(tool_input, dict) else None,
|
|
979
|
+
inputs=filtered_tool_input,
|
|
958
980
|
**kwargs,
|
|
959
981
|
)
|
|
960
982
|
content = None
|
|
961
983
|
artifact = None
|
|
962
984
|
status = "success"
|
|
963
|
-
error_to_raise:
|
|
985
|
+
error_to_raise: Exception | KeyboardInterrupt | None = None
|
|
964
986
|
try:
|
|
965
987
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
|
966
988
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
|
@@ -1011,19 +1033,6 @@ class ChildTool(BaseTool):
|
|
|
1011
1033
|
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
|
1012
1034
|
return output
|
|
1013
1035
|
|
|
1014
|
-
@deprecated("0.1.47", alternative="invoke", removal="1.0")
|
|
1015
|
-
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
|
1016
|
-
"""Make tool callable (deprecated).
|
|
1017
|
-
|
|
1018
|
-
Args:
|
|
1019
|
-
tool_input: The input to the tool.
|
|
1020
|
-
callbacks: Callbacks to use during execution.
|
|
1021
|
-
|
|
1022
|
-
Returns:
|
|
1023
|
-
The tool's output.
|
|
1024
|
-
"""
|
|
1025
|
-
return self.run(tool_input, callbacks=callbacks)
|
|
1026
|
-
|
|
1027
1036
|
|
|
1028
1037
|
def _is_tool_call(x: Any) -> bool:
|
|
1029
1038
|
"""Check if the input is a tool call dictionary.
|
|
@@ -1032,23 +1041,21 @@ def _is_tool_call(x: Any) -> bool:
|
|
|
1032
1041
|
x: The input to check.
|
|
1033
1042
|
|
|
1034
1043
|
Returns:
|
|
1035
|
-
True if the input is a tool call, False otherwise.
|
|
1044
|
+
`True` if the input is a tool call, `False` otherwise.
|
|
1036
1045
|
"""
|
|
1037
1046
|
return isinstance(x, dict) and x.get("type") == "tool_call"
|
|
1038
1047
|
|
|
1039
1048
|
|
|
1040
1049
|
def _handle_validation_error(
|
|
1041
|
-
e:
|
|
1050
|
+
e: ValidationError | ValidationErrorV1,
|
|
1042
1051
|
*,
|
|
1043
|
-
flag:
|
|
1044
|
-
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
|
|
1045
|
-
],
|
|
1052
|
+
flag: Literal[True] | str | Callable[[ValidationError | ValidationErrorV1], str],
|
|
1046
1053
|
) -> str:
|
|
1047
1054
|
"""Handle validation errors based on the configured flag.
|
|
1048
1055
|
|
|
1049
1056
|
Args:
|
|
1050
1057
|
e: The validation error that occurred.
|
|
1051
|
-
flag: How to handle the error (bool
|
|
1058
|
+
flag: How to handle the error (`bool`, `str`, or `Callable`).
|
|
1052
1059
|
|
|
1053
1060
|
Returns:
|
|
1054
1061
|
The error message to return.
|
|
@@ -1074,13 +1081,13 @@ def _handle_validation_error(
|
|
|
1074
1081
|
def _handle_tool_error(
|
|
1075
1082
|
e: ToolException,
|
|
1076
1083
|
*,
|
|
1077
|
-
flag:
|
|
1084
|
+
flag: Literal[True] | str | Callable[[ToolException], str] | None,
|
|
1078
1085
|
) -> str:
|
|
1079
1086
|
"""Handle tool execution errors based on the configured flag.
|
|
1080
1087
|
|
|
1081
1088
|
Args:
|
|
1082
1089
|
e: The tool exception that occurred.
|
|
1083
|
-
flag: How to handle the error (bool
|
|
1090
|
+
flag: How to handle the error (`bool`, `str`, or `Callable`).
|
|
1084
1091
|
|
|
1085
1092
|
Returns:
|
|
1086
1093
|
The error message to return.
|
|
@@ -1104,27 +1111,27 @@ def _handle_tool_error(
|
|
|
1104
1111
|
|
|
1105
1112
|
|
|
1106
1113
|
def _prep_run_args(
|
|
1107
|
-
value:
|
|
1108
|
-
config:
|
|
1114
|
+
value: str | dict | ToolCall,
|
|
1115
|
+
config: RunnableConfig | None,
|
|
1109
1116
|
**kwargs: Any,
|
|
1110
|
-
) -> tuple[
|
|
1117
|
+
) -> tuple[str | dict, dict]:
|
|
1111
1118
|
"""Prepare arguments for tool execution.
|
|
1112
1119
|
|
|
1113
1120
|
Args:
|
|
1114
|
-
value: The input value (
|
|
1121
|
+
value: The input value (`str`, `dict`, or `ToolCall`).
|
|
1115
1122
|
config: The runnable configuration.
|
|
1116
1123
|
**kwargs: Additional keyword arguments.
|
|
1117
1124
|
|
|
1118
1125
|
Returns:
|
|
1119
|
-
A tuple of (tool_input, run_kwargs)
|
|
1126
|
+
A tuple of `(tool_input, run_kwargs)`.
|
|
1120
1127
|
"""
|
|
1121
1128
|
config = ensure_config(config)
|
|
1122
1129
|
if _is_tool_call(value):
|
|
1123
|
-
tool_call_id:
|
|
1124
|
-
tool_input:
|
|
1130
|
+
tool_call_id: str | None = cast("ToolCall", value)["id"]
|
|
1131
|
+
tool_input: str | dict = cast("ToolCall", value)["args"].copy()
|
|
1125
1132
|
else:
|
|
1126
1133
|
tool_call_id = None
|
|
1127
|
-
tool_input = cast("
|
|
1134
|
+
tool_input = cast("str | dict", value)
|
|
1128
1135
|
return (
|
|
1129
1136
|
tool_input,
|
|
1130
1137
|
dict(
|
|
@@ -1143,11 +1150,11 @@ def _prep_run_args(
|
|
|
1143
1150
|
def _format_output(
|
|
1144
1151
|
content: Any,
|
|
1145
1152
|
artifact: Any,
|
|
1146
|
-
tool_call_id:
|
|
1153
|
+
tool_call_id: str | None,
|
|
1147
1154
|
name: str,
|
|
1148
1155
|
status: str,
|
|
1149
|
-
) ->
|
|
1150
|
-
"""Format tool output as a ToolMessage if appropriate.
|
|
1156
|
+
) -> ToolOutputMixin | Any:
|
|
1157
|
+
"""Format tool output as a `ToolMessage` if appropriate.
|
|
1151
1158
|
|
|
1152
1159
|
Args:
|
|
1153
1160
|
content: The main content of the tool output.
|
|
@@ -1157,7 +1164,7 @@ def _format_output(
|
|
|
1157
1164
|
status: The execution status.
|
|
1158
1165
|
|
|
1159
1166
|
Returns:
|
|
1160
|
-
The formatted output, either as a ToolMessage or the original content.
|
|
1167
|
+
The formatted output, either as a `ToolMessage` or the original content.
|
|
1161
1168
|
"""
|
|
1162
1169
|
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
|
|
1163
1170
|
return content
|
|
@@ -1181,7 +1188,7 @@ def _is_message_content_type(obj: Any) -> bool:
|
|
|
1181
1188
|
obj: The object to check.
|
|
1182
1189
|
|
|
1183
1190
|
Returns:
|
|
1184
|
-
True if the object is valid message content, False otherwise.
|
|
1191
|
+
`True` if the object is valid message content, `False` otherwise.
|
|
1185
1192
|
"""
|
|
1186
1193
|
return isinstance(obj, str) or (
|
|
1187
1194
|
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
|
|
@@ -1197,7 +1204,7 @@ def _is_message_content_block(obj: Any) -> bool:
|
|
|
1197
1204
|
obj: The object to check.
|
|
1198
1205
|
|
|
1199
1206
|
Returns:
|
|
1200
|
-
True if the object is a valid content block, False otherwise.
|
|
1207
|
+
`True` if the object is a valid content block, `False` otherwise.
|
|
1201
1208
|
"""
|
|
1202
1209
|
if isinstance(obj, str):
|
|
1203
1210
|
return True
|
|
@@ -1221,14 +1228,14 @@ def _stringify(content: Any) -> str:
|
|
|
1221
1228
|
return str(content)
|
|
1222
1229
|
|
|
1223
1230
|
|
|
1224
|
-
def _get_type_hints(func: Callable) ->
|
|
1231
|
+
def _get_type_hints(func: Callable) -> dict[str, type] | None:
|
|
1225
1232
|
"""Get type hints from a function, handling partial functions.
|
|
1226
1233
|
|
|
1227
1234
|
Args:
|
|
1228
1235
|
func: The function to get type hints from.
|
|
1229
1236
|
|
|
1230
1237
|
Returns:
|
|
1231
|
-
|
|
1238
|
+
`dict` of type hints, or `None` if extraction fails.
|
|
1232
1239
|
"""
|
|
1233
1240
|
if isinstance(func, functools.partial):
|
|
1234
1241
|
func = func.func
|
|
@@ -1238,14 +1245,14 @@ def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
|
|
|
1238
1245
|
return None
|
|
1239
1246
|
|
|
1240
1247
|
|
|
1241
|
-
def _get_runnable_config_param(func: Callable) ->
|
|
1242
|
-
"""Find the parameter name for RunnableConfig in a function.
|
|
1248
|
+
def _get_runnable_config_param(func: Callable) -> str | None:
|
|
1249
|
+
"""Find the parameter name for `RunnableConfig` in a function.
|
|
1243
1250
|
|
|
1244
1251
|
Args:
|
|
1245
1252
|
func: The function to check.
|
|
1246
1253
|
|
|
1247
1254
|
Returns:
|
|
1248
|
-
The parameter name for RunnableConfig
|
|
1255
|
+
The parameter name for `RunnableConfig`, or `None` if not found.
|
|
1249
1256
|
"""
|
|
1250
1257
|
type_hints = _get_type_hints(func)
|
|
1251
1258
|
if not type_hints:
|
|
@@ -1264,35 +1271,75 @@ class InjectedToolArg:
|
|
|
1264
1271
|
"""
|
|
1265
1272
|
|
|
1266
1273
|
|
|
1274
|
+
class _DirectlyInjectedToolArg:
|
|
1275
|
+
"""Annotation for tool arguments that are injected at runtime.
|
|
1276
|
+
|
|
1277
|
+
Injected via direct type annotation, rather than annotated metadata.
|
|
1278
|
+
|
|
1279
|
+
For example, `ToolRuntime` is a directly injected argument.
|
|
1280
|
+
|
|
1281
|
+
Note the direct annotation rather than the verbose alternative:
|
|
1282
|
+
`Annotated[ToolRuntime, InjectedRuntime]`
|
|
1283
|
+
|
|
1284
|
+
```python
|
|
1285
|
+
from langchain_core.tools import tool, ToolRuntime
|
|
1286
|
+
|
|
1287
|
+
|
|
1288
|
+
@tool
|
|
1289
|
+
def foo(x: int, runtime: ToolRuntime) -> str:
|
|
1290
|
+
# use runtime.state, runtime.context, runtime.store, etc.
|
|
1291
|
+
...
|
|
1292
|
+
```
|
|
1293
|
+
"""
|
|
1294
|
+
|
|
1295
|
+
|
|
1267
1296
|
class InjectedToolCallId(InjectedToolArg):
|
|
1268
1297
|
"""Annotation for injecting the tool call ID.
|
|
1269
1298
|
|
|
1270
1299
|
This annotation is used to mark a tool parameter that should receive
|
|
1271
1300
|
the tool call ID at runtime.
|
|
1272
1301
|
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
)
|
|
1302
|
+
```python
|
|
1303
|
+
from typing import Annotated
|
|
1304
|
+
from langchain_core.messages import ToolMessage
|
|
1305
|
+
from langchain_core.tools import tool, InjectedToolCallId
|
|
1306
|
+
|
|
1307
|
+
@tool
|
|
1308
|
+
def foo(
|
|
1309
|
+
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
|
|
1310
|
+
) -> ToolMessage:
|
|
1311
|
+
\"\"\"Return x.\"\"\"
|
|
1312
|
+
return ToolMessage(
|
|
1313
|
+
str(x),
|
|
1314
|
+
artifact=x,
|
|
1315
|
+
name="foo",
|
|
1316
|
+
tool_call_id=tool_call_id
|
|
1317
|
+
)
|
|
1290
1318
|
|
|
1319
|
+
```
|
|
1291
1320
|
"""
|
|
1292
1321
|
|
|
1293
1322
|
|
|
1323
|
+
def _is_directly_injected_arg_type(type_: Any) -> bool:
|
|
1324
|
+
"""Check if a type annotation indicates a directly injected argument.
|
|
1325
|
+
|
|
1326
|
+
This is currently only used for `ToolRuntime`.
|
|
1327
|
+
Checks if either the annotation itself is a subclass of `_DirectlyInjectedToolArg`
|
|
1328
|
+
or the origin of the annotation is a subclass of `_DirectlyInjectedToolArg`.
|
|
1329
|
+
|
|
1330
|
+
Ex: `ToolRuntime` or `ToolRuntime[ContextT, StateT]` would both return `True`.
|
|
1331
|
+
"""
|
|
1332
|
+
return (
|
|
1333
|
+
isinstance(type_, type) and issubclass(type_, _DirectlyInjectedToolArg)
|
|
1334
|
+
) or (
|
|
1335
|
+
(origin := get_origin(type_)) is not None
|
|
1336
|
+
and isinstance(origin, type)
|
|
1337
|
+
and issubclass(origin, _DirectlyInjectedToolArg)
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
|
|
1294
1341
|
def _is_injected_arg_type(
|
|
1295
|
-
type_:
|
|
1342
|
+
type_: type | TypeVar, injected_type: type[InjectedToolArg] | None = None
|
|
1296
1343
|
) -> bool:
|
|
1297
1344
|
"""Check if a type annotation indicates an injected argument.
|
|
1298
1345
|
|
|
@@ -1301,9 +1348,17 @@ def _is_injected_arg_type(
|
|
|
1301
1348
|
injected_type: The specific injected type to check for.
|
|
1302
1349
|
|
|
1303
1350
|
Returns:
|
|
1304
|
-
True if the type is an injected argument, False otherwise.
|
|
1351
|
+
`True` if the type is an injected argument, `False` otherwise.
|
|
1305
1352
|
"""
|
|
1306
|
-
|
|
1353
|
+
if injected_type is None:
|
|
1354
|
+
# if no injected type is specified,
|
|
1355
|
+
# check if the type is a directly injected argument
|
|
1356
|
+
if _is_directly_injected_arg_type(type_):
|
|
1357
|
+
return True
|
|
1358
|
+
injected_type = InjectedToolArg
|
|
1359
|
+
|
|
1360
|
+
# if the type is an Annotated type, check if annotated metadata
|
|
1361
|
+
# is an intance or subclass of the injected type
|
|
1307
1362
|
return any(
|
|
1308
1363
|
isinstance(arg, injected_type)
|
|
1309
1364
|
or (isinstance(arg, type) and issubclass(arg, injected_type))
|
|
@@ -1312,23 +1367,23 @@ def _is_injected_arg_type(
|
|
|
1312
1367
|
|
|
1313
1368
|
|
|
1314
1369
|
def get_all_basemodel_annotations(
|
|
1315
|
-
cls:
|
|
1316
|
-
) -> dict[str,
|
|
1317
|
-
"""Get all annotations from a Pydantic BaseModel and its parents.
|
|
1370
|
+
cls: TypeBaseModel | Any, *, default_to_bound: bool = True
|
|
1371
|
+
) -> dict[str, type | TypeVar]:
|
|
1372
|
+
"""Get all annotations from a Pydantic `BaseModel` and its parents.
|
|
1318
1373
|
|
|
1319
1374
|
Args:
|
|
1320
|
-
cls: The Pydantic BaseModel class.
|
|
1321
|
-
default_to_bound: Whether to default to the bound of a TypeVar if it exists.
|
|
1375
|
+
cls: The Pydantic `BaseModel` class.
|
|
1376
|
+
default_to_bound: Whether to default to the bound of a `TypeVar` if it exists.
|
|
1322
1377
|
|
|
1323
1378
|
Returns:
|
|
1324
|
-
|
|
1379
|
+
`dict` of field names to their type annotations.
|
|
1325
1380
|
"""
|
|
1326
1381
|
# cls has no subscript: cls = FooBar
|
|
1327
1382
|
if isinstance(cls, type):
|
|
1328
1383
|
fields = get_fields(cls)
|
|
1329
1384
|
alias_map = {field.alias: name for name, field in fields.items() if field.alias}
|
|
1330
1385
|
|
|
1331
|
-
annotations: dict[str,
|
|
1386
|
+
annotations: dict[str, type | TypeVar] = {}
|
|
1332
1387
|
for name, param in inspect.signature(cls).parameters.items():
|
|
1333
1388
|
# Exclude hidden init args added by pydantic Config. For example if
|
|
1334
1389
|
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
|
@@ -1369,7 +1424,7 @@ def get_all_basemodel_annotations(
|
|
|
1369
1424
|
# generic_type_vars = (type vars in Baz)
|
|
1370
1425
|
# generic_map = {type var in Baz: str}
|
|
1371
1426
|
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
|
|
1372
|
-
generic_map = dict(zip(generic_type_vars, get_args(parent)))
|
|
1427
|
+
generic_map = dict(zip(generic_type_vars, get_args(parent), strict=False))
|
|
1373
1428
|
for field in getattr(parent_origin, "__annotations__", {}):
|
|
1374
1429
|
annotations[field] = _replace_type_vars(
|
|
1375
1430
|
annotations[field], generic_map, default_to_bound=default_to_bound
|
|
@@ -1382,20 +1437,20 @@ def get_all_basemodel_annotations(
|
|
|
1382
1437
|
|
|
1383
1438
|
|
|
1384
1439
|
def _replace_type_vars(
|
|
1385
|
-
type_:
|
|
1386
|
-
generic_map:
|
|
1440
|
+
type_: type | TypeVar,
|
|
1441
|
+
generic_map: dict[TypeVar, type] | None = None,
|
|
1387
1442
|
*,
|
|
1388
1443
|
default_to_bound: bool = True,
|
|
1389
|
-
) ->
|
|
1390
|
-
"""Replace
|
|
1444
|
+
) -> type | TypeVar:
|
|
1445
|
+
"""Replace `TypeVar`s in a type annotation with concrete types.
|
|
1391
1446
|
|
|
1392
1447
|
Args:
|
|
1393
1448
|
type_: The type annotation to process.
|
|
1394
|
-
generic_map: Mapping of
|
|
1395
|
-
default_to_bound: Whether to use TypeVar bounds as defaults.
|
|
1449
|
+
generic_map: Mapping of `TypeVar`s to concrete types.
|
|
1450
|
+
default_to_bound: Whether to use `TypeVar` bounds as defaults.
|
|
1396
1451
|
|
|
1397
1452
|
Returns:
|
|
1398
|
-
The type with
|
|
1453
|
+
The type with `TypeVar`s replaced.
|
|
1399
1454
|
"""
|
|
1400
1455
|
generic_map = generic_map or {}
|
|
1401
1456
|
if isinstance(type_, TypeVar):
|