predict-rlm 0.2.2__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/.gitignore +1 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/PKG-INFO +3 -3
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/README.md +2 -2
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/pyproject.toml +1 -1
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/__init__.py +4 -1
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/_shared.py +21 -7
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/files.py +53 -1
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/interpreter.py +162 -8
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/predict_rlm.py +375 -125
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/sandbox/runner.js +49 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/spreadsheet/skill.py +17 -1
- predict_rlm-0.2.3/src/predict_rlm/trace.py +364 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/LICENSE +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/rlm_skills.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/__init__.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/docx/__init__.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/docx/modules/md2docx.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/docx/skill.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/pdf/__init__.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/pdf/skill.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/spreadsheet/__init__.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.3}/src/predict_rlm/skills/spreadsheet/modules/formula_eval.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: predict-rlm
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: Production-grade RLMs (Recursive Language Models) with tool use, built on DSPy
|
|
5
5
|
Project-URL: Homepage, https://www.trampoline.ai/
|
|
6
6
|
Project-URL: Repository, https://github.com/Trampoline-AI/predict-rlm
|
|
@@ -25,7 +25,7 @@ Requires-Dist: pymupdf>=1.24.0; extra == 'examples'
|
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
|
|
27
27
|
# predict-rlm
|
|
28
|
-
|
|
28
|
+
Production focused Self-harnessed LM runtime (RLM) that allows the LM to call its sub-lm with [DSPy](https://dspy.ai) signatures. Define your inputs, outputs, and tools — the model handles its own control flow. Get fully interpretable trajectories and performance that scales directly with model improvements. Without context rot.
|
|
29
29
|
|
|
30
30
|
Based on the [Recursive Language Models](https://arxiv.org/abs/2512.24601v1) paper by [Alex L. Zhang](https://x.com/a1zhang), [Tim Kraska](https://x.com/tim_kraska), and [Omar Khattab](https://x.com/lateinteraction) from the Stanford NLP lab.<br/>
|
|
31
31
|
|
|
@@ -68,7 +68,7 @@ uv add predict-rlm
|
|
|
68
68
|
|
|
69
69
|
- **Multimodal** — process images, documents, audio, and video through sub-LM calls using native provider multimodal APIs.
|
|
70
70
|
- **Async tool calling** — native RLM async support in the WASM sandbox, enabling concurrent sub-LM invocations and tool calls
|
|
71
|
-
- **Prompt-optimized skills & tools** —
|
|
71
|
+
- **Prompt-optimized skills & tools** — predict-rlm skills comes tested and optimized to ensure maximum LM interoperability and performance, bundling instructions, PyPI packages, and tools for domain-specific tasks
|
|
72
72
|
- **Simple file I/O** — pass local or cloud files as typed inputs and outputs via `File`, keeping interop with your existing data pipelines straightforward. (S3 files support soon)
|
|
73
73
|
- **Structured sub-LM calls** — native Pydantic and DSPy signature support for type-safe sub-LM invocations with structured outputs
|
|
74
74
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# predict-rlm
|
|
2
|
-
|
|
2
|
+
Production focused Self-harnessed LM runtime (RLM) that allows the LM to call its sub-lm with [DSPy](https://dspy.ai) signatures. Define your inputs, outputs, and tools — the model handles its own control flow. Get fully interpretable trajectories and performance that scales directly with model improvements. Without context rot.
|
|
3
3
|
|
|
4
4
|
Based on the [Recursive Language Models](https://arxiv.org/abs/2512.24601v1) paper by [Alex L. Zhang](https://x.com/a1zhang), [Tim Kraska](https://x.com/tim_kraska), and [Omar Khattab](https://x.com/lateinteraction) from the Stanford NLP lab.<br/>
|
|
5
5
|
|
|
@@ -42,7 +42,7 @@ uv add predict-rlm
|
|
|
42
42
|
|
|
43
43
|
- **Multimodal** — process images, documents, audio, and video through sub-LM calls using native provider multimodal APIs.
|
|
44
44
|
- **Async tool calling** — native RLM async support in the WASM sandbox, enabling concurrent sub-LM invocations and tool calls
|
|
45
|
-
- **Prompt-optimized skills & tools** —
|
|
45
|
+
- **Prompt-optimized skills & tools** — predict-rlm skills comes tested and optimized to ensure maximum LM interoperability and performance, bundling instructions, PyPI packages, and tools for domain-specific tasks
|
|
46
46
|
- **Simple file I/O** — pass local or cloud files as typed inputs and outputs via `File`, keeping interop with your existing data pipelines straightforward. (S3 files support soon)
|
|
47
47
|
- **Structured sub-LM calls** — native Pydantic and DSPy signature support for type-safe sub-LM invocations with structured outputs
|
|
48
48
|
|
|
@@ -9,9 +9,10 @@ File I/O:
|
|
|
9
9
|
(sync from sandbox). Use ``list[File]`` for multiple files.
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
from .files import File, LocalDir, LocalFile, OutputDir, OutputFile
|
|
12
|
+
from .files import File, LocalDir, LocalFile, OutputDir, OutputFile, SyncedFile
|
|
13
13
|
from .predict_rlm import PredictRLM
|
|
14
14
|
from .rlm_skills import Skill
|
|
15
|
+
from .trace import RunTrace
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
17
18
|
"File",
|
|
@@ -20,5 +21,7 @@ __all__ = [
|
|
|
20
21
|
"OutputDir",
|
|
21
22
|
"OutputFile",
|
|
22
23
|
"PredictRLM",
|
|
24
|
+
"RunTrace",
|
|
23
25
|
"Skill",
|
|
26
|
+
"SyncedFile",
|
|
24
27
|
]
|
|
@@ -4,7 +4,9 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import inspect
|
|
6
6
|
import textwrap
|
|
7
|
-
|
|
7
|
+
import typing
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Annotated, Callable
|
|
8
10
|
|
|
9
11
|
import dspy
|
|
10
12
|
from dspy.adapters.utils import translate_field_type
|
|
@@ -29,20 +31,32 @@ def format_tool_docs_full(tools: dict[str, Callable]) -> str:
|
|
|
29
31
|
# Get function signature with types
|
|
30
32
|
try:
|
|
31
33
|
sig = inspect.signature(func)
|
|
34
|
+
# Resolve string annotations (from `from __future__ import annotations`)
|
|
35
|
+
try:
|
|
36
|
+
resolved = typing.get_type_hints(func, include_extras=True)
|
|
37
|
+
except (TypeError, NameError):
|
|
38
|
+
resolved = {}
|
|
39
|
+
|
|
32
40
|
params = []
|
|
33
41
|
for p in sig.parameters.values():
|
|
34
|
-
|
|
35
|
-
|
|
42
|
+
ann = resolved.get(p.name)
|
|
43
|
+
if ann is not None:
|
|
44
|
+
# Unwrap Annotated[X, ...] → X (e.g. SyncedFile markers)
|
|
45
|
+
if typing.get_origin(ann) is Annotated:
|
|
46
|
+
ann = typing.get_args(ann)[0]
|
|
47
|
+
# Show Path as str — the RLM passes sandbox paths as strings
|
|
48
|
+
if ann is Path:
|
|
49
|
+
ann = str
|
|
50
|
+
type_name = getattr(ann, "__name__", str(ann))
|
|
36
51
|
params.append(f"{p.name}: {type_name}")
|
|
37
52
|
else:
|
|
38
53
|
params.append(p.name)
|
|
39
54
|
params_str = ", ".join(params)
|
|
40
55
|
|
|
41
56
|
# Get return type
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
)
|
|
57
|
+
ret_ann = resolved.get("return")
|
|
58
|
+
if ret_ann is not None:
|
|
59
|
+
ret_type = getattr(ret_ann, "__name__", str(ret_ann))
|
|
46
60
|
sig_str = f"{name}({params_str}) -> {ret_type}"
|
|
47
61
|
else:
|
|
48
62
|
sig_str = f"{name}({params_str})"
|
|
@@ -25,7 +25,8 @@ from __future__ import annotations
|
|
|
25
25
|
|
|
26
26
|
import os
|
|
27
27
|
import typing
|
|
28
|
-
from
|
|
28
|
+
from dataclasses import dataclass
|
|
29
|
+
from typing import Annotated, Any
|
|
29
30
|
|
|
30
31
|
from pydantic import BaseModel, Field
|
|
31
32
|
|
|
@@ -64,6 +65,34 @@ OutputFile = File
|
|
|
64
65
|
OutputDir = File
|
|
65
66
|
|
|
66
67
|
|
|
68
|
+
@dataclass(frozen=True)
|
|
69
|
+
class SyncedFile:
|
|
70
|
+
"""Annotation marker for tool parameters that need sandbox-host file sync.
|
|
71
|
+
|
|
72
|
+
Use with ``typing.Annotated`` on tool function parameters to declare that
|
|
73
|
+
a parameter is a sandbox file path. The framework automatically syncs the
|
|
74
|
+
file from the sandbox to the host before calling the tool, and optionally
|
|
75
|
+
mounts the modified file back into the sandbox after the tool returns.
|
|
76
|
+
|
|
77
|
+
Example::
|
|
78
|
+
|
|
79
|
+
def recalculate(
|
|
80
|
+
workbook: Annotated[Path, SyncedFile(host_dir="/tmp/wb")],
|
|
81
|
+
reference: Annotated[Path, SyncedFile(writeback=False)],
|
|
82
|
+
) -> str:
|
|
83
|
+
...
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
writeback: bool = True
|
|
87
|
+
"""If True (default), mount the file back into the sandbox after the tool
|
|
88
|
+
returns. Set to False for read-only access (skip the mount-after step)."""
|
|
89
|
+
|
|
90
|
+
host_dir: str | None = None
|
|
91
|
+
"""Host directory for the synced file. If None, a temporary directory is
|
|
92
|
+
created and cleaned up after the call. If specified, the directory is used
|
|
93
|
+
as-is and not cleaned up."""
|
|
94
|
+
|
|
95
|
+
|
|
67
96
|
def _unwrap_annotation(annotation: Any) -> Any:
|
|
68
97
|
"""Unwrap Optional/Annotated/list to get the inner file type."""
|
|
69
98
|
origin = typing.get_origin(annotation)
|
|
@@ -263,3 +292,26 @@ def build_file_plan(
|
|
|
263
292
|
"output_field_map": output_field_map,
|
|
264
293
|
"instructions": instructions,
|
|
265
294
|
}
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def get_synced_file_params(fn: Any) -> dict[str, SyncedFile]:
|
|
298
|
+
"""Extract SyncedFile annotations from a tool function's type hints.
|
|
299
|
+
|
|
300
|
+
Returns a dict mapping parameter names to their ``SyncedFile`` marker
|
|
301
|
+
for all parameters annotated with ``Annotated[..., SyncedFile(...)]``.
|
|
302
|
+
"""
|
|
303
|
+
try:
|
|
304
|
+
hints = typing.get_type_hints(fn, include_extras=True)
|
|
305
|
+
except (TypeError, NameError):
|
|
306
|
+
return {}
|
|
307
|
+
|
|
308
|
+
result: dict[str, SyncedFile] = {}
|
|
309
|
+
for name, hint in hints.items():
|
|
310
|
+
if name == "return":
|
|
311
|
+
continue
|
|
312
|
+
if typing.get_origin(hint) is Annotated:
|
|
313
|
+
for arg in typing.get_args(hint)[1:]:
|
|
314
|
+
if isinstance(arg, SyncedFile):
|
|
315
|
+
result[name] = arg
|
|
316
|
+
break
|
|
317
|
+
return result
|
|
@@ -16,11 +16,15 @@ from __future__ import annotations
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import concurrent.futures
|
|
18
18
|
import functools
|
|
19
|
+
import inspect
|
|
19
20
|
import json
|
|
20
21
|
import logging
|
|
21
22
|
import os
|
|
22
23
|
import re
|
|
23
24
|
import select
|
|
25
|
+
import shutil
|
|
26
|
+
import tempfile
|
|
27
|
+
import time
|
|
24
28
|
from pathlib import Path
|
|
25
29
|
from typing import TYPE_CHECKING, Any
|
|
26
30
|
|
|
@@ -35,6 +39,7 @@ if TYPE_CHECKING:
|
|
|
35
39
|
|
|
36
40
|
logger = logging.getLogger(__name__)
|
|
37
41
|
|
|
42
|
+
|
|
38
43
|
# JSON-RPC 2.0 helpers (local to avoid coupling to dspy internals)
|
|
39
44
|
JSONRPC_APP_ERRORS = {
|
|
40
45
|
"SyntaxError": -32000,
|
|
@@ -223,6 +228,17 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
223
228
|
all_read_paths = list(enable_read_paths or []) + list(extra_read_paths or [])
|
|
224
229
|
all_write_paths = list(enable_write_paths or []) + list(extra_write_paths or [])
|
|
225
230
|
|
|
231
|
+
# Scan tools for SyncedFile annotations with custom host_dir paths
|
|
232
|
+
# and add them to Deno permissions so the runner can write there.
|
|
233
|
+
if tools:
|
|
234
|
+
from predict_rlm.files import get_synced_file_params
|
|
235
|
+
|
|
236
|
+
for tool_fn in tools.values():
|
|
237
|
+
for sf in get_synced_file_params(tool_fn).values():
|
|
238
|
+
if sf.host_dir is not None:
|
|
239
|
+
all_write_paths.append(sf.host_dir)
|
|
240
|
+
all_read_paths.append(sf.host_dir)
|
|
241
|
+
|
|
226
242
|
# Build custom deno command if not provided
|
|
227
243
|
if deno_command is None:
|
|
228
244
|
deno_command = self._build_deno_command(
|
|
@@ -249,6 +265,9 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
249
265
|
# Per-interpreter thread pool for sync tool calls (avoids starving
|
|
250
266
|
# the shared default executor when many interpreters run concurrently)
|
|
251
267
|
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
268
|
+
# Pending file-sync operations requested by tools during execution.
|
|
269
|
+
# Maps request ID → asyncio.Future resolved by the execute loop.
|
|
270
|
+
self._pending_file_ops: dict[int, asyncio.Future] = {}
|
|
252
271
|
|
|
253
272
|
def _ensure_deno_process(self) -> None:
|
|
254
273
|
"""Override to capture raw fds for non-blocking I/O."""
|
|
@@ -327,6 +346,11 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
327
346
|
allowed_read.extend(str(p) for p in read_paths)
|
|
328
347
|
allowed_read.extend(str(p) for p in write_paths)
|
|
329
348
|
|
|
349
|
+
# Allow reading temp dirs so @file_sync tools can mount files back
|
|
350
|
+
import tempfile as _tempfile
|
|
351
|
+
allowed_read.append(_tempfile.gettempdir())
|
|
352
|
+
allowed_read.append("/tmp")
|
|
353
|
+
|
|
330
354
|
if allowed_read:
|
|
331
355
|
args.append(f"--allow-read={','.join(allowed_read)}")
|
|
332
356
|
|
|
@@ -681,6 +705,14 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
681
705
|
logger.info(f"Skipping malformed JSON: {output_line[:100]}")
|
|
682
706
|
continue
|
|
683
707
|
|
|
708
|
+
# Route file-sync responses to pending futures (from _execute_tool_async)
|
|
709
|
+
resp_id = result.get("id")
|
|
710
|
+
if resp_id is not None and resp_id in self._pending_file_ops:
|
|
711
|
+
future = self._pending_file_ops.pop(resp_id)
|
|
712
|
+
if not future.done():
|
|
713
|
+
future.set_result(result)
|
|
714
|
+
continue
|
|
715
|
+
|
|
684
716
|
# JSON-RPC request from sandbox (tool call)
|
|
685
717
|
if "method" in result:
|
|
686
718
|
if result["method"] == "tool_call":
|
|
@@ -823,8 +855,53 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
823
855
|
line, self._read_buf = self._read_buf.split("\n", 1)
|
|
824
856
|
return line.strip()
|
|
825
857
|
|
|
858
|
+
async def _sync_file_during_tool(self, virtual_path: str, host_path: str) -> None:
|
|
859
|
+
"""Sync a file from sandbox MEMFS to host during a tool call.
|
|
860
|
+
|
|
861
|
+
Sends a sync_file request to the Deno runner's responseReader (which
|
|
862
|
+
handles it during tool execution) and awaits the response via a Future
|
|
863
|
+
resolved by the _execute_async loop.
|
|
864
|
+
"""
|
|
865
|
+
self._request_id += 1
|
|
866
|
+
req_id = self._request_id
|
|
867
|
+
loop = asyncio.get_running_loop()
|
|
868
|
+
future = loop.create_future()
|
|
869
|
+
self._pending_file_ops[req_id] = future
|
|
870
|
+
msg = json.dumps({
|
|
871
|
+
"jsonrpc": "2.0", "method": "sync_file",
|
|
872
|
+
"params": {"virtual_path": virtual_path, "host_path": host_path},
|
|
873
|
+
"id": req_id,
|
|
874
|
+
})
|
|
875
|
+
await self._write_stdin_async(msg + "\n")
|
|
876
|
+
result = await future
|
|
877
|
+
if "error" in result:
|
|
878
|
+
raise CodeInterpreterError(
|
|
879
|
+
f"sync_file failed: {result['error'].get('message', result['error'])}"
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
async def _mount_file_during_tool(self, host_path: str, virtual_path: str) -> None:
|
|
883
|
+
"""Mount a file from host into sandbox MEMFS during a tool call."""
|
|
884
|
+
self._request_id += 1
|
|
885
|
+
req_id = self._request_id
|
|
886
|
+
loop = asyncio.get_running_loop()
|
|
887
|
+
future = loop.create_future()
|
|
888
|
+
self._pending_file_ops[req_id] = future
|
|
889
|
+
msg = json.dumps({
|
|
890
|
+
"jsonrpc": "2.0", "method": "mount_file",
|
|
891
|
+
"params": {"host_path": host_path, "virtual_path": virtual_path},
|
|
892
|
+
"id": req_id,
|
|
893
|
+
})
|
|
894
|
+
await self._write_stdin_async(msg + "\n")
|
|
895
|
+
result = await future
|
|
896
|
+
if "error" in result:
|
|
897
|
+
raise CodeInterpreterError(
|
|
898
|
+
f"mount_file failed: {result['error'].get('message', result['error'])}"
|
|
899
|
+
)
|
|
900
|
+
|
|
826
901
|
async def _execute_tool_async(self, tool_name: str, call_args: dict) -> dict:
|
|
827
902
|
"""Execute a tool asynchronously and return the response dict."""
|
|
903
|
+
from .trace import ToolCall, ms_since, record_tool_call
|
|
904
|
+
|
|
828
905
|
if self._debug:
|
|
829
906
|
import sys
|
|
830
907
|
|
|
@@ -832,39 +909,116 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
832
909
|
print(
|
|
833
910
|
f"\n\033[33m── Tool: {tool_name}({kwargs_preview}) ──\033[0m", file=sys.stderr
|
|
834
911
|
)
|
|
912
|
+
|
|
913
|
+
call_start = time.perf_counter()
|
|
914
|
+
# Copy to mutable containers so the SyncedFile handler below can
|
|
915
|
+
# rewrite sandbox paths to host paths before invoking the tool.
|
|
916
|
+
args = list(call_args.get("args", []))
|
|
917
|
+
kwargs = dict(call_args.get("kwargs", {}))
|
|
918
|
+
temp_dir: str | None = None
|
|
919
|
+
|
|
835
920
|
try:
|
|
836
921
|
if tool_name not in self.tools:
|
|
837
922
|
raise CodeInterpreterError(f"Unknown tool: {tool_name}")
|
|
838
923
|
|
|
839
924
|
tool_fn = self.tools[tool_name]
|
|
840
|
-
args = call_args.get("args", [])
|
|
841
|
-
kwargs = call_args.get("kwargs", {})
|
|
842
925
|
|
|
843
926
|
# Pass pydantic_schemas through to predict tool if present
|
|
844
927
|
pydantic_schemas = call_args.get("pydantic_schemas")
|
|
845
928
|
if pydantic_schemas and tool_name == "predict":
|
|
846
929
|
kwargs["pydantic_schemas"] = pydantic_schemas
|
|
847
930
|
|
|
931
|
+
# Handle SyncedFile-annotated tool parameters: sync sandbox files
|
|
932
|
+
# to host before calling, and mount modified files back after.
|
|
933
|
+
from predict_rlm.files import get_synced_file_params
|
|
934
|
+
|
|
935
|
+
synced_params = get_synced_file_params(tool_fn)
|
|
936
|
+
temp_dir = None
|
|
937
|
+
# (sandbox_path, host_path, writeback) for each synced param
|
|
938
|
+
synced_entries: list[tuple[str, str, bool]] = []
|
|
939
|
+
|
|
940
|
+
if synced_params:
|
|
941
|
+
sig = inspect.signature(tool_fn)
|
|
942
|
+
param_names = list(sig.parameters.keys())
|
|
943
|
+
|
|
944
|
+
for param_name, sf in synced_params.items():
|
|
945
|
+
# Resolve the sandbox path from args or kwargs
|
|
946
|
+
sandbox_path = kwargs.get(param_name)
|
|
947
|
+
if sandbox_path is None and param_name in param_names:
|
|
948
|
+
idx = param_names.index(param_name)
|
|
949
|
+
if idx < len(args):
|
|
950
|
+
sandbox_path = args[idx]
|
|
951
|
+
if not sandbox_path or not isinstance(sandbox_path, str):
|
|
952
|
+
continue
|
|
953
|
+
|
|
954
|
+
# Determine host directory
|
|
955
|
+
if sf.host_dir is not None:
|
|
956
|
+
host_dir = sf.host_dir
|
|
957
|
+
os.makedirs(host_dir, exist_ok=True)
|
|
958
|
+
else:
|
|
959
|
+
if temp_dir is None:
|
|
960
|
+
temp_dir = tempfile.mkdtemp(prefix="tool-file-sync-")
|
|
961
|
+
host_dir = temp_dir
|
|
962
|
+
|
|
963
|
+
host_path = os.path.join(host_dir, os.path.basename(sandbox_path))
|
|
964
|
+
await self._sync_file_during_tool(sandbox_path, host_path)
|
|
965
|
+
synced_entries.append((sandbox_path, host_path, sf.writeback))
|
|
966
|
+
|
|
967
|
+
# Replace the sandbox path with the host path in args/kwargs
|
|
968
|
+
if param_name in kwargs:
|
|
969
|
+
kwargs[param_name] = host_path
|
|
970
|
+
elif param_name in param_names:
|
|
971
|
+
idx = param_names.index(param_name)
|
|
972
|
+
if idx < len(args):
|
|
973
|
+
args[idx] = host_path
|
|
974
|
+
|
|
848
975
|
# Check if tool is async or sync
|
|
849
976
|
if asyncio.iscoroutinefunction(tool_fn):
|
|
850
977
|
result = await tool_fn(*args, **kwargs)
|
|
851
978
|
else:
|
|
852
|
-
# Run sync function in per-interpreter thread pool (not the
|
|
853
|
-
# shared default pool) to prevent starvation when many
|
|
854
|
-
# interpreters run concurrently.
|
|
855
|
-
# loop.run_in_executor only accepts positional args, so wrap
|
|
856
|
-
# the call in functools.partial to bind **kwargs.
|
|
857
979
|
loop = asyncio.get_running_loop()
|
|
858
980
|
result = await loop.run_in_executor(
|
|
859
981
|
self._executor, functools.partial(tool_fn, *args, **kwargs)
|
|
860
982
|
)
|
|
861
983
|
|
|
984
|
+
# Mount modified files back into the sandbox (only for writeback params)
|
|
985
|
+
if synced_entries:
|
|
986
|
+
for sandbox_path, host_path, writeback in synced_entries:
|
|
987
|
+
if writeback and os.path.isfile(host_path):
|
|
988
|
+
await self._mount_file_during_tool(host_path, sandbox_path)
|
|
989
|
+
if temp_dir:
|
|
990
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
991
|
+
|
|
862
992
|
is_json = isinstance(result, (list, dict))
|
|
863
|
-
|
|
993
|
+
response = {
|
|
864
994
|
"value": json.dumps(result) if is_json else str(result or ""),
|
|
865
995
|
"type": "json" if is_json else "string",
|
|
866
996
|
}
|
|
997
|
+
|
|
998
|
+
# Record non-predict tool calls (predict records itself with richer detail)
|
|
999
|
+
if tool_name != "predict":
|
|
1000
|
+
record_tool_call(ToolCall(
|
|
1001
|
+
name=tool_name,
|
|
1002
|
+
args=args,
|
|
1003
|
+
kwargs={k: v for k, v in kwargs.items() if k != "pydantic_schemas"},
|
|
1004
|
+
result=result,
|
|
1005
|
+
duration_ms=ms_since(call_start),
|
|
1006
|
+
))
|
|
1007
|
+
|
|
1008
|
+
return response
|
|
867
1009
|
except Exception as e:
|
|
1010
|
+
# Clean up any SyncedFile temp dir before returning
|
|
1011
|
+
if temp_dir:
|
|
1012
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
1013
|
+
if tool_name != "predict":
|
|
1014
|
+
record_tool_call(ToolCall(
|
|
1015
|
+
name=tool_name,
|
|
1016
|
+
args=args,
|
|
1017
|
+
kwargs={k: v for k, v in kwargs.items() if k != "pydantic_schemas"},
|
|
1018
|
+
result=None,
|
|
1019
|
+
error=str(e),
|
|
1020
|
+
duration_ms=ms_since(call_start),
|
|
1021
|
+
))
|
|
868
1022
|
return {"error": str(e)}
|
|
869
1023
|
|
|
870
1024
|
def _write_stdin(self, data: str) -> None:
|