langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +153 -79
- langchain/agents/middleware/__init__.py +18 -23
- langchain/agents/middleware/_execution.py +29 -32
- langchain/agents/middleware/_redaction.py +108 -22
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +47 -25
- langchain/agents/middleware/file_search.py +19 -14
- langchain/agents/middleware/human_in_the_loop.py +87 -57
- langchain/agents/middleware/model_call_limit.py +64 -18
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +307 -0
- langchain/agents/middleware/pii.py +82 -29
- langchain/agents/middleware/shell_tool.py +254 -107
- langchain/agents/middleware/summarization.py +469 -95
- langchain/agents/middleware/todo.py +129 -31
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +47 -38
- langchain/agents/middleware/tool_retry.py +183 -164
- langchain/agents/middleware/tool_selection.py +81 -37
- langchain/agents/middleware/types.py +856 -427
- langchain/agents/structured_output.py +65 -42
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +253 -196
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
- langchain-1.2.4.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -11,18 +11,18 @@ import subprocess
|
|
|
11
11
|
import tempfile
|
|
12
12
|
import threading
|
|
13
13
|
import time
|
|
14
|
-
import typing
|
|
15
14
|
import uuid
|
|
16
15
|
import weakref
|
|
17
16
|
from dataclasses import dataclass, field
|
|
18
17
|
from pathlib import Path
|
|
19
|
-
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
18
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
|
20
19
|
|
|
21
20
|
from langchain_core.messages import ToolMessage
|
|
22
|
-
from langchain_core.tools.base import
|
|
21
|
+
from langchain_core.tools.base import ToolException
|
|
23
22
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
24
23
|
from pydantic import BaseModel, model_validator
|
|
25
|
-
from
|
|
24
|
+
from pydantic.json_schema import SkipJsonSchema
|
|
25
|
+
from typing_extensions import NotRequired, override
|
|
26
26
|
|
|
27
27
|
from langchain.agents.middleware._execution import (
|
|
28
28
|
SHELL_TEMP_PREFIX,
|
|
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
|
|
|
38
38
|
ResolvedRedactionRule,
|
|
39
39
|
)
|
|
40
40
|
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
|
41
|
+
from langchain.tools import ToolRuntime, tool
|
|
41
42
|
|
|
42
43
|
if TYPE_CHECKING:
|
|
43
44
|
from collections.abc import Mapping, Sequence
|
|
44
45
|
|
|
45
46
|
from langgraph.runtime import Runtime
|
|
46
|
-
from langgraph.types import Command
|
|
47
47
|
|
|
48
|
-
from langchain.agents.middleware.types import ToolCallRequest
|
|
49
48
|
|
|
50
49
|
LOGGER = logging.getLogger(__name__)
|
|
51
50
|
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
|
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
|
|
|
59
58
|
"session remains stable. Outputs may be truncated when they become very large, and long "
|
|
60
59
|
"running commands will be terminated once their configured timeout elapses."
|
|
61
60
|
)
|
|
61
|
+
SHELL_TOOL_NAME = "shell"
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
def _cleanup_resources(
|
|
@@ -78,10 +78,10 @@ class _SessionResources:
|
|
|
78
78
|
session: ShellSession
|
|
79
79
|
tempdir: tempfile.TemporaryDirectory[str] | None
|
|
80
80
|
policy: BaseExecutionPolicy
|
|
81
|
-
|
|
81
|
+
finalizer: weakref.finalize = field(init=False, repr=False) # type: ignore[type-arg]
|
|
82
82
|
|
|
83
83
|
def __post_init__(self) -> None:
|
|
84
|
-
self.
|
|
84
|
+
self.finalizer = weakref.finalize(
|
|
85
85
|
self,
|
|
86
86
|
_cleanup_resources,
|
|
87
87
|
self.session,
|
|
@@ -90,7 +90,7 @@ class _SessionResources:
|
|
|
90
90
|
)
|
|
91
91
|
|
|
92
92
|
|
|
93
|
-
class ShellToolState(AgentState):
|
|
93
|
+
class ShellToolState(AgentState[Any]):
|
|
94
94
|
"""Agent state extension for tracking shell session resources."""
|
|
95
95
|
|
|
96
96
|
shell_session_resources: NotRequired[
|
|
@@ -134,7 +134,11 @@ class ShellSession:
|
|
|
134
134
|
self._terminated = False
|
|
135
135
|
|
|
136
136
|
def start(self) -> None:
|
|
137
|
-
"""Start the shell subprocess and reader threads.
|
|
137
|
+
"""Start the shell subprocess and reader threads.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
RuntimeError: If the shell session pipes cannot be initialized.
|
|
141
|
+
"""
|
|
138
142
|
if self._process and self._process.poll() is None:
|
|
139
143
|
return
|
|
140
144
|
|
|
@@ -211,9 +215,14 @@ class ShellSession:
|
|
|
211
215
|
with self._lock:
|
|
212
216
|
self._drain_queue()
|
|
213
217
|
payload = command if command.endswith("\n") else f"{command}\n"
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
218
|
+
try:
|
|
219
|
+
self._stdin.write(payload)
|
|
220
|
+
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
|
|
221
|
+
self._stdin.flush()
|
|
222
|
+
except (BrokenPipeError, OSError):
|
|
223
|
+
# The shell exited before we could write the marker command.
|
|
224
|
+
# This happens when commands like 'exit 1' terminate the shell.
|
|
225
|
+
return self._collect_output_after_exit(deadline)
|
|
217
226
|
|
|
218
227
|
return self._collect_output(marker, deadline, timeout)
|
|
219
228
|
|
|
@@ -248,6 +257,10 @@ class ShellSession:
|
|
|
248
257
|
if source == "stdout" and data.startswith(marker):
|
|
249
258
|
_, _, status = data.partition(" ")
|
|
250
259
|
exit_code = self._safe_int(status.strip())
|
|
260
|
+
# Drain any remaining stderr that may have arrived concurrently.
|
|
261
|
+
# The stderr reader thread runs independently, so output might
|
|
262
|
+
# still be in flight when the stdout marker arrives.
|
|
263
|
+
self._drain_remaining_stderr(collected, deadline)
|
|
251
264
|
break
|
|
252
265
|
|
|
253
266
|
total_lines += 1
|
|
@@ -300,6 +313,80 @@ class ShellSession:
|
|
|
300
313
|
total_bytes=total_bytes,
|
|
301
314
|
)
|
|
302
315
|
|
|
316
|
+
def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult:
|
|
317
|
+
"""Collect output after the shell exited unexpectedly.
|
|
318
|
+
|
|
319
|
+
Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
|
|
320
|
+
shell process terminated (e.g., due to an 'exit' command).
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
deadline: Absolute time by which collection must complete.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
`CommandExecutionResult` with collected output and the process exit code.
|
|
327
|
+
"""
|
|
328
|
+
collected: list[str] = []
|
|
329
|
+
total_lines = 0
|
|
330
|
+
total_bytes = 0
|
|
331
|
+
truncated_by_lines = False
|
|
332
|
+
truncated_by_bytes = False
|
|
333
|
+
|
|
334
|
+
# Give reader threads a brief moment to enqueue any remaining output.
|
|
335
|
+
drain_timeout = 0.1
|
|
336
|
+
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
|
337
|
+
|
|
338
|
+
while True:
|
|
339
|
+
remaining = drain_deadline - time.monotonic()
|
|
340
|
+
if remaining <= 0:
|
|
341
|
+
break
|
|
342
|
+
try:
|
|
343
|
+
source, data = self._queue.get(timeout=remaining)
|
|
344
|
+
except queue.Empty:
|
|
345
|
+
break
|
|
346
|
+
|
|
347
|
+
if data is None:
|
|
348
|
+
# EOF marker from a reader thread; continue draining.
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
total_lines += 1
|
|
352
|
+
encoded = data.encode("utf-8", "replace")
|
|
353
|
+
total_bytes += len(encoded)
|
|
354
|
+
|
|
355
|
+
if total_lines > self._policy.max_output_lines:
|
|
356
|
+
truncated_by_lines = True
|
|
357
|
+
continue
|
|
358
|
+
|
|
359
|
+
if (
|
|
360
|
+
self._policy.max_output_bytes is not None
|
|
361
|
+
and total_bytes > self._policy.max_output_bytes
|
|
362
|
+
):
|
|
363
|
+
truncated_by_bytes = True
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
if source == "stderr":
|
|
367
|
+
stripped = data.rstrip("\n")
|
|
368
|
+
collected.append(f"[stderr] {stripped}")
|
|
369
|
+
if data.endswith("\n"):
|
|
370
|
+
collected.append("\n")
|
|
371
|
+
else:
|
|
372
|
+
collected.append(data)
|
|
373
|
+
|
|
374
|
+
# Get exit code from the terminated process.
|
|
375
|
+
exit_code: int | None = None
|
|
376
|
+
if self._process:
|
|
377
|
+
exit_code = self._process.poll()
|
|
378
|
+
|
|
379
|
+
output = "".join(collected)
|
|
380
|
+
return CommandExecutionResult(
|
|
381
|
+
output=output,
|
|
382
|
+
exit_code=exit_code,
|
|
383
|
+
timed_out=False,
|
|
384
|
+
truncated_by_lines=truncated_by_lines,
|
|
385
|
+
truncated_by_bytes=truncated_by_bytes,
|
|
386
|
+
total_lines=total_lines,
|
|
387
|
+
total_bytes=total_bytes,
|
|
388
|
+
)
|
|
389
|
+
|
|
303
390
|
def _kill_process(self) -> None:
|
|
304
391
|
if not self._process:
|
|
305
392
|
return
|
|
@@ -323,6 +410,37 @@ class ShellSession:
|
|
|
323
410
|
except queue.Empty:
|
|
324
411
|
break
|
|
325
412
|
|
|
413
|
+
def _drain_remaining_stderr(
|
|
414
|
+
self, collected: list[str], deadline: float, drain_timeout: float = 0.05
|
|
415
|
+
) -> None:
|
|
416
|
+
"""Drain any stderr output that arrived concurrently with the done marker.
|
|
417
|
+
|
|
418
|
+
The stdout and stderr reader threads run independently. When a command writes to
|
|
419
|
+
stderr just before exiting, the stderr output may still be in transit when the
|
|
420
|
+
done marker arrives on stdout. This method briefly polls the queue to capture
|
|
421
|
+
such output.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
collected: The list to append collected stderr lines to.
|
|
425
|
+
deadline: The original command deadline (used as an upper bound).
|
|
426
|
+
drain_timeout: Maximum time to wait for additional stderr output.
|
|
427
|
+
"""
|
|
428
|
+
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
|
429
|
+
while True:
|
|
430
|
+
remaining = drain_deadline - time.monotonic()
|
|
431
|
+
if remaining <= 0:
|
|
432
|
+
break
|
|
433
|
+
try:
|
|
434
|
+
source, data = self._queue.get(timeout=remaining)
|
|
435
|
+
except queue.Empty:
|
|
436
|
+
break
|
|
437
|
+
if data is None or source != "stderr":
|
|
438
|
+
continue
|
|
439
|
+
stripped = data.rstrip("\n")
|
|
440
|
+
collected.append(f"[stderr] {stripped}")
|
|
441
|
+
if data.endswith("\n"):
|
|
442
|
+
collected.append("\n")
|
|
443
|
+
|
|
326
444
|
@staticmethod
|
|
327
445
|
def _safe_int(value: str) -> int | None:
|
|
328
446
|
with contextlib.suppress(ValueError):
|
|
@@ -334,7 +452,17 @@ class _ShellToolInput(BaseModel):
|
|
|
334
452
|
"""Input schema for the persistent shell tool."""
|
|
335
453
|
|
|
336
454
|
command: str | None = None
|
|
455
|
+
"""The shell command to execute."""
|
|
456
|
+
|
|
337
457
|
restart: bool | None = None
|
|
458
|
+
"""Whether to restart the shell session."""
|
|
459
|
+
|
|
460
|
+
runtime: Annotated[Any, SkipJsonSchema()] = None
|
|
461
|
+
"""The runtime for the shell tool.
|
|
462
|
+
|
|
463
|
+
Included as a workaround at the moment bc args_schema doesn't work with
|
|
464
|
+
injected ToolRuntime.
|
|
465
|
+
"""
|
|
338
466
|
|
|
339
467
|
@model_validator(mode="after")
|
|
340
468
|
def validate_payload(self) -> _ShellToolInput:
|
|
@@ -347,38 +475,21 @@ class _ShellToolInput(BaseModel):
|
|
|
347
475
|
return self
|
|
348
476
|
|
|
349
477
|
|
|
350
|
-
class _PersistentShellTool(BaseTool):
|
|
351
|
-
"""Tool wrapper that relies on middleware interception for execution."""
|
|
352
|
-
|
|
353
|
-
name: str = "shell"
|
|
354
|
-
description: str = DEFAULT_TOOL_DESCRIPTION
|
|
355
|
-
args_schema: type[BaseModel] = _ShellToolInput
|
|
356
|
-
|
|
357
|
-
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
|
|
358
|
-
super().__init__()
|
|
359
|
-
self._middleware = middleware
|
|
360
|
-
if description is not None:
|
|
361
|
-
self.description = description
|
|
362
|
-
|
|
363
|
-
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
|
|
364
|
-
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
|
|
365
|
-
raise RuntimeError(msg)
|
|
366
|
-
|
|
367
|
-
|
|
368
478
|
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
369
479
|
"""Middleware that registers a persistent shell tool for agents.
|
|
370
480
|
|
|
371
|
-
The middleware exposes a single long-lived shell session. Use the execution policy
|
|
372
|
-
match your deployment's security posture:
|
|
481
|
+
The middleware exposes a single long-lived shell session. Use the execution policy
|
|
482
|
+
to match your deployment's security posture:
|
|
373
483
|
|
|
374
|
-
*
|
|
375
|
-
|
|
376
|
-
*
|
|
377
|
-
|
|
378
|
-
*
|
|
379
|
-
|
|
484
|
+
* `HostExecutionPolicy` – full host access; best for trusted environments where the
|
|
485
|
+
agent already runs inside a container or VM that provides isolation.
|
|
486
|
+
* `CodexSandboxExecutionPolicy` – reuses the Codex CLI sandbox for additional
|
|
487
|
+
syscall/filesystem restrictions when the CLI is available.
|
|
488
|
+
* `DockerExecutionPolicy` – launches a separate Docker container for each agent run,
|
|
489
|
+
providing harder isolation, optional read-only root filesystems, and user
|
|
490
|
+
remapping.
|
|
380
491
|
|
|
381
|
-
When no policy is provided the middleware defaults to
|
|
492
|
+
When no policy is provided the middleware defaults to `HostExecutionPolicy`.
|
|
382
493
|
"""
|
|
383
494
|
|
|
384
495
|
state_schema = ShellToolState
|
|
@@ -392,29 +503,49 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
392
503
|
execution_policy: BaseExecutionPolicy | None = None,
|
|
393
504
|
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
|
394
505
|
tool_description: str | None = None,
|
|
506
|
+
tool_name: str = SHELL_TOOL_NAME,
|
|
395
507
|
shell_command: Sequence[str] | str | None = None,
|
|
396
508
|
env: Mapping[str, Any] | None = None,
|
|
397
509
|
) -> None:
|
|
398
|
-
"""Initialize
|
|
510
|
+
"""Initialize an instance of `ShellToolMiddleware`.
|
|
399
511
|
|
|
400
512
|
Args:
|
|
401
|
-
workspace_root: Base directory for the shell session.
|
|
402
|
-
|
|
403
|
-
|
|
513
|
+
workspace_root: Base directory for the shell session.
|
|
514
|
+
|
|
515
|
+
If omitted, a temporary directory is created when the agent starts and
|
|
516
|
+
removed when it ends.
|
|
517
|
+
startup_commands: Optional commands executed sequentially after the session
|
|
518
|
+
starts.
|
|
404
519
|
shutdown_commands: Optional commands executed before the session shuts down.
|
|
405
|
-
execution_policy: Execution policy controlling timeouts, output limits, and
|
|
406
|
-
configuration.
|
|
520
|
+
execution_policy: Execution policy controlling timeouts, output limits, and
|
|
521
|
+
resource configuration.
|
|
522
|
+
|
|
523
|
+
Defaults to `HostExecutionPolicy` for native execution.
|
|
407
524
|
redaction_rules: Optional redaction rules to sanitize command output before
|
|
408
525
|
returning it to the model.
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
526
|
+
|
|
527
|
+
!!! warning
|
|
528
|
+
Redaction rules are applied post execution and do not prevent
|
|
529
|
+
exfiltration of secrets or sensitive data when using
|
|
530
|
+
`HostExecutionPolicy`.
|
|
531
|
+
|
|
532
|
+
tool_description: Optional override for the registered shell tool
|
|
533
|
+
description.
|
|
534
|
+
tool_name: Name for the registered shell tool.
|
|
535
|
+
|
|
536
|
+
Defaults to `"shell"`.
|
|
537
|
+
shell_command: Optional shell executable (string) or argument sequence used
|
|
538
|
+
to launch the persistent session.
|
|
539
|
+
|
|
540
|
+
Defaults to an implementation-defined bash command.
|
|
541
|
+
env: Optional environment variables to supply to the shell session.
|
|
542
|
+
|
|
543
|
+
Values are coerced to strings before command execution. If omitted, the
|
|
544
|
+
session inherits the parent process environment.
|
|
415
545
|
"""
|
|
416
546
|
super().__init__()
|
|
417
547
|
self._workspace_root = Path(workspace_root) if workspace_root else None
|
|
548
|
+
self._tool_name = tool_name
|
|
418
549
|
self._shell_command = self._normalize_shell_command(shell_command)
|
|
419
550
|
self._environment = self._normalize_env(env)
|
|
420
551
|
if execution_policy is not None:
|
|
@@ -428,9 +559,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
428
559
|
self._startup_commands = self._normalize_commands(startup_commands)
|
|
429
560
|
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
|
430
561
|
|
|
562
|
+
# Create a proper tool that executes directly (no interception needed)
|
|
431
563
|
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
|
432
|
-
|
|
433
|
-
self.
|
|
564
|
+
|
|
565
|
+
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
|
|
566
|
+
def shell_tool(
|
|
567
|
+
*,
|
|
568
|
+
runtime: ToolRuntime[None, ShellToolState],
|
|
569
|
+
command: str | None = None,
|
|
570
|
+
restart: bool = False,
|
|
571
|
+
) -> ToolMessage | str:
|
|
572
|
+
resources = self._get_or_create_resources(runtime.state)
|
|
573
|
+
return self._run_shell_tool(
|
|
574
|
+
resources,
|
|
575
|
+
{"command": command, "restart": restart},
|
|
576
|
+
tool_call_id=runtime.tool_call_id,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
self._shell_tool = shell_tool
|
|
580
|
+
self.tools = [self._shell_tool]
|
|
434
581
|
|
|
435
582
|
@staticmethod
|
|
436
583
|
def _normalize_commands(
|
|
@@ -461,43 +608,73 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
461
608
|
normalized: dict[str, str] = {}
|
|
462
609
|
for key, value in env.items():
|
|
463
610
|
if not isinstance(key, str):
|
|
464
|
-
msg = "Environment variable names must be strings."
|
|
611
|
+
msg = "Environment variable names must be strings." # type: ignore[unreachable]
|
|
465
612
|
raise TypeError(msg)
|
|
466
613
|
normalized[key] = str(value)
|
|
467
614
|
return normalized
|
|
468
615
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
616
|
+
@override
|
|
617
|
+
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
|
618
|
+
"""Start the shell session and run startup commands.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
state: The current agent state.
|
|
622
|
+
runtime: The runtime context.
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
Shell session resources to be stored in the agent state.
|
|
626
|
+
"""
|
|
627
|
+
resources = self._get_or_create_resources(state)
|
|
472
628
|
return {"shell_session_resources": resources}
|
|
473
629
|
|
|
474
630
|
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
|
475
|
-
"""Async
|
|
631
|
+
"""Async start the shell session and run startup commands.
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
state: The current agent state.
|
|
635
|
+
runtime: The runtime context.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
Shell session resources to be stored in the agent state.
|
|
639
|
+
"""
|
|
476
640
|
return self.before_agent(state, runtime)
|
|
477
641
|
|
|
478
|
-
|
|
642
|
+
@override
|
|
643
|
+
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
|
479
644
|
"""Run shutdown commands and release resources when an agent completes."""
|
|
480
|
-
resources =
|
|
645
|
+
resources = state.get("shell_session_resources")
|
|
646
|
+
if not isinstance(resources, _SessionResources):
|
|
647
|
+
# Resources were never created, nothing to clean up
|
|
648
|
+
return
|
|
481
649
|
try:
|
|
482
650
|
self._run_shutdown_commands(resources.session)
|
|
483
651
|
finally:
|
|
484
|
-
resources.
|
|
652
|
+
resources.finalizer()
|
|
485
653
|
|
|
486
654
|
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
|
487
|
-
"""Async
|
|
655
|
+
"""Async run shutdown commands and release resources when an agent completes."""
|
|
488
656
|
return self.after_agent(state, runtime)
|
|
489
657
|
|
|
490
|
-
def
|
|
658
|
+
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
|
|
659
|
+
"""Get existing resources from state or create new ones if they don't exist.
|
|
660
|
+
|
|
661
|
+
This method enables resumability by checking if resources already exist in the state
|
|
662
|
+
(e.g., after an interrupt), and only creating new resources if they're not present.
|
|
663
|
+
|
|
664
|
+
Args:
|
|
665
|
+
state: The agent state which may contain shell session resources.
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
Session resources, either retrieved from state or newly created.
|
|
669
|
+
"""
|
|
491
670
|
resources = state.get("shell_session_resources")
|
|
492
|
-
if
|
|
493
|
-
resources
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
raise ToolException(msg)
|
|
500
|
-
return resources
|
|
671
|
+
if isinstance(resources, _SessionResources):
|
|
672
|
+
return resources
|
|
673
|
+
|
|
674
|
+
new_resources = self._create_resources()
|
|
675
|
+
# Cast needed to make state dict-like for mutation
|
|
676
|
+
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
|
|
677
|
+
return new_resources
|
|
501
678
|
|
|
502
679
|
def _create_resources(self) -> _SessionResources:
|
|
503
680
|
workspace = self._workspace_root
|
|
@@ -533,7 +710,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
533
710
|
return
|
|
534
711
|
for command in self._startup_commands:
|
|
535
712
|
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
|
|
536
|
-
if result.timed_out or (result.exit_code not in
|
|
713
|
+
if result.timed_out or (result.exit_code not in {0, None}):
|
|
537
714
|
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
|
|
538
715
|
raise RuntimeError(msg)
|
|
539
716
|
|
|
@@ -545,7 +722,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
545
722
|
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
|
546
723
|
if result.timed_out:
|
|
547
724
|
LOGGER.warning("Shutdown command '%s' timed out.", command)
|
|
548
|
-
elif result.exit_code not in
|
|
725
|
+
elif result.exit_code not in {0, None}:
|
|
549
726
|
LOGGER.warning(
|
|
550
727
|
"Shutdown command '%s' exited with %s.", command, result.exit_code
|
|
551
728
|
)
|
|
@@ -636,7 +813,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
636
813
|
f"(observed {result.total_bytes})."
|
|
637
814
|
)
|
|
638
815
|
|
|
639
|
-
if result.exit_code not in
|
|
816
|
+
if result.exit_code not in {0, None}:
|
|
640
817
|
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
|
|
641
818
|
final_status: Literal["success", "error"] = "error"
|
|
642
819
|
else:
|
|
@@ -659,36 +836,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
659
836
|
artifact=artifact,
|
|
660
837
|
)
|
|
661
838
|
|
|
662
|
-
def wrap_tool_call(
|
|
663
|
-
self,
|
|
664
|
-
request: ToolCallRequest,
|
|
665
|
-
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
|
|
666
|
-
) -> ToolMessage | Command:
|
|
667
|
-
"""Intercept local shell tool calls and execute them via the managed session."""
|
|
668
|
-
if isinstance(request.tool, _PersistentShellTool):
|
|
669
|
-
resources = self._ensure_resources(request.state)
|
|
670
|
-
return self._run_shell_tool(
|
|
671
|
-
resources,
|
|
672
|
-
request.tool_call["args"],
|
|
673
|
-
tool_call_id=request.tool_call.get("id"),
|
|
674
|
-
)
|
|
675
|
-
return handler(request)
|
|
676
|
-
|
|
677
|
-
async def awrap_tool_call(
|
|
678
|
-
self,
|
|
679
|
-
request: ToolCallRequest,
|
|
680
|
-
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
|
681
|
-
) -> ToolMessage | Command:
|
|
682
|
-
"""Async interception mirroring the synchronous tool handler."""
|
|
683
|
-
if isinstance(request.tool, _PersistentShellTool):
|
|
684
|
-
resources = self._ensure_resources(request.state)
|
|
685
|
-
return self._run_shell_tool(
|
|
686
|
-
resources,
|
|
687
|
-
request.tool_call["args"],
|
|
688
|
-
tool_call_id=request.tool_call.get("id"),
|
|
689
|
-
)
|
|
690
|
-
return await handler(request)
|
|
691
|
-
|
|
692
839
|
def _format_tool_message(
|
|
693
840
|
self,
|
|
694
841
|
content: str,
|
|
@@ -703,7 +850,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
703
850
|
return ToolMessage(
|
|
704
851
|
content=content,
|
|
705
852
|
tool_call_id=tool_call_id,
|
|
706
|
-
name=self.
|
|
853
|
+
name=self._tool_name,
|
|
707
854
|
status=status,
|
|
708
855
|
artifact=artifact,
|
|
709
856
|
)
|