Package not found. Please check the package name and try again.
predict-rlm 0.2.2__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/.gitignore +1 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/PKG-INFO +3 -3
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/README.md +2 -2
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/pyproject.toml +1 -1
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/__init__.py +4 -1
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/_shared.py +22 -8
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/files.py +53 -1
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/interpreter.py +201 -41
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/predict_rlm.py +403 -137
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/sandbox/runner.js +55 -3
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/spreadsheet/skill.py +17 -1
- predict_rlm-0.2.4/src/predict_rlm/trace.py +364 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/LICENSE +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/rlm_skills.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/__init__.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/docx/__init__.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/docx/modules/md2docx.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/docx/skill.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/pdf/__init__.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/pdf/skill.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/spreadsheet/__init__.py +0 -0
- {predict_rlm-0.2.2 → predict_rlm-0.2.4}/src/predict_rlm/skills/spreadsheet/modules/formula_eval.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: predict-rlm
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: Production-grade RLMs (Recursive Language Models) with tool use, built on DSPy
|
|
5
5
|
Project-URL: Homepage, https://www.trampoline.ai/
|
|
6
6
|
Project-URL: Repository, https://github.com/Trampoline-AI/predict-rlm
|
|
@@ -25,7 +25,7 @@ Requires-Dist: pymupdf>=1.24.0; extra == 'examples'
|
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
|
|
27
27
|
# predict-rlm
|
|
28
|
-
|
|
28
|
+
Production focused Self-harnessed LM runtime (RLM) that allows the LM to call its sub-lm with [DSPy](https://dspy.ai) signatures. Define your inputs, outputs, and tools — the model handles its own control flow. Get fully interpretable trajectories and performance that scales directly with model improvements. Without context rot.
|
|
29
29
|
|
|
30
30
|
Based on the [Recursive Language Models](https://arxiv.org/abs/2512.24601v1) paper by [Alex L. Zhang](https://x.com/a1zhang), [Tim Kraska](https://x.com/tim_kraska), and [Omar Khattab](https://x.com/lateinteraction) from the Stanford NLP lab.<br/>
|
|
31
31
|
|
|
@@ -68,7 +68,7 @@ uv add predict-rlm
|
|
|
68
68
|
|
|
69
69
|
- **Multimodal** — process images, documents, audio, and video through sub-LM calls using native provider multimodal APIs.
|
|
70
70
|
- **Async tool calling** — native RLM async support in the WASM sandbox, enabling concurrent sub-LM invocations and tool calls
|
|
71
|
-
- **Prompt-optimized skills & tools** —
|
|
71
|
+
- **Prompt-optimized skills & tools** — predict-rlm skills comes tested and optimized to ensure maximum LM interoperability and performance, bundling instructions, PyPI packages, and tools for domain-specific tasks
|
|
72
72
|
- **Simple file I/O** — pass local or cloud files as typed inputs and outputs via `File`, keeping interop with your existing data pipelines straightforward. (S3 files support soon)
|
|
73
73
|
- **Structured sub-LM calls** — native Pydantic and DSPy signature support for type-safe sub-LM invocations with structured outputs
|
|
74
74
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# predict-rlm
|
|
2
|
-
|
|
2
|
+
Production focused Self-harnessed LM runtime (RLM) that allows the LM to call its sub-lm with [DSPy](https://dspy.ai) signatures. Define your inputs, outputs, and tools — the model handles its own control flow. Get fully interpretable trajectories and performance that scales directly with model improvements. Without context rot.
|
|
3
3
|
|
|
4
4
|
Based on the [Recursive Language Models](https://arxiv.org/abs/2512.24601v1) paper by [Alex L. Zhang](https://x.com/a1zhang), [Tim Kraska](https://x.com/tim_kraska), and [Omar Khattab](https://x.com/lateinteraction) from the Stanford NLP lab.<br/>
|
|
5
5
|
|
|
@@ -42,7 +42,7 @@ uv add predict-rlm
|
|
|
42
42
|
|
|
43
43
|
- **Multimodal** — process images, documents, audio, and video through sub-LM calls using native provider multimodal APIs.
|
|
44
44
|
- **Async tool calling** — native RLM async support in the WASM sandbox, enabling concurrent sub-LM invocations and tool calls
|
|
45
|
-
- **Prompt-optimized skills & tools** —
|
|
45
|
+
- **Prompt-optimized skills & tools** — predict-rlm skills comes tested and optimized to ensure maximum LM interoperability and performance, bundling instructions, PyPI packages, and tools for domain-specific tasks
|
|
46
46
|
- **Simple file I/O** — pass local or cloud files as typed inputs and outputs via `File`, keeping interop with your existing data pipelines straightforward. (S3 files support soon)
|
|
47
47
|
- **Structured sub-LM calls** — native Pydantic and DSPy signature support for type-safe sub-LM invocations with structured outputs
|
|
48
48
|
|
|
@@ -9,9 +9,10 @@ File I/O:
|
|
|
9
9
|
(sync from sandbox). Use ``list[File]`` for multiple files.
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
from .files import File, LocalDir, LocalFile, OutputDir, OutputFile
|
|
12
|
+
from .files import File, LocalDir, LocalFile, OutputDir, OutputFile, SyncedFile
|
|
13
13
|
from .predict_rlm import PredictRLM
|
|
14
14
|
from .rlm_skills import Skill
|
|
15
|
+
from .trace import RunTrace
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
17
18
|
"File",
|
|
@@ -20,5 +21,7 @@ __all__ = [
|
|
|
20
21
|
"OutputDir",
|
|
21
22
|
"OutputFile",
|
|
22
23
|
"PredictRLM",
|
|
24
|
+
"RunTrace",
|
|
23
25
|
"Skill",
|
|
26
|
+
"SyncedFile",
|
|
24
27
|
]
|
|
@@ -4,7 +4,9 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import inspect
|
|
6
6
|
import textwrap
|
|
7
|
-
|
|
7
|
+
import typing
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Annotated, Callable
|
|
8
10
|
|
|
9
11
|
import dspy
|
|
10
12
|
from dspy.adapters.utils import translate_field_type
|
|
@@ -29,20 +31,32 @@ def format_tool_docs_full(tools: dict[str, Callable]) -> str:
|
|
|
29
31
|
# Get function signature with types
|
|
30
32
|
try:
|
|
31
33
|
sig = inspect.signature(func)
|
|
34
|
+
# Resolve string annotations (from `from __future__ import annotations`)
|
|
35
|
+
try:
|
|
36
|
+
resolved = typing.get_type_hints(func, include_extras=True)
|
|
37
|
+
except (TypeError, NameError):
|
|
38
|
+
resolved = {}
|
|
39
|
+
|
|
32
40
|
params = []
|
|
33
41
|
for p in sig.parameters.values():
|
|
34
|
-
|
|
35
|
-
|
|
42
|
+
ann = resolved.get(p.name)
|
|
43
|
+
if ann is not None:
|
|
44
|
+
# Unwrap Annotated[X, ...] → X (e.g. SyncedFile markers)
|
|
45
|
+
if typing.get_origin(ann) is Annotated:
|
|
46
|
+
ann = typing.get_args(ann)[0]
|
|
47
|
+
# Show Path as str — the RLM passes sandbox paths as strings
|
|
48
|
+
if ann is Path:
|
|
49
|
+
ann = str
|
|
50
|
+
type_name = getattr(ann, "__name__", str(ann))
|
|
36
51
|
params.append(f"{p.name}: {type_name}")
|
|
37
52
|
else:
|
|
38
53
|
params.append(p.name)
|
|
39
54
|
params_str = ", ".join(params)
|
|
40
55
|
|
|
41
56
|
# Get return type
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
)
|
|
57
|
+
ret_ann = resolved.get("return")
|
|
58
|
+
if ret_ann is not None:
|
|
59
|
+
ret_type = getattr(ret_ann, "__name__", str(ret_ann))
|
|
46
60
|
sig_str = f"{name}({params_str}) -> {ret_type}"
|
|
47
61
|
else:
|
|
48
62
|
sig_str = f"{name}({params_str})"
|
|
@@ -128,7 +142,7 @@ def build_rlm_signatures(
|
|
|
128
142
|
)
|
|
129
143
|
action_sig = action_sig.append(
|
|
130
144
|
"code",
|
|
131
|
-
dspy.OutputField(desc="Python code wrapped in ```
|
|
145
|
+
dspy.OutputField(desc="Python code wrapped in ```python blocks."),
|
|
132
146
|
type_=str,
|
|
133
147
|
)
|
|
134
148
|
|
|
@@ -25,7 +25,8 @@ from __future__ import annotations
|
|
|
25
25
|
|
|
26
26
|
import os
|
|
27
27
|
import typing
|
|
28
|
-
from
|
|
28
|
+
from dataclasses import dataclass
|
|
29
|
+
from typing import Annotated, Any
|
|
29
30
|
|
|
30
31
|
from pydantic import BaseModel, Field
|
|
31
32
|
|
|
@@ -64,6 +65,34 @@ OutputFile = File
|
|
|
64
65
|
OutputDir = File
|
|
65
66
|
|
|
66
67
|
|
|
68
|
+
@dataclass(frozen=True)
|
|
69
|
+
class SyncedFile:
|
|
70
|
+
"""Annotation marker for tool parameters that need sandbox-host file sync.
|
|
71
|
+
|
|
72
|
+
Use with ``typing.Annotated`` on tool function parameters to declare that
|
|
73
|
+
a parameter is a sandbox file path. The framework automatically syncs the
|
|
74
|
+
file from the sandbox to the host before calling the tool, and optionally
|
|
75
|
+
mounts the modified file back into the sandbox after the tool returns.
|
|
76
|
+
|
|
77
|
+
Example::
|
|
78
|
+
|
|
79
|
+
def recalculate(
|
|
80
|
+
workbook: Annotated[Path, SyncedFile(host_dir="/tmp/wb")],
|
|
81
|
+
reference: Annotated[Path, SyncedFile(writeback=False)],
|
|
82
|
+
) -> str:
|
|
83
|
+
...
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
writeback: bool = True
|
|
87
|
+
"""If True (default), mount the file back into the sandbox after the tool
|
|
88
|
+
returns. Set to False for read-only access (skip the mount-after step)."""
|
|
89
|
+
|
|
90
|
+
host_dir: str | None = None
|
|
91
|
+
"""Host directory for the synced file. If None, a temporary directory is
|
|
92
|
+
created and cleaned up after the call. If specified, the directory is used
|
|
93
|
+
as-is and not cleaned up."""
|
|
94
|
+
|
|
95
|
+
|
|
67
96
|
def _unwrap_annotation(annotation: Any) -> Any:
|
|
68
97
|
"""Unwrap Optional/Annotated/list to get the inner file type."""
|
|
69
98
|
origin = typing.get_origin(annotation)
|
|
@@ -263,3 +292,26 @@ def build_file_plan(
|
|
|
263
292
|
"output_field_map": output_field_map,
|
|
264
293
|
"instructions": instructions,
|
|
265
294
|
}
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def get_synced_file_params(fn: Any) -> dict[str, SyncedFile]:
|
|
298
|
+
"""Extract SyncedFile annotations from a tool function's type hints.
|
|
299
|
+
|
|
300
|
+
Returns a dict mapping parameter names to their ``SyncedFile`` marker
|
|
301
|
+
for all parameters annotated with ``Annotated[..., SyncedFile(...)]``.
|
|
302
|
+
"""
|
|
303
|
+
try:
|
|
304
|
+
hints = typing.get_type_hints(fn, include_extras=True)
|
|
305
|
+
except (TypeError, NameError):
|
|
306
|
+
return {}
|
|
307
|
+
|
|
308
|
+
result: dict[str, SyncedFile] = {}
|
|
309
|
+
for name, hint in hints.items():
|
|
310
|
+
if name == "return":
|
|
311
|
+
continue
|
|
312
|
+
if typing.get_origin(hint) is Annotated:
|
|
313
|
+
for arg in typing.get_args(hint)[1:]:
|
|
314
|
+
if isinstance(arg, SyncedFile):
|
|
315
|
+
result[name] = arg
|
|
316
|
+
break
|
|
317
|
+
return result
|
|
@@ -16,11 +16,15 @@ from __future__ import annotations
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import concurrent.futures
|
|
18
18
|
import functools
|
|
19
|
+
import inspect
|
|
19
20
|
import json
|
|
20
21
|
import logging
|
|
21
22
|
import os
|
|
22
23
|
import re
|
|
23
24
|
import select
|
|
25
|
+
import shutil
|
|
26
|
+
import tempfile
|
|
27
|
+
import time
|
|
24
28
|
from pathlib import Path
|
|
25
29
|
from typing import TYPE_CHECKING, Any
|
|
26
30
|
|
|
@@ -35,6 +39,7 @@ if TYPE_CHECKING:
|
|
|
35
39
|
|
|
36
40
|
logger = logging.getLogger(__name__)
|
|
37
41
|
|
|
42
|
+
|
|
38
43
|
# JSON-RPC 2.0 helpers (local to avoid coupling to dspy internals)
|
|
39
44
|
JSONRPC_APP_ERRORS = {
|
|
40
45
|
"SyntaxError": -32000,
|
|
@@ -223,6 +228,17 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
223
228
|
all_read_paths = list(enable_read_paths or []) + list(extra_read_paths or [])
|
|
224
229
|
all_write_paths = list(enable_write_paths or []) + list(extra_write_paths or [])
|
|
225
230
|
|
|
231
|
+
# Scan tools for SyncedFile annotations with custom host_dir paths
|
|
232
|
+
# and add them to Deno permissions so the runner can write there.
|
|
233
|
+
if tools:
|
|
234
|
+
from predict_rlm.files import get_synced_file_params
|
|
235
|
+
|
|
236
|
+
for tool_fn in tools.values():
|
|
237
|
+
for sf in get_synced_file_params(tool_fn).values():
|
|
238
|
+
if sf.host_dir is not None:
|
|
239
|
+
all_write_paths.append(sf.host_dir)
|
|
240
|
+
all_read_paths.append(sf.host_dir)
|
|
241
|
+
|
|
226
242
|
# Build custom deno command if not provided
|
|
227
243
|
if deno_command is None:
|
|
228
244
|
deno_command = self._build_deno_command(
|
|
@@ -249,6 +265,9 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
249
265
|
# Per-interpreter thread pool for sync tool calls (avoids starving
|
|
250
266
|
# the shared default executor when many interpreters run concurrently)
|
|
251
267
|
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
268
|
+
# Pending file-sync operations requested by tools during execution.
|
|
269
|
+
# Maps request ID → asyncio.Future resolved by the execute loop.
|
|
270
|
+
self._pending_file_ops: dict[int, asyncio.Future] = {}
|
|
252
271
|
|
|
253
272
|
def _ensure_deno_process(self) -> None:
|
|
254
273
|
"""Override to capture raw fds for non-blocking I/O."""
|
|
@@ -327,6 +346,11 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
327
346
|
allowed_read.extend(str(p) for p in read_paths)
|
|
328
347
|
allowed_read.extend(str(p) for p in write_paths)
|
|
329
348
|
|
|
349
|
+
# Allow reading temp dirs so @file_sync tools can mount files back
|
|
350
|
+
import tempfile as _tempfile
|
|
351
|
+
allowed_read.append(_tempfile.gettempdir())
|
|
352
|
+
allowed_read.append("/tmp")
|
|
353
|
+
|
|
330
354
|
if allowed_read:
|
|
331
355
|
args.append(f"--allow-read={','.join(allowed_read)}")
|
|
332
356
|
|
|
@@ -384,7 +408,11 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
384
408
|
self._write_stdin(msg + "\n")
|
|
385
409
|
|
|
386
410
|
def _send_request(self, method: str, params: dict, context: str) -> dict:
|
|
387
|
-
"""Send a JSON-RPC request without blocking the OS pipe.
|
|
411
|
+
"""Send a JSON-RPC request without blocking the OS pipe.
|
|
412
|
+
|
|
413
|
+
Skips non-JSON lines (e.g. Pyodide package-loading messages) matching
|
|
414
|
+
the parent PythonInterpreter behaviour.
|
|
415
|
+
"""
|
|
388
416
|
self._request_id += 1
|
|
389
417
|
request_id = self._request_id
|
|
390
418
|
msg = json.dumps(
|
|
@@ -397,26 +425,40 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
397
425
|
)
|
|
398
426
|
self._write_stdin(msg + "\n")
|
|
399
427
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
428
|
+
max_skip = 100
|
|
429
|
+
skipped = 0
|
|
430
|
+
while skipped <= max_skip:
|
|
431
|
+
response_line = self._read_with_timeout(timeout=None)
|
|
432
|
+
if not response_line:
|
|
433
|
+
exit_code = self.deno_process.poll()
|
|
434
|
+
if exit_code is not None:
|
|
435
|
+
stderr = self.deno_process.stderr.read() if self.deno_process.stderr else ""
|
|
436
|
+
raise CodeInterpreterError(
|
|
437
|
+
f"Deno exited (code {exit_code}) {context}: {stderr}"
|
|
438
|
+
)
|
|
439
|
+
raise CodeInterpreterError(f"No response {context}")
|
|
440
|
+
|
|
441
|
+
if not response_line.startswith("{"):
|
|
442
|
+
skipped += 1
|
|
443
|
+
continue
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
response = json.loads(response_line)
|
|
447
|
+
except json.JSONDecodeError:
|
|
448
|
+
skipped += 1
|
|
449
|
+
continue
|
|
450
|
+
|
|
451
|
+
if response.get("id") != request_id:
|
|
405
452
|
raise CodeInterpreterError(
|
|
406
|
-
f"
|
|
453
|
+
f"Response ID mismatch {context}: expected {request_id}, got {response.get('id')}"
|
|
407
454
|
)
|
|
408
|
-
|
|
455
|
+
if "error" in response:
|
|
456
|
+
raise CodeInterpreterError(
|
|
457
|
+
f"Error {context}: {response['error'].get('message', 'Unknown error')}"
|
|
458
|
+
)
|
|
459
|
+
return response
|
|
409
460
|
|
|
410
|
-
|
|
411
|
-
if response.get("id") != request_id:
|
|
412
|
-
raise CodeInterpreterError(
|
|
413
|
-
f"Response ID mismatch {context}: expected {request_id}, got {response.get('id')}"
|
|
414
|
-
)
|
|
415
|
-
if "error" in response:
|
|
416
|
-
raise CodeInterpreterError(
|
|
417
|
-
f"Error {context}: {response['error'].get('message', 'Unknown error')}"
|
|
418
|
-
)
|
|
419
|
-
return response
|
|
461
|
+
raise CodeInterpreterError(f"Too many non-JSON lines ({skipped}) {context}")
|
|
420
462
|
|
|
421
463
|
def _get_deno_dir(self) -> list[str]:
|
|
422
464
|
"""Get Deno cache directory paths (may have multiple on different platforms)."""
|
|
@@ -473,28 +515,16 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
473
515
|
self.deno_process.stdin.flush()
|
|
474
516
|
|
|
475
517
|
def _strip_code_fences(self, code: str) -> str:
|
|
476
|
-
"""Extract code from
|
|
518
|
+
"""Extract code from markdown fences.
|
|
477
519
|
|
|
478
|
-
Uses a specific ```repl tag (like the original RLM) to avoid ambiguity.
|
|
479
520
|
The closing ``` must be on its own line (^```$) to handle:
|
|
480
521
|
1. Code containing inline ``` (like in strings) - not on own line, won't match
|
|
481
522
|
2. Double fences from model (```...```\\n```) - stops at first proper close
|
|
482
|
-
Falls back to generic fence matching for backwards compatibility.
|
|
483
523
|
|
|
484
|
-
Supports multiple
|
|
524
|
+
Supports multiple fenced blocks - all blocks are concatenated with newlines.
|
|
485
525
|
"""
|
|
486
|
-
# Primary: look for ```repl blocks
|
|
487
|
-
# Use MULTILINE so ^ matches start of line - closing ``` must be alone on a line
|
|
488
|
-
# Non-greedy .*? stops at FIRST ``` on its own line
|
|
489
|
-
# findall to get ALL blocks, not just the first
|
|
490
|
-
matches = re.findall(r"```repl\s*\n(.*?)^```\s*$", code, re.DOTALL | re.MULTILINE)
|
|
491
|
-
if matches:
|
|
492
|
-
# Join all blocks with double newlines to ensure separation
|
|
493
|
-
return "\n\n".join(block.rstrip() for block in matches)
|
|
494
|
-
|
|
495
|
-
# Fallback: try generic ```python or ``` blocks for backwards compatibility
|
|
496
526
|
matches = re.findall(
|
|
497
|
-
r"```(?:python|py)?\s*\n(.*?)^```\s*$", code, re.DOTALL | re.MULTILINE
|
|
527
|
+
r"```(?:python|py|repl)?\s*\n(.*?)^```\s*$", code, re.DOTALL | re.MULTILINE
|
|
498
528
|
)
|
|
499
529
|
if matches:
|
|
500
530
|
return "\n\n".join(block.rstrip() for block in matches)
|
|
@@ -681,6 +711,14 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
681
711
|
logger.info(f"Skipping malformed JSON: {output_line[:100]}")
|
|
682
712
|
continue
|
|
683
713
|
|
|
714
|
+
# Route file-sync responses to pending futures (from _execute_tool_async)
|
|
715
|
+
resp_id = result.get("id")
|
|
716
|
+
if resp_id is not None and resp_id in self._pending_file_ops:
|
|
717
|
+
future = self._pending_file_ops.pop(resp_id)
|
|
718
|
+
if not future.done():
|
|
719
|
+
future.set_result(result)
|
|
720
|
+
continue
|
|
721
|
+
|
|
684
722
|
# JSON-RPC request from sandbox (tool call)
|
|
685
723
|
if "method" in result:
|
|
686
724
|
if result["method"] == "tool_call":
|
|
@@ -823,8 +861,53 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
823
861
|
line, self._read_buf = self._read_buf.split("\n", 1)
|
|
824
862
|
return line.strip()
|
|
825
863
|
|
|
864
|
+
async def _sync_file_during_tool(self, virtual_path: str, host_path: str) -> None:
|
|
865
|
+
"""Sync a file from sandbox MEMFS to host during a tool call.
|
|
866
|
+
|
|
867
|
+
Sends a sync_file request to the Deno runner's responseReader (which
|
|
868
|
+
handles it during tool execution) and awaits the response via a Future
|
|
869
|
+
resolved by the _execute_async loop.
|
|
870
|
+
"""
|
|
871
|
+
self._request_id += 1
|
|
872
|
+
req_id = self._request_id
|
|
873
|
+
loop = asyncio.get_running_loop()
|
|
874
|
+
future = loop.create_future()
|
|
875
|
+
self._pending_file_ops[req_id] = future
|
|
876
|
+
msg = json.dumps({
|
|
877
|
+
"jsonrpc": "2.0", "method": "sync_file",
|
|
878
|
+
"params": {"virtual_path": virtual_path, "host_path": host_path},
|
|
879
|
+
"id": req_id,
|
|
880
|
+
})
|
|
881
|
+
await self._write_stdin_async(msg + "\n")
|
|
882
|
+
result = await future
|
|
883
|
+
if "error" in result:
|
|
884
|
+
raise CodeInterpreterError(
|
|
885
|
+
f"sync_file failed: {result['error'].get('message', result['error'])}"
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
async def _mount_file_during_tool(self, host_path: str, virtual_path: str) -> None:
|
|
889
|
+
"""Mount a file from host into sandbox MEMFS during a tool call."""
|
|
890
|
+
self._request_id += 1
|
|
891
|
+
req_id = self._request_id
|
|
892
|
+
loop = asyncio.get_running_loop()
|
|
893
|
+
future = loop.create_future()
|
|
894
|
+
self._pending_file_ops[req_id] = future
|
|
895
|
+
msg = json.dumps({
|
|
896
|
+
"jsonrpc": "2.0", "method": "mount_file",
|
|
897
|
+
"params": {"host_path": host_path, "virtual_path": virtual_path},
|
|
898
|
+
"id": req_id,
|
|
899
|
+
})
|
|
900
|
+
await self._write_stdin_async(msg + "\n")
|
|
901
|
+
result = await future
|
|
902
|
+
if "error" in result:
|
|
903
|
+
raise CodeInterpreterError(
|
|
904
|
+
f"mount_file failed: {result['error'].get('message', result['error'])}"
|
|
905
|
+
)
|
|
906
|
+
|
|
826
907
|
async def _execute_tool_async(self, tool_name: str, call_args: dict) -> dict:
|
|
827
908
|
"""Execute a tool asynchronously and return the response dict."""
|
|
909
|
+
from .trace import ToolCall, ms_since, record_tool_call
|
|
910
|
+
|
|
828
911
|
if self._debug:
|
|
829
912
|
import sys
|
|
830
913
|
|
|
@@ -832,39 +915,116 @@ class JspiInterpreter(PythonInterpreter):
|
|
|
832
915
|
print(
|
|
833
916
|
f"\n\033[33m── Tool: {tool_name}({kwargs_preview}) ──\033[0m", file=sys.stderr
|
|
834
917
|
)
|
|
918
|
+
|
|
919
|
+
call_start = time.perf_counter()
|
|
920
|
+
# Copy to mutable containers so the SyncedFile handler below can
|
|
921
|
+
# rewrite sandbox paths to host paths before invoking the tool.
|
|
922
|
+
args = list(call_args.get("args", []))
|
|
923
|
+
kwargs = dict(call_args.get("kwargs", {}))
|
|
924
|
+
temp_dir: str | None = None
|
|
925
|
+
|
|
835
926
|
try:
|
|
836
927
|
if tool_name not in self.tools:
|
|
837
928
|
raise CodeInterpreterError(f"Unknown tool: {tool_name}")
|
|
838
929
|
|
|
839
930
|
tool_fn = self.tools[tool_name]
|
|
840
|
-
args = call_args.get("args", [])
|
|
841
|
-
kwargs = call_args.get("kwargs", {})
|
|
842
931
|
|
|
843
932
|
# Pass pydantic_schemas through to predict tool if present
|
|
844
933
|
pydantic_schemas = call_args.get("pydantic_schemas")
|
|
845
934
|
if pydantic_schemas and tool_name == "predict":
|
|
846
935
|
kwargs["pydantic_schemas"] = pydantic_schemas
|
|
847
936
|
|
|
937
|
+
# Handle SyncedFile-annotated tool parameters: sync sandbox files
|
|
938
|
+
# to host before calling, and mount modified files back after.
|
|
939
|
+
from predict_rlm.files import get_synced_file_params
|
|
940
|
+
|
|
941
|
+
synced_params = get_synced_file_params(tool_fn)
|
|
942
|
+
temp_dir = None
|
|
943
|
+
# (sandbox_path, host_path, writeback) for each synced param
|
|
944
|
+
synced_entries: list[tuple[str, str, bool]] = []
|
|
945
|
+
|
|
946
|
+
if synced_params:
|
|
947
|
+
sig = inspect.signature(tool_fn)
|
|
948
|
+
param_names = list(sig.parameters.keys())
|
|
949
|
+
|
|
950
|
+
for param_name, sf in synced_params.items():
|
|
951
|
+
# Resolve the sandbox path from args or kwargs
|
|
952
|
+
sandbox_path = kwargs.get(param_name)
|
|
953
|
+
if sandbox_path is None and param_name in param_names:
|
|
954
|
+
idx = param_names.index(param_name)
|
|
955
|
+
if idx < len(args):
|
|
956
|
+
sandbox_path = args[idx]
|
|
957
|
+
if not sandbox_path or not isinstance(sandbox_path, str):
|
|
958
|
+
continue
|
|
959
|
+
|
|
960
|
+
# Determine host directory
|
|
961
|
+
if sf.host_dir is not None:
|
|
962
|
+
host_dir = sf.host_dir
|
|
963
|
+
os.makedirs(host_dir, exist_ok=True)
|
|
964
|
+
else:
|
|
965
|
+
if temp_dir is None:
|
|
966
|
+
temp_dir = tempfile.mkdtemp(prefix="tool-file-sync-")
|
|
967
|
+
host_dir = temp_dir
|
|
968
|
+
|
|
969
|
+
host_path = os.path.join(host_dir, os.path.basename(sandbox_path))
|
|
970
|
+
await self._sync_file_during_tool(sandbox_path, host_path)
|
|
971
|
+
synced_entries.append((sandbox_path, host_path, sf.writeback))
|
|
972
|
+
|
|
973
|
+
# Replace the sandbox path with the host path in args/kwargs
|
|
974
|
+
if param_name in kwargs:
|
|
975
|
+
kwargs[param_name] = host_path
|
|
976
|
+
elif param_name in param_names:
|
|
977
|
+
idx = param_names.index(param_name)
|
|
978
|
+
if idx < len(args):
|
|
979
|
+
args[idx] = host_path
|
|
980
|
+
|
|
848
981
|
# Check if tool is async or sync
|
|
849
982
|
if asyncio.iscoroutinefunction(tool_fn):
|
|
850
983
|
result = await tool_fn(*args, **kwargs)
|
|
851
984
|
else:
|
|
852
|
-
# Run sync function in per-interpreter thread pool (not the
|
|
853
|
-
# shared default pool) to prevent starvation when many
|
|
854
|
-
# interpreters run concurrently.
|
|
855
|
-
# loop.run_in_executor only accepts positional args, so wrap
|
|
856
|
-
# the call in functools.partial to bind **kwargs.
|
|
857
985
|
loop = asyncio.get_running_loop()
|
|
858
986
|
result = await loop.run_in_executor(
|
|
859
987
|
self._executor, functools.partial(tool_fn, *args, **kwargs)
|
|
860
988
|
)
|
|
861
989
|
|
|
990
|
+
# Mount modified files back into the sandbox (only for writeback params)
|
|
991
|
+
if synced_entries:
|
|
992
|
+
for sandbox_path, host_path, writeback in synced_entries:
|
|
993
|
+
if writeback and os.path.isfile(host_path):
|
|
994
|
+
await self._mount_file_during_tool(host_path, sandbox_path)
|
|
995
|
+
if temp_dir:
|
|
996
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
997
|
+
|
|
862
998
|
is_json = isinstance(result, (list, dict))
|
|
863
|
-
|
|
999
|
+
response = {
|
|
864
1000
|
"value": json.dumps(result) if is_json else str(result or ""),
|
|
865
1001
|
"type": "json" if is_json else "string",
|
|
866
1002
|
}
|
|
1003
|
+
|
|
1004
|
+
# Record non-predict tool calls (predict records itself with richer detail)
|
|
1005
|
+
if tool_name != "predict":
|
|
1006
|
+
record_tool_call(ToolCall(
|
|
1007
|
+
name=tool_name,
|
|
1008
|
+
args=args,
|
|
1009
|
+
kwargs={k: v for k, v in kwargs.items() if k != "pydantic_schemas"},
|
|
1010
|
+
result=result,
|
|
1011
|
+
duration_ms=ms_since(call_start),
|
|
1012
|
+
))
|
|
1013
|
+
|
|
1014
|
+
return response
|
|
867
1015
|
except Exception as e:
|
|
1016
|
+
# Clean up any SyncedFile temp dir before returning
|
|
1017
|
+
if temp_dir:
|
|
1018
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
1019
|
+
if tool_name != "predict":
|
|
1020
|
+
record_tool_call(ToolCall(
|
|
1021
|
+
name=tool_name,
|
|
1022
|
+
args=args,
|
|
1023
|
+
kwargs={k: v for k, v in kwargs.items() if k != "pydantic_schemas"},
|
|
1024
|
+
result=None,
|
|
1025
|
+
error=str(e),
|
|
1026
|
+
duration_ms=ms_since(call_start),
|
|
1027
|
+
))
|
|
868
1028
|
return {"error": str(e)}
|
|
869
1029
|
|
|
870
1030
|
def _write_stdin(self, data: str) -> None:
|