langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +1 -7
  3. langchain/agents/factory.py +153 -79
  4. langchain/agents/middleware/__init__.py +18 -23
  5. langchain/agents/middleware/_execution.py +29 -32
  6. langchain/agents/middleware/_redaction.py +108 -22
  7. langchain/agents/middleware/_retry.py +123 -0
  8. langchain/agents/middleware/context_editing.py +47 -25
  9. langchain/agents/middleware/file_search.py +19 -14
  10. langchain/agents/middleware/human_in_the_loop.py +87 -57
  11. langchain/agents/middleware/model_call_limit.py +64 -18
  12. langchain/agents/middleware/model_fallback.py +7 -9
  13. langchain/agents/middleware/model_retry.py +307 -0
  14. langchain/agents/middleware/pii.py +82 -29
  15. langchain/agents/middleware/shell_tool.py +254 -107
  16. langchain/agents/middleware/summarization.py +469 -95
  17. langchain/agents/middleware/todo.py +129 -31
  18. langchain/agents/middleware/tool_call_limit.py +105 -71
  19. langchain/agents/middleware/tool_emulator.py +47 -38
  20. langchain/agents/middleware/tool_retry.py +183 -164
  21. langchain/agents/middleware/tool_selection.py +81 -37
  22. langchain/agents/middleware/types.py +856 -427
  23. langchain/agents/structured_output.py +65 -42
  24. langchain/chat_models/__init__.py +1 -7
  25. langchain/chat_models/base.py +253 -196
  26. langchain/embeddings/__init__.py +0 -5
  27. langchain/embeddings/base.py +79 -65
  28. langchain/messages/__init__.py +0 -5
  29. langchain/tools/__init__.py +1 -7
  30. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
  31. langchain-1.2.4.dist-info/RECORD +36 -0
  32. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
  33. langchain-1.0.5.dist-info/RECORD +0 -34
  34. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -11,18 +11,18 @@ import subprocess
11
11
  import tempfile
12
12
  import threading
13
13
  import time
14
- import typing
15
14
  import uuid
16
15
  import weakref
17
16
  from dataclasses import dataclass, field
18
17
  from pathlib import Path
19
- from typing import TYPE_CHECKING, Annotated, Any, Literal
18
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
20
19
 
21
20
  from langchain_core.messages import ToolMessage
22
- from langchain_core.tools.base import BaseTool, ToolException
21
+ from langchain_core.tools.base import ToolException
23
22
  from langgraph.channels.untracked_value import UntrackedValue
24
23
  from pydantic import BaseModel, model_validator
25
- from typing_extensions import NotRequired
24
+ from pydantic.json_schema import SkipJsonSchema
25
+ from typing_extensions import NotRequired, override
26
26
 
27
27
  from langchain.agents.middleware._execution import (
28
28
  SHELL_TEMP_PREFIX,
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
38
38
  ResolvedRedactionRule,
39
39
  )
40
40
  from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
41
+ from langchain.tools import ToolRuntime, tool
41
42
 
42
43
  if TYPE_CHECKING:
43
44
  from collections.abc import Mapping, Sequence
44
45
 
45
46
  from langgraph.runtime import Runtime
46
- from langgraph.types import Command
47
47
 
48
- from langchain.agents.middleware.types import ToolCallRequest
49
48
 
50
49
  LOGGER = logging.getLogger(__name__)
51
50
  _DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
59
58
  "session remains stable. Outputs may be truncated when they become very large, and long "
60
59
  "running commands will be terminated once their configured timeout elapses."
61
60
  )
61
+ SHELL_TOOL_NAME = "shell"
62
62
 
63
63
 
64
64
  def _cleanup_resources(
@@ -78,10 +78,10 @@ class _SessionResources:
78
78
  session: ShellSession
79
79
  tempdir: tempfile.TemporaryDirectory[str] | None
80
80
  policy: BaseExecutionPolicy
81
- _finalizer: weakref.finalize = field(init=False, repr=False)
81
+ finalizer: weakref.finalize = field(init=False, repr=False) # type: ignore[type-arg]
82
82
 
83
83
  def __post_init__(self) -> None:
84
- self._finalizer = weakref.finalize(
84
+ self.finalizer = weakref.finalize(
85
85
  self,
86
86
  _cleanup_resources,
87
87
  self.session,
@@ -90,7 +90,7 @@ class _SessionResources:
90
90
  )
91
91
 
92
92
 
93
- class ShellToolState(AgentState):
93
+ class ShellToolState(AgentState[Any]):
94
94
  """Agent state extension for tracking shell session resources."""
95
95
 
96
96
  shell_session_resources: NotRequired[
@@ -134,7 +134,11 @@ class ShellSession:
134
134
  self._terminated = False
135
135
 
136
136
  def start(self) -> None:
137
- """Start the shell subprocess and reader threads."""
137
+ """Start the shell subprocess and reader threads.
138
+
139
+ Raises:
140
+ RuntimeError: If the shell session pipes cannot be initialized.
141
+ """
138
142
  if self._process and self._process.poll() is None:
139
143
  return
140
144
 
@@ -211,9 +215,14 @@ class ShellSession:
211
215
  with self._lock:
212
216
  self._drain_queue()
213
217
  payload = command if command.endswith("\n") else f"{command}\n"
214
- self._stdin.write(payload)
215
- self._stdin.write(f"printf '{marker} %s\\n' $?\n")
216
- self._stdin.flush()
218
+ try:
219
+ self._stdin.write(payload)
220
+ self._stdin.write(f"printf '{marker} %s\\n' $?\n")
221
+ self._stdin.flush()
222
+ except (BrokenPipeError, OSError):
223
+ # The shell exited before we could write the marker command.
224
+ # This happens when commands like 'exit 1' terminate the shell.
225
+ return self._collect_output_after_exit(deadline)
217
226
 
218
227
  return self._collect_output(marker, deadline, timeout)
219
228
 
@@ -248,6 +257,10 @@ class ShellSession:
248
257
  if source == "stdout" and data.startswith(marker):
249
258
  _, _, status = data.partition(" ")
250
259
  exit_code = self._safe_int(status.strip())
260
+ # Drain any remaining stderr that may have arrived concurrently.
261
+ # The stderr reader thread runs independently, so output might
262
+ # still be in flight when the stdout marker arrives.
263
+ self._drain_remaining_stderr(collected, deadline)
251
264
  break
252
265
 
253
266
  total_lines += 1
@@ -300,6 +313,80 @@ class ShellSession:
300
313
  total_bytes=total_bytes,
301
314
  )
302
315
 
316
+ def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult:
317
+ """Collect output after the shell exited unexpectedly.
318
+
319
+ Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
320
+ shell process terminated (e.g., due to an 'exit' command).
321
+
322
+ Args:
323
+ deadline: Absolute time by which collection must complete.
324
+
325
+ Returns:
326
+ `CommandExecutionResult` with collected output and the process exit code.
327
+ """
328
+ collected: list[str] = []
329
+ total_lines = 0
330
+ total_bytes = 0
331
+ truncated_by_lines = False
332
+ truncated_by_bytes = False
333
+
334
+ # Give reader threads a brief moment to enqueue any remaining output.
335
+ drain_timeout = 0.1
336
+ drain_deadline = min(time.monotonic() + drain_timeout, deadline)
337
+
338
+ while True:
339
+ remaining = drain_deadline - time.monotonic()
340
+ if remaining <= 0:
341
+ break
342
+ try:
343
+ source, data = self._queue.get(timeout=remaining)
344
+ except queue.Empty:
345
+ break
346
+
347
+ if data is None:
348
+ # EOF marker from a reader thread; continue draining.
349
+ continue
350
+
351
+ total_lines += 1
352
+ encoded = data.encode("utf-8", "replace")
353
+ total_bytes += len(encoded)
354
+
355
+ if total_lines > self._policy.max_output_lines:
356
+ truncated_by_lines = True
357
+ continue
358
+
359
+ if (
360
+ self._policy.max_output_bytes is not None
361
+ and total_bytes > self._policy.max_output_bytes
362
+ ):
363
+ truncated_by_bytes = True
364
+ continue
365
+
366
+ if source == "stderr":
367
+ stripped = data.rstrip("\n")
368
+ collected.append(f"[stderr] {stripped}")
369
+ if data.endswith("\n"):
370
+ collected.append("\n")
371
+ else:
372
+ collected.append(data)
373
+
374
+ # Get exit code from the terminated process.
375
+ exit_code: int | None = None
376
+ if self._process:
377
+ exit_code = self._process.poll()
378
+
379
+ output = "".join(collected)
380
+ return CommandExecutionResult(
381
+ output=output,
382
+ exit_code=exit_code,
383
+ timed_out=False,
384
+ truncated_by_lines=truncated_by_lines,
385
+ truncated_by_bytes=truncated_by_bytes,
386
+ total_lines=total_lines,
387
+ total_bytes=total_bytes,
388
+ )
389
+
303
390
  def _kill_process(self) -> None:
304
391
  if not self._process:
305
392
  return
@@ -323,6 +410,37 @@ class ShellSession:
323
410
  except queue.Empty:
324
411
  break
325
412
 
413
+ def _drain_remaining_stderr(
414
+ self, collected: list[str], deadline: float, drain_timeout: float = 0.05
415
+ ) -> None:
416
+ """Drain any stderr output that arrived concurrently with the done marker.
417
+
418
+ The stdout and stderr reader threads run independently. When a command writes to
419
+ stderr just before exiting, the stderr output may still be in transit when the
420
+ done marker arrives on stdout. This method briefly polls the queue to capture
421
+ such output.
422
+
423
+ Args:
424
+ collected: The list to append collected stderr lines to.
425
+ deadline: The original command deadline (used as an upper bound).
426
+ drain_timeout: Maximum time to wait for additional stderr output.
427
+ """
428
+ drain_deadline = min(time.monotonic() + drain_timeout, deadline)
429
+ while True:
430
+ remaining = drain_deadline - time.monotonic()
431
+ if remaining <= 0:
432
+ break
433
+ try:
434
+ source, data = self._queue.get(timeout=remaining)
435
+ except queue.Empty:
436
+ break
437
+ if data is None or source != "stderr":
438
+ continue
439
+ stripped = data.rstrip("\n")
440
+ collected.append(f"[stderr] {stripped}")
441
+ if data.endswith("\n"):
442
+ collected.append("\n")
443
+
326
444
  @staticmethod
327
445
  def _safe_int(value: str) -> int | None:
328
446
  with contextlib.suppress(ValueError):
@@ -334,7 +452,17 @@ class _ShellToolInput(BaseModel):
334
452
  """Input schema for the persistent shell tool."""
335
453
 
336
454
  command: str | None = None
455
+ """The shell command to execute."""
456
+
337
457
  restart: bool | None = None
458
+ """Whether to restart the shell session."""
459
+
460
+ runtime: Annotated[Any, SkipJsonSchema()] = None
461
+ """The runtime for the shell tool.
462
+
463
+ Included as a workaround at the moment bc args_schema doesn't work with
464
+ injected ToolRuntime.
465
+ """
338
466
 
339
467
  @model_validator(mode="after")
340
468
  def validate_payload(self) -> _ShellToolInput:
@@ -347,38 +475,21 @@ class _ShellToolInput(BaseModel):
347
475
  return self
348
476
 
349
477
 
350
- class _PersistentShellTool(BaseTool):
351
- """Tool wrapper that relies on middleware interception for execution."""
352
-
353
- name: str = "shell"
354
- description: str = DEFAULT_TOOL_DESCRIPTION
355
- args_schema: type[BaseModel] = _ShellToolInput
356
-
357
- def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
358
- super().__init__()
359
- self._middleware = middleware
360
- if description is not None:
361
- self.description = description
362
-
363
- def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
364
- msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
365
- raise RuntimeError(msg)
366
-
367
-
368
478
  class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
369
479
  """Middleware that registers a persistent shell tool for agents.
370
480
 
371
- The middleware exposes a single long-lived shell session. Use the execution policy to
372
- match your deployment's security posture:
481
+ The middleware exposes a single long-lived shell session. Use the execution policy
482
+ to match your deployment's security posture:
373
483
 
374
- * ``HostExecutionPolicy`` - full host access; best for trusted environments where the
375
- agent already runs inside a container or VM that provides isolation.
376
- * ``CodexSandboxExecutionPolicy`` - reuses the Codex CLI sandbox for additional
377
- syscall/filesystem restrictions when the CLI is available.
378
- * ``DockerExecutionPolicy`` - launches a separate Docker container for each agent run,
379
- providing harder isolation, optional read-only root filesystems, and user remapping.
484
+ * `HostExecutionPolicy` full host access; best for trusted environments where the
485
+ agent already runs inside a container or VM that provides isolation.
486
+ * `CodexSandboxExecutionPolicy` reuses the Codex CLI sandbox for additional
487
+ syscall/filesystem restrictions when the CLI is available.
488
+ * `DockerExecutionPolicy` launches a separate Docker container for each agent run,
489
+ providing harder isolation, optional read-only root filesystems, and user
490
+ remapping.
380
491
 
381
- When no policy is provided the middleware defaults to ``HostExecutionPolicy``.
492
+ When no policy is provided the middleware defaults to `HostExecutionPolicy`.
382
493
  """
383
494
 
384
495
  state_schema = ShellToolState
@@ -392,29 +503,49 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
392
503
  execution_policy: BaseExecutionPolicy | None = None,
393
504
  redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
394
505
  tool_description: str | None = None,
506
+ tool_name: str = SHELL_TOOL_NAME,
395
507
  shell_command: Sequence[str] | str | None = None,
396
508
  env: Mapping[str, Any] | None = None,
397
509
  ) -> None:
398
- """Initialize the middleware.
510
+ """Initialize an instance of `ShellToolMiddleware`.
399
511
 
400
512
  Args:
401
- workspace_root: Base directory for the shell session. If omitted, a temporary
402
- directory is created when the agent starts and removed when it ends.
403
- startup_commands: Optional commands executed sequentially after the session starts.
513
+ workspace_root: Base directory for the shell session.
514
+
515
+ If omitted, a temporary directory is created when the agent starts and
516
+ removed when it ends.
517
+ startup_commands: Optional commands executed sequentially after the session
518
+ starts.
404
519
  shutdown_commands: Optional commands executed before the session shuts down.
405
- execution_policy: Execution policy controlling timeouts, output limits, and resource
406
- configuration. Defaults to :class:`HostExecutionPolicy` for native execution.
520
+ execution_policy: Execution policy controlling timeouts, output limits, and
521
+ resource configuration.
522
+
523
+ Defaults to `HostExecutionPolicy` for native execution.
407
524
  redaction_rules: Optional redaction rules to sanitize command output before
408
525
  returning it to the model.
409
- tool_description: Optional override for the registered shell tool description.
410
- shell_command: Optional shell executable (string) or argument sequence used to
411
- launch the persistent session. Defaults to an implementation-defined bash command.
412
- env: Optional environment variables to supply to the shell session. Values are
413
- coerced to strings before command execution. If omitted, the session inherits the
414
- parent process environment.
526
+
527
+ !!! warning
528
+ Redaction rules are applied post execution and do not prevent
529
+ exfiltration of secrets or sensitive data when using
530
+ `HostExecutionPolicy`.
531
+
532
+ tool_description: Optional override for the registered shell tool
533
+ description.
534
+ tool_name: Name for the registered shell tool.
535
+
536
+ Defaults to `"shell"`.
537
+ shell_command: Optional shell executable (string) or argument sequence used
538
+ to launch the persistent session.
539
+
540
+ Defaults to an implementation-defined bash command.
541
+ env: Optional environment variables to supply to the shell session.
542
+
543
+ Values are coerced to strings before command execution. If omitted, the
544
+ session inherits the parent process environment.
415
545
  """
416
546
  super().__init__()
417
547
  self._workspace_root = Path(workspace_root) if workspace_root else None
548
+ self._tool_name = tool_name
418
549
  self._shell_command = self._normalize_shell_command(shell_command)
419
550
  self._environment = self._normalize_env(env)
420
551
  if execution_policy is not None:
@@ -428,9 +559,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
428
559
  self._startup_commands = self._normalize_commands(startup_commands)
429
560
  self._shutdown_commands = self._normalize_commands(shutdown_commands)
430
561
 
562
+ # Create a proper tool that executes directly (no interception needed)
431
563
  description = tool_description or DEFAULT_TOOL_DESCRIPTION
432
- self._tool = _PersistentShellTool(self, description=description)
433
- self.tools = [self._tool]
564
+
565
+ @tool(self._tool_name, args_schema=_ShellToolInput, description=description)
566
+ def shell_tool(
567
+ *,
568
+ runtime: ToolRuntime[None, ShellToolState],
569
+ command: str | None = None,
570
+ restart: bool = False,
571
+ ) -> ToolMessage | str:
572
+ resources = self._get_or_create_resources(runtime.state)
573
+ return self._run_shell_tool(
574
+ resources,
575
+ {"command": command, "restart": restart},
576
+ tool_call_id=runtime.tool_call_id,
577
+ )
578
+
579
+ self._shell_tool = shell_tool
580
+ self.tools = [self._shell_tool]
434
581
 
435
582
  @staticmethod
436
583
  def _normalize_commands(
@@ -461,43 +608,73 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
461
608
  normalized: dict[str, str] = {}
462
609
  for key, value in env.items():
463
610
  if not isinstance(key, str):
464
- msg = "Environment variable names must be strings."
611
+ msg = "Environment variable names must be strings." # type: ignore[unreachable]
465
612
  raise TypeError(msg)
466
613
  normalized[key] = str(value)
467
614
  return normalized
468
615
 
469
- def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
470
- """Start the shell session and run startup commands."""
471
- resources = self._create_resources()
616
+ @override
617
+ def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
618
+ """Start the shell session and run startup commands.
619
+
620
+ Args:
621
+ state: The current agent state.
622
+ runtime: The runtime context.
623
+
624
+ Returns:
625
+ Shell session resources to be stored in the agent state.
626
+ """
627
+ resources = self._get_or_create_resources(state)
472
628
  return {"shell_session_resources": resources}
473
629
 
474
630
  async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
475
- """Async counterpart to `before_agent`."""
631
+ """Async start the shell session and run startup commands.
632
+
633
+ Args:
634
+ state: The current agent state.
635
+ runtime: The runtime context.
636
+
637
+ Returns:
638
+ Shell session resources to be stored in the agent state.
639
+ """
476
640
  return self.before_agent(state, runtime)
477
641
 
478
- def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
642
+ @override
643
+ def after_agent(self, state: ShellToolState, runtime: Runtime) -> None:
479
644
  """Run shutdown commands and release resources when an agent completes."""
480
- resources = self._ensure_resources(state)
645
+ resources = state.get("shell_session_resources")
646
+ if not isinstance(resources, _SessionResources):
647
+ # Resources were never created, nothing to clean up
648
+ return
481
649
  try:
482
650
  self._run_shutdown_commands(resources.session)
483
651
  finally:
484
- resources._finalizer()
652
+ resources.finalizer()
485
653
 
486
654
  async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
487
- """Async counterpart to `after_agent`."""
655
+ """Async run shutdown commands and release resources when an agent completes."""
488
656
  return self.after_agent(state, runtime)
489
657
 
490
- def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
658
+ def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
659
+ """Get existing resources from state or create new ones if they don't exist.
660
+
661
+ This method enables resumability by checking if resources already exist in the state
662
+ (e.g., after an interrupt), and only creating new resources if they're not present.
663
+
664
+ Args:
665
+ state: The agent state which may contain shell session resources.
666
+
667
+ Returns:
668
+ Session resources, either retrieved from state or newly created.
669
+ """
491
670
  resources = state.get("shell_session_resources")
492
- if resources is not None and not isinstance(resources, _SessionResources):
493
- resources = None
494
- if resources is None:
495
- msg = (
496
- "Shell session resources are unavailable. Ensure `before_agent` ran successfully "
497
- "before invoking the shell tool."
498
- )
499
- raise ToolException(msg)
500
- return resources
671
+ if isinstance(resources, _SessionResources):
672
+ return resources
673
+
674
+ new_resources = self._create_resources()
675
+ # Cast needed to make state dict-like for mutation
676
+ cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
677
+ return new_resources
501
678
 
502
679
  def _create_resources(self) -> _SessionResources:
503
680
  workspace = self._workspace_root
@@ -533,7 +710,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
533
710
  return
534
711
  for command in self._startup_commands:
535
712
  result = session.execute(command, timeout=self._execution_policy.startup_timeout)
536
- if result.timed_out or (result.exit_code not in (0, None)):
713
+ if result.timed_out or (result.exit_code not in {0, None}):
537
714
  msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
538
715
  raise RuntimeError(msg)
539
716
 
@@ -545,7 +722,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
545
722
  result = session.execute(command, timeout=self._execution_policy.command_timeout)
546
723
  if result.timed_out:
547
724
  LOGGER.warning("Shutdown command '%s' timed out.", command)
548
- elif result.exit_code not in (0, None):
725
+ elif result.exit_code not in {0, None}:
549
726
  LOGGER.warning(
550
727
  "Shutdown command '%s' exited with %s.", command, result.exit_code
551
728
  )
@@ -636,7 +813,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
636
813
  f"(observed {result.total_bytes})."
637
814
  )
638
815
 
639
- if result.exit_code not in (0, None):
816
+ if result.exit_code not in {0, None}:
640
817
  sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
641
818
  final_status: Literal["success", "error"] = "error"
642
819
  else:
@@ -659,36 +836,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
659
836
  artifact=artifact,
660
837
  )
661
838
 
662
- def wrap_tool_call(
663
- self,
664
- request: ToolCallRequest,
665
- handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
666
- ) -> ToolMessage | Command:
667
- """Intercept local shell tool calls and execute them via the managed session."""
668
- if isinstance(request.tool, _PersistentShellTool):
669
- resources = self._ensure_resources(request.state)
670
- return self._run_shell_tool(
671
- resources,
672
- request.tool_call["args"],
673
- tool_call_id=request.tool_call.get("id"),
674
- )
675
- return handler(request)
676
-
677
- async def awrap_tool_call(
678
- self,
679
- request: ToolCallRequest,
680
- handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
681
- ) -> ToolMessage | Command:
682
- """Async interception mirroring the synchronous tool handler."""
683
- if isinstance(request.tool, _PersistentShellTool):
684
- resources = self._ensure_resources(request.state)
685
- return self._run_shell_tool(
686
- resources,
687
- request.tool_call["args"],
688
- tool_call_id=request.tool_call.get("id"),
689
- )
690
- return await handler(request)
691
-
692
839
  def _format_tool_message(
693
840
  self,
694
841
  content: str,
@@ -703,7 +850,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
703
850
  return ToolMessage(
704
851
  content=content,
705
852
  tool_call_id=tool_call_id,
706
- name=self._tool.name,
853
+ name=self._tool_name,
707
854
  status=status,
708
855
  artifact=artifact,
709
856
  )