groknroll 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- groknroll/__init__.py +36 -0
- groknroll/__main__.py +9 -0
- groknroll/agents/__init__.py +18 -0
- groknroll/agents/agent_manager.py +187 -0
- groknroll/agents/base_agent.py +118 -0
- groknroll/agents/build_agent.py +231 -0
- groknroll/agents/plan_agent.py +215 -0
- groknroll/cli/__init__.py +7 -0
- groknroll/cli/enhanced_cli.py +372 -0
- groknroll/cli/large_codebase_cli.py +413 -0
- groknroll/cli/main.py +331 -0
- groknroll/cli/rlm_commands.py +258 -0
- groknroll/clients/__init__.py +63 -0
- groknroll/clients/anthropic.py +112 -0
- groknroll/clients/azure_openai.py +142 -0
- groknroll/clients/base_lm.py +33 -0
- groknroll/clients/gemini.py +162 -0
- groknroll/clients/litellm.py +105 -0
- groknroll/clients/openai.py +129 -0
- groknroll/clients/portkey.py +94 -0
- groknroll/core/__init__.py +9 -0
- groknroll/core/agent.py +339 -0
- groknroll/core/comms_utils.py +264 -0
- groknroll/core/context.py +251 -0
- groknroll/core/exceptions.py +181 -0
- groknroll/core/large_codebase.py +564 -0
- groknroll/core/lm_handler.py +206 -0
- groknroll/core/rlm.py +446 -0
- groknroll/core/rlm_codebase.py +448 -0
- groknroll/core/rlm_integration.py +256 -0
- groknroll/core/types.py +276 -0
- groknroll/environments/__init__.py +34 -0
- groknroll/environments/base_env.py +182 -0
- groknroll/environments/constants.py +32 -0
- groknroll/environments/docker_repl.py +336 -0
- groknroll/environments/local_repl.py +388 -0
- groknroll/environments/modal_repl.py +502 -0
- groknroll/environments/prime_repl.py +588 -0
- groknroll/logger/__init__.py +4 -0
- groknroll/logger/rlm_logger.py +63 -0
- groknroll/logger/verbose.py +393 -0
- groknroll/operations/__init__.py +15 -0
- groknroll/operations/bash_ops.py +447 -0
- groknroll/operations/file_ops.py +473 -0
- groknroll/operations/git_ops.py +620 -0
- groknroll/oracle/__init__.py +11 -0
- groknroll/oracle/codebase_indexer.py +238 -0
- groknroll/oracle/oracle_agent.py +278 -0
- groknroll/setup.py +34 -0
- groknroll/storage/__init__.py +14 -0
- groknroll/storage/database.py +272 -0
- groknroll/storage/models.py +128 -0
- groknroll/utils/__init__.py +0 -0
- groknroll/utils/parsing.py +168 -0
- groknroll/utils/prompts.py +146 -0
- groknroll/utils/rlm_utils.py +19 -0
- groknroll-2.0.0.dist-info/METADATA +246 -0
- groknroll-2.0.0.dist-info/RECORD +62 -0
- groknroll-2.0.0.dist-info/WHEEL +5 -0
- groknroll-2.0.0.dist-info/entry_points.txt +3 -0
- groknroll-2.0.0.dist-info/licenses/LICENSE +21 -0
- groknroll-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,588 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Prime Intellect REPL environment that runs Python code in Prime Sandboxes.
|
|
3
|
+
|
|
4
|
+
Uses the Prime SDK (https://docs.primeintellect.ai/sandboxes/sdk) for sandbox management.
|
|
5
|
+
Follows the same HTTP broker pattern as ModalREPL for LLM communication.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import base64
|
|
9
|
+
import json
|
|
10
|
+
import textwrap
|
|
11
|
+
import threading
|
|
12
|
+
import time
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import requests
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
from prime_sandboxes import (
|
|
18
|
+
APIClient,
|
|
19
|
+
BackgroundJob,
|
|
20
|
+
CreateSandboxRequest,
|
|
21
|
+
SandboxClient,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from groknroll.core.comms_utils import LMRequest, send_lm_request, send_lm_request_batched
|
|
25
|
+
from groknroll.core.types import REPLResult, RLMChatCompletion
|
|
26
|
+
from groknroll.environments.base_env import IsolatedEnv
|
|
27
|
+
from groknroll.environments.constants import APT_PACKAGES, PIP_PACKAGES
|
|
28
|
+
|
|
29
|
+
load_dotenv()
|
|
30
|
+
|
|
31
|
+
# =============================================================================
|
|
32
|
+
# Broker Server Script (runs inside sandbox, handles LLM request queue)
|
|
33
|
+
# =============================================================================
|
|
34
|
+
|
|
35
|
+
_BROKER_SCRIPT = textwrap.dedent(
|
|
36
|
+
'''
|
|
37
|
+
import json
|
|
38
|
+
import threading
|
|
39
|
+
import uuid
|
|
40
|
+
from flask import Flask, request, jsonify
|
|
41
|
+
|
|
42
|
+
app = Flask(__name__)
|
|
43
|
+
|
|
44
|
+
# Request queue: {{request_id: {{"request": {{...}}, "response": None, "event": Event}}}}
|
|
45
|
+
pending_requests = {{}}
|
|
46
|
+
lock = threading.Lock()
|
|
47
|
+
|
|
48
|
+
@app.route("/health")
|
|
49
|
+
def health():
|
|
50
|
+
return jsonify({{"status": "ok"}})
|
|
51
|
+
|
|
52
|
+
@app.route("/enqueue", methods=["POST"])
|
|
53
|
+
def enqueue():
|
|
54
|
+
"""Called by sandbox code to submit an LLM request and wait for response."""
|
|
55
|
+
data = request.json
|
|
56
|
+
request_id = str(uuid.uuid4())
|
|
57
|
+
event = threading.Event()
|
|
58
|
+
|
|
59
|
+
with lock:
|
|
60
|
+
pending_requests[request_id] = {{
|
|
61
|
+
"request": data,
|
|
62
|
+
"response": None,
|
|
63
|
+
"event": event,
|
|
64
|
+
}}
|
|
65
|
+
|
|
66
|
+
# Wait for response (with timeout)
|
|
67
|
+
event.wait(timeout=300)
|
|
68
|
+
|
|
69
|
+
with lock:
|
|
70
|
+
entry = pending_requests.pop(request_id, None)
|
|
71
|
+
|
|
72
|
+
if entry and entry["response"] is not None:
|
|
73
|
+
return jsonify(entry["response"])
|
|
74
|
+
else:
|
|
75
|
+
return jsonify({{"error": "Request timed out"}}), 504
|
|
76
|
+
|
|
77
|
+
@app.route("/pending")
|
|
78
|
+
def get_pending():
|
|
79
|
+
"""Called by PrimeREPL to get pending requests."""
|
|
80
|
+
with lock:
|
|
81
|
+
pending = [
|
|
82
|
+
{{"id": rid, "request": entry["request"]}}
|
|
83
|
+
for rid, entry in pending_requests.items()
|
|
84
|
+
if entry["response"] is None
|
|
85
|
+
]
|
|
86
|
+
return jsonify({{"pending": pending}})
|
|
87
|
+
|
|
88
|
+
@app.route("/respond", methods=["POST"])
|
|
89
|
+
def respond():
|
|
90
|
+
"""Called by PrimeREPL to submit a response."""
|
|
91
|
+
data = request.json
|
|
92
|
+
request_id = data.get("id")
|
|
93
|
+
response = data.get("response")
|
|
94
|
+
|
|
95
|
+
with lock:
|
|
96
|
+
if request_id in pending_requests:
|
|
97
|
+
pending_requests[request_id]["response"] = response
|
|
98
|
+
pending_requests[request_id]["event"].set()
|
|
99
|
+
return jsonify({{"status": "ok"}})
|
|
100
|
+
|
|
101
|
+
return jsonify({{"error": "Request not found"}}), 404
|
|
102
|
+
|
|
103
|
+
if __name__ == "__main__":
|
|
104
|
+
app.run(host="0.0.0.0", port={broker_port}, threaded=True)
|
|
105
|
+
'''
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# =============================================================================
|
|
110
|
+
# Execution Script (runs inside the sandbox for each code block)
|
|
111
|
+
# =============================================================================
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _build_exec_script(code: str, broker_port: int = 8888, depth: int = 1) -> str:
|
|
115
|
+
"""
|
|
116
|
+
Build a script that executes code with state persistence.
|
|
117
|
+
LLM queries go through the local broker server.
|
|
118
|
+
"""
|
|
119
|
+
code_b64 = base64.b64encode(code.encode()).decode()
|
|
120
|
+
|
|
121
|
+
return textwrap.dedent(
|
|
122
|
+
f'''
|
|
123
|
+
import sys
|
|
124
|
+
import io
|
|
125
|
+
import json
|
|
126
|
+
import base64
|
|
127
|
+
import traceback
|
|
128
|
+
import os
|
|
129
|
+
import requests
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
import dill
|
|
133
|
+
except ImportError:
|
|
134
|
+
import pickle as dill
|
|
135
|
+
|
|
136
|
+
# =============================================================================
|
|
137
|
+
# LLM Query Functions (via local broker)
|
|
138
|
+
# =============================================================================
|
|
139
|
+
|
|
140
|
+
BROKER_URL = "http://127.0.0.1:{broker_port}"
|
|
141
|
+
|
|
142
|
+
def llm_query(prompt, model=None):
|
|
143
|
+
"""Query the LM via the broker."""
|
|
144
|
+
try:
|
|
145
|
+
response = requests.post(
|
|
146
|
+
f"{{BROKER_URL}}/enqueue",
|
|
147
|
+
json={{"type": "single", "prompt": prompt, "model": model, "depth": {depth}}},
|
|
148
|
+
timeout=300,
|
|
149
|
+
)
|
|
150
|
+
data = response.json()
|
|
151
|
+
if data.get("error"):
|
|
152
|
+
return f"Error: {{data['error']}}"
|
|
153
|
+
return data.get("response", "Error: No response")
|
|
154
|
+
except Exception as e:
|
|
155
|
+
return f"Error: LM query failed - {{e}}"
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def llm_query_batched(prompts, model=None):
|
|
159
|
+
"""Query the LM with multiple prompts."""
|
|
160
|
+
try:
|
|
161
|
+
response = requests.post(
|
|
162
|
+
f"{{BROKER_URL}}/enqueue",
|
|
163
|
+
json={{"type": "batched", "prompts": prompts, "model": model, "depth": {depth}}},
|
|
164
|
+
timeout=300,
|
|
165
|
+
)
|
|
166
|
+
data = response.json()
|
|
167
|
+
if data.get("error"):
|
|
168
|
+
return [f"Error: {{data['error']}}"] * len(prompts)
|
|
169
|
+
return data.get("responses", ["Error: No response"] * len(prompts))
|
|
170
|
+
except Exception as e:
|
|
171
|
+
return [f"Error: LM query failed - {{e}}"] * len(prompts)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# =============================================================================
|
|
175
|
+
# State Management
|
|
176
|
+
# =============================================================================
|
|
177
|
+
|
|
178
|
+
STATE_FILE = "/tmp/rlm_state.dill"
|
|
179
|
+
|
|
180
|
+
def load_state():
|
|
181
|
+
if os.path.exists(STATE_FILE):
|
|
182
|
+
try:
|
|
183
|
+
with open(STATE_FILE, "rb") as f:
|
|
184
|
+
return dill.load(f)
|
|
185
|
+
except:
|
|
186
|
+
pass
|
|
187
|
+
return {{}}
|
|
188
|
+
|
|
189
|
+
def save_state(state):
|
|
190
|
+
clean_state = {{}}
|
|
191
|
+
for k, v in state.items():
|
|
192
|
+
if k.startswith("_"):
|
|
193
|
+
continue
|
|
194
|
+
try:
|
|
195
|
+
dill.dumps(v)
|
|
196
|
+
clean_state[k] = v
|
|
197
|
+
except:
|
|
198
|
+
pass
|
|
199
|
+
with open(STATE_FILE, "wb") as f:
|
|
200
|
+
dill.dump(clean_state, f)
|
|
201
|
+
|
|
202
|
+
def serialize_locals(state):
|
|
203
|
+
result = {{}}
|
|
204
|
+
for k, v in state.items():
|
|
205
|
+
if k.startswith("_"):
|
|
206
|
+
continue
|
|
207
|
+
try:
|
|
208
|
+
result[k] = repr(v)
|
|
209
|
+
except:
|
|
210
|
+
result[k] = f"<{{type(v).__name__}}>"
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
# =============================================================================
|
|
214
|
+
# Execution
|
|
215
|
+
# =============================================================================
|
|
216
|
+
|
|
217
|
+
_locals = load_state()
|
|
218
|
+
|
|
219
|
+
def FINAL_VAR(variable_name):
|
|
220
|
+
variable_name = variable_name.strip().strip("\\"\\'")
|
|
221
|
+
if variable_name in _locals:
|
|
222
|
+
return str(_locals[variable_name])
|
|
223
|
+
return f"Error: Variable '{{variable_name}}' not found"
|
|
224
|
+
|
|
225
|
+
_globals = {{
|
|
226
|
+
"__builtins__": __builtins__,
|
|
227
|
+
"__name__": "__main__",
|
|
228
|
+
"llm_query": llm_query,
|
|
229
|
+
"llm_query_batched": llm_query_batched,
|
|
230
|
+
"FINAL_VAR": FINAL_VAR,
|
|
231
|
+
}}
|
|
232
|
+
|
|
233
|
+
code = base64.b64decode("{code_b64}").decode()
|
|
234
|
+
|
|
235
|
+
stdout_buf = io.StringIO()
|
|
236
|
+
stderr_buf = io.StringIO()
|
|
237
|
+
old_stdout, old_stderr = sys.stdout, sys.stderr
|
|
238
|
+
|
|
239
|
+
try:
|
|
240
|
+
sys.stdout = stdout_buf
|
|
241
|
+
sys.stderr = stderr_buf
|
|
242
|
+
combined = {{**_globals, **_locals}}
|
|
243
|
+
exec(code, combined, combined)
|
|
244
|
+
for key, value in combined.items():
|
|
245
|
+
if key not in _globals and not key.startswith("_"):
|
|
246
|
+
_locals[key] = value
|
|
247
|
+
except Exception as e:
|
|
248
|
+
traceback.print_exc(file=stderr_buf)
|
|
249
|
+
finally:
|
|
250
|
+
sys.stdout = old_stdout
|
|
251
|
+
sys.stderr = old_stderr
|
|
252
|
+
|
|
253
|
+
save_state(_locals)
|
|
254
|
+
|
|
255
|
+
result = {{
|
|
256
|
+
"stdout": stdout_buf.getvalue(),
|
|
257
|
+
"stderr": stderr_buf.getvalue(),
|
|
258
|
+
"locals": serialize_locals(_locals),
|
|
259
|
+
}}
|
|
260
|
+
print(json.dumps(result))
|
|
261
|
+
'''
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class PrimeREPL(IsolatedEnv):
|
|
266
|
+
"""
|
|
267
|
+
Prime Intellect REPL environment that runs Python code in Prime Sandboxes.
|
|
268
|
+
|
|
269
|
+
Uses Prime's port exposure for LLM communication:
|
|
270
|
+
- Sandbox runs a broker server exposed via sandboxes.expose()
|
|
271
|
+
- PrimeREPL polls the broker for pending LLM requests
|
|
272
|
+
- PrimeREPL forwards requests to the LM handler and posts responses back
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
BROKER_PORT = 8888
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
name: str = "rlm-sandbox",
|
|
280
|
+
docker_image: str = "python:3.11-slim",
|
|
281
|
+
timeout_minutes: int = 60,
|
|
282
|
+
lm_handler_address: tuple[str, int] | None = None,
|
|
283
|
+
context_payload: dict | list | str | None = None,
|
|
284
|
+
setup_code: str | None = None,
|
|
285
|
+
network_access: bool = True,
|
|
286
|
+
persistent: bool = False,
|
|
287
|
+
depth: int = 1,
|
|
288
|
+
**kwargs: Any,
|
|
289
|
+
):
|
|
290
|
+
super().__init__(persistent=persistent, depth=depth, **kwargs)
|
|
291
|
+
|
|
292
|
+
if persistent:
|
|
293
|
+
raise NotImplementedError(
|
|
294
|
+
"Persistent REPLs are currently not supported for environment: PrimeREPL"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
self.name = name
|
|
298
|
+
self.docker_image = docker_image
|
|
299
|
+
self.timeout_minutes = timeout_minutes
|
|
300
|
+
self.lm_handler_address = lm_handler_address
|
|
301
|
+
self.network_access = network_access
|
|
302
|
+
|
|
303
|
+
# Client and sandbox state
|
|
304
|
+
self.client: SandboxClient | None = None
|
|
305
|
+
self.sandbox_id: str | None = None
|
|
306
|
+
self.broker_job: BackgroundJob | None = None
|
|
307
|
+
self.broker_url: str | None = None
|
|
308
|
+
self.broker_exposure_id: str | None = None
|
|
309
|
+
|
|
310
|
+
# Polling thread for LLM requests
|
|
311
|
+
self.poller_thread: threading.Thread | None = None
|
|
312
|
+
self.poller_stop = threading.Event()
|
|
313
|
+
self.pending_llm_calls: list[RLMChatCompletion] = []
|
|
314
|
+
self._calls_lock = threading.Lock()
|
|
315
|
+
|
|
316
|
+
self.setup()
|
|
317
|
+
|
|
318
|
+
if context_payload is not None:
|
|
319
|
+
self.load_context(context_payload)
|
|
320
|
+
|
|
321
|
+
if setup_code:
|
|
322
|
+
self.execute_code(setup_code)
|
|
323
|
+
|
|
324
|
+
def setup(self):
|
|
325
|
+
"""Create the Prime sandbox, broker, and start polling."""
|
|
326
|
+
# Create the client
|
|
327
|
+
self.client = SandboxClient(APIClient())
|
|
328
|
+
|
|
329
|
+
# Create the sandbox
|
|
330
|
+
request = CreateSandboxRequest(
|
|
331
|
+
name=self.name,
|
|
332
|
+
docker_image=self.docker_image,
|
|
333
|
+
timeout_minutes=self.timeout_minutes,
|
|
334
|
+
network_access=self.network_access,
|
|
335
|
+
)
|
|
336
|
+
sandbox = self.client.create(request)
|
|
337
|
+
self.sandbox_id = sandbox.id
|
|
338
|
+
|
|
339
|
+
# Wait for sandbox to be ready
|
|
340
|
+
self.client.wait_for_creation(self.sandbox_id, max_attempts=self.timeout_minutes * 60)
|
|
341
|
+
|
|
342
|
+
# Install apt dependencies
|
|
343
|
+
apt_cmd = "apt-get update && apt-get install -y " + " ".join(APT_PACKAGES)
|
|
344
|
+
self.client.execute_command(self.sandbox_id, apt_cmd)
|
|
345
|
+
|
|
346
|
+
# Install pip dependencies
|
|
347
|
+
pip_cmd = "pip install " + " ".join(f'"{pkg}"' for pkg in PIP_PACKAGES)
|
|
348
|
+
self.client.execute_command(self.sandbox_id, pip_cmd)
|
|
349
|
+
|
|
350
|
+
# Write the broker script to the sandbox.
|
|
351
|
+
# Unlike Modal's sandbox.exec() which accepts separate args, Prime's
|
|
352
|
+
# start_background_job() takes a shell command string. We write to a file
|
|
353
|
+
# to avoid shell escaping issues with quotes/special chars in the script.
|
|
354
|
+
broker_script = _BROKER_SCRIPT.format(broker_port=self.BROKER_PORT)
|
|
355
|
+
broker_script_b64 = base64.b64encode(broker_script.encode()).decode()
|
|
356
|
+
self.client.execute_command(
|
|
357
|
+
self.sandbox_id,
|
|
358
|
+
f"echo '{broker_script_b64}' | base64 -d > /tmp/broker.py",
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Start the broker as a background job
|
|
362
|
+
self.broker_job = self.client.start_background_job(
|
|
363
|
+
self.sandbox_id,
|
|
364
|
+
"python /tmp/broker.py",
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Wait for broker to be ready with health check
|
|
368
|
+
self._wait_for_broker()
|
|
369
|
+
|
|
370
|
+
# Expose the broker port
|
|
371
|
+
exposed = self.client.expose(self.sandbox_id, port=self.BROKER_PORT, name="rlm-broker")
|
|
372
|
+
self.broker_url = exposed.url
|
|
373
|
+
self.broker_exposure_id = exposed.exposure_id
|
|
374
|
+
|
|
375
|
+
# Start polling thread if we have an LM handler
|
|
376
|
+
if self.lm_handler_address and self.broker_url:
|
|
377
|
+
self.poller_stop.clear()
|
|
378
|
+
self.poller_thread = threading.Thread(target=self._poll_broker, daemon=True)
|
|
379
|
+
self.poller_thread.start()
|
|
380
|
+
|
|
381
|
+
def _wait_for_broker(self, max_attempts: int = 30):
|
|
382
|
+
"""Wait for the broker to be ready by checking health endpoint."""
|
|
383
|
+
# Use Python to check health (curl may not be installed in slim images)
|
|
384
|
+
health_check_cmd = (
|
|
385
|
+
f'python -c "import requests; '
|
|
386
|
+
f"r = requests.get('http://127.0.0.1:{self.BROKER_PORT}/health', timeout=2); "
|
|
387
|
+
f'print(r.text)"'
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
for _ in range(max_attempts):
|
|
391
|
+
time.sleep(1)
|
|
392
|
+
try:
|
|
393
|
+
result = self.client.execute_command(
|
|
394
|
+
self.sandbox_id,
|
|
395
|
+
health_check_cmd,
|
|
396
|
+
)
|
|
397
|
+
if "ok" in result.stdout.lower():
|
|
398
|
+
return
|
|
399
|
+
except Exception:
|
|
400
|
+
pass
|
|
401
|
+
|
|
402
|
+
# Get broker logs for debugging by reading log files directly
|
|
403
|
+
error_info = "Broker failed to start."
|
|
404
|
+
if self.broker_job:
|
|
405
|
+
try:
|
|
406
|
+
stdout_result = self.client.execute_command(
|
|
407
|
+
self.sandbox_id,
|
|
408
|
+
f"cat {self.broker_job.stdout_log_file} 2>/dev/null || echo 'No stdout log'",
|
|
409
|
+
)
|
|
410
|
+
stderr_result = self.client.execute_command(
|
|
411
|
+
self.sandbox_id,
|
|
412
|
+
f"cat {self.broker_job.stderr_log_file} 2>/dev/null || echo 'No stderr log'",
|
|
413
|
+
)
|
|
414
|
+
error_info += f"\nstdout: {stdout_result.stdout}\nstderr: {stderr_result.stdout}"
|
|
415
|
+
except Exception as e:
|
|
416
|
+
error_info += f"\nFailed to read logs: {e}"
|
|
417
|
+
raise RuntimeError(error_info)
|
|
418
|
+
|
|
419
|
+
def _poll_broker(self):
|
|
420
|
+
"""Poll the broker for pending LLM requests and handle them."""
|
|
421
|
+
while not self.poller_stop.is_set():
|
|
422
|
+
try:
|
|
423
|
+
# Get pending requests
|
|
424
|
+
resp = requests.get(
|
|
425
|
+
f"{self.broker_url}/pending",
|
|
426
|
+
timeout=10,
|
|
427
|
+
)
|
|
428
|
+
pending = resp.json().get("pending", [])
|
|
429
|
+
|
|
430
|
+
for item in pending:
|
|
431
|
+
request_id = item["id"]
|
|
432
|
+
req_data = item["request"]
|
|
433
|
+
|
|
434
|
+
# Handle the request
|
|
435
|
+
response = self._handle_llm_request(req_data)
|
|
436
|
+
|
|
437
|
+
# Send response back
|
|
438
|
+
requests.post(
|
|
439
|
+
f"{self.broker_url}/respond",
|
|
440
|
+
json={"id": request_id, "response": response},
|
|
441
|
+
timeout=10,
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
except requests.exceptions.RequestException:
|
|
445
|
+
pass
|
|
446
|
+
except Exception:
|
|
447
|
+
pass
|
|
448
|
+
|
|
449
|
+
time.sleep(0.1)
|
|
450
|
+
|
|
451
|
+
def _handle_llm_request(self, req_data: dict) -> dict:
|
|
452
|
+
"""Handle an LLM request from the sandbox."""
|
|
453
|
+
req_type = req_data.get("type")
|
|
454
|
+
model = req_data.get("model")
|
|
455
|
+
|
|
456
|
+
if req_type == "single":
|
|
457
|
+
prompt = req_data.get("prompt")
|
|
458
|
+
request = LMRequest(prompt=prompt, model=model, depth=self.depth)
|
|
459
|
+
response = send_lm_request(self.lm_handler_address, request)
|
|
460
|
+
|
|
461
|
+
if not response.success:
|
|
462
|
+
return {"error": response.error}
|
|
463
|
+
|
|
464
|
+
# Track the call
|
|
465
|
+
with self._calls_lock:
|
|
466
|
+
self.pending_llm_calls.append(response.chat_completion)
|
|
467
|
+
|
|
468
|
+
return {"response": response.chat_completion.response}
|
|
469
|
+
|
|
470
|
+
elif req_type == "batched":
|
|
471
|
+
prompts = req_data.get("prompts", [])
|
|
472
|
+
responses = send_lm_request_batched(
|
|
473
|
+
self.lm_handler_address, prompts, model=model, depth=self.depth
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
results = []
|
|
477
|
+
for resp in responses:
|
|
478
|
+
if not resp.success:
|
|
479
|
+
results.append(f"Error: {resp.error}")
|
|
480
|
+
else:
|
|
481
|
+
with self._calls_lock:
|
|
482
|
+
self.pending_llm_calls.append(resp.chat_completion)
|
|
483
|
+
results.append(resp.chat_completion.response)
|
|
484
|
+
|
|
485
|
+
return {"responses": results}
|
|
486
|
+
|
|
487
|
+
return {"error": "Unknown request type"}
|
|
488
|
+
|
|
489
|
+
def load_context(self, context_payload: dict | list | str):
|
|
490
|
+
"""Load context into the sandbox environment."""
|
|
491
|
+
if isinstance(context_payload, str):
|
|
492
|
+
escaped = context_payload.replace("\\", "\\\\").replace('"""', '\\"\\"\\"')
|
|
493
|
+
context_code = f'context = """{escaped}"""'
|
|
494
|
+
else:
|
|
495
|
+
context_json = json.dumps(context_payload)
|
|
496
|
+
escaped_json = context_json.replace("\\", "\\\\").replace("'", "\\'")
|
|
497
|
+
context_code = f"import json; context = json.loads('{escaped_json}')"
|
|
498
|
+
|
|
499
|
+
self.execute_code(context_code)
|
|
500
|
+
|
|
501
|
+
def execute_code(self, code: str) -> REPLResult:
|
|
502
|
+
"""Execute code in the Prime sandbox and return result."""
|
|
503
|
+
start_time = time.perf_counter()
|
|
504
|
+
|
|
505
|
+
# Clear pending LLM calls
|
|
506
|
+
with self._calls_lock:
|
|
507
|
+
self.pending_llm_calls.clear()
|
|
508
|
+
|
|
509
|
+
# Build and write the script
|
|
510
|
+
script = _build_exec_script(code, self.BROKER_PORT, self.depth)
|
|
511
|
+
script_b64 = base64.b64encode(script.encode()).decode()
|
|
512
|
+
self.client.execute_command(
|
|
513
|
+
self.sandbox_id,
|
|
514
|
+
f"echo '{script_b64}' | base64 -d > /tmp/exec_script.py",
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
# Execute the script
|
|
518
|
+
result = self.client.execute_command(
|
|
519
|
+
self.sandbox_id, "python /tmp/exec_script.py", timeout=60 * 10
|
|
520
|
+
)
|
|
521
|
+
stdout = result.stdout
|
|
522
|
+
stderr = result.stderr
|
|
523
|
+
|
|
524
|
+
# Collect LLM calls made during this execution
|
|
525
|
+
with self._calls_lock:
|
|
526
|
+
pending_calls = self.pending_llm_calls.copy()
|
|
527
|
+
self.pending_llm_calls.clear()
|
|
528
|
+
|
|
529
|
+
execution_time = time.perf_counter() - start_time
|
|
530
|
+
|
|
531
|
+
# Parse the JSON result
|
|
532
|
+
try:
|
|
533
|
+
lines = stdout.strip().split("\n")
|
|
534
|
+
result_json = lines[-1] if lines else "{}"
|
|
535
|
+
parsed = json.loads(result_json)
|
|
536
|
+
|
|
537
|
+
return REPLResult(
|
|
538
|
+
stdout=parsed.get("stdout", ""),
|
|
539
|
+
stderr=parsed.get("stderr", "") + stderr,
|
|
540
|
+
locals=parsed.get("locals", {}),
|
|
541
|
+
execution_time=execution_time,
|
|
542
|
+
rlm_calls=pending_calls,
|
|
543
|
+
)
|
|
544
|
+
except json.JSONDecodeError:
|
|
545
|
+
return REPLResult(
|
|
546
|
+
stdout=stdout,
|
|
547
|
+
stderr=stderr or "Failed to parse execution result",
|
|
548
|
+
locals={},
|
|
549
|
+
execution_time=execution_time,
|
|
550
|
+
rlm_calls=pending_calls,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
def cleanup(self):
|
|
554
|
+
"""Terminate the sandbox and stop polling."""
|
|
555
|
+
# Stop the poller thread
|
|
556
|
+
if self.poller_thread is not None:
|
|
557
|
+
self.poller_stop.set()
|
|
558
|
+
self.poller_thread.join(timeout=2)
|
|
559
|
+
self.poller_thread = None
|
|
560
|
+
|
|
561
|
+
# Cleanup sandbox resources
|
|
562
|
+
if self.client is None or self.sandbox_id is None:
|
|
563
|
+
return
|
|
564
|
+
|
|
565
|
+
# Unexpose the broker port
|
|
566
|
+
if self.broker_exposure_id:
|
|
567
|
+
try:
|
|
568
|
+
self.client.unexpose(self.sandbox_id, self.broker_exposure_id)
|
|
569
|
+
except Exception:
|
|
570
|
+
pass
|
|
571
|
+
|
|
572
|
+
# Delete the sandbox
|
|
573
|
+
try:
|
|
574
|
+
self.client.delete(self.sandbox_id)
|
|
575
|
+
except Exception:
|
|
576
|
+
pass
|
|
577
|
+
|
|
578
|
+
self.sandbox_id = None
|
|
579
|
+
|
|
580
|
+
def __enter__(self):
|
|
581
|
+
return self
|
|
582
|
+
|
|
583
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
584
|
+
self.cleanup()
|
|
585
|
+
return False
|
|
586
|
+
|
|
587
|
+
def __del__(self):
|
|
588
|
+
self.cleanup()
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Logger for RLM iterations.
|
|
3
|
+
|
|
4
|
+
Writes RLMIteration data to JSON-lines files for analysis and debugging.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import uuid
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
|
|
12
|
+
from groknroll.core.types import RLMIteration, RLMMetadata
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RLMLogger:
|
|
16
|
+
"""Logger that writes RLMIteration data to a JSON-lines file."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, log_dir: str, file_name: str = "rlm"):
|
|
19
|
+
self.log_dir = log_dir
|
|
20
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
21
|
+
|
|
22
|
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
23
|
+
run_id = str(uuid.uuid4())[:8]
|
|
24
|
+
self.log_file_path = os.path.join(log_dir, f"{file_name}_{timestamp}_{run_id}.jsonl")
|
|
25
|
+
|
|
26
|
+
self._iteration_count = 0
|
|
27
|
+
self._metadata_logged = False
|
|
28
|
+
|
|
29
|
+
def log_metadata(self, metadata: RLMMetadata):
|
|
30
|
+
"""Log RLM metadata as the first entry in the file."""
|
|
31
|
+
if self._metadata_logged:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
entry = {
|
|
35
|
+
"type": "metadata",
|
|
36
|
+
"timestamp": datetime.now().isoformat(),
|
|
37
|
+
**metadata.to_dict(),
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
with open(self.log_file_path, "a") as f:
|
|
41
|
+
json.dump(entry, f)
|
|
42
|
+
f.write("\n")
|
|
43
|
+
|
|
44
|
+
self._metadata_logged = True
|
|
45
|
+
|
|
46
|
+
def log(self, iteration: RLMIteration):
|
|
47
|
+
"""Log an RLMIteration to the file."""
|
|
48
|
+
self._iteration_count += 1
|
|
49
|
+
|
|
50
|
+
entry = {
|
|
51
|
+
"type": "iteration",
|
|
52
|
+
"iteration": self._iteration_count,
|
|
53
|
+
"timestamp": datetime.now().isoformat(),
|
|
54
|
+
**iteration.to_dict(),
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
with open(self.log_file_path, "a") as f:
|
|
58
|
+
json.dump(entry, f)
|
|
59
|
+
f.write("\n")
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def iteration_count(self) -> int:
|
|
63
|
+
return self._iteration_count
|