langchain 1.1.2__tar.gz → 1.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langchain-1.1.2 → langchain-1.2.2}/PKG-INFO +2 -2
- {langchain-1.1.2 → langchain-1.2.2}/langchain/__init__.py +1 -1
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/factory.py +16 -13
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/_redaction.py +27 -12
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/_retry.py +1 -1
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/context_editing.py +2 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/model_call_limit.py +6 -4
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/model_fallback.py +4 -4
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/model_retry.py +2 -2
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/pii.py +8 -2
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/shell_tool.py +134 -12
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/summarization.py +35 -17
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/todo.py +85 -6
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/tool_call_limit.py +4 -3
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/tool_retry.py +2 -2
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/tool_selection.py +3 -3
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/types.py +114 -19
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/structured_output.py +22 -12
- {langchain-1.1.2 → langchain-1.2.2}/langchain/chat_models/base.py +217 -171
- {langchain-1.1.2 → langchain-1.2.2}/langchain/embeddings/base.py +79 -65
- {langchain-1.1.2 → langchain-1.2.2}/pyproject.toml +34 -68
- {langchain-1.1.2 → langchain-1.2.2}/scripts/check_imports.py +3 -1
- langchain-1.2.2/tests/cassettes/test_strict_mode[False].yaml.gz +0 -0
- langchain-1.2.2/tests/cassettes/test_strict_mode[True].yaml.gz +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/agents/middleware/test_shell_tool_integration.py +20 -20
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/cache/fake_embeddings.py +4 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/chat_models/test_base.py +2 -2
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/conftest.py +1 -2
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/embeddings/test_base.py +1 -1
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/any_str.py +3 -2
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/conftest.py +1 -1
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/memory_assert.py +1 -2
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/messages.py +2 -2
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_composition.py +76 -34
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_decorators.py +144 -86
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_diagram.py +26 -25
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_framework.py +60 -85
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_overrides.py +2 -10
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_sync_async_wrappers.py +83 -20
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_tools.py +19 -19
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py +466 -121
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/test_wrap_tool_call.py +101 -44
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_context_editing.py +39 -30
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_file_search.py +41 -2
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py +123 -105
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py +29 -26
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_model_fallback.py +107 -32
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_model_retry.py +6 -6
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_pii.py +134 -129
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_shell_execution_policies.py +25 -21
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py +10 -7
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_summarization.py +84 -21
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_todo.py +299 -19
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit.py +12 -4
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_tool_emulator.py +21 -22
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_tool_retry.py +58 -58
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_tool_selection.py +14 -25
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/model.py +3 -4
- langchain-1.2.2/tests/unit_tests/agents/test_agent_name.py +99 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_create_agent_tool_validation.py +14 -5
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_injected_runtime_create_agent.py +12 -16
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_response_format.py +59 -38
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_response_format_integration.py +59 -12
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_responses.py +2 -2
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_responses_spec.py +4 -3
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_return_direct_spec.py +2 -2
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_state_schema.py +8 -4
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_system_message.py +27 -29
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/chat_models/test_chat_models.py +23 -3
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/conftest.py +5 -8
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/embeddings/test_base.py +15 -14
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/test_dependencies.py +7 -1
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/test_pytest_config.py +1 -3
- langchain-1.2.2/tests/unit_tests/test_version.py +27 -0
- {langchain-1.1.2 → langchain-1.2.2}/uv.lock +148 -138
- langchain-1.1.2/tests/unit_tests/stubs.py +0 -46
- {langchain-1.1.2 → langchain-1.2.2}/.gitignore +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/LICENSE +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/Makefile +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/README.md +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/extended_testing_deps.txt +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/_execution.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/file_search.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/human_in_the_loop.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/agents/middleware/tool_emulator.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/chat_models/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/embeddings/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/messages/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/py.typed +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/rate_limiters/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/tools/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/langchain/tools/tool_node.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/cassettes/test_inference_to_native_output[False].yaml.gz +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/cassettes/test_inference_to_native_output[True].yaml.gz +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/cassettes/test_inference_to_tool_output[False].yaml.gz +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/cassettes/test_inference_to_tool_output[True].yaml.gz +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/agents/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/agents/middleware/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/cache/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/chat_models/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/embeddings/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/integration_tests/test_compile.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/__snapshots__/test_middleware_decorators.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/__snapshots__/test_middleware_framework.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/__snapshots__/test_return_direct_graph.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/compose-postgres.yml +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/compose-redis.yml +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/conftest_checkpointer.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/conftest_store.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_decorators.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_diagram.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_framework.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/__snapshots__/test_decorators.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/__snapshots__/test_diagram.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/core/__snapshots__/test_framework.ambr +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/middleware/implementations/test_structured_output_retry.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/specifications/responses.json +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/specifications/return_direct.json +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_react_agent.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/test_return_direct_graph.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/agents/utils.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/chat_models/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/embeddings/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/embeddings/test_imports.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/test_imports.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/tools/__init__.py +0 -0
- {langchain-1.1.2 → langchain-1.2.2}/tests/unit_tests/tools/test_imports.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: langchain
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.2
|
|
4
4
|
Summary: Building applications with LLMs through composability
|
|
5
5
|
Project-URL: Homepage, https://docs.langchain.com/
|
|
6
6
|
Project-URL: Documentation, https://reference.langchain.com/python/langchain/langchain/
|
|
@@ -12,7 +12,7 @@ Project-URL: Reddit, https://www.reddit.com/r/LangChain/
|
|
|
12
12
|
License: MIT
|
|
13
13
|
License-File: LICENSE
|
|
14
14
|
Requires-Python: <4.0.0,>=3.10.0
|
|
15
|
-
Requires-Dist: langchain-core<2.0.0,>=1.1
|
|
15
|
+
Requires-Dist: langchain-core<2.0.0,>=1.2.1
|
|
16
16
|
Requires-Dist: langgraph<1.1.0,>=1.0.2
|
|
17
17
|
Requires-Dist: pydantic<3.0.0,>=2.7.4
|
|
18
18
|
Provides-Extra: anthropic
|
|
@@ -20,9 +20,7 @@ from langgraph._internal._runnable import RunnableCallable
|
|
|
20
20
|
from langgraph.constants import END, START
|
|
21
21
|
from langgraph.graph.state import StateGraph
|
|
22
22
|
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
|
|
23
|
-
from langgraph.runtime import Runtime # noqa: TC002
|
|
24
23
|
from langgraph.types import Command, Send
|
|
25
|
-
from langgraph.typing import ContextT # noqa: TC002
|
|
26
24
|
from typing_extensions import NotRequired, Required, TypedDict
|
|
27
25
|
|
|
28
26
|
from langchain.agents.middleware.types import (
|
|
@@ -56,8 +54,10 @@ if TYPE_CHECKING:
|
|
|
56
54
|
from langchain_core.runnables import Runnable
|
|
57
55
|
from langgraph.cache.base import BaseCache
|
|
58
56
|
from langgraph.graph.state import CompiledStateGraph
|
|
57
|
+
from langgraph.runtime import Runtime
|
|
59
58
|
from langgraph.store.base import BaseStore
|
|
60
59
|
from langgraph.types import Checkpointer
|
|
60
|
+
from langgraph.typing import ContextT
|
|
61
61
|
|
|
62
62
|
from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper
|
|
63
63
|
|
|
@@ -314,7 +314,7 @@ def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None
|
|
|
314
314
|
def _extract_metadata(type_: type) -> list:
|
|
315
315
|
"""Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
|
|
316
316
|
# Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
|
|
317
|
-
if get_origin(type_) in
|
|
317
|
+
if get_origin(type_) in {Required, NotRequired}:
|
|
318
318
|
inner_type = get_args(type_)[0]
|
|
319
319
|
if get_origin(inner_type) is Annotated:
|
|
320
320
|
return list(get_args(inner_type)[1:])
|
|
@@ -538,7 +538,7 @@ def _chain_async_tool_call_wrappers(
|
|
|
538
538
|
return result
|
|
539
539
|
|
|
540
540
|
|
|
541
|
-
def create_agent(
|
|
541
|
+
def create_agent(
|
|
542
542
|
model: str | BaseChatModel,
|
|
543
543
|
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
|
544
544
|
*,
|
|
@@ -791,9 +791,9 @@ def create_agent( # noqa: PLR0915
|
|
|
791
791
|
default_tools = list(built_in_tools)
|
|
792
792
|
|
|
793
793
|
# validate middleware
|
|
794
|
-
|
|
795
|
-
"Please remove duplicate middleware instances."
|
|
796
|
-
|
|
794
|
+
if len({m.name for m in middleware}) != len(middleware):
|
|
795
|
+
msg = "Please remove duplicate middleware instances."
|
|
796
|
+
raise AssertionError(msg)
|
|
797
797
|
middleware_w_before_agent = [
|
|
798
798
|
m
|
|
799
799
|
for m in middleware
|
|
@@ -886,12 +886,12 @@ def create_agent( # noqa: PLR0915
|
|
|
886
886
|
)
|
|
887
887
|
try:
|
|
888
888
|
structured_response = provider_strategy_binding.parse(output)
|
|
889
|
-
except Exception as exc:
|
|
889
|
+
except Exception as exc:
|
|
890
890
|
schema_name = getattr(
|
|
891
891
|
effective_response_format.schema_spec.schema, "__name__", "response_format"
|
|
892
892
|
)
|
|
893
893
|
validation_error = StructuredOutputValidationError(schema_name, exc, output)
|
|
894
|
-
raise validation_error
|
|
894
|
+
raise validation_error from exc
|
|
895
895
|
else:
|
|
896
896
|
return {"messages": [output], "structured_response": structured_response}
|
|
897
897
|
return {"messages": [output]}
|
|
@@ -937,8 +937,7 @@ def create_agent( # noqa: PLR0915
|
|
|
937
937
|
|
|
938
938
|
tool_message_content = (
|
|
939
939
|
effective_response_format.tool_message_content
|
|
940
|
-
|
|
941
|
-
else f"Returning structured response: {structured_response}"
|
|
940
|
+
or f"Returning structured response: {structured_response}"
|
|
942
941
|
)
|
|
943
942
|
|
|
944
943
|
return {
|
|
@@ -952,13 +951,13 @@ def create_agent( # noqa: PLR0915
|
|
|
952
951
|
],
|
|
953
952
|
"structured_response": structured_response,
|
|
954
953
|
}
|
|
955
|
-
except Exception as exc:
|
|
954
|
+
except Exception as exc:
|
|
956
955
|
exception = StructuredOutputValidationError(tool_call["name"], exc, output)
|
|
957
956
|
should_retry, error_message = _handle_structured_output_error(
|
|
958
957
|
exception, effective_response_format
|
|
959
958
|
)
|
|
960
959
|
if not should_retry:
|
|
961
|
-
raise exception
|
|
960
|
+
raise exception from exc
|
|
962
961
|
|
|
963
962
|
return {
|
|
964
963
|
"messages": [
|
|
@@ -1100,6 +1099,8 @@ def create_agent( # noqa: PLR0915
|
|
|
1100
1099
|
messages = [request.system_message, *messages]
|
|
1101
1100
|
|
|
1102
1101
|
output = model_.invoke(messages)
|
|
1102
|
+
if name:
|
|
1103
|
+
output.name = name
|
|
1103
1104
|
|
|
1104
1105
|
# Handle model output to get messages and structured_response
|
|
1105
1106
|
handled_output = _handle_model_output(output, effective_response_format)
|
|
@@ -1153,6 +1154,8 @@ def create_agent( # noqa: PLR0915
|
|
|
1153
1154
|
messages = [request.system_message, *messages]
|
|
1154
1155
|
|
|
1155
1156
|
output = await model_.ainvoke(messages)
|
|
1157
|
+
if name:
|
|
1158
|
+
output.name = name
|
|
1156
1159
|
|
|
1157
1160
|
# Handle model output to get messages and structured_response
|
|
1158
1161
|
handled_output = _handle_model_output(output, effective_response_format)
|
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import hashlib
|
|
6
6
|
import ipaddress
|
|
7
|
+
import operator
|
|
7
8
|
import re
|
|
8
9
|
from collections.abc import Callable, Sequence
|
|
9
10
|
from dataclasses import dataclass
|
|
@@ -127,7 +128,7 @@ def detect_url(content: str) -> list[PIIMatch]:
|
|
|
127
128
|
for match in re.finditer(scheme_pattern, content):
|
|
128
129
|
url = match.group()
|
|
129
130
|
result = urlparse(url)
|
|
130
|
-
if result.scheme in
|
|
131
|
+
if result.scheme in {"http", "https"} and result.netloc:
|
|
131
132
|
matches.append(
|
|
132
133
|
PIIMatch(
|
|
133
134
|
type="url",
|
|
@@ -179,11 +180,14 @@ BUILTIN_DETECTORS: dict[str, Detector] = {
|
|
|
179
180
|
}
|
|
180
181
|
"""Registry of built-in detectors keyed by type name."""
|
|
181
182
|
|
|
183
|
+
_CARD_NUMBER_MIN_DIGITS = 13
|
|
184
|
+
_CARD_NUMBER_MAX_DIGITS = 19
|
|
185
|
+
|
|
182
186
|
|
|
183
187
|
def _passes_luhn(card_number: str) -> bool:
|
|
184
188
|
"""Validate credit card number using the Luhn checksum."""
|
|
185
189
|
digits = [int(d) for d in card_number if d.isdigit()]
|
|
186
|
-
if not
|
|
190
|
+
if not _CARD_NUMBER_MIN_DIGITS <= len(digits) <= _CARD_NUMBER_MAX_DIGITS:
|
|
187
191
|
return False
|
|
188
192
|
|
|
189
193
|
checksum = 0
|
|
@@ -191,7 +195,7 @@ def _passes_luhn(card_number: str) -> bool:
|
|
|
191
195
|
value = digit
|
|
192
196
|
if index % 2 == 1:
|
|
193
197
|
value *= 2
|
|
194
|
-
if value > 9:
|
|
198
|
+
if value > 9: # noqa: PLR2004
|
|
195
199
|
value -= 9
|
|
196
200
|
checksum += value
|
|
197
201
|
return checksum % 10 == 0
|
|
@@ -199,24 +203,28 @@ def _passes_luhn(card_number: str) -> bool:
|
|
|
199
203
|
|
|
200
204
|
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
201
205
|
result = content
|
|
202
|
-
for match in sorted(matches, key=
|
|
206
|
+
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
|
|
203
207
|
replacement = f"[REDACTED_{match['type'].upper()}]"
|
|
204
208
|
result = result[: match["start"]] + replacement + result[match["end"] :]
|
|
205
209
|
return result
|
|
206
210
|
|
|
207
211
|
|
|
212
|
+
_UNMASKED_CHAR_NUMBER = 4
|
|
213
|
+
_IPV4_PARTS_NUMBER = 4
|
|
214
|
+
|
|
215
|
+
|
|
208
216
|
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
209
217
|
result = content
|
|
210
|
-
for match in sorted(matches, key=
|
|
218
|
+
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
|
|
211
219
|
value = match["value"]
|
|
212
220
|
pii_type = match["type"]
|
|
213
221
|
if pii_type == "email":
|
|
214
222
|
parts = value.split("@")
|
|
215
|
-
if len(parts) == 2:
|
|
223
|
+
if len(parts) == 2: # noqa: PLR2004
|
|
216
224
|
domain_parts = parts[1].split(".")
|
|
217
225
|
masked = (
|
|
218
226
|
f"{parts[0]}@****.{domain_parts[-1]}"
|
|
219
|
-
if len(domain_parts)
|
|
227
|
+
if len(domain_parts) > 1
|
|
220
228
|
else f"{parts[0]}@****"
|
|
221
229
|
)
|
|
222
230
|
else:
|
|
@@ -225,12 +233,15 @@ def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
|
225
233
|
digits_only = "".join(c for c in value if c.isdigit())
|
|
226
234
|
separator = "-" if "-" in value else " " if " " in value else ""
|
|
227
235
|
if separator:
|
|
228
|
-
masked =
|
|
236
|
+
masked = (
|
|
237
|
+
f"****{separator}****{separator}****{separator}"
|
|
238
|
+
f"{digits_only[-_UNMASKED_CHAR_NUMBER:]}"
|
|
239
|
+
)
|
|
229
240
|
else:
|
|
230
|
-
masked = f"************{digits_only[-
|
|
241
|
+
masked = f"************{digits_only[-_UNMASKED_CHAR_NUMBER:]}"
|
|
231
242
|
elif pii_type == "ip":
|
|
232
243
|
octets = value.split(".")
|
|
233
|
-
masked = f"*.*.*.{octets[-1]}" if len(octets) ==
|
|
244
|
+
masked = f"*.*.*.{octets[-1]}" if len(octets) == _IPV4_PARTS_NUMBER else "****"
|
|
234
245
|
elif pii_type == "mac_address":
|
|
235
246
|
separator = ":" if ":" in value else "-"
|
|
236
247
|
masked = (
|
|
@@ -239,14 +250,18 @@ def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
|
239
250
|
elif pii_type == "url":
|
|
240
251
|
masked = "[MASKED_URL]"
|
|
241
252
|
else:
|
|
242
|
-
masked =
|
|
253
|
+
masked = (
|
|
254
|
+
f"****{value[-_UNMASKED_CHAR_NUMBER:]}"
|
|
255
|
+
if len(value) > _UNMASKED_CHAR_NUMBER
|
|
256
|
+
else "****"
|
|
257
|
+
)
|
|
243
258
|
result = result[: match["start"]] + masked + result[match["end"] :]
|
|
244
259
|
return result
|
|
245
260
|
|
|
246
261
|
|
|
247
262
|
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
248
263
|
result = content
|
|
249
|
-
for match in sorted(matches, key=
|
|
264
|
+
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
|
|
250
265
|
digest = hashlib.sha256(match["value"].encode()).hexdigest()[:8]
|
|
251
266
|
replacement = f"<{match['type']}_hash:{digest}>"
|
|
252
267
|
result = result[: match["start"]] + replacement + result[match["end"] :]
|
|
@@ -116,7 +116,7 @@ def calculate_delay(
|
|
|
116
116
|
|
|
117
117
|
if jitter and delay > 0:
|
|
118
118
|
jitter_amount = delay * 0.25 # ±25% jitter
|
|
119
|
-
delay
|
|
119
|
+
delay += random.uniform(-jitter_amount, jitter_amount) # noqa: S311
|
|
120
120
|
# Ensure delay is not negative after jitter
|
|
121
121
|
delay = max(0, delay)
|
|
122
122
|
|
|
@@ -228,6 +228,7 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
228
228
|
|
|
229
229
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
230
230
|
return count_tokens_approximately(messages)
|
|
231
|
+
|
|
231
232
|
else:
|
|
232
233
|
system_msg = [request.system_message] if request.system_message else []
|
|
233
234
|
|
|
@@ -255,6 +256,7 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
255
256
|
|
|
256
257
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
257
258
|
return count_tokens_approximately(messages)
|
|
259
|
+
|
|
258
260
|
else:
|
|
259
261
|
system_msg = [request.system_message] if request.system_message else []
|
|
260
262
|
|
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
|
6
6
|
|
|
7
7
|
from langchain_core.messages import AIMessage
|
|
8
8
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
9
|
-
from typing_extensions import NotRequired
|
|
9
|
+
from typing_extensions import NotRequired, override
|
|
10
10
|
|
|
11
11
|
from langchain.agents.middleware.types import (
|
|
12
12
|
AgentMiddleware,
|
|
@@ -148,7 +148,7 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
148
148
|
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
|
149
149
|
raise ValueError(msg)
|
|
150
150
|
|
|
151
|
-
if exit_behavior not in
|
|
151
|
+
if exit_behavior not in {"end", "error"}:
|
|
152
152
|
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
|
|
153
153
|
raise ValueError(msg)
|
|
154
154
|
|
|
@@ -157,7 +157,8 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
157
157
|
self.exit_behavior = exit_behavior
|
|
158
158
|
|
|
159
159
|
@hook_config(can_jump_to=["end"])
|
|
160
|
-
|
|
160
|
+
@override
|
|
161
|
+
def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
|
|
161
162
|
"""Check model call limits before making a model call.
|
|
162
163
|
|
|
163
164
|
Args:
|
|
@@ -222,7 +223,8 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
222
223
|
"""
|
|
223
224
|
return self.before_model(state, runtime)
|
|
224
225
|
|
|
225
|
-
|
|
226
|
+
@override
|
|
227
|
+
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
|
|
226
228
|
"""Increment model call counts after a model call.
|
|
227
229
|
|
|
228
230
|
Args:
|
|
@@ -87,14 +87,14 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
87
87
|
last_exception: Exception
|
|
88
88
|
try:
|
|
89
89
|
return handler(request)
|
|
90
|
-
except Exception as e:
|
|
90
|
+
except Exception as e:
|
|
91
91
|
last_exception = e
|
|
92
92
|
|
|
93
93
|
# Try fallback models
|
|
94
94
|
for fallback_model in self.models:
|
|
95
95
|
try:
|
|
96
96
|
return handler(request.override(model=fallback_model))
|
|
97
|
-
except Exception as e:
|
|
97
|
+
except Exception as e:
|
|
98
98
|
last_exception = e
|
|
99
99
|
continue
|
|
100
100
|
|
|
@@ -121,14 +121,14 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
121
121
|
last_exception: Exception
|
|
122
122
|
try:
|
|
123
123
|
return await handler(request)
|
|
124
|
-
except Exception as e:
|
|
124
|
+
except Exception as e:
|
|
125
125
|
last_exception = e
|
|
126
126
|
|
|
127
127
|
# Try fallback models
|
|
128
128
|
for fallback_model in self.models:
|
|
129
129
|
try:
|
|
130
130
|
return await handler(request.override(model=fallback_model))
|
|
131
|
-
except Exception as e:
|
|
131
|
+
except Exception as e:
|
|
132
132
|
last_exception = e
|
|
133
133
|
continue
|
|
134
134
|
|
|
@@ -223,7 +223,7 @@ class ModelRetryMiddleware(AgentMiddleware):
|
|
|
223
223
|
for attempt in range(self.max_retries + 1):
|
|
224
224
|
try:
|
|
225
225
|
return handler(request)
|
|
226
|
-
except Exception as exc:
|
|
226
|
+
except Exception as exc:
|
|
227
227
|
attempts_made = attempt + 1 # attempt is 0-indexed
|
|
228
228
|
|
|
229
229
|
# Check if we should retry this exception
|
|
@@ -270,7 +270,7 @@ class ModelRetryMiddleware(AgentMiddleware):
|
|
|
270
270
|
for attempt in range(self.max_retries + 1):
|
|
271
271
|
try:
|
|
272
272
|
return await handler(request)
|
|
273
|
-
except Exception as exc:
|
|
273
|
+
except Exception as exc:
|
|
274
274
|
attempts_made = attempt + 1 # attempt is 0-indexed
|
|
275
275
|
|
|
276
276
|
# Check if we should retry this exception
|
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
from typing import TYPE_CHECKING, Any, Literal
|
|
6
6
|
|
|
7
7
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
|
|
8
|
+
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from langchain.agents.middleware._redaction import (
|
|
10
11
|
PIIDetectionError,
|
|
@@ -92,6 +93,8 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
92
93
|
|
|
93
94
|
def __init__(
|
|
94
95
|
self,
|
|
96
|
+
# From a typing point of view, the literals are covered by 'str'.
|
|
97
|
+
# Nonetheless, we escape PYI051 to keep hints and autocompletion for the caller.
|
|
95
98
|
pii_type: Literal["email", "credit_card", "ip", "mac_address", "url"] | str, # noqa: PYI051
|
|
96
99
|
*,
|
|
97
100
|
strategy: Literal["block", "redact", "mask", "hash"] = "redact",
|
|
@@ -158,10 +161,11 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
158
161
|
return sanitized, matches
|
|
159
162
|
|
|
160
163
|
@hook_config(can_jump_to=["end"])
|
|
164
|
+
@override
|
|
161
165
|
def before_model(
|
|
162
166
|
self,
|
|
163
167
|
state: AgentState,
|
|
164
|
-
runtime: Runtime,
|
|
168
|
+
runtime: Runtime,
|
|
165
169
|
) -> dict[str, Any] | None:
|
|
166
170
|
"""Check user messages and tool results for PII before model invocation.
|
|
167
171
|
|
|
@@ -273,10 +277,11 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
273
277
|
"""
|
|
274
278
|
return self.before_model(state, runtime)
|
|
275
279
|
|
|
280
|
+
@override
|
|
276
281
|
def after_model(
|
|
277
282
|
self,
|
|
278
283
|
state: AgentState,
|
|
279
|
-
runtime: Runtime,
|
|
284
|
+
runtime: Runtime,
|
|
280
285
|
) -> dict[str, Any] | None:
|
|
281
286
|
"""Check AI messages for PII after model invocation.
|
|
282
287
|
|
|
@@ -355,6 +360,7 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
355
360
|
|
|
356
361
|
__all__ = [
|
|
357
362
|
"PIIDetectionError",
|
|
363
|
+
"PIIMatch",
|
|
358
364
|
"PIIMiddleware",
|
|
359
365
|
"detect_credit_card",
|
|
360
366
|
"detect_email",
|
|
@@ -22,7 +22,7 @@ from langchain_core.tools.base import ToolException
|
|
|
22
22
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
23
23
|
from pydantic import BaseModel, model_validator
|
|
24
24
|
from pydantic.json_schema import SkipJsonSchema
|
|
25
|
-
from typing_extensions import NotRequired
|
|
25
|
+
from typing_extensions import NotRequired, override
|
|
26
26
|
|
|
27
27
|
from langchain.agents.middleware._execution import (
|
|
28
28
|
SHELL_TEMP_PREFIX,
|
|
@@ -78,10 +78,10 @@ class _SessionResources:
|
|
|
78
78
|
session: ShellSession
|
|
79
79
|
tempdir: tempfile.TemporaryDirectory[str] | None
|
|
80
80
|
policy: BaseExecutionPolicy
|
|
81
|
-
|
|
81
|
+
finalizer: weakref.finalize = field(init=False, repr=False)
|
|
82
82
|
|
|
83
83
|
def __post_init__(self) -> None:
|
|
84
|
-
self.
|
|
84
|
+
self.finalizer = weakref.finalize(
|
|
85
85
|
self,
|
|
86
86
|
_cleanup_resources,
|
|
87
87
|
self.session,
|
|
@@ -211,9 +211,14 @@ class ShellSession:
|
|
|
211
211
|
with self._lock:
|
|
212
212
|
self._drain_queue()
|
|
213
213
|
payload = command if command.endswith("\n") else f"{command}\n"
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
214
|
+
try:
|
|
215
|
+
self._stdin.write(payload)
|
|
216
|
+
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
|
|
217
|
+
self._stdin.flush()
|
|
218
|
+
except (BrokenPipeError, OSError):
|
|
219
|
+
# The shell exited before we could write the marker command.
|
|
220
|
+
# This happens when commands like 'exit 1' terminate the shell.
|
|
221
|
+
return self._collect_output_after_exit(deadline)
|
|
217
222
|
|
|
218
223
|
return self._collect_output(marker, deadline, timeout)
|
|
219
224
|
|
|
@@ -248,6 +253,10 @@ class ShellSession:
|
|
|
248
253
|
if source == "stdout" and data.startswith(marker):
|
|
249
254
|
_, _, status = data.partition(" ")
|
|
250
255
|
exit_code = self._safe_int(status.strip())
|
|
256
|
+
# Drain any remaining stderr that may have arrived concurrently.
|
|
257
|
+
# The stderr reader thread runs independently, so output might
|
|
258
|
+
# still be in flight when the stdout marker arrives.
|
|
259
|
+
self._drain_remaining_stderr(collected, deadline)
|
|
251
260
|
break
|
|
252
261
|
|
|
253
262
|
total_lines += 1
|
|
@@ -300,6 +309,80 @@ class ShellSession:
|
|
|
300
309
|
total_bytes=total_bytes,
|
|
301
310
|
)
|
|
302
311
|
|
|
312
|
+
def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult:
|
|
313
|
+
"""Collect output after the shell exited unexpectedly.
|
|
314
|
+
|
|
315
|
+
Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
|
|
316
|
+
shell process terminated (e.g., due to an 'exit' command).
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
deadline: Absolute time by which collection must complete.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
`CommandExecutionResult` with collected output and the process exit code.
|
|
323
|
+
"""
|
|
324
|
+
collected: list[str] = []
|
|
325
|
+
total_lines = 0
|
|
326
|
+
total_bytes = 0
|
|
327
|
+
truncated_by_lines = False
|
|
328
|
+
truncated_by_bytes = False
|
|
329
|
+
|
|
330
|
+
# Give reader threads a brief moment to enqueue any remaining output.
|
|
331
|
+
drain_timeout = 0.1
|
|
332
|
+
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
|
333
|
+
|
|
334
|
+
while True:
|
|
335
|
+
remaining = drain_deadline - time.monotonic()
|
|
336
|
+
if remaining <= 0:
|
|
337
|
+
break
|
|
338
|
+
try:
|
|
339
|
+
source, data = self._queue.get(timeout=remaining)
|
|
340
|
+
except queue.Empty:
|
|
341
|
+
break
|
|
342
|
+
|
|
343
|
+
if data is None:
|
|
344
|
+
# EOF marker from a reader thread; continue draining.
|
|
345
|
+
continue
|
|
346
|
+
|
|
347
|
+
total_lines += 1
|
|
348
|
+
encoded = data.encode("utf-8", "replace")
|
|
349
|
+
total_bytes += len(encoded)
|
|
350
|
+
|
|
351
|
+
if total_lines > self._policy.max_output_lines:
|
|
352
|
+
truncated_by_lines = True
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
if (
|
|
356
|
+
self._policy.max_output_bytes is not None
|
|
357
|
+
and total_bytes > self._policy.max_output_bytes
|
|
358
|
+
):
|
|
359
|
+
truncated_by_bytes = True
|
|
360
|
+
continue
|
|
361
|
+
|
|
362
|
+
if source == "stderr":
|
|
363
|
+
stripped = data.rstrip("\n")
|
|
364
|
+
collected.append(f"[stderr] {stripped}")
|
|
365
|
+
if data.endswith("\n"):
|
|
366
|
+
collected.append("\n")
|
|
367
|
+
else:
|
|
368
|
+
collected.append(data)
|
|
369
|
+
|
|
370
|
+
# Get exit code from the terminated process.
|
|
371
|
+
exit_code: int | None = None
|
|
372
|
+
if self._process:
|
|
373
|
+
exit_code = self._process.poll()
|
|
374
|
+
|
|
375
|
+
output = "".join(collected)
|
|
376
|
+
return CommandExecutionResult(
|
|
377
|
+
output=output,
|
|
378
|
+
exit_code=exit_code,
|
|
379
|
+
timed_out=False,
|
|
380
|
+
truncated_by_lines=truncated_by_lines,
|
|
381
|
+
truncated_by_bytes=truncated_by_bytes,
|
|
382
|
+
total_lines=total_lines,
|
|
383
|
+
total_bytes=total_bytes,
|
|
384
|
+
)
|
|
385
|
+
|
|
303
386
|
def _kill_process(self) -> None:
|
|
304
387
|
if not self._process:
|
|
305
388
|
return
|
|
@@ -323,6 +406,37 @@ class ShellSession:
|
|
|
323
406
|
except queue.Empty:
|
|
324
407
|
break
|
|
325
408
|
|
|
409
|
+
def _drain_remaining_stderr(
|
|
410
|
+
self, collected: list[str], deadline: float, drain_timeout: float = 0.05
|
|
411
|
+
) -> None:
|
|
412
|
+
"""Drain any stderr output that arrived concurrently with the done marker.
|
|
413
|
+
|
|
414
|
+
The stdout and stderr reader threads run independently. When a command writes to
|
|
415
|
+
stderr just before exiting, the stderr output may still be in transit when the
|
|
416
|
+
done marker arrives on stdout. This method briefly polls the queue to capture
|
|
417
|
+
such output.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
collected: The list to append collected stderr lines to.
|
|
421
|
+
deadline: The original command deadline (used as an upper bound).
|
|
422
|
+
drain_timeout: Maximum time to wait for additional stderr output.
|
|
423
|
+
"""
|
|
424
|
+
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
|
425
|
+
while True:
|
|
426
|
+
remaining = drain_deadline - time.monotonic()
|
|
427
|
+
if remaining <= 0:
|
|
428
|
+
break
|
|
429
|
+
try:
|
|
430
|
+
source, data = self._queue.get(timeout=remaining)
|
|
431
|
+
except queue.Empty:
|
|
432
|
+
break
|
|
433
|
+
if data is None or source != "stderr":
|
|
434
|
+
continue
|
|
435
|
+
stripped = data.rstrip("\n")
|
|
436
|
+
collected.append(f"[stderr] {stripped}")
|
|
437
|
+
if data.endswith("\n"):
|
|
438
|
+
collected.append("\n")
|
|
439
|
+
|
|
326
440
|
@staticmethod
|
|
327
441
|
def _safe_int(value: str) -> int | None:
|
|
328
442
|
with contextlib.suppress(ValueError):
|
|
@@ -405,6 +519,12 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
405
519
|
Defaults to `HostExecutionPolicy` for native execution.
|
|
406
520
|
redaction_rules: Optional redaction rules to sanitize command output before
|
|
407
521
|
returning it to the model.
|
|
522
|
+
|
|
523
|
+
!!! warning
|
|
524
|
+
Redaction rules are applied post execution and do not prevent
|
|
525
|
+
exfiltration of secrets or sensitive data when using
|
|
526
|
+
`HostExecutionPolicy`.
|
|
527
|
+
|
|
408
528
|
tool_description: Optional override for the registered shell tool
|
|
409
529
|
description.
|
|
410
530
|
tool_name: Name for the registered shell tool.
|
|
@@ -489,7 +609,8 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
489
609
|
normalized[key] = str(value)
|
|
490
610
|
return normalized
|
|
491
611
|
|
|
492
|
-
|
|
612
|
+
@override
|
|
613
|
+
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
|
493
614
|
"""Start the shell session and run startup commands."""
|
|
494
615
|
resources = self._get_or_create_resources(state)
|
|
495
616
|
return {"shell_session_resources": resources}
|
|
@@ -498,7 +619,8 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
498
619
|
"""Async start the shell session and run startup commands."""
|
|
499
620
|
return self.before_agent(state, runtime)
|
|
500
621
|
|
|
501
|
-
|
|
622
|
+
@override
|
|
623
|
+
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
|
502
624
|
"""Run shutdown commands and release resources when an agent completes."""
|
|
503
625
|
resources = state.get("shell_session_resources")
|
|
504
626
|
if not isinstance(resources, _SessionResources):
|
|
@@ -507,7 +629,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
507
629
|
try:
|
|
508
630
|
self._run_shutdown_commands(resources.session)
|
|
509
631
|
finally:
|
|
510
|
-
resources.
|
|
632
|
+
resources.finalizer()
|
|
511
633
|
|
|
512
634
|
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
|
513
635
|
"""Async run shutdown commands and release resources when an agent completes."""
|
|
@@ -568,7 +690,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
568
690
|
return
|
|
569
691
|
for command in self._startup_commands:
|
|
570
692
|
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
|
|
571
|
-
if result.timed_out or (result.exit_code not in
|
|
693
|
+
if result.timed_out or (result.exit_code not in {0, None}):
|
|
572
694
|
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
|
|
573
695
|
raise RuntimeError(msg)
|
|
574
696
|
|
|
@@ -580,7 +702,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
580
702
|
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
|
581
703
|
if result.timed_out:
|
|
582
704
|
LOGGER.warning("Shutdown command '%s' timed out.", command)
|
|
583
|
-
elif result.exit_code not in
|
|
705
|
+
elif result.exit_code not in {0, None}:
|
|
584
706
|
LOGGER.warning(
|
|
585
707
|
"Shutdown command '%s' exited with %s.", command, result.exit_code
|
|
586
708
|
)
|
|
@@ -671,7 +793,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
671
793
|
f"(observed {result.total_bytes})."
|
|
672
794
|
)
|
|
673
795
|
|
|
674
|
-
if result.exit_code not in
|
|
796
|
+
if result.exit_code not in {0, None}:
|
|
675
797
|
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
|
|
676
798
|
final_status: Literal["success", "error"] = "error"
|
|
677
799
|
else:
|