langchain 1.0.0rc1__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +16 -19
- langchain/agents/middleware/__init__.py +12 -0
- langchain/agents/middleware/_execution.py +388 -0
- langchain/agents/middleware/_redaction.py +350 -0
- langchain/agents/middleware/file_search.py +382 -0
- langchain/agents/middleware/pii.py +43 -477
- langchain/agents/middleware/shell_tool.py +718 -0
- langchain/agents/middleware/types.py +7 -5
- langchain/chat_models/base.py +7 -17
- langchain/embeddings/__init__.py +6 -0
- langchain/embeddings/base.py +21 -7
- langchain/tools/tool_node.py +49 -46
- {langchain-1.0.0rc1.dist-info → langchain-1.0.1.dist-info}/METADATA +12 -9
- {langchain-1.0.0rc1.dist-info → langchain-1.0.1.dist-info}/RECORD +17 -13
- {langchain-1.0.0rc1.dist-info → langchain-1.0.1.dist-info}/WHEEL +0 -0
- {langchain-1.0.0rc1.dist-info → langchain-1.0.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,718 @@
|
|
|
1
|
+
"""Middleware that exposes a persistent shell tool to agents."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import queue
|
|
9
|
+
import signal
|
|
10
|
+
import subprocess
|
|
11
|
+
import tempfile
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
import typing
|
|
15
|
+
import uuid
|
|
16
|
+
import weakref
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
20
|
+
|
|
21
|
+
from langchain_core.messages import ToolMessage
|
|
22
|
+
from langchain_core.tools.base import BaseTool, ToolException
|
|
23
|
+
from langgraph.channels.untracked_value import UntrackedValue
|
|
24
|
+
from pydantic import BaseModel, model_validator
|
|
25
|
+
from typing_extensions import NotRequired
|
|
26
|
+
|
|
27
|
+
from langchain.agents.middleware._execution import (
|
|
28
|
+
SHELL_TEMP_PREFIX,
|
|
29
|
+
BaseExecutionPolicy,
|
|
30
|
+
CodexSandboxExecutionPolicy,
|
|
31
|
+
DockerExecutionPolicy,
|
|
32
|
+
HostExecutionPolicy,
|
|
33
|
+
)
|
|
34
|
+
from langchain.agents.middleware._redaction import (
|
|
35
|
+
PIIDetectionError,
|
|
36
|
+
PIIMatch,
|
|
37
|
+
RedactionRule,
|
|
38
|
+
ResolvedRedactionRule,
|
|
39
|
+
)
|
|
40
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
|
41
|
+
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
from collections.abc import Mapping, Sequence
|
|
44
|
+
|
|
45
|
+
from langgraph.runtime import Runtime
|
|
46
|
+
from langgraph.types import Command
|
|
47
|
+
|
|
48
|
+
from langchain.tools.tool_node import ToolCallRequest
|
|
49
|
+
|
|
50
|
+
LOGGER = logging.getLogger(__name__)
|
|
51
|
+
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
|
52
|
+
|
|
53
|
+
DEFAULT_TOOL_DESCRIPTION = (
|
|
54
|
+
"Execute a shell command inside a persistent session. Before running a command, "
|
|
55
|
+
"confirm the working directory is correct (e.g., inspect with `ls` or `pwd`) and ensure "
|
|
56
|
+
"any parent directories exist. Prefer absolute paths and quote paths containing spaces, "
|
|
57
|
+
'such as `cd "/path/with spaces"`. Chain multiple commands with `&&` or `;` instead of '
|
|
58
|
+
"embedding newlines. Avoid unnecessary `cd` usage unless explicitly required so the "
|
|
59
|
+
"session remains stable. Outputs may be truncated when they become very large, and long "
|
|
60
|
+
"running commands will be terminated once their configured timeout elapses."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _cleanup_resources(
|
|
65
|
+
session: ShellSession, tempdir: tempfile.TemporaryDirectory[str] | None, timeout: float
|
|
66
|
+
) -> None:
|
|
67
|
+
with contextlib.suppress(Exception):
|
|
68
|
+
session.stop(timeout)
|
|
69
|
+
if tempdir is not None:
|
|
70
|
+
with contextlib.suppress(Exception):
|
|
71
|
+
tempdir.cleanup()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class _SessionResources:
|
|
76
|
+
"""Container for per-run shell resources."""
|
|
77
|
+
|
|
78
|
+
session: ShellSession
|
|
79
|
+
tempdir: tempfile.TemporaryDirectory[str] | None
|
|
80
|
+
policy: BaseExecutionPolicy
|
|
81
|
+
_finalizer: weakref.finalize = field(init=False, repr=False)
|
|
82
|
+
|
|
83
|
+
def __post_init__(self) -> None:
|
|
84
|
+
self._finalizer = weakref.finalize(
|
|
85
|
+
self,
|
|
86
|
+
_cleanup_resources,
|
|
87
|
+
self.session,
|
|
88
|
+
self.tempdir,
|
|
89
|
+
self.policy.termination_timeout,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ShellToolState(AgentState):
|
|
94
|
+
"""Agent state extension for tracking shell session resources."""
|
|
95
|
+
|
|
96
|
+
shell_session_resources: NotRequired[
|
|
97
|
+
Annotated[_SessionResources | None, UntrackedValue, PrivateStateAttr]
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass(frozen=True)
|
|
102
|
+
class CommandExecutionResult:
|
|
103
|
+
"""Structured result from command execution."""
|
|
104
|
+
|
|
105
|
+
output: str
|
|
106
|
+
exit_code: int | None
|
|
107
|
+
timed_out: bool
|
|
108
|
+
truncated_by_lines: bool
|
|
109
|
+
truncated_by_bytes: bool
|
|
110
|
+
total_lines: int
|
|
111
|
+
total_bytes: int
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class ShellSession:
|
|
115
|
+
"""Persistent shell session that supports sequential command execution."""
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
workspace: Path,
|
|
120
|
+
policy: BaseExecutionPolicy,
|
|
121
|
+
command: tuple[str, ...],
|
|
122
|
+
environment: Mapping[str, str],
|
|
123
|
+
) -> None:
|
|
124
|
+
self._workspace = workspace
|
|
125
|
+
self._policy = policy
|
|
126
|
+
self._command = command
|
|
127
|
+
self._environment = dict(environment)
|
|
128
|
+
self._process: subprocess.Popen[str] | None = None
|
|
129
|
+
self._stdin: Any = None
|
|
130
|
+
self._queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
|
|
131
|
+
self._lock = threading.Lock()
|
|
132
|
+
self._stdout_thread: threading.Thread | None = None
|
|
133
|
+
self._stderr_thread: threading.Thread | None = None
|
|
134
|
+
self._terminated = False
|
|
135
|
+
|
|
136
|
+
def start(self) -> None:
|
|
137
|
+
"""Start the shell subprocess and reader threads."""
|
|
138
|
+
if self._process and self._process.poll() is None:
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
self._process = self._policy.spawn(
|
|
142
|
+
workspace=self._workspace,
|
|
143
|
+
env=self._environment,
|
|
144
|
+
command=self._command,
|
|
145
|
+
)
|
|
146
|
+
if (
|
|
147
|
+
self._process.stdin is None
|
|
148
|
+
or self._process.stdout is None
|
|
149
|
+
or self._process.stderr is None
|
|
150
|
+
):
|
|
151
|
+
msg = "Failed to initialize shell session pipes."
|
|
152
|
+
raise RuntimeError(msg)
|
|
153
|
+
|
|
154
|
+
self._stdin = self._process.stdin
|
|
155
|
+
self._terminated = False
|
|
156
|
+
self._queue = queue.Queue()
|
|
157
|
+
|
|
158
|
+
self._stdout_thread = threading.Thread(
|
|
159
|
+
target=self._enqueue_stream,
|
|
160
|
+
args=(self._process.stdout, "stdout"),
|
|
161
|
+
daemon=True,
|
|
162
|
+
)
|
|
163
|
+
self._stderr_thread = threading.Thread(
|
|
164
|
+
target=self._enqueue_stream,
|
|
165
|
+
args=(self._process.stderr, "stderr"),
|
|
166
|
+
daemon=True,
|
|
167
|
+
)
|
|
168
|
+
self._stdout_thread.start()
|
|
169
|
+
self._stderr_thread.start()
|
|
170
|
+
|
|
171
|
+
def restart(self) -> None:
|
|
172
|
+
"""Restart the shell process."""
|
|
173
|
+
self.stop(self._policy.termination_timeout)
|
|
174
|
+
self.start()
|
|
175
|
+
|
|
176
|
+
def stop(self, timeout: float) -> None:
|
|
177
|
+
"""Stop the shell subprocess."""
|
|
178
|
+
if not self._process:
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
if self._process.poll() is None and not self._terminated:
|
|
182
|
+
try:
|
|
183
|
+
self._stdin.write("exit\n")
|
|
184
|
+
self._stdin.flush()
|
|
185
|
+
except (BrokenPipeError, OSError):
|
|
186
|
+
LOGGER.debug(
|
|
187
|
+
"Failed to write exit command; terminating shell session.",
|
|
188
|
+
exc_info=True,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
if self._process.wait(timeout=timeout) is None:
|
|
193
|
+
self._kill_process()
|
|
194
|
+
except subprocess.TimeoutExpired:
|
|
195
|
+
self._kill_process()
|
|
196
|
+
finally:
|
|
197
|
+
self._terminated = True
|
|
198
|
+
with contextlib.suppress(Exception):
|
|
199
|
+
self._stdin.close()
|
|
200
|
+
self._process = None
|
|
201
|
+
|
|
202
|
+
def execute(self, command: str, *, timeout: float) -> CommandExecutionResult:
|
|
203
|
+
"""Execute a command in the persistent shell."""
|
|
204
|
+
if not self._process or self._process.poll() is not None:
|
|
205
|
+
msg = "Shell session is not running."
|
|
206
|
+
raise RuntimeError(msg)
|
|
207
|
+
|
|
208
|
+
marker = f"{_DONE_MARKER_PREFIX}{uuid.uuid4().hex}"
|
|
209
|
+
deadline = time.monotonic() + timeout
|
|
210
|
+
|
|
211
|
+
with self._lock:
|
|
212
|
+
self._drain_queue()
|
|
213
|
+
payload = command if command.endswith("\n") else f"{command}\n"
|
|
214
|
+
self._stdin.write(payload)
|
|
215
|
+
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
|
|
216
|
+
self._stdin.flush()
|
|
217
|
+
|
|
218
|
+
return self._collect_output(marker, deadline, timeout)
|
|
219
|
+
|
|
220
|
+
def _collect_output(
|
|
221
|
+
self,
|
|
222
|
+
marker: str,
|
|
223
|
+
deadline: float,
|
|
224
|
+
timeout: float,
|
|
225
|
+
) -> CommandExecutionResult:
|
|
226
|
+
collected: list[str] = []
|
|
227
|
+
total_lines = 0
|
|
228
|
+
total_bytes = 0
|
|
229
|
+
truncated_by_lines = False
|
|
230
|
+
truncated_by_bytes = False
|
|
231
|
+
exit_code: int | None = None
|
|
232
|
+
timed_out = False
|
|
233
|
+
|
|
234
|
+
while True:
|
|
235
|
+
remaining = deadline - time.monotonic()
|
|
236
|
+
if remaining <= 0:
|
|
237
|
+
timed_out = True
|
|
238
|
+
break
|
|
239
|
+
try:
|
|
240
|
+
source, data = self._queue.get(timeout=remaining)
|
|
241
|
+
except queue.Empty:
|
|
242
|
+
timed_out = True
|
|
243
|
+
break
|
|
244
|
+
|
|
245
|
+
if data is None:
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
if source == "stdout" and data.startswith(marker):
|
|
249
|
+
_, _, status = data.partition(" ")
|
|
250
|
+
exit_code = self._safe_int(status.strip())
|
|
251
|
+
break
|
|
252
|
+
|
|
253
|
+
total_lines += 1
|
|
254
|
+
encoded = data.encode("utf-8", "replace")
|
|
255
|
+
total_bytes += len(encoded)
|
|
256
|
+
|
|
257
|
+
if total_lines > self._policy.max_output_lines:
|
|
258
|
+
truncated_by_lines = True
|
|
259
|
+
continue
|
|
260
|
+
|
|
261
|
+
if (
|
|
262
|
+
self._policy.max_output_bytes is not None
|
|
263
|
+
and total_bytes > self._policy.max_output_bytes
|
|
264
|
+
):
|
|
265
|
+
truncated_by_bytes = True
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
if source == "stderr":
|
|
269
|
+
stripped = data.rstrip("\n")
|
|
270
|
+
collected.append(f"[stderr] {stripped}")
|
|
271
|
+
if data.endswith("\n"):
|
|
272
|
+
collected.append("\n")
|
|
273
|
+
else:
|
|
274
|
+
collected.append(data)
|
|
275
|
+
|
|
276
|
+
if timed_out:
|
|
277
|
+
LOGGER.warning(
|
|
278
|
+
"Command timed out after %.2f seconds; restarting shell session.",
|
|
279
|
+
timeout,
|
|
280
|
+
)
|
|
281
|
+
self.restart()
|
|
282
|
+
return CommandExecutionResult(
|
|
283
|
+
output="",
|
|
284
|
+
exit_code=None,
|
|
285
|
+
timed_out=True,
|
|
286
|
+
truncated_by_lines=truncated_by_lines,
|
|
287
|
+
truncated_by_bytes=truncated_by_bytes,
|
|
288
|
+
total_lines=total_lines,
|
|
289
|
+
total_bytes=total_bytes,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
output = "".join(collected)
|
|
293
|
+
return CommandExecutionResult(
|
|
294
|
+
output=output,
|
|
295
|
+
exit_code=exit_code,
|
|
296
|
+
timed_out=False,
|
|
297
|
+
truncated_by_lines=truncated_by_lines,
|
|
298
|
+
truncated_by_bytes=truncated_by_bytes,
|
|
299
|
+
total_lines=total_lines,
|
|
300
|
+
total_bytes=total_bytes,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def _kill_process(self) -> None:
|
|
304
|
+
if not self._process:
|
|
305
|
+
return
|
|
306
|
+
|
|
307
|
+
if hasattr(os, "killpg"):
|
|
308
|
+
with contextlib.suppress(ProcessLookupError):
|
|
309
|
+
os.killpg(os.getpgid(self._process.pid), signal.SIGKILL)
|
|
310
|
+
else: # pragma: no cover
|
|
311
|
+
with contextlib.suppress(ProcessLookupError):
|
|
312
|
+
self._process.kill()
|
|
313
|
+
|
|
314
|
+
def _enqueue_stream(self, stream: Any, label: str) -> None:
|
|
315
|
+
for line in iter(stream.readline, ""):
|
|
316
|
+
self._queue.put((label, line))
|
|
317
|
+
self._queue.put((label, None))
|
|
318
|
+
|
|
319
|
+
def _drain_queue(self) -> None:
|
|
320
|
+
while True:
|
|
321
|
+
try:
|
|
322
|
+
self._queue.get_nowait()
|
|
323
|
+
except queue.Empty:
|
|
324
|
+
break
|
|
325
|
+
|
|
326
|
+
@staticmethod
|
|
327
|
+
def _safe_int(value: str) -> int | None:
|
|
328
|
+
with contextlib.suppress(ValueError):
|
|
329
|
+
return int(value)
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class _ShellToolInput(BaseModel):
|
|
334
|
+
"""Input schema for the persistent shell tool."""
|
|
335
|
+
|
|
336
|
+
command: str | None = None
|
|
337
|
+
restart: bool | None = None
|
|
338
|
+
|
|
339
|
+
@model_validator(mode="after")
|
|
340
|
+
def validate_payload(self) -> _ShellToolInput:
|
|
341
|
+
if self.command is None and not self.restart:
|
|
342
|
+
msg = "Shell tool requires either 'command' or 'restart'."
|
|
343
|
+
raise ValueError(msg)
|
|
344
|
+
if self.command is not None and self.restart:
|
|
345
|
+
msg = "Specify only one of 'command' or 'restart'."
|
|
346
|
+
raise ValueError(msg)
|
|
347
|
+
return self
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class _PersistentShellTool(BaseTool):
|
|
351
|
+
"""Tool wrapper that relies on middleware interception for execution."""
|
|
352
|
+
|
|
353
|
+
name: str = "shell"
|
|
354
|
+
description: str = DEFAULT_TOOL_DESCRIPTION
|
|
355
|
+
args_schema: type[BaseModel] = _ShellToolInput
|
|
356
|
+
|
|
357
|
+
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
|
|
358
|
+
super().__init__()
|
|
359
|
+
self._middleware = middleware
|
|
360
|
+
if description is not None:
|
|
361
|
+
self.description = description
|
|
362
|
+
|
|
363
|
+
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
|
|
364
|
+
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
|
|
365
|
+
raise RuntimeError(msg)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
369
|
+
"""Middleware that registers a persistent shell tool for agents.
|
|
370
|
+
|
|
371
|
+
The middleware exposes a single long-lived shell session. Use the execution policy to
|
|
372
|
+
match your deployment's security posture:
|
|
373
|
+
|
|
374
|
+
* ``HostExecutionPolicy`` - full host access; best for trusted environments where the
|
|
375
|
+
agent already runs inside a container or VM that provides isolation.
|
|
376
|
+
* ``CodexSandboxExecutionPolicy`` - reuses the Codex CLI sandbox for additional
|
|
377
|
+
syscall/filesystem restrictions when the CLI is available.
|
|
378
|
+
* ``DockerExecutionPolicy`` - launches a separate Docker container for each agent run,
|
|
379
|
+
providing harder isolation, optional read-only root filesystems, and user remapping.
|
|
380
|
+
|
|
381
|
+
When no policy is provided the middleware defaults to ``HostExecutionPolicy``.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
state_schema = ShellToolState
|
|
385
|
+
|
|
386
|
+
def __init__(
|
|
387
|
+
self,
|
|
388
|
+
workspace_root: str | Path | None = None,
|
|
389
|
+
*,
|
|
390
|
+
startup_commands: tuple[str, ...] | list[str] | str | None = None,
|
|
391
|
+
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
|
|
392
|
+
execution_policy: BaseExecutionPolicy | None = None,
|
|
393
|
+
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
|
394
|
+
tool_description: str | None = None,
|
|
395
|
+
shell_command: Sequence[str] | str | None = None,
|
|
396
|
+
env: Mapping[str, Any] | None = None,
|
|
397
|
+
) -> None:
|
|
398
|
+
"""Initialize the middleware.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
workspace_root: Base directory for the shell session. If omitted, a temporary
|
|
402
|
+
directory is created when the agent starts and removed when it ends.
|
|
403
|
+
startup_commands: Optional commands executed sequentially after the session starts.
|
|
404
|
+
shutdown_commands: Optional commands executed before the session shuts down.
|
|
405
|
+
execution_policy: Execution policy controlling timeouts, output limits, and resource
|
|
406
|
+
configuration. Defaults to :class:`HostExecutionPolicy` for native execution.
|
|
407
|
+
redaction_rules: Optional redaction rules to sanitize command output before
|
|
408
|
+
returning it to the model.
|
|
409
|
+
tool_description: Optional override for the registered shell tool description.
|
|
410
|
+
shell_command: Optional shell executable (string) or argument sequence used to
|
|
411
|
+
launch the persistent session. Defaults to an implementation-defined bash command.
|
|
412
|
+
env: Optional environment variables to supply to the shell session. Values are
|
|
413
|
+
coerced to strings before command execution. If omitted, the session inherits the
|
|
414
|
+
parent process environment.
|
|
415
|
+
"""
|
|
416
|
+
super().__init__()
|
|
417
|
+
self._workspace_root = Path(workspace_root) if workspace_root else None
|
|
418
|
+
self._shell_command = self._normalize_shell_command(shell_command)
|
|
419
|
+
self._environment = self._normalize_env(env)
|
|
420
|
+
if execution_policy is not None:
|
|
421
|
+
self._execution_policy = execution_policy
|
|
422
|
+
else:
|
|
423
|
+
self._execution_policy = HostExecutionPolicy()
|
|
424
|
+
rules = redaction_rules or ()
|
|
425
|
+
self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple(
|
|
426
|
+
rule.resolve() for rule in rules
|
|
427
|
+
)
|
|
428
|
+
self._startup_commands = self._normalize_commands(startup_commands)
|
|
429
|
+
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
|
430
|
+
|
|
431
|
+
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
|
432
|
+
self._tool = _PersistentShellTool(self, description=description)
|
|
433
|
+
self.tools = [self._tool]
|
|
434
|
+
|
|
435
|
+
@staticmethod
|
|
436
|
+
def _normalize_commands(
|
|
437
|
+
commands: tuple[str, ...] | list[str] | str | None,
|
|
438
|
+
) -> tuple[str, ...]:
|
|
439
|
+
if commands is None:
|
|
440
|
+
return ()
|
|
441
|
+
if isinstance(commands, str):
|
|
442
|
+
return (commands,)
|
|
443
|
+
return tuple(commands)
|
|
444
|
+
|
|
445
|
+
@staticmethod
|
|
446
|
+
def _normalize_shell_command(
|
|
447
|
+
shell_command: Sequence[str] | str | None,
|
|
448
|
+
) -> tuple[str, ...]:
|
|
449
|
+
if shell_command is None:
|
|
450
|
+
return ("/bin/bash",)
|
|
451
|
+
normalized = (shell_command,) if isinstance(shell_command, str) else tuple(shell_command)
|
|
452
|
+
if not normalized:
|
|
453
|
+
msg = "Shell command must contain at least one argument."
|
|
454
|
+
raise ValueError(msg)
|
|
455
|
+
return normalized
|
|
456
|
+
|
|
457
|
+
@staticmethod
|
|
458
|
+
def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
|
|
459
|
+
if env is None:
|
|
460
|
+
return None
|
|
461
|
+
normalized: dict[str, str] = {}
|
|
462
|
+
for key, value in env.items():
|
|
463
|
+
if not isinstance(key, str):
|
|
464
|
+
msg = "Environment variable names must be strings."
|
|
465
|
+
raise TypeError(msg)
|
|
466
|
+
normalized[key] = str(value)
|
|
467
|
+
return normalized
|
|
468
|
+
|
|
469
|
+
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
|
470
|
+
"""Start the shell session and run startup commands."""
|
|
471
|
+
resources = self._create_resources()
|
|
472
|
+
return {"shell_session_resources": resources}
|
|
473
|
+
|
|
474
|
+
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
|
475
|
+
"""Async counterpart to `before_agent`."""
|
|
476
|
+
return self.before_agent(state, runtime)
|
|
477
|
+
|
|
478
|
+
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
|
|
479
|
+
"""Run shutdown commands and release resources when an agent completes."""
|
|
480
|
+
resources = self._ensure_resources(state)
|
|
481
|
+
try:
|
|
482
|
+
self._run_shutdown_commands(resources.session)
|
|
483
|
+
finally:
|
|
484
|
+
resources._finalizer()
|
|
485
|
+
|
|
486
|
+
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
|
487
|
+
"""Async counterpart to `after_agent`."""
|
|
488
|
+
return self.after_agent(state, runtime)
|
|
489
|
+
|
|
490
|
+
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
|
|
491
|
+
resources = state.get("shell_session_resources")
|
|
492
|
+
if resources is not None and not isinstance(resources, _SessionResources):
|
|
493
|
+
resources = None
|
|
494
|
+
if resources is None:
|
|
495
|
+
msg = (
|
|
496
|
+
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
|
|
497
|
+
"before invoking the shell tool."
|
|
498
|
+
)
|
|
499
|
+
raise ToolException(msg)
|
|
500
|
+
return resources
|
|
501
|
+
|
|
502
|
+
def _create_resources(self) -> _SessionResources:
|
|
503
|
+
workspace = self._workspace_root
|
|
504
|
+
tempdir: tempfile.TemporaryDirectory[str] | None = None
|
|
505
|
+
if workspace is None:
|
|
506
|
+
tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX)
|
|
507
|
+
workspace_path = Path(tempdir.name)
|
|
508
|
+
else:
|
|
509
|
+
workspace_path = workspace
|
|
510
|
+
workspace_path.mkdir(parents=True, exist_ok=True)
|
|
511
|
+
|
|
512
|
+
session = ShellSession(
|
|
513
|
+
workspace_path,
|
|
514
|
+
self._execution_policy,
|
|
515
|
+
self._shell_command,
|
|
516
|
+
self._environment or {},
|
|
517
|
+
)
|
|
518
|
+
try:
|
|
519
|
+
session.start()
|
|
520
|
+
LOGGER.info("Started shell session in %s", workspace_path)
|
|
521
|
+
self._run_startup_commands(session)
|
|
522
|
+
except BaseException:
|
|
523
|
+
LOGGER.exception("Starting shell session failed; cleaning up resources.")
|
|
524
|
+
session.stop(self._execution_policy.termination_timeout)
|
|
525
|
+
if tempdir is not None:
|
|
526
|
+
tempdir.cleanup()
|
|
527
|
+
raise
|
|
528
|
+
|
|
529
|
+
return _SessionResources(session=session, tempdir=tempdir, policy=self._execution_policy)
|
|
530
|
+
|
|
531
|
+
def _run_startup_commands(self, session: ShellSession) -> None:
|
|
532
|
+
if not self._startup_commands:
|
|
533
|
+
return
|
|
534
|
+
for command in self._startup_commands:
|
|
535
|
+
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
|
|
536
|
+
if result.timed_out or (result.exit_code not in (0, None)):
|
|
537
|
+
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
|
|
538
|
+
raise RuntimeError(msg)
|
|
539
|
+
|
|
540
|
+
def _run_shutdown_commands(self, session: ShellSession) -> None:
|
|
541
|
+
if not self._shutdown_commands:
|
|
542
|
+
return
|
|
543
|
+
for command in self._shutdown_commands:
|
|
544
|
+
try:
|
|
545
|
+
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
|
546
|
+
if result.timed_out:
|
|
547
|
+
LOGGER.warning("Shutdown command '%s' timed out.", command)
|
|
548
|
+
elif result.exit_code not in (0, None):
|
|
549
|
+
LOGGER.warning(
|
|
550
|
+
"Shutdown command '%s' exited with %s.", command, result.exit_code
|
|
551
|
+
)
|
|
552
|
+
except (RuntimeError, ToolException, OSError) as exc:
|
|
553
|
+
LOGGER.warning(
|
|
554
|
+
"Failed to run shutdown command '%s': %s", command, exc, exc_info=True
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
def _apply_redactions(self, content: str) -> tuple[str, dict[str, list[PIIMatch]]]:
|
|
558
|
+
"""Apply configured redaction rules to command output."""
|
|
559
|
+
matches_by_type: dict[str, list[PIIMatch]] = {}
|
|
560
|
+
updated = content
|
|
561
|
+
for rule in self._redaction_rules:
|
|
562
|
+
updated, matches = rule.apply(updated)
|
|
563
|
+
if matches:
|
|
564
|
+
matches_by_type.setdefault(rule.pii_type, []).extend(matches)
|
|
565
|
+
return updated, matches_by_type
|
|
566
|
+
|
|
567
|
+
def _run_shell_tool(
|
|
568
|
+
self,
|
|
569
|
+
resources: _SessionResources,
|
|
570
|
+
payload: dict[str, Any],
|
|
571
|
+
*,
|
|
572
|
+
tool_call_id: str | None,
|
|
573
|
+
) -> Any:
|
|
574
|
+
session = resources.session
|
|
575
|
+
|
|
576
|
+
if payload.get("restart"):
|
|
577
|
+
LOGGER.info("Restarting shell session on request.")
|
|
578
|
+
try:
|
|
579
|
+
session.restart()
|
|
580
|
+
self._run_startup_commands(session)
|
|
581
|
+
except BaseException as err:
|
|
582
|
+
LOGGER.exception("Restarting shell session failed; session remains unavailable.")
|
|
583
|
+
msg = "Failed to restart shell session."
|
|
584
|
+
raise ToolException(msg) from err
|
|
585
|
+
message = "Shell session restarted."
|
|
586
|
+
return self._format_tool_message(message, tool_call_id, status="success")
|
|
587
|
+
|
|
588
|
+
command = payload.get("command")
|
|
589
|
+
if not command or not isinstance(command, str):
|
|
590
|
+
msg = "Shell tool expects a 'command' string when restart is not requested."
|
|
591
|
+
raise ToolException(msg)
|
|
592
|
+
|
|
593
|
+
LOGGER.info("Executing shell command: %s", command)
|
|
594
|
+
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
|
595
|
+
|
|
596
|
+
if result.timed_out:
|
|
597
|
+
timeout_seconds = self._execution_policy.command_timeout
|
|
598
|
+
message = f"Error: Command timed out after {timeout_seconds:.1f} seconds."
|
|
599
|
+
return self._format_tool_message(
|
|
600
|
+
message,
|
|
601
|
+
tool_call_id,
|
|
602
|
+
status="error",
|
|
603
|
+
artifact={
|
|
604
|
+
"timed_out": True,
|
|
605
|
+
"exit_code": None,
|
|
606
|
+
},
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
try:
|
|
610
|
+
sanitized_output, matches = self._apply_redactions(result.output)
|
|
611
|
+
except PIIDetectionError as error:
|
|
612
|
+
LOGGER.warning("Blocking command output due to detected %s.", error.pii_type)
|
|
613
|
+
message = f"Output blocked: detected {error.pii_type}."
|
|
614
|
+
return self._format_tool_message(
|
|
615
|
+
message,
|
|
616
|
+
tool_call_id,
|
|
617
|
+
status="error",
|
|
618
|
+
artifact={
|
|
619
|
+
"timed_out": False,
|
|
620
|
+
"exit_code": result.exit_code,
|
|
621
|
+
"matches": {error.pii_type: error.matches},
|
|
622
|
+
},
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
sanitized_output = sanitized_output or "<no output>"
|
|
626
|
+
if result.truncated_by_lines:
|
|
627
|
+
sanitized_output = (
|
|
628
|
+
f"{sanitized_output.rstrip()}\n\n"
|
|
629
|
+
f"... Output truncated at {self._execution_policy.max_output_lines} lines "
|
|
630
|
+
f"(observed {result.total_lines})."
|
|
631
|
+
)
|
|
632
|
+
if result.truncated_by_bytes and self._execution_policy.max_output_bytes is not None:
|
|
633
|
+
sanitized_output = (
|
|
634
|
+
f"{sanitized_output.rstrip()}\n\n"
|
|
635
|
+
f"... Output truncated at {self._execution_policy.max_output_bytes} bytes "
|
|
636
|
+
f"(observed {result.total_bytes})."
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
if result.exit_code not in (0, None):
|
|
640
|
+
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
|
|
641
|
+
final_status: Literal["success", "error"] = "error"
|
|
642
|
+
else:
|
|
643
|
+
final_status = "success"
|
|
644
|
+
|
|
645
|
+
artifact = {
|
|
646
|
+
"timed_out": False,
|
|
647
|
+
"exit_code": result.exit_code,
|
|
648
|
+
"truncated_by_lines": result.truncated_by_lines,
|
|
649
|
+
"truncated_by_bytes": result.truncated_by_bytes,
|
|
650
|
+
"total_lines": result.total_lines,
|
|
651
|
+
"total_bytes": result.total_bytes,
|
|
652
|
+
"redaction_matches": matches,
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
return self._format_tool_message(
|
|
656
|
+
sanitized_output,
|
|
657
|
+
tool_call_id,
|
|
658
|
+
status=final_status,
|
|
659
|
+
artifact=artifact,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
def wrap_tool_call(
|
|
663
|
+
self,
|
|
664
|
+
request: ToolCallRequest,
|
|
665
|
+
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
|
|
666
|
+
) -> ToolMessage | Command:
|
|
667
|
+
"""Intercept local shell tool calls and execute them via the managed session."""
|
|
668
|
+
if isinstance(request.tool, _PersistentShellTool):
|
|
669
|
+
resources = self._ensure_resources(request.state)
|
|
670
|
+
return self._run_shell_tool(
|
|
671
|
+
resources,
|
|
672
|
+
request.tool_call["args"],
|
|
673
|
+
tool_call_id=request.tool_call.get("id"),
|
|
674
|
+
)
|
|
675
|
+
return handler(request)
|
|
676
|
+
|
|
677
|
+
async def awrap_tool_call(
|
|
678
|
+
self,
|
|
679
|
+
request: ToolCallRequest,
|
|
680
|
+
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
|
681
|
+
) -> ToolMessage | Command:
|
|
682
|
+
"""Async interception mirroring the synchronous tool handler."""
|
|
683
|
+
if isinstance(request.tool, _PersistentShellTool):
|
|
684
|
+
resources = self._ensure_resources(request.state)
|
|
685
|
+
return self._run_shell_tool(
|
|
686
|
+
resources,
|
|
687
|
+
request.tool_call["args"],
|
|
688
|
+
tool_call_id=request.tool_call.get("id"),
|
|
689
|
+
)
|
|
690
|
+
return await handler(request)
|
|
691
|
+
|
|
692
|
+
def _format_tool_message(
|
|
693
|
+
self,
|
|
694
|
+
content: str,
|
|
695
|
+
tool_call_id: str | None,
|
|
696
|
+
*,
|
|
697
|
+
status: Literal["success", "error"],
|
|
698
|
+
artifact: dict[str, Any] | None = None,
|
|
699
|
+
) -> ToolMessage | str:
|
|
700
|
+
artifact = artifact or {}
|
|
701
|
+
if tool_call_id is None:
|
|
702
|
+
return content
|
|
703
|
+
return ToolMessage(
|
|
704
|
+
content=content,
|
|
705
|
+
tool_call_id=tool_call_id,
|
|
706
|
+
name=self._tool.name,
|
|
707
|
+
status=status,
|
|
708
|
+
artifact=artifact,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
__all__ = [
|
|
713
|
+
"CodexSandboxExecutionPolicy",
|
|
714
|
+
"DockerExecutionPolicy",
|
|
715
|
+
"HostExecutionPolicy",
|
|
716
|
+
"RedactionRule",
|
|
717
|
+
"ShellToolMiddleware",
|
|
718
|
+
]
|