lm-deluge 0.0.67__py3-none-any.whl → 0.0.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +1 -2
- lm_deluge/api_requests/anthropic.py +117 -22
- lm_deluge/api_requests/base.py +84 -11
- lm_deluge/api_requests/bedrock.py +30 -6
- lm_deluge/api_requests/chat_reasoning.py +4 -0
- lm_deluge/api_requests/gemini.py +166 -20
- lm_deluge/api_requests/openai.py +145 -25
- lm_deluge/batches.py +15 -45
- lm_deluge/client.py +309 -50
- lm_deluge/config.py +15 -3
- lm_deluge/models/__init__.py +14 -1
- lm_deluge/models/anthropic.py +29 -14
- lm_deluge/models/arcee.py +16 -0
- lm_deluge/models/deepseek.py +36 -4
- lm_deluge/models/google.py +42 -0
- lm_deluge/models/grok.py +24 -0
- lm_deluge/models/kimi.py +36 -0
- lm_deluge/models/minimax.py +18 -0
- lm_deluge/models/openai.py +100 -0
- lm_deluge/models/openrouter.py +133 -7
- lm_deluge/models/together.py +11 -0
- lm_deluge/models/zai.py +50 -0
- lm_deluge/pipelines/gepa/__init__.py +95 -0
- lm_deluge/pipelines/gepa/core.py +354 -0
- lm_deluge/pipelines/gepa/docs/samples.py +705 -0
- lm_deluge/pipelines/gepa/examples/01_synthetic_keywords.py +140 -0
- lm_deluge/pipelines/gepa/examples/02_gsm8k_math.py +261 -0
- lm_deluge/pipelines/gepa/examples/03_hotpotqa_multihop.py +300 -0
- lm_deluge/pipelines/gepa/examples/04_batch_classification.py +271 -0
- lm_deluge/pipelines/gepa/examples/simple_qa.py +129 -0
- lm_deluge/pipelines/gepa/optimizer.py +435 -0
- lm_deluge/pipelines/gepa/proposer.py +235 -0
- lm_deluge/pipelines/gepa/util.py +165 -0
- lm_deluge/{llm_tools → pipelines}/score.py +2 -2
- lm_deluge/{llm_tools → pipelines}/translate.py +5 -3
- lm_deluge/prompt.py +537 -88
- lm_deluge/request_context.py +7 -2
- lm_deluge/server/__init__.py +24 -0
- lm_deluge/server/__main__.py +144 -0
- lm_deluge/server/adapters.py +369 -0
- lm_deluge/server/app.py +388 -0
- lm_deluge/server/auth.py +71 -0
- lm_deluge/server/model_policy.py +215 -0
- lm_deluge/server/models_anthropic.py +172 -0
- lm_deluge/server/models_openai.py +175 -0
- lm_deluge/tool/__init__.py +1130 -0
- lm_deluge/tool/builtin/anthropic/__init__.py +300 -0
- lm_deluge/tool/builtin/anthropic/bash.py +0 -0
- lm_deluge/tool/builtin/anthropic/computer_use.py +0 -0
- lm_deluge/tool/builtin/gemini.py +59 -0
- lm_deluge/tool/builtin/openai.py +74 -0
- lm_deluge/tool/cua/__init__.py +173 -0
- lm_deluge/tool/cua/actions.py +148 -0
- lm_deluge/tool/cua/base.py +27 -0
- lm_deluge/tool/cua/batch.py +215 -0
- lm_deluge/tool/cua/converters.py +466 -0
- lm_deluge/tool/cua/kernel.py +702 -0
- lm_deluge/tool/cua/trycua.py +989 -0
- lm_deluge/tool/prefab/__init__.py +45 -0
- lm_deluge/tool/prefab/batch_tool.py +156 -0
- lm_deluge/tool/prefab/docs.py +1119 -0
- lm_deluge/tool/prefab/email.py +294 -0
- lm_deluge/tool/prefab/filesystem.py +1711 -0
- lm_deluge/tool/prefab/full_text_search/__init__.py +285 -0
- lm_deluge/tool/prefab/full_text_search/tantivy_index.py +396 -0
- lm_deluge/tool/prefab/memory.py +458 -0
- lm_deluge/tool/prefab/otc/__init__.py +165 -0
- lm_deluge/tool/prefab/otc/executor.py +281 -0
- lm_deluge/tool/prefab/otc/parse.py +188 -0
- lm_deluge/tool/prefab/random.py +212 -0
- lm_deluge/tool/prefab/rlm/__init__.py +296 -0
- lm_deluge/tool/prefab/rlm/executor.py +349 -0
- lm_deluge/tool/prefab/rlm/parse.py +144 -0
- lm_deluge/tool/prefab/sandbox/__init__.py +19 -0
- lm_deluge/tool/prefab/sandbox/daytona_sandbox.py +483 -0
- lm_deluge/tool/prefab/sandbox/docker_sandbox.py +609 -0
- lm_deluge/tool/prefab/sandbox/fargate_sandbox.py +546 -0
- lm_deluge/tool/prefab/sandbox/modal_sandbox.py +469 -0
- lm_deluge/tool/prefab/sandbox/seatbelt_sandbox.py +827 -0
- lm_deluge/tool/prefab/sheets.py +385 -0
- lm_deluge/tool/prefab/skills.py +0 -0
- lm_deluge/tool/prefab/subagents.py +233 -0
- lm_deluge/tool/prefab/todos.py +342 -0
- lm_deluge/tool/prefab/tool_search.py +169 -0
- lm_deluge/tool/prefab/web_search.py +199 -0
- lm_deluge/tracker.py +16 -13
- lm_deluge/util/schema.py +412 -0
- lm_deluge/warnings.py +8 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/METADATA +23 -9
- lm_deluge-0.0.90.dist-info/RECORD +132 -0
- lm_deluge/built_in_tools/anthropic/__init__.py +0 -128
- lm_deluge/built_in_tools/openai.py +0 -28
- lm_deluge/presets/cerebras.py +0 -17
- lm_deluge/presets/meta.py +0 -13
- lm_deluge/tool.py +0 -849
- lm_deluge-0.0.67.dist-info/RECORD +0 -72
- lm_deluge/{llm_tools → pipelines}/__init__.py +1 -1
- /lm_deluge/{llm_tools → pipelines}/classify.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/extract.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/locate.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/ocr.py +0 -0
- /lm_deluge/{built_in_tools/anthropic/bash.py → skills/anthropic.py} +0 -0
- /lm_deluge/{built_in_tools/anthropic/computer_use.py → skills/compat.py} +0 -0
- /lm_deluge/{built_in_tools → tool/builtin}/anthropic/editor.py +0 -0
- /lm_deluge/{built_in_tools → tool/builtin}/base.py +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RLM (Recursive Language Model) for lm-deluge.
|
|
3
|
+
|
|
4
|
+
Enables models to process long contexts through a REPL environment
|
|
5
|
+
with recursive LM calls, based on the RLM paper:
|
|
6
|
+
https://alexzhang13.github.io/blog/2025/rlm/
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
from lm_deluge.prompt import Conversation
|
|
15
|
+
from lm_deluge.tool import Tool
|
|
16
|
+
|
|
17
|
+
from .executor import RLMExecutionError, RLMExecutor
|
|
18
|
+
from .parse import RLMSecurityError
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from lm_deluge.api_requests.base import APIResponse
|
|
22
|
+
from lm_deluge.client import _LLMClient
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
RLM_SYSTEM_PROMPT = """You have access to a long context stored in the variable `{context_var}`.
|
|
26
|
+
You can write Python code to analyze this context using the `execute` tool.
|
|
27
|
+
|
|
28
|
+
IMPORTANT RULES:
|
|
29
|
+
1. You MUST use print() to see output. Bare expressions produce NO output.
|
|
30
|
+
2. You MUST call final(answer) when you have the answer. This is required!
|
|
31
|
+
|
|
32
|
+
Available in your code environment:
|
|
33
|
+
- `{context_var}`: The full context as a string ({context_len:,} characters)
|
|
34
|
+
- `lm(prompt)`: Make a recursive LLM call (runs in parallel when possible)
|
|
35
|
+
- `final(answer)`: Signal completion with the given answer - YOU MUST CALL THIS!
|
|
36
|
+
- `final_var(varname)`: Signal completion with a variable's value
|
|
37
|
+
- Modules: `re`, `math`, `collections`, `json` (imports are allowed but optional)
|
|
38
|
+
- From collections: `Counter`, `defaultdict`, `deque`, `namedtuple`, `OrderedDict`
|
|
39
|
+
- Standard builtins: `len`, `str`, `int`, `list`, `dict`, `sum`, `sorted`, `map`, `filter`, etc.
|
|
40
|
+
|
|
41
|
+
Example - count word occurrences:
|
|
42
|
+
```python
|
|
43
|
+
count = len(re.findall(r'\\bword\\b', {context_var}))
|
|
44
|
+
print(f"Found {{count}} occurrences")
|
|
45
|
+
final(count)
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Example - use Counter:
|
|
49
|
+
```python
|
|
50
|
+
words = {context_var}.split()
|
|
51
|
+
counts = Counter(words)
|
|
52
|
+
print(counts.most_common(10))
|
|
53
|
+
final(counts.most_common(10))
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
Example - analyze with lm() calls:
|
|
57
|
+
```python
|
|
58
|
+
chunks = [{context_var}[i:i+2000] for i in range(0, len({context_var}), 2000)][:3]
|
|
59
|
+
summaries = [lm(f"Summarize: {{chunk}}") for chunk in chunks]
|
|
60
|
+
combined = "\\n".join(str(s) for s in summaries)
|
|
61
|
+
final(f"Summary:\\n{{combined}}")
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Variables persist between execute() calls. Always call final() when you have the answer!
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class RLMManager:
|
|
69
|
+
"""Manages RLM execution for a long context.
|
|
70
|
+
|
|
71
|
+
The RLMManager exposes a REPL-like interface as tools that allow an LLM
|
|
72
|
+
to analyze a long context by writing Python code.
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
>>> manager = RLMManager(
|
|
76
|
+
... context=long_document,
|
|
77
|
+
... client=LLMClient("gpt-4.1-mini"), # For lm() calls
|
|
78
|
+
... )
|
|
79
|
+
>>> main_client = LLMClient("gpt-4.1")
|
|
80
|
+
>>> conv = Conversation.system(manager.get_system_prompt())
|
|
81
|
+
>>> conv = conv.user("What are the main themes in this document?")
|
|
82
|
+
>>> conv, resp = await main_client.run_agent_loop(
|
|
83
|
+
... conv,
|
|
84
|
+
... tools=manager.get_tools(),
|
|
85
|
+
... )
|
|
86
|
+
>>> if manager.is_complete:
|
|
87
|
+
... print(manager.final_answer)
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
context: str,
|
|
93
|
+
client: _LLMClient,
|
|
94
|
+
context_var_name: str = "CONTEXT",
|
|
95
|
+
max_lm_calls_per_execution: int = 20,
|
|
96
|
+
):
|
|
97
|
+
"""Initialize the RLMManager.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
context: The long context string to analyze
|
|
101
|
+
client: LLMClient for making recursive lm() calls
|
|
102
|
+
context_var_name: Variable name for the context (default: "CONTEXT")
|
|
103
|
+
max_lm_calls_per_execution: Maximum lm() calls allowed per execute() call
|
|
104
|
+
"""
|
|
105
|
+
self.context = context
|
|
106
|
+
self.client = client
|
|
107
|
+
self.context_var_name = context_var_name
|
|
108
|
+
self.max_lm_calls_per_execution = max_lm_calls_per_execution
|
|
109
|
+
|
|
110
|
+
self.executor = RLMExecutor(
|
|
111
|
+
context=context,
|
|
112
|
+
client=client,
|
|
113
|
+
context_var_name=context_var_name,
|
|
114
|
+
max_lm_calls_per_execution=max_lm_calls_per_execution,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self._final_answer: str | None = None
|
|
118
|
+
self._tools: list[Tool] | None = None
|
|
119
|
+
|
|
120
|
+
async def _execute(self, code: str) -> str:
|
|
121
|
+
"""Execute code against the context."""
|
|
122
|
+
try:
|
|
123
|
+
answer, is_final = await self.executor.execute(code)
|
|
124
|
+
if is_final:
|
|
125
|
+
self._final_answer = answer
|
|
126
|
+
# Truncate for display but keep full answer stored
|
|
127
|
+
display = answer[:1000] + "..." if len(answer) > 1000 else answer
|
|
128
|
+
return f"[FINAL ANSWER SET]\n{display}"
|
|
129
|
+
return answer
|
|
130
|
+
except RLMSecurityError as e:
|
|
131
|
+
return f"Security error: {e}"
|
|
132
|
+
except RLMExecutionError as e:
|
|
133
|
+
return f"Execution error: {e}"
|
|
134
|
+
except Exception as e:
|
|
135
|
+
return f"Unexpected error: {type(e).__name__}: {e}"
|
|
136
|
+
|
|
137
|
+
def get_system_prompt(self) -> str:
|
|
138
|
+
"""Get the system prompt explaining the RLM environment."""
|
|
139
|
+
return RLM_SYSTEM_PROMPT.format(
|
|
140
|
+
context_var=self.context_var_name,
|
|
141
|
+
context_len=len(self.context),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def get_tools(self) -> list[Tool]:
|
|
145
|
+
"""Get the tools for RLM execution."""
|
|
146
|
+
if self._tools is not None:
|
|
147
|
+
return self._tools
|
|
148
|
+
|
|
149
|
+
self._tools = [
|
|
150
|
+
Tool(
|
|
151
|
+
name="execute",
|
|
152
|
+
description=(
|
|
153
|
+
f"Execute Python code to analyze the context. "
|
|
154
|
+
f"The context ({len(self.context):,} chars) is available as `{self.context_var_name}`. "
|
|
155
|
+
f"Use `lm(prompt)` for recursive LLM calls (parallel when possible), and "
|
|
156
|
+
f"`final(answer)` or `final_var(varname)` to signal completion. "
|
|
157
|
+
f"Variables persist between calls. "
|
|
158
|
+
f"Modules available without import: re, math, collections, json."
|
|
159
|
+
),
|
|
160
|
+
run=self._execute,
|
|
161
|
+
parameters={
|
|
162
|
+
"code": {
|
|
163
|
+
"type": "string",
|
|
164
|
+
"description": "Python code to execute",
|
|
165
|
+
}
|
|
166
|
+
},
|
|
167
|
+
required=["code"],
|
|
168
|
+
)
|
|
169
|
+
]
|
|
170
|
+
return self._tools
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def is_complete(self) -> bool:
|
|
174
|
+
"""Check if FINAL() was called."""
|
|
175
|
+
return self._final_answer is not None
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def final_answer(self) -> str | None:
|
|
179
|
+
"""Get the final answer if set."""
|
|
180
|
+
return self._final_answer
|
|
181
|
+
|
|
182
|
+
def reset(self) -> None:
|
|
183
|
+
"""Reset the RLM state."""
|
|
184
|
+
self.executor.reset()
|
|
185
|
+
self._final_answer = None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@dataclass
|
|
189
|
+
class RLMResult:
|
|
190
|
+
"""Result from RLMPipeline."""
|
|
191
|
+
|
|
192
|
+
answer: str
|
|
193
|
+
conversation: Conversation
|
|
194
|
+
rounds_used: int
|
|
195
|
+
final_response: APIResponse
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class RLMPipeline:
|
|
199
|
+
"""High-level pipeline for RLM processing.
|
|
200
|
+
|
|
201
|
+
A thin wrapper that takes a long context and question, sets up an RLMManager,
|
|
202
|
+
runs an agent loop until final() is called, and returns the result.
|
|
203
|
+
|
|
204
|
+
Example:
|
|
205
|
+
>>> pipeline = RLMPipeline(
|
|
206
|
+
... context=long_document,
|
|
207
|
+
... client=LLMClient("gpt-4.1"), # Smart orchestrator
|
|
208
|
+
... lm_client=LLMClient("gpt-4.1-mini"), # Cheaper model for lm() calls
|
|
209
|
+
... question="What are the main themes in this document?",
|
|
210
|
+
... )
|
|
211
|
+
>>> result = await pipeline.run()
|
|
212
|
+
>>> print(result.answer)
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
context: str,
|
|
218
|
+
client: _LLMClient,
|
|
219
|
+
question: str,
|
|
220
|
+
*,
|
|
221
|
+
lm_client: _LLMClient | None = None,
|
|
222
|
+
context_var_name: str = "CONTEXT",
|
|
223
|
+
max_rounds: int = 15,
|
|
224
|
+
max_lm_calls_per_execution: int = 20,
|
|
225
|
+
):
|
|
226
|
+
"""Initialize the RLMPipeline.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
context: The long context string to analyze
|
|
230
|
+
client: LLMClient for the main agent (runs the execute loop)
|
|
231
|
+
question: The question to answer about the context
|
|
232
|
+
lm_client: LLMClient for lm() calls (defaults to same as client)
|
|
233
|
+
context_var_name: Variable name for the context (default: "CONTEXT")
|
|
234
|
+
max_rounds: Maximum agent loop rounds (default: 15)
|
|
235
|
+
max_lm_calls_per_execution: Maximum lm() calls per execute() call
|
|
236
|
+
"""
|
|
237
|
+
self.context = context
|
|
238
|
+
self.client = client
|
|
239
|
+
self.lm_client = lm_client or client
|
|
240
|
+
self.question = question
|
|
241
|
+
self.context_var_name = context_var_name
|
|
242
|
+
self.max_rounds = max_rounds
|
|
243
|
+
self.max_lm_calls_per_execution = max_lm_calls_per_execution
|
|
244
|
+
|
|
245
|
+
async def run(self) -> RLMResult:
|
|
246
|
+
"""Run the RLM pipeline until completion."""
|
|
247
|
+
manager = RLMManager(
|
|
248
|
+
context=self.context,
|
|
249
|
+
client=self.lm_client,
|
|
250
|
+
context_var_name=self.context_var_name,
|
|
251
|
+
max_lm_calls_per_execution=self.max_lm_calls_per_execution,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Build conversation with system prompt and question
|
|
255
|
+
conv = Conversation.system(manager.get_system_prompt())
|
|
256
|
+
conv = conv.user(
|
|
257
|
+
f"Question to answer about the context:\n\n{self.question}\n\n"
|
|
258
|
+
"Use the execute tool to analyze the context and find the answer. "
|
|
259
|
+
"Start by peeking at the context structure, then use appropriate "
|
|
260
|
+
"techniques (regex, chunking, lm() calls) to find the answer. "
|
|
261
|
+
"Call final(answer) when you have the answer."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Run agent loop
|
|
265
|
+
conv, resp = await self.client.run_agent_loop(
|
|
266
|
+
conv,
|
|
267
|
+
tools=manager.get_tools(),
|
|
268
|
+
max_rounds=self.max_rounds,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Extract answer
|
|
272
|
+
if manager.is_complete:
|
|
273
|
+
answer = manager.final_answer or "No answer produced"
|
|
274
|
+
else:
|
|
275
|
+
# Model stopped without calling final() - use last response
|
|
276
|
+
answer = resp.completion or "No answer produced (final not called)"
|
|
277
|
+
|
|
278
|
+
# Count rounds used
|
|
279
|
+
rounds_used = sum(1 for m in conv.messages if m.role == "assistant")
|
|
280
|
+
|
|
281
|
+
return RLMResult(
|
|
282
|
+
answer=answer,
|
|
283
|
+
conversation=conv,
|
|
284
|
+
rounds_used=rounds_used,
|
|
285
|
+
final_response=resp,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
__all__ = [
|
|
290
|
+
"RLMManager",
|
|
291
|
+
"RLMPipeline",
|
|
292
|
+
"RLMResult",
|
|
293
|
+
"RLMExecutor",
|
|
294
|
+
"RLMExecutionError",
|
|
295
|
+
"RLMSecurityError",
|
|
296
|
+
]
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RLM (Recursive Language Model) code executor.
|
|
3
|
+
|
|
4
|
+
Executes Python code with access to a context variable and lm() function
|
|
5
|
+
for recursive language model calls.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
12
|
+
|
|
13
|
+
from .parse import (
|
|
14
|
+
RLM_MODULES,
|
|
15
|
+
RLM_SAFE_BUILTINS,
|
|
16
|
+
RLMExecutionError,
|
|
17
|
+
validate_rlm_code,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from lm_deluge.client import _LLMClient
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OutputCapture:
|
|
25
|
+
"""Captures print() output during execution."""
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self.outputs: list[str] = []
|
|
29
|
+
|
|
30
|
+
def print(self, *args, **kwargs):
|
|
31
|
+
"""Replacement print function that captures output."""
|
|
32
|
+
output = " ".join(str(arg) for arg in args)
|
|
33
|
+
self.outputs.append(output)
|
|
34
|
+
|
|
35
|
+
def get_output(self) -> str:
|
|
36
|
+
return "\n".join(self.outputs)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PendingLMResult:
|
|
40
|
+
"""Placeholder for an lm() call result that hasn't completed yet."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, call_id: int, results: dict[int, str]):
|
|
43
|
+
self._call_id = call_id
|
|
44
|
+
self._results = results
|
|
45
|
+
|
|
46
|
+
def _require_result(self) -> str:
|
|
47
|
+
if self._call_id not in self._results:
|
|
48
|
+
raise RuntimeError(f"LM result for call {self._call_id} not yet available")
|
|
49
|
+
return self._results[self._call_id]
|
|
50
|
+
|
|
51
|
+
def is_ready(self) -> bool:
|
|
52
|
+
return self._call_id in self._results
|
|
53
|
+
|
|
54
|
+
def __repr__(self) -> str:
|
|
55
|
+
return repr(self._require_result())
|
|
56
|
+
|
|
57
|
+
def __str__(self) -> str:
|
|
58
|
+
return str(self._require_result())
|
|
59
|
+
|
|
60
|
+
def __getattr__(self, name: str) -> Any:
|
|
61
|
+
return getattr(self._require_result(), name)
|
|
62
|
+
|
|
63
|
+
def __getitem__(self, key: Any) -> Any:
|
|
64
|
+
return self._require_result()[key]
|
|
65
|
+
|
|
66
|
+
def __iter__(self):
|
|
67
|
+
return iter(self._require_result())
|
|
68
|
+
|
|
69
|
+
def __len__(self) -> int:
|
|
70
|
+
return len(self._require_result())
|
|
71
|
+
|
|
72
|
+
def __bool__(self) -> bool:
|
|
73
|
+
return bool(self._require_result())
|
|
74
|
+
|
|
75
|
+
def __add__(self, other):
|
|
76
|
+
return self._require_result() + other
|
|
77
|
+
|
|
78
|
+
def __radd__(self, other):
|
|
79
|
+
return other + self._require_result()
|
|
80
|
+
|
|
81
|
+
def __contains__(self, item):
|
|
82
|
+
return item in self._require_result()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class FinalAnswer(Exception):
|
|
86
|
+
"""Raised when FINAL() or FINAL_VAR() is called to signal completion."""
|
|
87
|
+
|
|
88
|
+
def __init__(self, answer: Any):
|
|
89
|
+
self.answer = answer
|
|
90
|
+
super().__init__("Final answer signaled")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _resolve_value(value: Any, results: dict[int, str]) -> Any:
|
|
94
|
+
"""Recursively resolve any PendingLMResult placeholders in a value."""
|
|
95
|
+
if isinstance(value, PendingLMResult):
|
|
96
|
+
return value._require_result()
|
|
97
|
+
if isinstance(value, list):
|
|
98
|
+
return [_resolve_value(v, results) for v in value]
|
|
99
|
+
if isinstance(value, tuple):
|
|
100
|
+
return tuple(_resolve_value(v, results) for v in value)
|
|
101
|
+
if isinstance(value, dict):
|
|
102
|
+
return {k: _resolve_value(v, results) for k, v in value.items()}
|
|
103
|
+
if isinstance(value, set):
|
|
104
|
+
return {_resolve_value(v, results) for v in value}
|
|
105
|
+
return value
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _contains_unresolved(value: Any) -> bool:
|
|
109
|
+
"""Check if a value contains any unresolved PendingLMResult."""
|
|
110
|
+
if isinstance(value, PendingLMResult):
|
|
111
|
+
return not value.is_ready()
|
|
112
|
+
if isinstance(value, (list, tuple, set)):
|
|
113
|
+
return any(_contains_unresolved(item) for item in value)
|
|
114
|
+
if isinstance(value, dict):
|
|
115
|
+
return any(_contains_unresolved(v) for v in value.values())
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class RLMExecutor:
|
|
120
|
+
"""Executes RLM code with access to context and lm() calls."""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
context: str,
|
|
125
|
+
client: _LLMClient,
|
|
126
|
+
context_var_name: str = "CONTEXT",
|
|
127
|
+
max_lm_calls_per_execution: int = 20,
|
|
128
|
+
):
|
|
129
|
+
"""Initialize the RLM executor.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
context: The long context string to analyze
|
|
133
|
+
client: LLMClient for making recursive lm() calls
|
|
134
|
+
context_var_name: Variable name for the context (default: "CONTEXT")
|
|
135
|
+
max_lm_calls_per_execution: Maximum lm() calls allowed per execute() call
|
|
136
|
+
"""
|
|
137
|
+
self.context = context
|
|
138
|
+
self.client = client
|
|
139
|
+
self.context_var_name = context_var_name
|
|
140
|
+
self.max_lm_calls_per_execution = max_lm_calls_per_execution
|
|
141
|
+
|
|
142
|
+
# Persistent state across execute() calls
|
|
143
|
+
self._persistent_locals: dict[str, Any] = {}
|
|
144
|
+
|
|
145
|
+
def _make_lm_wrapper(
|
|
146
|
+
self,
|
|
147
|
+
pending_lm_calls: list[dict],
|
|
148
|
+
lm_results: dict[int, str],
|
|
149
|
+
call_state: dict[str, int],
|
|
150
|
+
pending_call_ids: set[int],
|
|
151
|
+
) -> Callable[[str], PendingLMResult]:
|
|
152
|
+
"""Create the lm(prompt) wrapper function."""
|
|
153
|
+
|
|
154
|
+
def lm_call(prompt: str) -> PendingLMResult:
|
|
155
|
+
# Check for unresolved dependencies in the prompt
|
|
156
|
+
if _contains_unresolved(prompt):
|
|
157
|
+
raise RuntimeError("LM result for call dependency not yet available")
|
|
158
|
+
|
|
159
|
+
call_id = call_state["next_lm_id"]
|
|
160
|
+
call_state["next_lm_id"] += 1
|
|
161
|
+
|
|
162
|
+
# Only queue if not already completed or pending
|
|
163
|
+
if call_id not in lm_results and call_id not in pending_call_ids:
|
|
164
|
+
if len(pending_lm_calls) >= self.max_lm_calls_per_execution:
|
|
165
|
+
raise RuntimeError(
|
|
166
|
+
f"Too many lm() calls in single execution "
|
|
167
|
+
f"(max {self.max_lm_calls_per_execution})"
|
|
168
|
+
)
|
|
169
|
+
pending_call_ids.add(call_id)
|
|
170
|
+
pending_lm_calls.append(
|
|
171
|
+
{
|
|
172
|
+
"id": call_id,
|
|
173
|
+
"prompt": str(prompt),
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return PendingLMResult(call_id, lm_results)
|
|
178
|
+
|
|
179
|
+
return lm_call
|
|
180
|
+
|
|
181
|
+
def _make_final_func(
|
|
182
|
+
self, exec_namespace: dict[str, Any], lm_results: dict[int, str]
|
|
183
|
+
) -> Callable[[Any], None]:
|
|
184
|
+
"""Create final(answer) function."""
|
|
185
|
+
|
|
186
|
+
def final_func(answer: Any) -> None:
|
|
187
|
+
resolved = _resolve_value(answer, lm_results)
|
|
188
|
+
raise FinalAnswer(resolved)
|
|
189
|
+
|
|
190
|
+
return final_func
|
|
191
|
+
|
|
192
|
+
def _make_final_var_func(
|
|
193
|
+
self, exec_namespace: dict[str, Any], lm_results: dict[int, str]
|
|
194
|
+
) -> Callable[[str], None]:
|
|
195
|
+
"""Create final_var(varname) function."""
|
|
196
|
+
|
|
197
|
+
def final_var_func(varname: str) -> None:
|
|
198
|
+
if varname not in exec_namespace:
|
|
199
|
+
raise RuntimeError(f"Variable '{varname}' not found")
|
|
200
|
+
value = exec_namespace[varname]
|
|
201
|
+
resolved = _resolve_value(value, lm_results)
|
|
202
|
+
raise FinalAnswer(resolved)
|
|
203
|
+
|
|
204
|
+
return final_var_func
|
|
205
|
+
|
|
206
|
+
async def _execute_pending_lm_calls(
|
|
207
|
+
self,
|
|
208
|
+
pending_calls: list[dict],
|
|
209
|
+
results: dict[int, str],
|
|
210
|
+
) -> None:
|
|
211
|
+
"""Execute all pending lm() calls in parallel."""
|
|
212
|
+
if not pending_calls:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
from lm_deluge.prompt import Conversation
|
|
216
|
+
|
|
217
|
+
# Start all calls in parallel using start_nowait
|
|
218
|
+
task_mapping: list[tuple[int, int]] = [] # (call_id, task_id)
|
|
219
|
+
for call in pending_calls:
|
|
220
|
+
conv = Conversation.user(call["prompt"])
|
|
221
|
+
task_id = self.client.start_nowait(conv)
|
|
222
|
+
task_mapping.append((call["id"], task_id))
|
|
223
|
+
|
|
224
|
+
# Wait for all to complete
|
|
225
|
+
for call_id, task_id in task_mapping:
|
|
226
|
+
try:
|
|
227
|
+
response = await self.client.wait_for(task_id)
|
|
228
|
+
results[call_id] = response.completion or "(no response)"
|
|
229
|
+
except Exception as e:
|
|
230
|
+
results[call_id] = f"Error: {e}"
|
|
231
|
+
|
|
232
|
+
# Clear the pending list
|
|
233
|
+
pending_calls.clear()
|
|
234
|
+
|
|
235
|
+
def _format_answer(self, value: Any) -> str:
|
|
236
|
+
"""Format the final answer as a string."""
|
|
237
|
+
if isinstance(value, str):
|
|
238
|
+
return value
|
|
239
|
+
try:
|
|
240
|
+
return json.dumps(value, default=str, indent=2)
|
|
241
|
+
except Exception:
|
|
242
|
+
return str(value)
|
|
243
|
+
|
|
244
|
+
async def execute(self, code: str) -> tuple[str, bool]:
|
|
245
|
+
"""Execute RLM code.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
code: Python code to execute
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Tuple of (output_string, is_final) where is_final indicates
|
|
252
|
+
whether FINAL()/FINAL_VAR() was called.
|
|
253
|
+
"""
|
|
254
|
+
# Validate the code
|
|
255
|
+
tree = validate_rlm_code(code)
|
|
256
|
+
|
|
257
|
+
# Set up execution environment
|
|
258
|
+
pending_lm_calls: list[dict] = []
|
|
259
|
+
lm_results: dict[int, str] = {}
|
|
260
|
+
pending_call_ids: set[int] = set()
|
|
261
|
+
call_state = {"next_lm_id": 0}
|
|
262
|
+
output_capture = OutputCapture()
|
|
263
|
+
|
|
264
|
+
# Create the lm() wrapper
|
|
265
|
+
lm_wrapper = self._make_lm_wrapper(
|
|
266
|
+
pending_lm_calls, lm_results, call_state, pending_call_ids
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Build a single namespace for execution
|
|
270
|
+
# Using a single dict for both globals and locals ensures that
|
|
271
|
+
# variables are visible inside nested scopes (list comprehensions, etc.)
|
|
272
|
+
exec_namespace: dict[str, Any] = {
|
|
273
|
+
"__builtins__": {**RLM_SAFE_BUILTINS, "print": output_capture.print},
|
|
274
|
+
self.context_var_name: self.context,
|
|
275
|
+
"lm": lm_wrapper,
|
|
276
|
+
"json": json, # Explicitly include json
|
|
277
|
+
**RLM_MODULES,
|
|
278
|
+
# Include persistent state from previous calls
|
|
279
|
+
**self._persistent_locals,
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
# Add final and final_var (they need access to exec_namespace for final_var)
|
|
283
|
+
exec_namespace["final"] = self._make_final_func(exec_namespace, lm_results)
|
|
284
|
+
exec_namespace["final_var"] = self._make_final_var_func(
|
|
285
|
+
exec_namespace, lm_results
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Track which keys are "system" keys that shouldn't be persisted
|
|
289
|
+
system_keys = set(exec_namespace.keys())
|
|
290
|
+
|
|
291
|
+
# Execute with retry loop for deferred lm() resolution
|
|
292
|
+
max_iterations = 50
|
|
293
|
+
compiled = compile(tree, "<rlm>", "exec")
|
|
294
|
+
|
|
295
|
+
for iteration in range(max_iterations):
|
|
296
|
+
# Reset call sequencing for this pass
|
|
297
|
+
call_state["next_lm_id"] = 0
|
|
298
|
+
pending_call_ids.clear()
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
exec(compiled, exec_namespace)
|
|
302
|
+
|
|
303
|
+
# Execution completed - run any remaining pending calls
|
|
304
|
+
await self._execute_pending_lm_calls(pending_lm_calls, lm_results)
|
|
305
|
+
|
|
306
|
+
# Update persistent locals (exclude system keys)
|
|
307
|
+
for key, value in exec_namespace.items():
|
|
308
|
+
if key not in system_keys:
|
|
309
|
+
self._persistent_locals[key] = value
|
|
310
|
+
|
|
311
|
+
break
|
|
312
|
+
|
|
313
|
+
except FinalAnswer as fa:
|
|
314
|
+
# FINAL() or FINAL_VAR() was called
|
|
315
|
+
for key, value in exec_namespace.items():
|
|
316
|
+
if key not in system_keys:
|
|
317
|
+
self._persistent_locals[key] = value
|
|
318
|
+
return (self._format_answer(fa.answer), True)
|
|
319
|
+
|
|
320
|
+
except RuntimeError as e:
|
|
321
|
+
if "not yet available" in str(e):
|
|
322
|
+
# Need to resolve pending lm() calls and retry
|
|
323
|
+
await self._execute_pending_lm_calls(pending_lm_calls, lm_results)
|
|
324
|
+
pending_call_ids.clear()
|
|
325
|
+
# Continue to retry
|
|
326
|
+
else:
|
|
327
|
+
raise RLMExecutionError(f"Runtime error: {e}")
|
|
328
|
+
|
|
329
|
+
except Exception as e:
|
|
330
|
+
raise RLMExecutionError(f"Execution error: {type(e).__name__}: {e}")
|
|
331
|
+
|
|
332
|
+
else:
|
|
333
|
+
raise RLMExecutionError(
|
|
334
|
+
f"Execution exceeded maximum iterations ({max_iterations})"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Get output
|
|
338
|
+
output = output_capture.get_output()
|
|
339
|
+
|
|
340
|
+
# If no print output, check for result variable
|
|
341
|
+
if not output and "result" in exec_namespace:
|
|
342
|
+
result_value = _resolve_value(exec_namespace["result"], lm_results)
|
|
343
|
+
output = self._format_answer(result_value)
|
|
344
|
+
|
|
345
|
+
return (output or "Execution completed with no output", False)
|
|
346
|
+
|
|
347
|
+
def reset(self) -> None:
|
|
348
|
+
"""Reset the persistent state."""
|
|
349
|
+
self._persistent_locals.clear()
|