lm-deluge 0.0.87__py3-none-any.whl → 0.0.89__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/api_requests/gemini.py +19 -7
- lm_deluge/models/google.py +13 -0
- lm_deluge/tool/prefab/__init__.py +9 -1
- lm_deluge/tool/prefab/full_text_search/__init__.py +285 -0
- lm_deluge/tool/prefab/full_text_search/tantivy_index.py +396 -0
- lm_deluge/tool/prefab/rlm/__init__.py +296 -0
- lm_deluge/tool/prefab/rlm/executor.py +349 -0
- lm_deluge/tool/prefab/rlm/parse.py +144 -0
- lm_deluge/tool/prefab/sandbox.py +908 -0
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/METADATA +12 -1
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/RECORD +14 -9
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import shutil
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Set, TypedDict
|
|
8
|
+
|
|
9
|
+
from lenlp import normalizer
|
|
10
|
+
from tantivy import Document, Index, SchemaBuilder
|
|
11
|
+
from tqdm.auto import tqdm
|
|
12
|
+
|
|
13
|
+
# Pattern to match field-like constructs (X:Y)
|
|
14
|
+
FIELDLIKE = re.compile(r'(?<!\\)\b([^\s:"]+)\s*:\s*([^\s")]+)')
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SearchResultContent(TypedDict):
|
|
18
|
+
title: str
|
|
19
|
+
url: str
|
|
20
|
+
description: str
|
|
21
|
+
keywords: str # comma-delimited :)
|
|
22
|
+
content: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SearchResult:
|
|
27
|
+
id: str
|
|
28
|
+
score: float
|
|
29
|
+
content: SearchResultContent | dict[str, Any] # for our purposes should be SRC
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SearchIndex(ABC):
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def build_index(
|
|
35
|
+
self,
|
|
36
|
+
records: list[dict],
|
|
37
|
+
deduplicate_by: str | None = None,
|
|
38
|
+
):
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def search(
|
|
43
|
+
self, queries: list[str], fields: list[str], limit: int = 8, escape: bool = True
|
|
44
|
+
) -> list[SearchResult]:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def deduplicate_exact(records: list[dict], deduplicate_by: str) -> list[dict]:
|
|
49
|
+
seen = set()
|
|
50
|
+
deduplicated = []
|
|
51
|
+
for record in tqdm(records):
|
|
52
|
+
key = normalizer.normalize(record[deduplicate_by])
|
|
53
|
+
if key not in seen:
|
|
54
|
+
deduplicated.append(record)
|
|
55
|
+
seen.add(key)
|
|
56
|
+
print(
|
|
57
|
+
f"Deduplication removed {len(records) - len(deduplicated)} of {len(records)} records."
|
|
58
|
+
)
|
|
59
|
+
return deduplicated
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _fix_unbalanced_quotes(s: str) -> str:
|
|
63
|
+
"""Keep balanced quotes; if an odd number of quotes, drop the strays.
|
|
64
|
+
|
|
65
|
+
Handles both double quotes (") and single quotes (') since Tantivy
|
|
66
|
+
treats both as phrase delimiters.
|
|
67
|
+
"""
|
|
68
|
+
# Handle double quotes
|
|
69
|
+
if s.count('"') % 2 != 0:
|
|
70
|
+
s = s.replace('"', " ")
|
|
71
|
+
|
|
72
|
+
# Handle single quotes
|
|
73
|
+
if s.count("'") % 2 != 0:
|
|
74
|
+
s = s.replace("'", " ")
|
|
75
|
+
|
|
76
|
+
return s
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _quote_nonfield_colons(
|
|
80
|
+
q: str, known_fields: Set[str], json_prefixes: Set[str]
|
|
81
|
+
) -> str:
|
|
82
|
+
"""Quote X:Y patterns when X is not a known field name.
|
|
83
|
+
|
|
84
|
+
This prevents strings like "3:12" from being interpreted as field queries.
|
|
85
|
+
Preserves legitimate field queries like "title:roof".
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def repl(m: re.Match) -> str:
|
|
89
|
+
raw_prefix = m.group(1)
|
|
90
|
+
# Handle escaped dots (a\.b)
|
|
91
|
+
norm = raw_prefix.replace(r"\.", ".")
|
|
92
|
+
top = norm.split(".", 1)[0]
|
|
93
|
+
|
|
94
|
+
if norm in known_fields or top in json_prefixes:
|
|
95
|
+
return m.group(0) # Genuine field query; keep as-is
|
|
96
|
+
return f'"{m.group(0)}"' # Protect accidental field-like token
|
|
97
|
+
|
|
98
|
+
return FIELDLIKE.sub(repl, q)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _strip_syntax_to_words(q: str) -> str:
|
|
102
|
+
"""Last-resort sanitizer: remove operators but keep words/numbers.
|
|
103
|
+
|
|
104
|
+
This is used as a fallback when all else fails to ensure we never crash.
|
|
105
|
+
"""
|
|
106
|
+
q = re.sub(r'([+\-!(){}\[\]^"~*?:\\\/]|&&|\|\|)', " ", q)
|
|
107
|
+
q = re.sub(r"\s+", " ", q).strip()
|
|
108
|
+
return q or "*" # Never return empty (match all if nothing left)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def sanitize_query_for_tantivy(
|
|
112
|
+
query: str, known_fields: Iterable[str], json_field_prefixes: Iterable[str] = ()
|
|
113
|
+
) -> str:
|
|
114
|
+
"""Context-aware query sanitization for Tantivy.
|
|
115
|
+
|
|
116
|
+
Uses heuristics to detect whether special characters are intended as
|
|
117
|
+
Tantivy syntax or just regular text:
|
|
118
|
+
- Parentheses: only syntax if AND/OR/&&/|| present
|
|
119
|
+
- Square brackets: only syntax if TO or IN present
|
|
120
|
+
- Curly braces: only syntax if TO present
|
|
121
|
+
|
|
122
|
+
Always quotes X:Y patterns when X is not a known field, and fixes
|
|
123
|
+
unbalanced quotes.
|
|
124
|
+
"""
|
|
125
|
+
# Step 1: Detect if special syntax is actually being used
|
|
126
|
+
has_boolean_ops = bool(re.search(r"\b(AND|OR)\b|&&|\|\|", query))
|
|
127
|
+
has_range_keyword = bool(re.search(r"\bTO\b", query))
|
|
128
|
+
has_set_keyword = bool(re.search(r"\bIN\b", query))
|
|
129
|
+
|
|
130
|
+
# Step 2: Remove syntax chars that are likely just regular text
|
|
131
|
+
if not has_boolean_ops:
|
|
132
|
+
# No boolean operators, so parentheses aren't for grouping
|
|
133
|
+
query = query.replace("(", " ").replace(")", " ")
|
|
134
|
+
|
|
135
|
+
if not has_range_keyword:
|
|
136
|
+
# No ranges, so curly braces aren't for exclusive bounds
|
|
137
|
+
query = query.replace("{", " ").replace("}", " ")
|
|
138
|
+
|
|
139
|
+
if not (has_range_keyword or has_set_keyword):
|
|
140
|
+
# No ranges or sets, so square brackets aren't query syntax
|
|
141
|
+
query = query.replace("[", " ").replace("]", " ")
|
|
142
|
+
|
|
143
|
+
# Step 3: Fix unbalanced quotes (both " and ')
|
|
144
|
+
query = _fix_unbalanced_quotes(query)
|
|
145
|
+
|
|
146
|
+
# Step 4: Quote non-field colons (e.g., "3:12")
|
|
147
|
+
query = _quote_nonfield_colons(query, set(known_fields), set(json_field_prefixes))
|
|
148
|
+
|
|
149
|
+
return query
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class TantivySearch(SearchIndex):
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
index_path: str,
|
|
156
|
+
include_fields: list[str] | None = None,
|
|
157
|
+
exclude_fields: list[str] | None = None,
|
|
158
|
+
):
|
|
159
|
+
self.index_path = Path(index_path)
|
|
160
|
+
self.schema = None
|
|
161
|
+
self.index = None
|
|
162
|
+
self.include_fields = include_fields
|
|
163
|
+
self.exclude_fields = exclude_fields
|
|
164
|
+
self.searchable_fields: set[str] | None = (
|
|
165
|
+
None # Will be set during schema creation
|
|
166
|
+
)
|
|
167
|
+
self.id_to_record: dict[str, dict] = {} # Cache for efficient document fetching
|
|
168
|
+
|
|
169
|
+
def create_schema(self, sample_record: dict):
|
|
170
|
+
schema_builder = SchemaBuilder()
|
|
171
|
+
|
|
172
|
+
# Determine which fields should be searchable
|
|
173
|
+
all_fields = set(sample_record.keys()) - {"id"}
|
|
174
|
+
|
|
175
|
+
if self.include_fields is not None and self.exclude_fields is not None:
|
|
176
|
+
# Both provided: searchable = include - exclude
|
|
177
|
+
self.searchable_fields = set(self.include_fields) - set(self.exclude_fields)
|
|
178
|
+
elif self.include_fields is not None:
|
|
179
|
+
# Only include provided: use only those
|
|
180
|
+
self.searchable_fields = set(self.include_fields)
|
|
181
|
+
elif self.exclude_fields is not None:
|
|
182
|
+
# Only exclude provided: use all except those
|
|
183
|
+
self.searchable_fields = all_fields - set(self.exclude_fields)
|
|
184
|
+
else:
|
|
185
|
+
# Neither provided: all fields searchable
|
|
186
|
+
# Set explicitly so we have field names for query sanitization
|
|
187
|
+
self.searchable_fields = all_fields
|
|
188
|
+
|
|
189
|
+
# Add all fields as text fields (they're all indexed, but we'll control search access)
|
|
190
|
+
for key in sample_record.keys():
|
|
191
|
+
if key == "id":
|
|
192
|
+
continue
|
|
193
|
+
schema_builder.add_text_field(key, stored=True, tokenizer_name="en_stem")
|
|
194
|
+
schema_builder.add_text_field(
|
|
195
|
+
"id", stored=True, tokenizer_name="raw", index_option="basic"
|
|
196
|
+
)
|
|
197
|
+
self.schema = schema_builder.build()
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def from_index(cls, index_path: str):
|
|
201
|
+
index = cls(index_path)
|
|
202
|
+
index.load_index()
|
|
203
|
+
return index
|
|
204
|
+
|
|
205
|
+
def load_index(self):
|
|
206
|
+
if not self.index_path.exists():
|
|
207
|
+
raise FileNotFoundError(f"Index not found at {self.index_path}")
|
|
208
|
+
self.index = Index.open(str(self.index_path))
|
|
209
|
+
|
|
210
|
+
def _get_schema_field_names(self) -> set[str]:
|
|
211
|
+
"""Get field names for query sanitization."""
|
|
212
|
+
# searchable_fields is always set after create_schema() is called
|
|
213
|
+
return self.searchable_fields if self.searchable_fields is not None else set()
|
|
214
|
+
|
|
215
|
+
def build_index(
|
|
216
|
+
self,
|
|
217
|
+
records: list[dict],
|
|
218
|
+
deduplicate_by: str | None = None,
|
|
219
|
+
):
|
|
220
|
+
if not records:
|
|
221
|
+
raise ValueError("No records to index")
|
|
222
|
+
|
|
223
|
+
if not self.schema:
|
|
224
|
+
self.create_schema(records[0])
|
|
225
|
+
assert self.schema is not None, "Schema not created"
|
|
226
|
+
|
|
227
|
+
# Create index
|
|
228
|
+
if self.index_path.exists():
|
|
229
|
+
shutil.rmtree(self.index_path)
|
|
230
|
+
os.makedirs(self.index_path, exist_ok=True)
|
|
231
|
+
self.index = Index(self.schema, str(self.index_path))
|
|
232
|
+
|
|
233
|
+
# Deduplicate if requested
|
|
234
|
+
if deduplicate_by is not None:
|
|
235
|
+
records = deduplicate_exact(records, deduplicate_by)
|
|
236
|
+
|
|
237
|
+
# Index documents
|
|
238
|
+
writer = self.index.writer()
|
|
239
|
+
for record in tqdm(records):
|
|
240
|
+
writer.add_document(Document(**{k: [str(v)] for k, v in record.items()}))
|
|
241
|
+
writer.commit()
|
|
242
|
+
writer.wait_merging_threads()
|
|
243
|
+
|
|
244
|
+
def search(
|
|
245
|
+
self, queries: list[str], fields: list[str], limit: int = 8, escape: bool = True
|
|
246
|
+
) -> list[SearchResult]:
|
|
247
|
+
assert self.index is not None, "Index not built"
|
|
248
|
+
|
|
249
|
+
# Validate that requested fields are searchable
|
|
250
|
+
if self.searchable_fields is not None:
|
|
251
|
+
non_searchable = set(fields) - self.searchable_fields
|
|
252
|
+
if non_searchable:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Cannot search non-searchable fields: {sorted(non_searchable)}. "
|
|
255
|
+
f"Searchable fields are: {sorted(self.searchable_fields)}"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
self.index.reload()
|
|
259
|
+
searcher = self.index.searcher()
|
|
260
|
+
results = []
|
|
261
|
+
|
|
262
|
+
# Get field names for query sanitization
|
|
263
|
+
known_fields = self._get_schema_field_names()
|
|
264
|
+
|
|
265
|
+
for query in queries:
|
|
266
|
+
if escape:
|
|
267
|
+
# Three-tier fallback strategy
|
|
268
|
+
# Tier 1: Try sanitized query with context-aware cleaning
|
|
269
|
+
sanitized = sanitize_query_for_tantivy(query, known_fields)
|
|
270
|
+
try:
|
|
271
|
+
query_obj = self.index.parse_query(sanitized, fields)
|
|
272
|
+
except ValueError:
|
|
273
|
+
# Tier 2: Bag-of-words fallback - strip all operators
|
|
274
|
+
bag_of_words = _strip_syntax_to_words(sanitized)
|
|
275
|
+
try:
|
|
276
|
+
query_obj = self.index.parse_query(bag_of_words, fields)
|
|
277
|
+
except ValueError:
|
|
278
|
+
# This should never happen, but if it does, match nothing
|
|
279
|
+
query_obj = self.index.parse_query("", fields)
|
|
280
|
+
else:
|
|
281
|
+
# User disabled escaping - use query as-is
|
|
282
|
+
try:
|
|
283
|
+
query_obj = self.index.parse_query(query, fields)
|
|
284
|
+
except Exception as e:
|
|
285
|
+
print(f"Error parsing query: {e}")
|
|
286
|
+
query_obj = None
|
|
287
|
+
|
|
288
|
+
if query_obj:
|
|
289
|
+
hits = searcher.search(query_obj, limit=limit).hits
|
|
290
|
+
else:
|
|
291
|
+
hits = []
|
|
292
|
+
|
|
293
|
+
for score, doc_address in hits:
|
|
294
|
+
doc = searcher.doc(doc_address)
|
|
295
|
+
content = {k: v[0] for k, v in doc.to_dict().items()}
|
|
296
|
+
results.append(
|
|
297
|
+
SearchResult(
|
|
298
|
+
id=content["id"],
|
|
299
|
+
score=score,
|
|
300
|
+
content=content,
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Rank fusion using reciprocal rank fusion
|
|
305
|
+
doc_scores = {}
|
|
306
|
+
for rank, result in enumerate(results, 1):
|
|
307
|
+
doc_key = str(result.id)
|
|
308
|
+
if doc_key not in doc_scores:
|
|
309
|
+
doc_scores[doc_key] = 0
|
|
310
|
+
doc_scores[doc_key] += 1 / (60 + rank)
|
|
311
|
+
|
|
312
|
+
# Sort and deduplicate
|
|
313
|
+
unique_results = {}
|
|
314
|
+
for result in results:
|
|
315
|
+
doc_key = str(result.id)
|
|
316
|
+
if doc_key not in unique_results:
|
|
317
|
+
unique_results[doc_key] = result
|
|
318
|
+
|
|
319
|
+
sorted_results = sorted(
|
|
320
|
+
unique_results.values(),
|
|
321
|
+
key=lambda x: doc_scores[str(x.id)],
|
|
322
|
+
reverse=True,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
return sorted_results[:limit]
|
|
326
|
+
|
|
327
|
+
def cache_documents_for_fetch(
|
|
328
|
+
self, documents: list[dict], id_column: str = "id"
|
|
329
|
+
) -> None:
|
|
330
|
+
"""Cache documents for efficient retrieval by ID.
|
|
331
|
+
|
|
332
|
+
This builds an in-memory lookup table for O(1) document fetching.
|
|
333
|
+
Should be called after build_index() with the original document data.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
documents: List of document records (typically original unfiltered data)
|
|
337
|
+
id_column: Name of the ID field in the documents (default: "id")
|
|
338
|
+
"""
|
|
339
|
+
self.id_to_record = {}
|
|
340
|
+
for record in documents:
|
|
341
|
+
doc_id = str(record.get(id_column, ""))
|
|
342
|
+
self.id_to_record[doc_id] = record
|
|
343
|
+
|
|
344
|
+
def fetch(self, document_ids: list[str]) -> tuple[list[dict], list[str], list[str]]:
|
|
345
|
+
"""Fetch documents by ID using cached lookup.
|
|
346
|
+
|
|
347
|
+
Returns full original documents from the cache built by cache_documents_for_fetch().
|
|
348
|
+
If cache is empty, falls back to querying the Tantivy index (slower, returns only indexed fields).
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
document_ids: List of document IDs to retrieve
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
Tuple of (documents, found_ids, missing_ids) where:
|
|
355
|
+
- documents: List of document dicts with all fields
|
|
356
|
+
- found_ids: List of IDs that were successfully found
|
|
357
|
+
- missing_ids: List of IDs that were not found
|
|
358
|
+
"""
|
|
359
|
+
results = []
|
|
360
|
+
found_ids = []
|
|
361
|
+
missing_ids = []
|
|
362
|
+
|
|
363
|
+
# Use cached lookup if available
|
|
364
|
+
if self.id_to_record:
|
|
365
|
+
for doc_id in document_ids:
|
|
366
|
+
if doc_id in self.id_to_record:
|
|
367
|
+
results.append(self.id_to_record[doc_id])
|
|
368
|
+
found_ids.append(doc_id)
|
|
369
|
+
else:
|
|
370
|
+
missing_ids.append(doc_id)
|
|
371
|
+
return results, found_ids, missing_ids
|
|
372
|
+
|
|
373
|
+
# Fallback: query Tantivy index (slower, returns only indexed fields)
|
|
374
|
+
assert self.index is not None, "Index not built and no cached documents"
|
|
375
|
+
|
|
376
|
+
self.index.reload()
|
|
377
|
+
searcher = self.index.searcher()
|
|
378
|
+
|
|
379
|
+
for doc_id in document_ids:
|
|
380
|
+
query_str = f'id:"{doc_id}"'
|
|
381
|
+
try:
|
|
382
|
+
query = self.index.parse_query(query_str, ["id"])
|
|
383
|
+
hits = searcher.search(query, limit=1).hits
|
|
384
|
+
|
|
385
|
+
if hits:
|
|
386
|
+
_, doc_address = hits[0]
|
|
387
|
+
doc = searcher.doc(doc_address)
|
|
388
|
+
content = {k: v[0] for k, v in doc.to_dict().items()}
|
|
389
|
+
results.append(content)
|
|
390
|
+
found_ids.append(doc_id)
|
|
391
|
+
else:
|
|
392
|
+
missing_ids.append(doc_id)
|
|
393
|
+
except Exception:
|
|
394
|
+
missing_ids.append(doc_id)
|
|
395
|
+
|
|
396
|
+
return results, found_ids, missing_ids
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RLM (Recursive Language Model) for lm-deluge.
|
|
3
|
+
|
|
4
|
+
Enables models to process long contexts through a REPL environment
|
|
5
|
+
with recursive LM calls, based on the RLM paper:
|
|
6
|
+
https://alexzhang13.github.io/blog/2025/rlm/
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
from lm_deluge.prompt import Conversation
|
|
15
|
+
from lm_deluge.tool import Tool
|
|
16
|
+
|
|
17
|
+
from .executor import RLMExecutionError, RLMExecutor
|
|
18
|
+
from .parse import RLMSecurityError
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from lm_deluge.api_requests.base import APIResponse
|
|
22
|
+
from lm_deluge.client import _LLMClient
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
RLM_SYSTEM_PROMPT = """You have access to a long context stored in the variable `{context_var}`.
|
|
26
|
+
You can write Python code to analyze this context using the `execute` tool.
|
|
27
|
+
|
|
28
|
+
IMPORTANT RULES:
|
|
29
|
+
1. You MUST use print() to see output. Bare expressions produce NO output.
|
|
30
|
+
2. You MUST call final(answer) when you have the answer. This is required!
|
|
31
|
+
|
|
32
|
+
Available in your code environment:
|
|
33
|
+
- `{context_var}`: The full context as a string ({context_len:,} characters)
|
|
34
|
+
- `lm(prompt)`: Make a recursive LLM call (runs in parallel when possible)
|
|
35
|
+
- `final(answer)`: Signal completion with the given answer - YOU MUST CALL THIS!
|
|
36
|
+
- `final_var(varname)`: Signal completion with a variable's value
|
|
37
|
+
- Modules: `re`, `math`, `collections`, `json` (imports are allowed but optional)
|
|
38
|
+
- From collections: `Counter`, `defaultdict`, `deque`, `namedtuple`, `OrderedDict`
|
|
39
|
+
- Standard builtins: `len`, `str`, `int`, `list`, `dict`, `sum`, `sorted`, `map`, `filter`, etc.
|
|
40
|
+
|
|
41
|
+
Example - count word occurrences:
|
|
42
|
+
```python
|
|
43
|
+
count = len(re.findall(r'\\bword\\b', {context_var}))
|
|
44
|
+
print(f"Found {{count}} occurrences")
|
|
45
|
+
final(count)
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Example - use Counter:
|
|
49
|
+
```python
|
|
50
|
+
words = {context_var}.split()
|
|
51
|
+
counts = Counter(words)
|
|
52
|
+
print(counts.most_common(10))
|
|
53
|
+
final(counts.most_common(10))
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
Example - analyze with lm() calls:
|
|
57
|
+
```python
|
|
58
|
+
chunks = [{context_var}[i:i+2000] for i in range(0, len({context_var}), 2000)][:3]
|
|
59
|
+
summaries = [lm(f"Summarize: {{chunk}}") for chunk in chunks]
|
|
60
|
+
combined = "\\n".join(str(s) for s in summaries)
|
|
61
|
+
final(f"Summary:\\n{{combined}}")
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Variables persist between execute() calls. Always call final() when you have the answer!
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class RLMManager:
|
|
69
|
+
"""Manages RLM execution for a long context.
|
|
70
|
+
|
|
71
|
+
The RLMManager exposes a REPL-like interface as tools that allow an LLM
|
|
72
|
+
to analyze a long context by writing Python code.
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
>>> manager = RLMManager(
|
|
76
|
+
... context=long_document,
|
|
77
|
+
... client=LLMClient("gpt-4.1-mini"), # For lm() calls
|
|
78
|
+
... )
|
|
79
|
+
>>> main_client = LLMClient("gpt-4.1")
|
|
80
|
+
>>> conv = Conversation.system(manager.get_system_prompt())
|
|
81
|
+
>>> conv = conv.user("What are the main themes in this document?")
|
|
82
|
+
>>> conv, resp = await main_client.run_agent_loop(
|
|
83
|
+
... conv,
|
|
84
|
+
... tools=manager.get_tools(),
|
|
85
|
+
... )
|
|
86
|
+
>>> if manager.is_complete:
|
|
87
|
+
... print(manager.final_answer)
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
context: str,
|
|
93
|
+
client: _LLMClient,
|
|
94
|
+
context_var_name: str = "CONTEXT",
|
|
95
|
+
max_lm_calls_per_execution: int = 20,
|
|
96
|
+
):
|
|
97
|
+
"""Initialize the RLMManager.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
context: The long context string to analyze
|
|
101
|
+
client: LLMClient for making recursive lm() calls
|
|
102
|
+
context_var_name: Variable name for the context (default: "CONTEXT")
|
|
103
|
+
max_lm_calls_per_execution: Maximum lm() calls allowed per execute() call
|
|
104
|
+
"""
|
|
105
|
+
self.context = context
|
|
106
|
+
self.client = client
|
|
107
|
+
self.context_var_name = context_var_name
|
|
108
|
+
self.max_lm_calls_per_execution = max_lm_calls_per_execution
|
|
109
|
+
|
|
110
|
+
self.executor = RLMExecutor(
|
|
111
|
+
context=context,
|
|
112
|
+
client=client,
|
|
113
|
+
context_var_name=context_var_name,
|
|
114
|
+
max_lm_calls_per_execution=max_lm_calls_per_execution,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self._final_answer: str | None = None
|
|
118
|
+
self._tools: list[Tool] | None = None
|
|
119
|
+
|
|
120
|
+
async def _execute(self, code: str) -> str:
|
|
121
|
+
"""Execute code against the context."""
|
|
122
|
+
try:
|
|
123
|
+
answer, is_final = await self.executor.execute(code)
|
|
124
|
+
if is_final:
|
|
125
|
+
self._final_answer = answer
|
|
126
|
+
# Truncate for display but keep full answer stored
|
|
127
|
+
display = answer[:1000] + "..." if len(answer) > 1000 else answer
|
|
128
|
+
return f"[FINAL ANSWER SET]\n{display}"
|
|
129
|
+
return answer
|
|
130
|
+
except RLMSecurityError as e:
|
|
131
|
+
return f"Security error: {e}"
|
|
132
|
+
except RLMExecutionError as e:
|
|
133
|
+
return f"Execution error: {e}"
|
|
134
|
+
except Exception as e:
|
|
135
|
+
return f"Unexpected error: {type(e).__name__}: {e}"
|
|
136
|
+
|
|
137
|
+
def get_system_prompt(self) -> str:
|
|
138
|
+
"""Get the system prompt explaining the RLM environment."""
|
|
139
|
+
return RLM_SYSTEM_PROMPT.format(
|
|
140
|
+
context_var=self.context_var_name,
|
|
141
|
+
context_len=len(self.context),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def get_tools(self) -> list[Tool]:
|
|
145
|
+
"""Get the tools for RLM execution."""
|
|
146
|
+
if self._tools is not None:
|
|
147
|
+
return self._tools
|
|
148
|
+
|
|
149
|
+
self._tools = [
|
|
150
|
+
Tool(
|
|
151
|
+
name="execute",
|
|
152
|
+
description=(
|
|
153
|
+
f"Execute Python code to analyze the context. "
|
|
154
|
+
f"The context ({len(self.context):,} chars) is available as `{self.context_var_name}`. "
|
|
155
|
+
f"Use `lm(prompt)` for recursive LLM calls (parallel when possible), and "
|
|
156
|
+
f"`final(answer)` or `final_var(varname)` to signal completion. "
|
|
157
|
+
f"Variables persist between calls. "
|
|
158
|
+
f"Modules available without import: re, math, collections, json."
|
|
159
|
+
),
|
|
160
|
+
run=self._execute,
|
|
161
|
+
parameters={
|
|
162
|
+
"code": {
|
|
163
|
+
"type": "string",
|
|
164
|
+
"description": "Python code to execute",
|
|
165
|
+
}
|
|
166
|
+
},
|
|
167
|
+
required=["code"],
|
|
168
|
+
)
|
|
169
|
+
]
|
|
170
|
+
return self._tools
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def is_complete(self) -> bool:
|
|
174
|
+
"""Check if FINAL() was called."""
|
|
175
|
+
return self._final_answer is not None
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def final_answer(self) -> str | None:
|
|
179
|
+
"""Get the final answer if set."""
|
|
180
|
+
return self._final_answer
|
|
181
|
+
|
|
182
|
+
def reset(self) -> None:
|
|
183
|
+
"""Reset the RLM state."""
|
|
184
|
+
self.executor.reset()
|
|
185
|
+
self._final_answer = None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@dataclass
|
|
189
|
+
class RLMResult:
|
|
190
|
+
"""Result from RLMPipeline."""
|
|
191
|
+
|
|
192
|
+
answer: str
|
|
193
|
+
conversation: Conversation
|
|
194
|
+
rounds_used: int
|
|
195
|
+
final_response: APIResponse
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class RLMPipeline:
|
|
199
|
+
"""High-level pipeline for RLM processing.
|
|
200
|
+
|
|
201
|
+
A thin wrapper that takes a long context and question, sets up an RLMManager,
|
|
202
|
+
runs an agent loop until final() is called, and returns the result.
|
|
203
|
+
|
|
204
|
+
Example:
|
|
205
|
+
>>> pipeline = RLMPipeline(
|
|
206
|
+
... context=long_document,
|
|
207
|
+
... client=LLMClient("gpt-4.1"), # Smart orchestrator
|
|
208
|
+
... lm_client=LLMClient("gpt-4.1-mini"), # Cheaper model for lm() calls
|
|
209
|
+
... question="What are the main themes in this document?",
|
|
210
|
+
... )
|
|
211
|
+
>>> result = await pipeline.run()
|
|
212
|
+
>>> print(result.answer)
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
context: str,
|
|
218
|
+
client: _LLMClient,
|
|
219
|
+
question: str,
|
|
220
|
+
*,
|
|
221
|
+
lm_client: _LLMClient | None = None,
|
|
222
|
+
context_var_name: str = "CONTEXT",
|
|
223
|
+
max_rounds: int = 15,
|
|
224
|
+
max_lm_calls_per_execution: int = 20,
|
|
225
|
+
):
|
|
226
|
+
"""Initialize the RLMPipeline.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
context: The long context string to analyze
|
|
230
|
+
client: LLMClient for the main agent (runs the execute loop)
|
|
231
|
+
question: The question to answer about the context
|
|
232
|
+
lm_client: LLMClient for lm() calls (defaults to same as client)
|
|
233
|
+
context_var_name: Variable name for the context (default: "CONTEXT")
|
|
234
|
+
max_rounds: Maximum agent loop rounds (default: 15)
|
|
235
|
+
max_lm_calls_per_execution: Maximum lm() calls per execute() call
|
|
236
|
+
"""
|
|
237
|
+
self.context = context
|
|
238
|
+
self.client = client
|
|
239
|
+
self.lm_client = lm_client or client
|
|
240
|
+
self.question = question
|
|
241
|
+
self.context_var_name = context_var_name
|
|
242
|
+
self.max_rounds = max_rounds
|
|
243
|
+
self.max_lm_calls_per_execution = max_lm_calls_per_execution
|
|
244
|
+
|
|
245
|
+
async def run(self) -> RLMResult:
|
|
246
|
+
"""Run the RLM pipeline until completion."""
|
|
247
|
+
manager = RLMManager(
|
|
248
|
+
context=self.context,
|
|
249
|
+
client=self.lm_client,
|
|
250
|
+
context_var_name=self.context_var_name,
|
|
251
|
+
max_lm_calls_per_execution=self.max_lm_calls_per_execution,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Build conversation with system prompt and question
|
|
255
|
+
conv = Conversation.system(manager.get_system_prompt())
|
|
256
|
+
conv = conv.user(
|
|
257
|
+
f"Question to answer about the context:\n\n{self.question}\n\n"
|
|
258
|
+
"Use the execute tool to analyze the context and find the answer. "
|
|
259
|
+
"Start by peeking at the context structure, then use appropriate "
|
|
260
|
+
"techniques (regex, chunking, lm() calls) to find the answer. "
|
|
261
|
+
"Call final(answer) when you have the answer."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Run agent loop
|
|
265
|
+
conv, resp = await self.client.run_agent_loop(
|
|
266
|
+
conv,
|
|
267
|
+
tools=manager.get_tools(),
|
|
268
|
+
max_rounds=self.max_rounds,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Extract answer
|
|
272
|
+
if manager.is_complete:
|
|
273
|
+
answer = manager.final_answer or "No answer produced"
|
|
274
|
+
else:
|
|
275
|
+
# Model stopped without calling final() - use last response
|
|
276
|
+
answer = resp.completion or "No answer produced (final not called)"
|
|
277
|
+
|
|
278
|
+
# Count rounds used
|
|
279
|
+
rounds_used = sum(1 for m in conv.messages if m.role == "assistant")
|
|
280
|
+
|
|
281
|
+
return RLMResult(
|
|
282
|
+
answer=answer,
|
|
283
|
+
conversation=conv,
|
|
284
|
+
rounds_used=rounds_used,
|
|
285
|
+
final_response=resp,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
__all__ = [
|
|
290
|
+
"RLMManager",
|
|
291
|
+
"RLMPipeline",
|
|
292
|
+
"RLMResult",
|
|
293
|
+
"RLMExecutor",
|
|
294
|
+
"RLMExecutionError",
|
|
295
|
+
"RLMSecurityError",
|
|
296
|
+
]
|