predict-rlm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ """predict-rlm — Production-grade RLMs with tool use, built on DSPy.
2
+
3
+ Core classes:
4
+ PredictRLM — RLM with a ``predict()`` tool for running DSPy signatures
5
+ Skill — Reusable bundle of instructions, packages, and tools
6
+
7
+ File I/O:
8
+ File — Unified file type for inputs (mount into sandbox) and outputs
9
+ (sync from sandbox). Use ``list[File]`` for multiple files.
10
+ """
11
+
12
+ from .files import File, LocalDir, LocalFile, OutputDir, OutputFile
13
+ from .predict_rlm import PredictRLM
14
+ from .rlm_skills import Skill
15
+
16
+ __all__ = [
17
+ "File",
18
+ "LocalDir",
19
+ "LocalFile",
20
+ "OutputDir",
21
+ "OutputFile",
22
+ "PredictRLM",
23
+ "Skill",
24
+ ]
predict_rlm/_shared.py ADDED
@@ -0,0 +1,164 @@
1
+ """Shared utilities for RLM subclasses."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ import textwrap
7
+ from typing import TYPE_CHECKING, Callable
8
+
9
+ import dspy
10
+ from dspy.adapters.utils import translate_field_type
11
+
12
+ if TYPE_CHECKING:
13
+ from dspy.signatures.signature import Signature
14
+
15
+
16
+ def format_tool_docs_full(tools: dict[str, Callable]) -> str:
17
+ """Format tools with full docstrings for inclusion in instructions.
18
+
19
+ Unlike DSPy's default _format_tool_docs which only uses the first line of
20
+ the docstring, this includes the full docstring (Args, Returns, etc.).
21
+ """
22
+ if not tools:
23
+ return ""
24
+
25
+ lines = [
26
+ "\n## Additional Tools\n\nAll tools are async — use `await` when calling them. Use `asyncio.gather()` to run multiple tool calls in parallel."
27
+ ]
28
+ for name, func in tools.items():
29
+ # Get function signature with types
30
+ try:
31
+ sig = inspect.signature(func)
32
+ params = []
33
+ for p in sig.parameters.values():
34
+ if p.annotation != inspect.Parameter.empty:
35
+ type_name = getattr(p.annotation, "__name__", str(p.annotation))
36
+ params.append(f"{p.name}: {type_name}")
37
+ else:
38
+ params.append(p.name)
39
+ params_str = ", ".join(params)
40
+
41
+ # Get return type
42
+ if sig.return_annotation != inspect.Parameter.empty:
43
+ ret_type = getattr(
44
+ sig.return_annotation, "__name__", str(sig.return_annotation)
45
+ )
46
+ sig_str = f"{name}({params_str}) -> {ret_type}"
47
+ else:
48
+ sig_str = f"{name}({params_str})"
49
+ except (ValueError, TypeError):
50
+ sig_str = f"{name}(...)"
51
+
52
+ # Get full docstring, cleaned up
53
+ if func.__doc__:
54
+ doc = textwrap.dedent(func.__doc__).strip()
55
+ else:
56
+ doc = "No description"
57
+
58
+ lines.append(f"\n### `await {sig_str}`")
59
+ lines.append(doc)
60
+
61
+ return "\n".join(lines)
62
+
63
+
64
+ def build_rlm_signatures(
65
+ signature: Signature,
66
+ instructions_template: str,
67
+ user_tools: dict[str, Callable],
68
+ format_tool_docs: Callable[[dict[str, Callable]], str],
69
+ skill_instructions: str = "",
70
+ file_instructions: str = "",
71
+ ) -> tuple[Signature, Signature]:
72
+ """Build action and extract signatures for RLM subclasses.
73
+
74
+ Full override of base RLM because its ACTION_INSTRUCTIONS_TEMPLATE embeds
75
+ llm_query/llm_query_batched docs. Since instructions are baked into Signature
76
+ at creation, we rebuild with custom instructions.
77
+ """
78
+ inputs_str = ", ".join(f"`{n}`" for n in signature.input_fields)
79
+ final_output_names = ", ".join(signature.output_fields.keys())
80
+
81
+ output_fields = "\n".join(
82
+ f"- {translate_field_type(n, f)}" for n, f in signature.output_fields.items()
83
+ )
84
+
85
+ # Include original signature instructions if present
86
+ task_instructions = f"{signature.instructions}\n\n" if signature.instructions else ""
87
+
88
+ # Format tool documentation for user-provided tools
89
+ tool_docs = format_tool_docs(user_tools)
90
+
91
+ # Build the full instructions with optional skill instructions
92
+ full_instructions = (
93
+ task_instructions
94
+ + instructions_template.format(
95
+ inputs=inputs_str,
96
+ final_output_names=final_output_names,
97
+ output_fields=output_fields,
98
+ )
99
+ + tool_docs
100
+ )
101
+ if file_instructions:
102
+ full_instructions += f"\n\n{file_instructions}"
103
+ if skill_instructions:
104
+ full_instructions += f"\n\n## Skills\n\n{skill_instructions}"
105
+
106
+ action_sig = dspy.Signature({}, full_instructions)
107
+ action_sig = action_sig.append(
108
+ "variables_info",
109
+ dspy.InputField(desc="Metadata about the variables available in the REPL"),
110
+ type_=str,
111
+ )
112
+ action_sig = action_sig.append(
113
+ "repl_history",
114
+ dspy.InputField(desc="Previous REPL code executions and their outputs"),
115
+ type_=dspy.primitives.repl_types.REPLHistory,
116
+ )
117
+ action_sig = action_sig.append(
118
+ "iteration",
119
+ dspy.InputField(desc="Current iteration number (1-indexed) out of max_iterations"),
120
+ type_=str,
121
+ )
122
+ action_sig = action_sig.append(
123
+ "reasoning",
124
+ dspy.OutputField(
125
+ desc="Think step-by-step: what do you know? What remains? Plan your next action."
126
+ ),
127
+ type_=str,
128
+ )
129
+ action_sig = action_sig.append(
130
+ "code",
131
+ dspy.OutputField(desc="Python code wrapped in ```repl blocks."),
132
+ type_=str,
133
+ )
134
+
135
+ # Extract signature with original task instructions
136
+ extract_instructions = """Based on the REPL trajectory, extract the final outputs now.
137
+
138
+ Review your trajectory to see what information you gathered and what values you computed, then provide the final outputs."""
139
+
140
+ extended_task_instructions = ""
141
+ if task_instructions:
142
+ extended_task_instructions = (
143
+ "The trajectory was generated with the following objective: \n"
144
+ + task_instructions
145
+ + "\n"
146
+ )
147
+ full_extract_instructions = extended_task_instructions + extract_instructions
148
+
149
+ extract_sig = dspy.Signature(
150
+ {**signature.output_fields},
151
+ full_extract_instructions,
152
+ )
153
+ extract_sig = extract_sig.prepend(
154
+ "repl_history",
155
+ dspy.InputField(desc="Your REPL interactions so far"),
156
+ type_=dspy.primitives.repl_types.REPLHistory,
157
+ )
158
+ extract_sig = extract_sig.prepend(
159
+ "variables_info",
160
+ dspy.InputField(desc="Metadata about the variables available in the REPL"),
161
+ type_=str,
162
+ )
163
+
164
+ return action_sig, extract_sig
predict_rlm/files.py ADDED
@@ -0,0 +1,265 @@
1
+ """Declarative file I/O types for PredictRLM signatures.
2
+
3
+ Use ``File`` as the type for file-typed fields in DSPy signatures.
4
+ The behavior is determined by whether the field is an input or output:
5
+
6
+ - **Input field** (``dspy.InputField``): the file is mounted from the host
7
+ into the sandbox at ``/sandbox/input/{field_name}/``.
8
+ - **Output field** (``dspy.OutputField``): the RLM writes to
9
+ ``/sandbox/output/{field_name}/`` and the file is synced back to the host.
10
+
11
+ ``list[File]`` works for both multiple inputs and multiple outputs.
12
+
13
+ Example::
14
+
15
+ class ConvertPDF(dspy.Signature):
16
+ source: File = dspy.InputField(desc="PDF to convert")
17
+ result: File = dspy.OutputField(desc="Generated Excel file")
18
+
19
+ rlm = PredictRLM(ConvertPDF, lm="openai/gpt-5.4", sub_lm="openai/gpt-5.1")
20
+ prediction = await rlm.acall(source=File(path="report.pdf"))
21
+ print(prediction.result.path) # host path to the generated file
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import os
27
+ import typing
28
+ from typing import Any
29
+
30
+ from pydantic import BaseModel, Field
31
+
32
+
33
+ class File(BaseModel):
34
+ """A file reference for PredictRLM signatures.
35
+
36
+ Behavior depends on the field position in the signature:
37
+ - As an input field: mounts the file from the host into the sandbox.
38
+ - As an output field: syncs the file from the sandbox back to the host.
39
+ """
40
+
41
+ path: str | None = Field(
42
+ default=None,
43
+ description="Path to the file. For inputs, the host path to mount. "
44
+ "For outputs, populated after execution with the host path.",
45
+ )
46
+
47
+ @classmethod
48
+ def from_dir(cls, path: str) -> list[File]:
49
+ """Create a list of File references from all files in a directory.
50
+
51
+ Walks the directory recursively and returns a File for each file found.
52
+ """
53
+ files: list[File] = []
54
+ for root, _dirs, filenames in os.walk(path):
55
+ for fname in sorted(filenames):
56
+ files.append(cls(path=os.path.join(root, fname)))
57
+ return files
58
+
59
+
60
+ # Deprecated aliases — kept for backwards compatibility
61
+ LocalFile = File
62
+ LocalDir = File
63
+ OutputFile = File
64
+ OutputDir = File
65
+
66
+
67
+ def _unwrap_annotation(annotation: Any) -> Any:
68
+ """Unwrap Optional/Annotated/list to get the inner file type."""
69
+ origin = typing.get_origin(annotation)
70
+ if origin is typing.Union:
71
+ args = [a for a in typing.get_args(annotation) if a is not type(None)]
72
+ if len(args) == 1:
73
+ return _unwrap_annotation(args[0])
74
+ if origin is typing.Annotated:
75
+ return _unwrap_annotation(typing.get_args(annotation)[0])
76
+ if origin is list:
77
+ args = typing.get_args(annotation)
78
+ if args:
79
+ return _unwrap_annotation(args[0])
80
+ return annotation
81
+
82
+
83
+ def _is_list_annotation(annotation: Any) -> bool:
84
+ """Check if an annotation is list[...] (possibly wrapped in Optional)."""
85
+ origin = typing.get_origin(annotation)
86
+ if origin is typing.Union:
87
+ args = [a for a in typing.get_args(annotation) if a is not type(None)]
88
+ if len(args) == 1:
89
+ return _is_list_annotation(args[0])
90
+ if origin is typing.Annotated:
91
+ return _is_list_annotation(typing.get_args(annotation)[0])
92
+ return origin is list
93
+
94
+
95
+ def is_file_type(annotation: Any) -> bool:
96
+ """Check if a field annotation is File or list[File]."""
97
+ inner = _unwrap_annotation(annotation)
98
+ return isinstance(inner, type) and issubclass(inner, File)
99
+
100
+
101
+ # Deprecated aliases
102
+ is_input_file_type = is_file_type
103
+ is_output_file_type = is_file_type
104
+
105
+
106
+ def scan_file_fields(
107
+ signature: Any,
108
+ ) -> tuple[dict[str, str], dict[str, str]]:
109
+ """Scan a DSPy signature for file-typed fields.
110
+
111
+ Returns:
112
+ (input_file_fields, output_file_fields) — dicts mapping field names
113
+ to 'file' or 'list_file'.
114
+ """
115
+ input_file_fields: dict[str, str] = {}
116
+ output_file_fields: dict[str, str] = {}
117
+
118
+ for name, field in signature.input_fields.items():
119
+ annotation = field.annotation
120
+ if is_file_type(annotation):
121
+ kind = "list_file" if _is_list_annotation(annotation) else "file"
122
+ input_file_fields[name] = kind
123
+
124
+ for name, field in signature.output_fields.items():
125
+ annotation = field.annotation
126
+ if is_file_type(annotation):
127
+ kind = "list_file" if _is_list_annotation(annotation) else "file"
128
+ output_file_fields[name] = kind
129
+
130
+ return input_file_fields, output_file_fields
131
+
132
+
133
+ def build_file_instructions(
134
+ input_mounts: dict[str, str | list[str]],
135
+ output_dirs: dict[str, str],
136
+ ) -> str:
137
+ """Generate the '## Files' instructions block for the RLM.
138
+
139
+ Args:
140
+ input_mounts: Maps field names to sandbox paths (str for file, list for dir).
141
+ output_dirs: Maps field names to sandbox output directory paths.
142
+ """
143
+ lines = ["## Files\n"]
144
+
145
+ if input_mounts:
146
+ lines.append(
147
+ "Input files (available in the sandbox filesystem "
148
+ "— use standard Python file I/O):"
149
+ )
150
+ for field_name, sandbox_path in input_mounts.items():
151
+ if isinstance(sandbox_path, list):
152
+ lines.append(f"- `{field_name}`: directory at /sandbox/input/{field_name}/")
153
+ for p in sandbox_path:
154
+ lines.append(f" - {p}")
155
+ else:
156
+ lines.append(f"- `{field_name}`: {sandbox_path}")
157
+ lines.append("")
158
+
159
+ if output_dirs:
160
+ lines.append(
161
+ "Output directories (write your output files here, "
162
+ "then SUBMIT the sandbox path you wrote to):"
163
+ )
164
+ for field_name, sandbox_dir in output_dirs.items():
165
+ lines.append(f"- `{field_name}`: write to {sandbox_dir}")
166
+ lines.append("")
167
+
168
+ return "\n".join(lines)
169
+
170
+
171
+ def build_file_plan(
172
+ input_args: dict[str, Any],
173
+ input_file_fields: dict[str, str],
174
+ output_file_fields: dict[str, str],
175
+ output_dir: str | None = None,
176
+ ) -> dict[str, Any] | None:
177
+ """Build the file plan for mounting/syncing.
178
+
179
+ Returns None if there are no file fields. Otherwise returns:
180
+ {
181
+ "mounts": [(host_path, virtual_path), ...],
182
+ "read_paths": [host_path, ...],
183
+ "output_dirs": [virtual_path, ...],
184
+ "write_dir": str | None, # host output base dir
185
+ "output_field_map": {field_name: {"virtual_dir": str, "host_dir": str, "kind": str}},
186
+ "input_mounts_for_instructions": {field_name: sandbox_path_str | [paths]},
187
+ "output_dirs_for_instructions": {field_name: sandbox_dir_str},
188
+ "instructions": str,
189
+ }
190
+ """
191
+ if not input_file_fields and not output_file_fields:
192
+ return None
193
+
194
+ import tempfile
195
+
196
+ mounts: list[tuple[str, str]] = []
197
+ read_paths: list[str] = []
198
+ input_mounts_for_instructions: dict[str, str | list[str]] = {}
199
+
200
+ # Process input file fields
201
+ for field_name, kind in input_file_fields.items():
202
+ value = input_args.get(field_name)
203
+ if value is None:
204
+ continue
205
+
206
+ if kind == "list_file":
207
+ # list[File] — mount each file
208
+ file_paths: list[str] = []
209
+ for item in value:
210
+ host_path = item.path
211
+ basename = os.path.basename(host_path)
212
+ virtual_path = f"/sandbox/input/{field_name}/{basename}"
213
+ mounts.append((host_path, virtual_path))
214
+ read_paths.append(host_path)
215
+ file_paths.append(virtual_path)
216
+ input_mounts_for_instructions[field_name] = file_paths
217
+ elif kind == "file":
218
+ host_path = value.path
219
+ basename = os.path.basename(host_path)
220
+ virtual_path = f"/sandbox/input/{field_name}/{basename}"
221
+ mounts.append((host_path, virtual_path))
222
+ read_paths.append(host_path)
223
+ input_mounts_for_instructions[field_name] = virtual_path
224
+
225
+ # Process output file fields
226
+ output_field_map: dict[str, dict[str, str]] = {}
227
+ output_dirs_virtual: list[str] = []
228
+ output_dirs_for_instructions: dict[str, str] = {}
229
+
230
+ # Determine host output base directory
231
+ if output_file_fields:
232
+ host_output_base = output_dir or tempfile.mkdtemp(prefix="predict-rlm-")
233
+ else:
234
+ host_output_base = None
235
+
236
+ for field_name, kind in output_file_fields.items():
237
+ virtual_dir = f"/sandbox/output/{field_name}"
238
+ output_dirs_virtual.append(virtual_dir)
239
+ output_dirs_for_instructions[field_name] = f"{virtual_dir}/"
240
+
241
+ # Check if user specified a path on the File
242
+ output_value = input_args.get(field_name)
243
+ if output_value and hasattr(output_value, "path") and output_value.path:
244
+ host_dir = output_value.path
245
+ else:
246
+ host_dir = os.path.join(host_output_base, field_name)
247
+
248
+ output_field_map[field_name] = {
249
+ "virtual_dir": virtual_dir,
250
+ "host_dir": host_dir,
251
+ "kind": kind,
252
+ }
253
+
254
+ instructions = build_file_instructions(
255
+ input_mounts_for_instructions, output_dirs_for_instructions
256
+ )
257
+
258
+ return {
259
+ "mounts": mounts,
260
+ "read_paths": read_paths,
261
+ "output_dirs": output_dirs_virtual,
262
+ "write_dir": host_output_base,
263
+ "output_field_map": output_field_map,
264
+ "instructions": instructions,
265
+ }