lm-deluge 0.0.87__py3-none-any.whl → 0.0.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,349 @@
1
+ """
2
+ RLM (Recursive Language Model) code executor.
3
+
4
+ Executes Python code with access to a context variable and lm() function
5
+ for recursive language model calls.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ from typing import TYPE_CHECKING, Any, Callable
12
+
13
+ from .parse import (
14
+ RLM_MODULES,
15
+ RLM_SAFE_BUILTINS,
16
+ RLMExecutionError,
17
+ validate_rlm_code,
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+ from lm_deluge.client import _LLMClient
22
+
23
+
24
+ class OutputCapture:
25
+ """Captures print() output during execution."""
26
+
27
+ def __init__(self):
28
+ self.outputs: list[str] = []
29
+
30
+ def print(self, *args, **kwargs):
31
+ """Replacement print function that captures output."""
32
+ output = " ".join(str(arg) for arg in args)
33
+ self.outputs.append(output)
34
+
35
+ def get_output(self) -> str:
36
+ return "\n".join(self.outputs)
37
+
38
+
39
+ class PendingLMResult:
40
+ """Placeholder for an lm() call result that hasn't completed yet."""
41
+
42
+ def __init__(self, call_id: int, results: dict[int, str]):
43
+ self._call_id = call_id
44
+ self._results = results
45
+
46
+ def _require_result(self) -> str:
47
+ if self._call_id not in self._results:
48
+ raise RuntimeError(f"LM result for call {self._call_id} not yet available")
49
+ return self._results[self._call_id]
50
+
51
+ def is_ready(self) -> bool:
52
+ return self._call_id in self._results
53
+
54
+ def __repr__(self) -> str:
55
+ return repr(self._require_result())
56
+
57
+ def __str__(self) -> str:
58
+ return str(self._require_result())
59
+
60
+ def __getattr__(self, name: str) -> Any:
61
+ return getattr(self._require_result(), name)
62
+
63
+ def __getitem__(self, key: Any) -> Any:
64
+ return self._require_result()[key]
65
+
66
+ def __iter__(self):
67
+ return iter(self._require_result())
68
+
69
+ def __len__(self) -> int:
70
+ return len(self._require_result())
71
+
72
+ def __bool__(self) -> bool:
73
+ return bool(self._require_result())
74
+
75
+ def __add__(self, other):
76
+ return self._require_result() + other
77
+
78
+ def __radd__(self, other):
79
+ return other + self._require_result()
80
+
81
+ def __contains__(self, item):
82
+ return item in self._require_result()
83
+
84
+
85
+ class FinalAnswer(Exception):
86
+ """Raised when FINAL() or FINAL_VAR() is called to signal completion."""
87
+
88
+ def __init__(self, answer: Any):
89
+ self.answer = answer
90
+ super().__init__("Final answer signaled")
91
+
92
+
93
+ def _resolve_value(value: Any, results: dict[int, str]) -> Any:
94
+ """Recursively resolve any PendingLMResult placeholders in a value."""
95
+ if isinstance(value, PendingLMResult):
96
+ return value._require_result()
97
+ if isinstance(value, list):
98
+ return [_resolve_value(v, results) for v in value]
99
+ if isinstance(value, tuple):
100
+ return tuple(_resolve_value(v, results) for v in value)
101
+ if isinstance(value, dict):
102
+ return {k: _resolve_value(v, results) for k, v in value.items()}
103
+ if isinstance(value, set):
104
+ return {_resolve_value(v, results) for v in value}
105
+ return value
106
+
107
+
108
+ def _contains_unresolved(value: Any) -> bool:
109
+ """Check if a value contains any unresolved PendingLMResult."""
110
+ if isinstance(value, PendingLMResult):
111
+ return not value.is_ready()
112
+ if isinstance(value, (list, tuple, set)):
113
+ return any(_contains_unresolved(item) for item in value)
114
+ if isinstance(value, dict):
115
+ return any(_contains_unresolved(v) for v in value.values())
116
+ return False
117
+
118
+
119
+ class RLMExecutor:
120
+ """Executes RLM code with access to context and lm() calls."""
121
+
122
+ def __init__(
123
+ self,
124
+ context: str,
125
+ client: _LLMClient,
126
+ context_var_name: str = "CONTEXT",
127
+ max_lm_calls_per_execution: int = 20,
128
+ ):
129
+ """Initialize the RLM executor.
130
+
131
+ Args:
132
+ context: The long context string to analyze
133
+ client: LLMClient for making recursive lm() calls
134
+ context_var_name: Variable name for the context (default: "CONTEXT")
135
+ max_lm_calls_per_execution: Maximum lm() calls allowed per execute() call
136
+ """
137
+ self.context = context
138
+ self.client = client
139
+ self.context_var_name = context_var_name
140
+ self.max_lm_calls_per_execution = max_lm_calls_per_execution
141
+
142
+ # Persistent state across execute() calls
143
+ self._persistent_locals: dict[str, Any] = {}
144
+
145
+ def _make_lm_wrapper(
146
+ self,
147
+ pending_lm_calls: list[dict],
148
+ lm_results: dict[int, str],
149
+ call_state: dict[str, int],
150
+ pending_call_ids: set[int],
151
+ ) -> Callable[[str], PendingLMResult]:
152
+ """Create the lm(prompt) wrapper function."""
153
+
154
+ def lm_call(prompt: str) -> PendingLMResult:
155
+ # Check for unresolved dependencies in the prompt
156
+ if _contains_unresolved(prompt):
157
+ raise RuntimeError("LM result for call dependency not yet available")
158
+
159
+ call_id = call_state["next_lm_id"]
160
+ call_state["next_lm_id"] += 1
161
+
162
+ # Only queue if not already completed or pending
163
+ if call_id not in lm_results and call_id not in pending_call_ids:
164
+ if len(pending_lm_calls) >= self.max_lm_calls_per_execution:
165
+ raise RuntimeError(
166
+ f"Too many lm() calls in single execution "
167
+ f"(max {self.max_lm_calls_per_execution})"
168
+ )
169
+ pending_call_ids.add(call_id)
170
+ pending_lm_calls.append(
171
+ {
172
+ "id": call_id,
173
+ "prompt": str(prompt),
174
+ }
175
+ )
176
+
177
+ return PendingLMResult(call_id, lm_results)
178
+
179
+ return lm_call
180
+
181
+ def _make_final_func(
182
+ self, exec_namespace: dict[str, Any], lm_results: dict[int, str]
183
+ ) -> Callable[[Any], None]:
184
+ """Create final(answer) function."""
185
+
186
+ def final_func(answer: Any) -> None:
187
+ resolved = _resolve_value(answer, lm_results)
188
+ raise FinalAnswer(resolved)
189
+
190
+ return final_func
191
+
192
+ def _make_final_var_func(
193
+ self, exec_namespace: dict[str, Any], lm_results: dict[int, str]
194
+ ) -> Callable[[str], None]:
195
+ """Create final_var(varname) function."""
196
+
197
+ def final_var_func(varname: str) -> None:
198
+ if varname not in exec_namespace:
199
+ raise RuntimeError(f"Variable '{varname}' not found")
200
+ value = exec_namespace[varname]
201
+ resolved = _resolve_value(value, lm_results)
202
+ raise FinalAnswer(resolved)
203
+
204
+ return final_var_func
205
+
206
+ async def _execute_pending_lm_calls(
207
+ self,
208
+ pending_calls: list[dict],
209
+ results: dict[int, str],
210
+ ) -> None:
211
+ """Execute all pending lm() calls in parallel."""
212
+ if not pending_calls:
213
+ return
214
+
215
+ from lm_deluge.prompt import Conversation
216
+
217
+ # Start all calls in parallel using start_nowait
218
+ task_mapping: list[tuple[int, int]] = [] # (call_id, task_id)
219
+ for call in pending_calls:
220
+ conv = Conversation.user(call["prompt"])
221
+ task_id = self.client.start_nowait(conv)
222
+ task_mapping.append((call["id"], task_id))
223
+
224
+ # Wait for all to complete
225
+ for call_id, task_id in task_mapping:
226
+ try:
227
+ response = await self.client.wait_for(task_id)
228
+ results[call_id] = response.completion or "(no response)"
229
+ except Exception as e:
230
+ results[call_id] = f"Error: {e}"
231
+
232
+ # Clear the pending list
233
+ pending_calls.clear()
234
+
235
+ def _format_answer(self, value: Any) -> str:
236
+ """Format the final answer as a string."""
237
+ if isinstance(value, str):
238
+ return value
239
+ try:
240
+ return json.dumps(value, default=str, indent=2)
241
+ except Exception:
242
+ return str(value)
243
+
244
+ async def execute(self, code: str) -> tuple[str, bool]:
245
+ """Execute RLM code.
246
+
247
+ Args:
248
+ code: Python code to execute
249
+
250
+ Returns:
251
+ Tuple of (output_string, is_final) where is_final indicates
252
+ whether FINAL()/FINAL_VAR() was called.
253
+ """
254
+ # Validate the code
255
+ tree = validate_rlm_code(code)
256
+
257
+ # Set up execution environment
258
+ pending_lm_calls: list[dict] = []
259
+ lm_results: dict[int, str] = {}
260
+ pending_call_ids: set[int] = set()
261
+ call_state = {"next_lm_id": 0}
262
+ output_capture = OutputCapture()
263
+
264
+ # Create the lm() wrapper
265
+ lm_wrapper = self._make_lm_wrapper(
266
+ pending_lm_calls, lm_results, call_state, pending_call_ids
267
+ )
268
+
269
+ # Build a single namespace for execution
270
+ # Using a single dict for both globals and locals ensures that
271
+ # variables are visible inside nested scopes (list comprehensions, etc.)
272
+ exec_namespace: dict[str, Any] = {
273
+ "__builtins__": {**RLM_SAFE_BUILTINS, "print": output_capture.print},
274
+ self.context_var_name: self.context,
275
+ "lm": lm_wrapper,
276
+ "json": json, # Explicitly include json
277
+ **RLM_MODULES,
278
+ # Include persistent state from previous calls
279
+ **self._persistent_locals,
280
+ }
281
+
282
+ # Add final and final_var (they need access to exec_namespace for final_var)
283
+ exec_namespace["final"] = self._make_final_func(exec_namespace, lm_results)
284
+ exec_namespace["final_var"] = self._make_final_var_func(
285
+ exec_namespace, lm_results
286
+ )
287
+
288
+ # Track which keys are "system" keys that shouldn't be persisted
289
+ system_keys = set(exec_namespace.keys())
290
+
291
+ # Execute with retry loop for deferred lm() resolution
292
+ max_iterations = 50
293
+ compiled = compile(tree, "<rlm>", "exec")
294
+
295
+ for iteration in range(max_iterations):
296
+ # Reset call sequencing for this pass
297
+ call_state["next_lm_id"] = 0
298
+ pending_call_ids.clear()
299
+
300
+ try:
301
+ exec(compiled, exec_namespace)
302
+
303
+ # Execution completed - run any remaining pending calls
304
+ await self._execute_pending_lm_calls(pending_lm_calls, lm_results)
305
+
306
+ # Update persistent locals (exclude system keys)
307
+ for key, value in exec_namespace.items():
308
+ if key not in system_keys:
309
+ self._persistent_locals[key] = value
310
+
311
+ break
312
+
313
+ except FinalAnswer as fa:
314
+ # FINAL() or FINAL_VAR() was called
315
+ for key, value in exec_namespace.items():
316
+ if key not in system_keys:
317
+ self._persistent_locals[key] = value
318
+ return (self._format_answer(fa.answer), True)
319
+
320
+ except RuntimeError as e:
321
+ if "not yet available" in str(e):
322
+ # Need to resolve pending lm() calls and retry
323
+ await self._execute_pending_lm_calls(pending_lm_calls, lm_results)
324
+ pending_call_ids.clear()
325
+ # Continue to retry
326
+ else:
327
+ raise RLMExecutionError(f"Runtime error: {e}")
328
+
329
+ except Exception as e:
330
+ raise RLMExecutionError(f"Execution error: {type(e).__name__}: {e}")
331
+
332
+ else:
333
+ raise RLMExecutionError(
334
+ f"Execution exceeded maximum iterations ({max_iterations})"
335
+ )
336
+
337
+ # Get output
338
+ output = output_capture.get_output()
339
+
340
+ # If no print output, check for result variable
341
+ if not output and "result" in exec_namespace:
342
+ result_value = _resolve_value(exec_namespace["result"], lm_results)
343
+ output = self._format_answer(result_value)
344
+
345
+ return (output or "Execution completed with no output", False)
346
+
347
+ def reset(self) -> None:
348
+ """Reset the persistent state."""
349
+ self._persistent_locals.clear()
@@ -0,0 +1,144 @@
1
+ """
2
+ RLM (Recursive Language Model) code parsing and validation.
3
+
4
+ Extends OTC's security model with additional modules for context analysis.
5
+ """
6
+
7
+ import ast
8
+ import collections
9
+ import json
10
+ import math
11
+ import re
12
+
13
+ # Import OTC's base security definitions
14
+ from ..otc.parse import (
15
+ FORBIDDEN_CALLS,
16
+ SAFE_BUILTINS,
17
+ ASTValidator,
18
+ OTCSecurityError,
19
+ )
20
+
21
+ # RLM uses the same builtins as OTC
22
+ RLM_SAFE_BUILTINS = SAFE_BUILTINS.copy()
23
+
24
+ # Modules available in RLM - imports of these are stripped (no-ops)
25
+ RLM_ALLOWED_IMPORTS = {"re", "math", "collections", "json"}
26
+
27
+ # Modules and common imports available in RLM (injected into globals)
28
+ RLM_MODULES = {
29
+ # Full modules
30
+ "re": re,
31
+ "math": math,
32
+ "collections": collections,
33
+ "json": json,
34
+ # Common imports from collections
35
+ "Counter": collections.Counter,
36
+ "defaultdict": collections.defaultdict,
37
+ "deque": collections.deque,
38
+ "namedtuple": collections.namedtuple,
39
+ "OrderedDict": collections.OrderedDict,
40
+ }
41
+
42
+
43
+ class RLMSecurityError(OTCSecurityError):
44
+ """Raised when RLM code violates security constraints."""
45
+
46
+ pass
47
+
48
+
49
+ class RLMExecutionError(Exception):
50
+ """Raised when RLM code execution fails."""
51
+
52
+ pass
53
+
54
+
55
+ class RLMASTValidator(ASTValidator):
56
+ """Validates RLM code with additional checks.
57
+
58
+ Import statements for allowed modules are stripped (no-ops).
59
+ Imports of disallowed modules raise errors.
60
+ """
61
+
62
+ def __init__(self, allowed_names: set[str] | None = None):
63
+ super().__init__(allowed_tool_names=set())
64
+ self.allowed_names = allowed_names or set()
65
+
66
+ def visit(self, node: ast.AST) -> None:
67
+ # Check imports - allowed ones will be stripped later, disallowed ones error
68
+ if isinstance(node, ast.Import):
69
+ for alias in node.names:
70
+ if alias.name not in RLM_ALLOWED_IMPORTS:
71
+ self.errors.append(
72
+ f"Forbidden import: {alias.name} at line {node.lineno}. "
73
+ f"Available modules: {', '.join(sorted(RLM_ALLOWED_IMPORTS))}"
74
+ )
75
+ self.generic_visit(node)
76
+ return
77
+
78
+ if isinstance(node, ast.ImportFrom):
79
+ if node.module not in RLM_ALLOWED_IMPORTS:
80
+ self.errors.append(
81
+ f"Forbidden import: from {node.module} at line {node.lineno}. "
82
+ f"Available modules: {', '.join(sorted(RLM_ALLOWED_IMPORTS))}"
83
+ )
84
+ self.generic_visit(node)
85
+ return
86
+
87
+ # For all other nodes, use parent validation
88
+ super().visit(node)
89
+
90
+ def visit_Call(self, node: ast.Call) -> None:
91
+ if isinstance(node.func, ast.Name):
92
+ if node.func.id in FORBIDDEN_CALLS:
93
+ self.errors.append(
94
+ f"Forbidden function call: {node.func.id} at line {node.lineno}"
95
+ )
96
+ self.generic_visit(node)
97
+
98
+
99
+ class ImportStripper(ast.NodeTransformer):
100
+ """Strips import statements for allowed modules from the AST."""
101
+
102
+ def visit_Import(self, node: ast.Import) -> ast.AST | None:
103
+ # Keep only imports of non-allowed modules (which will error at validation)
104
+ remaining = [
105
+ alias for alias in node.names if alias.name not in RLM_ALLOWED_IMPORTS
106
+ ]
107
+ if not remaining:
108
+ return None # Remove the entire import statement
109
+ node.names = remaining
110
+ return node
111
+
112
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
113
+ # Strip imports from allowed modules
114
+ if node.module in RLM_ALLOWED_IMPORTS:
115
+ return None # Remove the entire import statement
116
+ return node
117
+
118
+
119
+ def validate_rlm_code(code: str) -> ast.Module:
120
+ """Parse and validate RLM code, returning AST if valid.
121
+
122
+ Import statements for allowed modules (re, math, collections, json) are
123
+ stripped from the AST since these modules are already in the namespace.
124
+ """
125
+ try:
126
+ tree = ast.parse(code)
127
+ except SyntaxError as e:
128
+ raise RLMSecurityError(f"Syntax error: {e}")
129
+
130
+ # Validate first (before stripping)
131
+ validator = RLMASTValidator()
132
+ errors = validator.validate(tree)
133
+
134
+ if errors:
135
+ raise RLMSecurityError(
136
+ "Security violations:\n" + "\n".join(f" - {e}" for e in errors)
137
+ )
138
+
139
+ # Strip allowed imports (they're no-ops since modules are pre-loaded)
140
+ stripper = ImportStripper()
141
+ tree = stripper.visit(tree)
142
+ ast.fix_missing_locations(tree)
143
+
144
+ return tree