langchain 1.0.5__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +99 -40
- langchain/agents/middleware/__init__.py +5 -7
- langchain/agents/middleware/_execution.py +21 -20
- langchain/agents/middleware/_redaction.py +27 -12
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +26 -22
- langchain/agents/middleware/file_search.py +18 -13
- langchain/agents/middleware/human_in_the_loop.py +60 -54
- langchain/agents/middleware/model_call_limit.py +63 -17
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +300 -0
- langchain/agents/middleware/pii.py +80 -27
- langchain/agents/middleware/shell_tool.py +230 -103
- langchain/agents/middleware/summarization.py +439 -90
- langchain/agents/middleware/todo.py +111 -27
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +42 -33
- langchain/agents/middleware/tool_retry.py +171 -159
- langchain/agents/middleware/tool_selection.py +37 -27
- langchain/agents/middleware/types.py +754 -392
- langchain/agents/structured_output.py +22 -12
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +233 -184
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/METADATA +3 -5
- langchain-1.2.3.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -11,18 +11,18 @@ import subprocess
|
|
|
11
11
|
import tempfile
|
|
12
12
|
import threading
|
|
13
13
|
import time
|
|
14
|
-
import typing
|
|
15
14
|
import uuid
|
|
16
15
|
import weakref
|
|
17
16
|
from dataclasses import dataclass, field
|
|
18
17
|
from pathlib import Path
|
|
19
|
-
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
18
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
|
20
19
|
|
|
21
20
|
from langchain_core.messages import ToolMessage
|
|
22
|
-
from langchain_core.tools.base import
|
|
21
|
+
from langchain_core.tools.base import ToolException
|
|
23
22
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
24
23
|
from pydantic import BaseModel, model_validator
|
|
25
|
-
from
|
|
24
|
+
from pydantic.json_schema import SkipJsonSchema
|
|
25
|
+
from typing_extensions import NotRequired, override
|
|
26
26
|
|
|
27
27
|
from langchain.agents.middleware._execution import (
|
|
28
28
|
SHELL_TEMP_PREFIX,
|
|
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
|
|
|
38
38
|
ResolvedRedactionRule,
|
|
39
39
|
)
|
|
40
40
|
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
|
41
|
+
from langchain.tools import ToolRuntime, tool
|
|
41
42
|
|
|
42
43
|
if TYPE_CHECKING:
|
|
43
44
|
from collections.abc import Mapping, Sequence
|
|
44
45
|
|
|
45
46
|
from langgraph.runtime import Runtime
|
|
46
|
-
from langgraph.types import Command
|
|
47
47
|
|
|
48
|
-
from langchain.agents.middleware.types import ToolCallRequest
|
|
49
48
|
|
|
50
49
|
LOGGER = logging.getLogger(__name__)
|
|
51
50
|
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
|
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
|
|
|
59
58
|
"session remains stable. Outputs may be truncated when they become very large, and long "
|
|
60
59
|
"running commands will be terminated once their configured timeout elapses."
|
|
61
60
|
)
|
|
61
|
+
SHELL_TOOL_NAME = "shell"
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
def _cleanup_resources(
|
|
@@ -78,10 +78,10 @@ class _SessionResources:
|
|
|
78
78
|
session: ShellSession
|
|
79
79
|
tempdir: tempfile.TemporaryDirectory[str] | None
|
|
80
80
|
policy: BaseExecutionPolicy
|
|
81
|
-
|
|
81
|
+
finalizer: weakref.finalize = field(init=False, repr=False)
|
|
82
82
|
|
|
83
83
|
def __post_init__(self) -> None:
|
|
84
|
-
self.
|
|
84
|
+
self.finalizer = weakref.finalize(
|
|
85
85
|
self,
|
|
86
86
|
_cleanup_resources,
|
|
87
87
|
self.session,
|
|
@@ -211,9 +211,14 @@ class ShellSession:
|
|
|
211
211
|
with self._lock:
|
|
212
212
|
self._drain_queue()
|
|
213
213
|
payload = command if command.endswith("\n") else f"{command}\n"
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
214
|
+
try:
|
|
215
|
+
self._stdin.write(payload)
|
|
216
|
+
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
|
|
217
|
+
self._stdin.flush()
|
|
218
|
+
except (BrokenPipeError, OSError):
|
|
219
|
+
# The shell exited before we could write the marker command.
|
|
220
|
+
# This happens when commands like 'exit 1' terminate the shell.
|
|
221
|
+
return self._collect_output_after_exit(deadline)
|
|
217
222
|
|
|
218
223
|
return self._collect_output(marker, deadline, timeout)
|
|
219
224
|
|
|
@@ -248,6 +253,10 @@ class ShellSession:
|
|
|
248
253
|
if source == "stdout" and data.startswith(marker):
|
|
249
254
|
_, _, status = data.partition(" ")
|
|
250
255
|
exit_code = self._safe_int(status.strip())
|
|
256
|
+
# Drain any remaining stderr that may have arrived concurrently.
|
|
257
|
+
# The stderr reader thread runs independently, so output might
|
|
258
|
+
# still be in flight when the stdout marker arrives.
|
|
259
|
+
self._drain_remaining_stderr(collected, deadline)
|
|
251
260
|
break
|
|
252
261
|
|
|
253
262
|
total_lines += 1
|
|
@@ -300,6 +309,80 @@ class ShellSession:
|
|
|
300
309
|
total_bytes=total_bytes,
|
|
301
310
|
)
|
|
302
311
|
|
|
312
|
+
def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult:
|
|
313
|
+
"""Collect output after the shell exited unexpectedly.
|
|
314
|
+
|
|
315
|
+
Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
|
|
316
|
+
shell process terminated (e.g., due to an 'exit' command).
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
deadline: Absolute time by which collection must complete.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
`CommandExecutionResult` with collected output and the process exit code.
|
|
323
|
+
"""
|
|
324
|
+
collected: list[str] = []
|
|
325
|
+
total_lines = 0
|
|
326
|
+
total_bytes = 0
|
|
327
|
+
truncated_by_lines = False
|
|
328
|
+
truncated_by_bytes = False
|
|
329
|
+
|
|
330
|
+
# Give reader threads a brief moment to enqueue any remaining output.
|
|
331
|
+
drain_timeout = 0.1
|
|
332
|
+
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
|
333
|
+
|
|
334
|
+
while True:
|
|
335
|
+
remaining = drain_deadline - time.monotonic()
|
|
336
|
+
if remaining <= 0:
|
|
337
|
+
break
|
|
338
|
+
try:
|
|
339
|
+
source, data = self._queue.get(timeout=remaining)
|
|
340
|
+
except queue.Empty:
|
|
341
|
+
break
|
|
342
|
+
|
|
343
|
+
if data is None:
|
|
344
|
+
# EOF marker from a reader thread; continue draining.
|
|
345
|
+
continue
|
|
346
|
+
|
|
347
|
+
total_lines += 1
|
|
348
|
+
encoded = data.encode("utf-8", "replace")
|
|
349
|
+
total_bytes += len(encoded)
|
|
350
|
+
|
|
351
|
+
if total_lines > self._policy.max_output_lines:
|
|
352
|
+
truncated_by_lines = True
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
if (
|
|
356
|
+
self._policy.max_output_bytes is not None
|
|
357
|
+
and total_bytes > self._policy.max_output_bytes
|
|
358
|
+
):
|
|
359
|
+
truncated_by_bytes = True
|
|
360
|
+
continue
|
|
361
|
+
|
|
362
|
+
if source == "stderr":
|
|
363
|
+
stripped = data.rstrip("\n")
|
|
364
|
+
collected.append(f"[stderr] {stripped}")
|
|
365
|
+
if data.endswith("\n"):
|
|
366
|
+
collected.append("\n")
|
|
367
|
+
else:
|
|
368
|
+
collected.append(data)
|
|
369
|
+
|
|
370
|
+
# Get exit code from the terminated process.
|
|
371
|
+
exit_code: int | None = None
|
|
372
|
+
if self._process:
|
|
373
|
+
exit_code = self._process.poll()
|
|
374
|
+
|
|
375
|
+
output = "".join(collected)
|
|
376
|
+
return CommandExecutionResult(
|
|
377
|
+
output=output,
|
|
378
|
+
exit_code=exit_code,
|
|
379
|
+
timed_out=False,
|
|
380
|
+
truncated_by_lines=truncated_by_lines,
|
|
381
|
+
truncated_by_bytes=truncated_by_bytes,
|
|
382
|
+
total_lines=total_lines,
|
|
383
|
+
total_bytes=total_bytes,
|
|
384
|
+
)
|
|
385
|
+
|
|
303
386
|
def _kill_process(self) -> None:
|
|
304
387
|
if not self._process:
|
|
305
388
|
return
|
|
@@ -323,6 +406,37 @@ class ShellSession:
|
|
|
323
406
|
except queue.Empty:
|
|
324
407
|
break
|
|
325
408
|
|
|
409
|
+
def _drain_remaining_stderr(
|
|
410
|
+
self, collected: list[str], deadline: float, drain_timeout: float = 0.05
|
|
411
|
+
) -> None:
|
|
412
|
+
"""Drain any stderr output that arrived concurrently with the done marker.
|
|
413
|
+
|
|
414
|
+
The stdout and stderr reader threads run independently. When a command writes to
|
|
415
|
+
stderr just before exiting, the stderr output may still be in transit when the
|
|
416
|
+
done marker arrives on stdout. This method briefly polls the queue to capture
|
|
417
|
+
such output.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
collected: The list to append collected stderr lines to.
|
|
421
|
+
deadline: The original command deadline (used as an upper bound).
|
|
422
|
+
drain_timeout: Maximum time to wait for additional stderr output.
|
|
423
|
+
"""
|
|
424
|
+
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
|
425
|
+
while True:
|
|
426
|
+
remaining = drain_deadline - time.monotonic()
|
|
427
|
+
if remaining <= 0:
|
|
428
|
+
break
|
|
429
|
+
try:
|
|
430
|
+
source, data = self._queue.get(timeout=remaining)
|
|
431
|
+
except queue.Empty:
|
|
432
|
+
break
|
|
433
|
+
if data is None or source != "stderr":
|
|
434
|
+
continue
|
|
435
|
+
stripped = data.rstrip("\n")
|
|
436
|
+
collected.append(f"[stderr] {stripped}")
|
|
437
|
+
if data.endswith("\n"):
|
|
438
|
+
collected.append("\n")
|
|
439
|
+
|
|
326
440
|
@staticmethod
|
|
327
441
|
def _safe_int(value: str) -> int | None:
|
|
328
442
|
with contextlib.suppress(ValueError):
|
|
@@ -334,7 +448,17 @@ class _ShellToolInput(BaseModel):
|
|
|
334
448
|
"""Input schema for the persistent shell tool."""
|
|
335
449
|
|
|
336
450
|
command: str | None = None
|
|
451
|
+
"""The shell command to execute."""
|
|
452
|
+
|
|
337
453
|
restart: bool | None = None
|
|
454
|
+
"""Whether to restart the shell session."""
|
|
455
|
+
|
|
456
|
+
runtime: Annotated[Any, SkipJsonSchema()] = None
|
|
457
|
+
"""The runtime for the shell tool.
|
|
458
|
+
|
|
459
|
+
Included as a workaround at the moment bc args_schema doesn't work with
|
|
460
|
+
injected ToolRuntime.
|
|
461
|
+
"""
|
|
338
462
|
|
|
339
463
|
@model_validator(mode="after")
|
|
340
464
|
def validate_payload(self) -> _ShellToolInput:
|
|
@@ -347,38 +471,21 @@ class _ShellToolInput(BaseModel):
|
|
|
347
471
|
return self
|
|
348
472
|
|
|
349
473
|
|
|
350
|
-
class _PersistentShellTool(BaseTool):
|
|
351
|
-
"""Tool wrapper that relies on middleware interception for execution."""
|
|
352
|
-
|
|
353
|
-
name: str = "shell"
|
|
354
|
-
description: str = DEFAULT_TOOL_DESCRIPTION
|
|
355
|
-
args_schema: type[BaseModel] = _ShellToolInput
|
|
356
|
-
|
|
357
|
-
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
|
|
358
|
-
super().__init__()
|
|
359
|
-
self._middleware = middleware
|
|
360
|
-
if description is not None:
|
|
361
|
-
self.description = description
|
|
362
|
-
|
|
363
|
-
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
|
|
364
|
-
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
|
|
365
|
-
raise RuntimeError(msg)
|
|
366
|
-
|
|
367
|
-
|
|
368
474
|
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
369
475
|
"""Middleware that registers a persistent shell tool for agents.
|
|
370
476
|
|
|
371
|
-
The middleware exposes a single long-lived shell session. Use the execution policy
|
|
372
|
-
match your deployment's security posture:
|
|
477
|
+
The middleware exposes a single long-lived shell session. Use the execution policy
|
|
478
|
+
to match your deployment's security posture:
|
|
373
479
|
|
|
374
|
-
*
|
|
375
|
-
|
|
376
|
-
*
|
|
377
|
-
|
|
378
|
-
*
|
|
379
|
-
|
|
480
|
+
* `HostExecutionPolicy` – full host access; best for trusted environments where the
|
|
481
|
+
agent already runs inside a container or VM that provides isolation.
|
|
482
|
+
* `CodexSandboxExecutionPolicy` – reuses the Codex CLI sandbox for additional
|
|
483
|
+
syscall/filesystem restrictions when the CLI is available.
|
|
484
|
+
* `DockerExecutionPolicy` – launches a separate Docker container for each agent run,
|
|
485
|
+
providing harder isolation, optional read-only root filesystems, and user
|
|
486
|
+
remapping.
|
|
380
487
|
|
|
381
|
-
When no policy is provided the middleware defaults to
|
|
488
|
+
When no policy is provided the middleware defaults to `HostExecutionPolicy`.
|
|
382
489
|
"""
|
|
383
490
|
|
|
384
491
|
state_schema = ShellToolState
|
|
@@ -392,29 +499,49 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
392
499
|
execution_policy: BaseExecutionPolicy | None = None,
|
|
393
500
|
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
|
394
501
|
tool_description: str | None = None,
|
|
502
|
+
tool_name: str = SHELL_TOOL_NAME,
|
|
395
503
|
shell_command: Sequence[str] | str | None = None,
|
|
396
504
|
env: Mapping[str, Any] | None = None,
|
|
397
505
|
) -> None:
|
|
398
|
-
"""Initialize
|
|
506
|
+
"""Initialize an instance of `ShellToolMiddleware`.
|
|
399
507
|
|
|
400
508
|
Args:
|
|
401
|
-
workspace_root: Base directory for the shell session.
|
|
402
|
-
|
|
403
|
-
|
|
509
|
+
workspace_root: Base directory for the shell session.
|
|
510
|
+
|
|
511
|
+
If omitted, a temporary directory is created when the agent starts and
|
|
512
|
+
removed when it ends.
|
|
513
|
+
startup_commands: Optional commands executed sequentially after the session
|
|
514
|
+
starts.
|
|
404
515
|
shutdown_commands: Optional commands executed before the session shuts down.
|
|
405
|
-
execution_policy: Execution policy controlling timeouts, output limits, and
|
|
406
|
-
configuration.
|
|
516
|
+
execution_policy: Execution policy controlling timeouts, output limits, and
|
|
517
|
+
resource configuration.
|
|
518
|
+
|
|
519
|
+
Defaults to `HostExecutionPolicy` for native execution.
|
|
407
520
|
redaction_rules: Optional redaction rules to sanitize command output before
|
|
408
521
|
returning it to the model.
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
522
|
+
|
|
523
|
+
!!! warning
|
|
524
|
+
Redaction rules are applied post execution and do not prevent
|
|
525
|
+
exfiltration of secrets or sensitive data when using
|
|
526
|
+
`HostExecutionPolicy`.
|
|
527
|
+
|
|
528
|
+
tool_description: Optional override for the registered shell tool
|
|
529
|
+
description.
|
|
530
|
+
tool_name: Name for the registered shell tool.
|
|
531
|
+
|
|
532
|
+
Defaults to `"shell"`.
|
|
533
|
+
shell_command: Optional shell executable (string) or argument sequence used
|
|
534
|
+
to launch the persistent session.
|
|
535
|
+
|
|
536
|
+
Defaults to an implementation-defined bash command.
|
|
537
|
+
env: Optional environment variables to supply to the shell session.
|
|
538
|
+
|
|
539
|
+
Values are coerced to strings before command execution. If omitted, the
|
|
540
|
+
session inherits the parent process environment.
|
|
415
541
|
"""
|
|
416
542
|
super().__init__()
|
|
417
543
|
self._workspace_root = Path(workspace_root) if workspace_root else None
|
|
544
|
+
self._tool_name = tool_name
|
|
418
545
|
self._shell_command = self._normalize_shell_command(shell_command)
|
|
419
546
|
self._environment = self._normalize_env(env)
|
|
420
547
|
if execution_policy is not None:
|
|
@@ -428,9 +555,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
428
555
|
self._startup_commands = self._normalize_commands(startup_commands)
|
|
429
556
|
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
|
430
557
|
|
|
558
|
+
# Create a proper tool that executes directly (no interception needed)
|
|
431
559
|
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
|
432
|
-
|
|
433
|
-
self.
|
|
560
|
+
|
|
561
|
+
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
|
|
562
|
+
def shell_tool(
|
|
563
|
+
*,
|
|
564
|
+
runtime: ToolRuntime[None, ShellToolState],
|
|
565
|
+
command: str | None = None,
|
|
566
|
+
restart: bool = False,
|
|
567
|
+
) -> ToolMessage | str:
|
|
568
|
+
resources = self._get_or_create_resources(runtime.state)
|
|
569
|
+
return self._run_shell_tool(
|
|
570
|
+
resources,
|
|
571
|
+
{"command": command, "restart": restart},
|
|
572
|
+
tool_call_id=runtime.tool_call_id,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
self._shell_tool = shell_tool
|
|
576
|
+
self.tools = [self._shell_tool]
|
|
434
577
|
|
|
435
578
|
@staticmethod
|
|
436
579
|
def _normalize_commands(
|
|
@@ -466,38 +609,52 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
466
609
|
normalized[key] = str(value)
|
|
467
610
|
return normalized
|
|
468
611
|
|
|
469
|
-
|
|
612
|
+
@override
|
|
613
|
+
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
|
470
614
|
"""Start the shell session and run startup commands."""
|
|
471
|
-
resources = self.
|
|
615
|
+
resources = self._get_or_create_resources(state)
|
|
472
616
|
return {"shell_session_resources": resources}
|
|
473
617
|
|
|
474
618
|
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
|
475
|
-
"""Async
|
|
619
|
+
"""Async start the shell session and run startup commands."""
|
|
476
620
|
return self.before_agent(state, runtime)
|
|
477
621
|
|
|
478
|
-
|
|
622
|
+
@override
|
|
623
|
+
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
|
479
624
|
"""Run shutdown commands and release resources when an agent completes."""
|
|
480
|
-
resources =
|
|
625
|
+
resources = state.get("shell_session_resources")
|
|
626
|
+
if not isinstance(resources, _SessionResources):
|
|
627
|
+
# Resources were never created, nothing to clean up
|
|
628
|
+
return
|
|
481
629
|
try:
|
|
482
630
|
self._run_shutdown_commands(resources.session)
|
|
483
631
|
finally:
|
|
484
|
-
resources.
|
|
632
|
+
resources.finalizer()
|
|
485
633
|
|
|
486
634
|
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
|
487
|
-
"""Async
|
|
635
|
+
"""Async run shutdown commands and release resources when an agent completes."""
|
|
488
636
|
return self.after_agent(state, runtime)
|
|
489
637
|
|
|
490
|
-
def
|
|
638
|
+
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
|
|
639
|
+
"""Get existing resources from state or create new ones if they don't exist.
|
|
640
|
+
|
|
641
|
+
This method enables resumability by checking if resources already exist in the state
|
|
642
|
+
(e.g., after an interrupt), and only creating new resources if they're not present.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
state: The agent state which may contain shell session resources.
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
Session resources, either retrieved from state or newly created.
|
|
649
|
+
"""
|
|
491
650
|
resources = state.get("shell_session_resources")
|
|
492
|
-
if
|
|
493
|
-
resources
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
raise ToolException(msg)
|
|
500
|
-
return resources
|
|
651
|
+
if isinstance(resources, _SessionResources):
|
|
652
|
+
return resources
|
|
653
|
+
|
|
654
|
+
new_resources = self._create_resources()
|
|
655
|
+
# Cast needed to make state dict-like for mutation
|
|
656
|
+
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
|
|
657
|
+
return new_resources
|
|
501
658
|
|
|
502
659
|
def _create_resources(self) -> _SessionResources:
|
|
503
660
|
workspace = self._workspace_root
|
|
@@ -533,7 +690,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
533
690
|
return
|
|
534
691
|
for command in self._startup_commands:
|
|
535
692
|
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
|
|
536
|
-
if result.timed_out or (result.exit_code not in
|
|
693
|
+
if result.timed_out or (result.exit_code not in {0, None}):
|
|
537
694
|
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
|
|
538
695
|
raise RuntimeError(msg)
|
|
539
696
|
|
|
@@ -545,7 +702,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
545
702
|
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
|
546
703
|
if result.timed_out:
|
|
547
704
|
LOGGER.warning("Shutdown command '%s' timed out.", command)
|
|
548
|
-
elif result.exit_code not in
|
|
705
|
+
elif result.exit_code not in {0, None}:
|
|
549
706
|
LOGGER.warning(
|
|
550
707
|
"Shutdown command '%s' exited with %s.", command, result.exit_code
|
|
551
708
|
)
|
|
@@ -636,7 +793,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
636
793
|
f"(observed {result.total_bytes})."
|
|
637
794
|
)
|
|
638
795
|
|
|
639
|
-
if result.exit_code not in
|
|
796
|
+
if result.exit_code not in {0, None}:
|
|
640
797
|
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
|
|
641
798
|
final_status: Literal["success", "error"] = "error"
|
|
642
799
|
else:
|
|
@@ -659,36 +816,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
659
816
|
artifact=artifact,
|
|
660
817
|
)
|
|
661
818
|
|
|
662
|
-
def wrap_tool_call(
|
|
663
|
-
self,
|
|
664
|
-
request: ToolCallRequest,
|
|
665
|
-
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
|
|
666
|
-
) -> ToolMessage | Command:
|
|
667
|
-
"""Intercept local shell tool calls and execute them via the managed session."""
|
|
668
|
-
if isinstance(request.tool, _PersistentShellTool):
|
|
669
|
-
resources = self._ensure_resources(request.state)
|
|
670
|
-
return self._run_shell_tool(
|
|
671
|
-
resources,
|
|
672
|
-
request.tool_call["args"],
|
|
673
|
-
tool_call_id=request.tool_call.get("id"),
|
|
674
|
-
)
|
|
675
|
-
return handler(request)
|
|
676
|
-
|
|
677
|
-
async def awrap_tool_call(
|
|
678
|
-
self,
|
|
679
|
-
request: ToolCallRequest,
|
|
680
|
-
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
|
681
|
-
) -> ToolMessage | Command:
|
|
682
|
-
"""Async interception mirroring the synchronous tool handler."""
|
|
683
|
-
if isinstance(request.tool, _PersistentShellTool):
|
|
684
|
-
resources = self._ensure_resources(request.state)
|
|
685
|
-
return self._run_shell_tool(
|
|
686
|
-
resources,
|
|
687
|
-
request.tool_call["args"],
|
|
688
|
-
tool_call_id=request.tool_call.get("id"),
|
|
689
|
-
)
|
|
690
|
-
return await handler(request)
|
|
691
|
-
|
|
692
819
|
def _format_tool_message(
|
|
693
820
|
self,
|
|
694
821
|
content: str,
|
|
@@ -703,7 +830,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
703
830
|
return ToolMessage(
|
|
704
831
|
content=content,
|
|
705
832
|
tool_call_id=tool_call_id,
|
|
706
|
-
name=self.
|
|
833
|
+
name=self._tool_name,
|
|
707
834
|
status=status,
|
|
708
835
|
artifact=artifact,
|
|
709
836
|
)
|