recursive-llm-ts 2.0.12 → 3.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,322 +0,0 @@
1
- """Core RLM implementation."""
2
-
3
- import asyncio
4
- import re
5
- from typing import Optional, Dict, Any, List
6
-
7
- import litellm
8
-
9
- from .types import Message
10
- from .repl import REPLExecutor, REPLError
11
- from .prompts import build_system_prompt
12
- from .parser import parse_response, is_final
13
-
14
-
15
- class RLMError(Exception):
16
- """Base error for RLM."""
17
- pass
18
-
19
-
20
- class MaxIterationsError(RLMError):
21
- """Max iterations exceeded."""
22
- pass
23
-
24
-
25
- class MaxDepthError(RLMError):
26
- """Max recursion depth exceeded."""
27
- pass
28
-
29
-
30
- class RLM:
31
- """Recursive Language Model."""
32
-
33
- def __init__(
34
- self,
35
- model: str,
36
- recursive_model: Optional[str] = None,
37
- api_base: Optional[str] = None,
38
- api_key: Optional[str] = None,
39
- max_depth: int = 5,
40
- max_iterations: int = 30,
41
- _current_depth: int = 0,
42
- **llm_kwargs: Any
43
- ):
44
- """
45
- Initialize RLM.
46
-
47
- Args:
48
- model: Model name (e.g., "gpt-4o", "claude-sonnet-4", "ollama/llama3.2")
49
- recursive_model: Optional cheaper model for recursive calls
50
- api_base: Optional API base URL
51
- api_key: Optional API key
52
- max_depth: Maximum recursion depth
53
- max_iterations: Maximum REPL iterations per call
54
- _current_depth: Internal current depth tracker
55
- **llm_kwargs: Additional LiteLLM parameters
56
- """
57
- # Patch for recursive-llm-ts bug where config is passed as 2nd positional arg
58
- if isinstance(recursive_model, dict):
59
- config = recursive_model
60
- # Reset recursive_model default
61
- self.recursive_model = config.get('recursive_model', model)
62
- self.api_base = config.get('api_base', api_base)
63
- self.api_key = config.get('api_key', api_key)
64
- self.max_depth = int(config.get('max_depth', max_depth))
65
- self.max_iterations = int(config.get('max_iterations', max_iterations))
66
-
67
- # Extract other llm kwargs
68
- excluded = {'recursive_model', 'api_base', 'api_key', 'max_depth', 'max_iterations'}
69
- self.llm_kwargs = {k: v for k, v in config.items() if k not in excluded}
70
- # Merge with any actual kwargs passed
71
- self.llm_kwargs.update(llm_kwargs)
72
- else:
73
- self.recursive_model = recursive_model or model
74
- self.api_base = api_base
75
- self.api_key = api_key
76
- self.max_depth = max_depth
77
- self.max_iterations = max_iterations
78
- self.llm_kwargs = llm_kwargs
79
-
80
- self._current_depth = _current_depth
81
- self.model = model
82
-
83
- self.repl = REPLExecutor()
84
-
85
- # Stats
86
- self._llm_calls = 0
87
- self._iterations = 0
88
-
89
- def completion(
90
- self,
91
- query: str = "",
92
- context: str = "",
93
- **kwargs: Any
94
- ) -> str:
95
- """
96
- Sync wrapper for acompletion.
97
-
98
- Args:
99
- query: User query (optional if query is in context)
100
- context: Context to process (optional, can pass query here)
101
- **kwargs: Additional LiteLLM parameters
102
-
103
- Returns:
104
- Final answer string
105
-
106
- Examples:
107
- # Standard usage
108
- rlm.completion(query="Summarize this", context=document)
109
-
110
- # Query in context (RLM will extract task)
111
- rlm.completion(context="Summarize this document: ...")
112
-
113
- # Single string (treat as context)
114
- rlm.completion("Process this text and extract dates")
115
- """
116
- # If only one argument provided, treat it as context
117
- if query and not context:
118
- context = query
119
- query = ""
120
-
121
- return asyncio.run(self.acompletion(query, context, **kwargs))
122
-
123
- async def acompletion(
124
- self,
125
- query: str = "",
126
- context: str = "",
127
- **kwargs: Any
128
- ) -> str:
129
- """
130
- Main async completion method.
131
-
132
- Args:
133
- query: User query (optional if query is in context)
134
- context: Context to process (optional, can pass query here)
135
- **kwargs: Additional LiteLLM parameters
136
-
137
- Returns:
138
- Final answer string
139
-
140
- Raises:
141
- MaxIterationsError: If max iterations exceeded
142
- MaxDepthError: If max recursion depth exceeded
143
-
144
- Examples:
145
- # Explicit query and context
146
- await rlm.acompletion(query="What is this?", context=doc)
147
-
148
- # Query embedded in context
149
- await rlm.acompletion(context="Extract all dates from: ...")
150
-
151
- # LLM will figure out the task
152
- await rlm.acompletion(context=document_with_instructions)
153
- """
154
- # If only query provided, treat it as context
155
- if query and not context:
156
- context = query
157
- query = ""
158
- if self._current_depth >= self.max_depth:
159
- raise MaxDepthError(f"Max recursion depth ({self.max_depth}) exceeded")
160
-
161
- # Initialize REPL environment
162
- repl_env = self._build_repl_env(query, context)
163
-
164
- # Build initial messages
165
- system_prompt = build_system_prompt(len(context), self._current_depth)
166
- messages: List[Message] = [
167
- {"role": "system", "content": system_prompt},
168
- {"role": "user", "content": query}
169
- ]
170
-
171
- # Main loop
172
- for iteration in range(self.max_iterations):
173
- self._iterations = iteration + 1
174
-
175
- # Call LLM
176
- response = await self._call_llm(messages, **kwargs)
177
-
178
- # Check for FINAL
179
- if is_final(response):
180
- answer = parse_response(response, repl_env)
181
- if answer is not None:
182
- return answer
183
-
184
- # Execute code in REPL
185
- try:
186
- exec_result = self.repl.execute(response, repl_env)
187
- except REPLError as e:
188
- exec_result = f"Error: {str(e)}"
189
- except Exception as e:
190
- exec_result = f"Unexpected error: {str(e)}"
191
-
192
- # Add to conversation
193
- messages.append({"role": "assistant", "content": response})
194
- messages.append({"role": "user", "content": exec_result})
195
-
196
- raise MaxIterationsError(
197
- f"Max iterations ({self.max_iterations}) exceeded without FINAL()"
198
- )
199
-
200
- async def _call_llm(
201
- self,
202
- messages: List[Message],
203
- **kwargs: Any
204
- ) -> str:
205
- """
206
- Call LLM API.
207
-
208
- Args:
209
- messages: Conversation messages
210
- **kwargs: Additional parameters (can override model here)
211
-
212
- Returns:
213
- LLM response text
214
- """
215
- self._llm_calls += 1
216
-
217
- # Choose model based on depth
218
- default_model = self.model if self._current_depth == 0 else self.recursive_model
219
-
220
- # Allow override via kwargs
221
- model = kwargs.pop('model', default_model)
222
-
223
- # Merge kwargs
224
- call_kwargs = {**self.llm_kwargs, **kwargs}
225
- if self.api_base:
226
- call_kwargs['api_base'] = self.api_base
227
- if self.api_key:
228
- call_kwargs['api_key'] = self.api_key
229
-
230
- # Call LiteLLM
231
- response = await litellm.acompletion(
232
- model=model,
233
- messages=messages,
234
- **call_kwargs
235
- )
236
-
237
- # Extract text
238
- return response.choices[0].message.content
239
-
240
- def _build_repl_env(self, query: str, context: str) -> Dict[str, Any]:
241
- """
242
- Build REPL environment.
243
-
244
- Args:
245
- query: User query
246
- context: Context string
247
-
248
- Returns:
249
- Environment dict
250
- """
251
- env: Dict[str, Any] = {
252
- 'context': context,
253
- 'query': query,
254
- 'recursive_llm': self._make_recursive_fn(),
255
- 're': re, # Whitelist re module
256
- }
257
- return env
258
-
259
- def _make_recursive_fn(self) -> Any:
260
- """
261
- Create recursive LLM function for REPL.
262
-
263
- Returns:
264
- Async function that can be called from REPL
265
- """
266
- async def recursive_llm(sub_query: str, sub_context: str) -> str:
267
- """
268
- Recursively process sub-context.
269
-
270
- Args:
271
- sub_query: Query for sub-context
272
- sub_context: Sub-context to process
273
-
274
- Returns:
275
- Answer from recursive call
276
- """
277
- if self._current_depth + 1 >= self.max_depth:
278
- return f"Max recursion depth ({self.max_depth}) reached"
279
-
280
- # Create sub-RLM with increased depth
281
- sub_rlm = RLM(
282
- model=self.recursive_model,
283
- recursive_model=self.recursive_model,
284
- api_base=self.api_base,
285
- api_key=self.api_key,
286
- max_depth=self.max_depth,
287
- max_iterations=self.max_iterations,
288
- _current_depth=self._current_depth + 1,
289
- **self.llm_kwargs
290
- )
291
-
292
- return await sub_rlm.acompletion(sub_query, sub_context)
293
-
294
- # Wrap in sync function for REPL compatibility
295
- def sync_recursive_llm(sub_query: str, sub_context: str) -> str:
296
- """Sync wrapper for recursive_llm."""
297
- # Check if we're in an async context
298
- try:
299
- loop = asyncio.get_running_loop()
300
- # We're in async context, but REPL is sync
301
- # Create a new thread to run async code
302
- import concurrent.futures
303
- with concurrent.futures.ThreadPoolExecutor() as executor:
304
- future = executor.submit(
305
- asyncio.run,
306
- recursive_llm(sub_query, sub_context)
307
- )
308
- return future.result()
309
- except RuntimeError:
310
- # No running loop, safe to use asyncio.run
311
- return asyncio.run(recursive_llm(sub_query, sub_context))
312
-
313
- return sync_recursive_llm
314
-
315
- @property
316
- def stats(self) -> Dict[str, int]:
317
- """Get execution statistics."""
318
- return {
319
- 'llm_calls': self._llm_calls,
320
- 'iterations': self._iterations,
321
- 'depth': self._current_depth,
322
- }
@@ -1,93 +0,0 @@
1
- """Parse FINAL() and FINAL_VAR() statements from LLM responses."""
2
-
3
- import re
4
- from typing import Optional, Dict, Any
5
-
6
-
7
- def extract_final(response: str) -> Optional[str]:
8
- """
9
- Extract answer from FINAL() statement.
10
-
11
- Args:
12
- response: LLM response text
13
-
14
- Returns:
15
- Extracted answer or None if not found
16
- """
17
- # Look for FINAL("answer") or FINAL('answer')
18
- patterns = [
19
- r'FINAL\s*\(\s*"""(.*)"""', # FINAL("""answer""") - triple double quotes
20
- r"FINAL\s*\(\s*'''(.*)'''", # FINAL('''answer''') - triple single quotes
21
- r'FINAL\s*\(\s*"([^"]*)"', # FINAL("answer") - double quotes
22
- r"FINAL\s*\(\s*'([^']*)'", # FINAL('answer') - single quotes
23
- ]
24
-
25
- for pattern in patterns:
26
- match = re.search(pattern, response, re.DOTALL)
27
- if match:
28
- return match.group(1).strip()
29
-
30
- return None
31
-
32
-
33
- def extract_final_var(response: str, env: Dict[str, Any]) -> Optional[str]:
34
- """
35
- Extract answer from FINAL_VAR() statement.
36
-
37
- Args:
38
- response: LLM response text
39
- env: REPL environment with variables
40
-
41
- Returns:
42
- Variable value as string or None if not found
43
- """
44
- # Look for FINAL_VAR(var_name)
45
- match = re.search(r'FINAL_VAR\s*\(\s*(\w+)\s*\)', response)
46
- if not match:
47
- return None
48
-
49
- var_name = match.group(1)
50
-
51
- # Get variable from environment
52
- if var_name in env:
53
- value = env[var_name]
54
- return str(value)
55
-
56
- return None
57
-
58
-
59
- def is_final(response: str) -> bool:
60
- """
61
- Check if response contains FINAL() or FINAL_VAR().
62
-
63
- Args:
64
- response: LLM response text
65
-
66
- Returns:
67
- True if response contains final statement
68
- """
69
- return 'FINAL(' in response or 'FINAL_VAR(' in response
70
-
71
-
72
- def parse_response(response: str, env: Dict[str, Any]) -> Optional[str]:
73
- """
74
- Parse response for any final statement.
75
-
76
- Args:
77
- response: LLM response text
78
- env: REPL environment
79
-
80
- Returns:
81
- Final answer or None
82
- """
83
- # Try FINAL() first
84
- answer = extract_final(response)
85
- if answer is not None:
86
- return answer
87
-
88
- # Try FINAL_VAR()
89
- answer = extract_final_var(response, env)
90
- if answer is not None:
91
- return answer
92
-
93
- return None
@@ -1,50 +0,0 @@
1
- """System prompt templates for RLM."""
2
-
3
-
4
- def build_system_prompt(context_size: int, depth: int = 0) -> str:
5
- """
6
- Build system prompt for RLM.
7
-
8
- Args:
9
- context_size: Size of context in characters
10
- depth: Current recursion depth
11
-
12
- Returns:
13
- System prompt string
14
- """
15
- # Minimal prompt (paper-style)
16
- prompt = f"""You are a Recursive Language Model. You interact with context through a Python REPL environment.
17
-
18
- The context is stored in variable `context` (not in this prompt). Size: {context_size:,} characters.
19
-
20
- Available in environment:
21
- - context: str (the document to analyze)
22
- - query: str (the question: "{"{"}query{"}"}")
23
- - recursive_llm(sub_query, sub_context) -> str (recursively process sub-context)
24
- - re: already imported regex module (use re.findall, re.search, etc.)
25
-
26
- Write Python code to answer the query. The last expression or print() output will be shown to you.
27
-
28
- Examples:
29
- - print(context[:100]) # See first 100 chars
30
- - errors = re.findall(r'ERROR', context) # Find all ERROR
31
- - count = len(errors); print(count) # Count and show
32
-
33
- When you have the answer, use FINAL("answer") - this is NOT a function, just write it as text.
34
-
35
- Depth: {depth}"""
36
-
37
- return prompt
38
-
39
-
40
- def build_user_prompt(query: str) -> str:
41
- """
42
- Build user prompt.
43
-
44
- Args:
45
- query: User's question
46
-
47
- Returns:
48
- User prompt string
49
- """
50
- return query
@@ -1,235 +0,0 @@
1
- """Safe REPL executor using RestrictedPython."""
2
-
3
- import io
4
- import sys
5
- from typing import Dict, Any, Optional
6
- from RestrictedPython import compile_restricted_exec, safe_globals, limited_builtins, utility_builtins
7
- from RestrictedPython.Guards import guarded_iter_unpack_sequence, safer_getattr
8
- from RestrictedPython.PrintCollector import PrintCollector
9
-
10
-
11
- class REPLError(Exception):
12
- """Error during REPL execution."""
13
- pass
14
-
15
-
16
- class REPLExecutor:
17
- """Safe Python code executor."""
18
-
19
- def __init__(self, timeout: int = 5, max_output_chars: int = 2000):
20
- """
21
- Initialize REPL executor.
22
-
23
- Args:
24
- timeout: Execution timeout in seconds (not currently enforced)
25
- max_output_chars: Maximum characters to return (truncate if longer)
26
- """
27
- self.timeout = timeout
28
- self.max_output_chars = max_output_chars
29
-
30
- def execute(self, code: str, env: Dict[str, Any]) -> str:
31
- """
32
- Execute Python code in restricted environment.
33
-
34
- Args:
35
- code: Python code to execute
36
- env: Environment with context, query, recursive_llm, etc.
37
-
38
- Returns:
39
- String result of execution (stdout or last expression)
40
-
41
- Raises:
42
- REPLError: If code execution fails
43
- """
44
- # Filter out code blocks if present (LLM might wrap code)
45
- code = self._extract_code(code)
46
-
47
- if not code.strip():
48
- return "No code to execute"
49
-
50
- # Build restricted globals
51
- restricted_globals = self._build_globals(env)
52
-
53
- # Capture stdout
54
- old_stdout = sys.stdout
55
- sys.stdout = captured_output = io.StringIO()
56
-
57
- try:
58
- # Compile with RestrictedPython
59
- byte_code = compile_restricted_exec(code)
60
-
61
- if byte_code.errors:
62
- raise REPLError(f"Compilation error: {', '.join(byte_code.errors)}")
63
-
64
- # Execute
65
- exec(byte_code.code, restricted_globals, env)
66
-
67
- # Get output from stdout
68
- output = captured_output.getvalue()
69
-
70
- # Get output from PrintCollector if available
71
- if '_print' in env and hasattr(env['_print'], '__call__'):
72
- # PrintCollector stores prints in its txt attribute
73
- print_collector = env['_print']
74
- if hasattr(print_collector, 'txt'):
75
- output += ''.join(print_collector.txt)
76
-
77
- # Check if last line was an expression (try to get its value)
78
- # This handles cases like: error_count (should return its value)
79
- lines = code.strip().split('\n')
80
- if lines:
81
- last_line = lines[-1].strip()
82
- # If last line is a simple expression (no assignment, no keyword)
83
- if last_line and not any(kw in last_line for kw in ['=', 'import', 'def', 'class', 'if', 'for', 'while', 'with']):
84
- try:
85
- # Try to evaluate the last line as expression
86
- result = eval(last_line, restricted_globals, env)
87
- if result is not None:
88
- output += str(result) + '\n'
89
- except:
90
- pass # Not an expression, ignore
91
-
92
- if not output:
93
- return "Code executed successfully (no output)"
94
-
95
- # Truncate output if too long (as per paper: "truncated version of output")
96
- if len(output) > self.max_output_chars:
97
- truncated = output[:self.max_output_chars]
98
- return f"{truncated}\n\n[Output truncated: {len(output)} chars total, showing first {self.max_output_chars}]"
99
-
100
- return output.strip()
101
-
102
- except Exception as e:
103
- raise REPLError(f"Execution error: {str(e)}")
104
-
105
- finally:
106
- sys.stdout = old_stdout
107
-
108
- def _extract_code(self, text: str) -> str:
109
- """
110
- Extract code from markdown code blocks if present.
111
-
112
- Args:
113
- text: Raw text that might contain code
114
-
115
- Returns:
116
- Extracted code
117
- """
118
- # Check for markdown code blocks
119
- if '```python' in text:
120
- start = text.find('```python') + len('```python')
121
- end = text.find('```', start)
122
- if end != -1:
123
- return text[start:end].strip()
124
-
125
- if '```' in text:
126
- start = text.find('```') + 3
127
- end = text.find('```', start)
128
- if end != -1:
129
- return text[start:end].strip()
130
-
131
- return text
132
-
133
- def _build_globals(self, env: Dict[str, Any]) -> Dict[str, Any]:
134
- """
135
- Build restricted globals for safe execution.
136
-
137
- Args:
138
- env: User environment
139
-
140
- Returns:
141
- Safe globals dict
142
- """
143
- restricted_globals = safe_globals.copy()
144
- restricted_globals.update(limited_builtins)
145
- restricted_globals.update(utility_builtins)
146
-
147
- # Add guards
148
- restricted_globals['_iter_unpack_sequence_'] = guarded_iter_unpack_sequence
149
- restricted_globals['_getattr_'] = safer_getattr
150
- restricted_globals['_getitem_'] = lambda obj, index: obj[index]
151
- restricted_globals['_getiter_'] = iter
152
- restricted_globals['_print_'] = PrintCollector
153
-
154
- # Add additional safe builtins
155
- restricted_globals.update({
156
- # Types
157
- 'len': len,
158
- 'str': str,
159
- 'int': int,
160
- 'float': float,
161
- 'bool': bool,
162
- 'list': list,
163
- 'dict': dict,
164
- 'tuple': tuple,
165
- 'set': set,
166
- 'frozenset': frozenset,
167
- 'bytes': bytes,
168
- 'bytearray': bytearray,
169
-
170
- # Iteration
171
- 'range': range,
172
- 'enumerate': enumerate,
173
- 'zip': zip,
174
- 'map': map,
175
- 'filter': filter,
176
- 'reversed': reversed,
177
- 'iter': iter,
178
- 'next': next,
179
-
180
- # Aggregation
181
- 'sorted': sorted,
182
- 'sum': sum,
183
- 'min': min,
184
- 'max': max,
185
- 'any': any,
186
- 'all': all,
187
-
188
- # Math
189
- 'abs': abs,
190
- 'round': round,
191
- 'pow': pow,
192
- 'divmod': divmod,
193
-
194
- # String/repr
195
- 'chr': chr,
196
- 'ord': ord,
197
- 'hex': hex,
198
- 'oct': oct,
199
- 'bin': bin,
200
- 'repr': repr,
201
- 'ascii': ascii,
202
- 'format': format,
203
-
204
- # Type checking
205
- 'isinstance': isinstance,
206
- 'issubclass': issubclass,
207
- 'callable': callable,
208
- 'type': type,
209
- 'hasattr': hasattr,
210
-
211
- # Constants
212
- 'True': True,
213
- 'False': False,
214
- 'None': None,
215
- })
216
-
217
- # Add safe standard library modules
218
- # These are read-only and don't allow file/network access
219
- import re
220
- import json
221
- import math
222
- from datetime import datetime, timedelta
223
- from collections import Counter, defaultdict
224
-
225
- restricted_globals.update({
226
- 're': re, # Regex (read-only)
227
- 'json': json, # JSON parsing (read-only)
228
- 'math': math, # Math functions
229
- 'datetime': datetime, # Date parsing
230
- 'timedelta': timedelta, # Time deltas
231
- 'Counter': Counter, # Counting helper
232
- 'defaultdict': defaultdict, # Dict with defaults
233
- })
234
-
235
- return restricted_globals