groknroll 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- groknroll/__init__.py +36 -0
- groknroll/__main__.py +9 -0
- groknroll/agents/__init__.py +18 -0
- groknroll/agents/agent_manager.py +187 -0
- groknroll/agents/base_agent.py +118 -0
- groknroll/agents/build_agent.py +231 -0
- groknroll/agents/plan_agent.py +215 -0
- groknroll/cli/__init__.py +7 -0
- groknroll/cli/enhanced_cli.py +372 -0
- groknroll/cli/large_codebase_cli.py +413 -0
- groknroll/cli/main.py +331 -0
- groknroll/cli/rlm_commands.py +258 -0
- groknroll/clients/__init__.py +63 -0
- groknroll/clients/anthropic.py +112 -0
- groknroll/clients/azure_openai.py +142 -0
- groknroll/clients/base_lm.py +33 -0
- groknroll/clients/gemini.py +162 -0
- groknroll/clients/litellm.py +105 -0
- groknroll/clients/openai.py +129 -0
- groknroll/clients/portkey.py +94 -0
- groknroll/core/__init__.py +9 -0
- groknroll/core/agent.py +339 -0
- groknroll/core/comms_utils.py +264 -0
- groknroll/core/context.py +251 -0
- groknroll/core/exceptions.py +181 -0
- groknroll/core/large_codebase.py +564 -0
- groknroll/core/lm_handler.py +206 -0
- groknroll/core/rlm.py +446 -0
- groknroll/core/rlm_codebase.py +448 -0
- groknroll/core/rlm_integration.py +256 -0
- groknroll/core/types.py +276 -0
- groknroll/environments/__init__.py +34 -0
- groknroll/environments/base_env.py +182 -0
- groknroll/environments/constants.py +32 -0
- groknroll/environments/docker_repl.py +336 -0
- groknroll/environments/local_repl.py +388 -0
- groknroll/environments/modal_repl.py +502 -0
- groknroll/environments/prime_repl.py +588 -0
- groknroll/logger/__init__.py +4 -0
- groknroll/logger/rlm_logger.py +63 -0
- groknroll/logger/verbose.py +393 -0
- groknroll/operations/__init__.py +15 -0
- groknroll/operations/bash_ops.py +447 -0
- groknroll/operations/file_ops.py +473 -0
- groknroll/operations/git_ops.py +620 -0
- groknroll/oracle/__init__.py +11 -0
- groknroll/oracle/codebase_indexer.py +238 -0
- groknroll/oracle/oracle_agent.py +278 -0
- groknroll/setup.py +34 -0
- groknroll/storage/__init__.py +14 -0
- groknroll/storage/database.py +272 -0
- groknroll/storage/models.py +128 -0
- groknroll/utils/__init__.py +0 -0
- groknroll/utils/parsing.py +168 -0
- groknroll/utils/prompts.py +146 -0
- groknroll/utils/rlm_utils.py +19 -0
- groknroll-2.0.0.dist-info/METADATA +246 -0
- groknroll-2.0.0.dist-info/RECORD +62 -0
- groknroll-2.0.0.dist-info/WHEEL +5 -0
- groknroll-2.0.0.dist-info/entry_points.txt +3 -0
- groknroll-2.0.0.dist-info/licenses/LICENSE +21 -0
- groknroll-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LMHandler - Routes LLM requests from the RLM process and environment subprocesses.
|
|
3
|
+
|
|
4
|
+
Uses a multi-threaded socket server. Protocol: 4-byte length prefix + JSON payload.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import time
|
|
9
|
+
from socketserver import StreamRequestHandler, ThreadingTCPServer
|
|
10
|
+
from threading import Thread
|
|
11
|
+
|
|
12
|
+
from groknroll.clients.base_lm import BaseLM
|
|
13
|
+
from groknroll.core.comms_utils import LMRequest, LMResponse, socket_recv, socket_send
|
|
14
|
+
from groknroll.core.types import RLMChatCompletion, UsageSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LMRequestHandler(StreamRequestHandler):
|
|
18
|
+
"""Socket handler for LLM completion requests."""
|
|
19
|
+
|
|
20
|
+
def handle(self):
|
|
21
|
+
try:
|
|
22
|
+
request_data = socket_recv(self.connection)
|
|
23
|
+
if not isinstance(request_data, dict):
|
|
24
|
+
response = LMResponse.error_response("Request must be a JSON object")
|
|
25
|
+
socket_send(self.connection, response.to_dict())
|
|
26
|
+
return
|
|
27
|
+
|
|
28
|
+
request = LMRequest.from_dict(request_data)
|
|
29
|
+
handler: LMHandler = self.server.lm_handler # type: ignore
|
|
30
|
+
|
|
31
|
+
if request.is_batched:
|
|
32
|
+
# Batched request: process multiple prompts concurrently
|
|
33
|
+
response = self._handle_batched(request, handler)
|
|
34
|
+
elif request.prompt:
|
|
35
|
+
# Single request: process one prompt
|
|
36
|
+
response = self._handle_single(request, handler)
|
|
37
|
+
else:
|
|
38
|
+
response = LMResponse.error_response("Missing 'prompt' or 'prompts' in request.")
|
|
39
|
+
|
|
40
|
+
socket_send(self.connection, response.to_dict())
|
|
41
|
+
|
|
42
|
+
except Exception as e:
|
|
43
|
+
response = LMResponse.error_response(str(e))
|
|
44
|
+
socket_send(self.connection, response.to_dict())
|
|
45
|
+
|
|
46
|
+
def _handle_single(self, request: LMRequest, handler: "LMHandler") -> LMResponse:
|
|
47
|
+
"""Handle a single prompt request."""
|
|
48
|
+
client = handler.get_client(request.model, request.depth)
|
|
49
|
+
|
|
50
|
+
start_time = time.perf_counter()
|
|
51
|
+
content = client.completion(request.prompt)
|
|
52
|
+
end_time = time.perf_counter()
|
|
53
|
+
|
|
54
|
+
usage_summary = client.get_last_usage()
|
|
55
|
+
return LMResponse.success_response(
|
|
56
|
+
chat_completion=RLMChatCompletion(
|
|
57
|
+
root_model=request.model or client.model_name,
|
|
58
|
+
prompt=request.prompt,
|
|
59
|
+
response=content,
|
|
60
|
+
usage_summary=usage_summary,
|
|
61
|
+
execution_time=end_time - start_time,
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _handle_batched(self, request: LMRequest, handler: "LMHandler") -> LMResponse:
|
|
66
|
+
"""Handle a batched prompts request using async for concurrency."""
|
|
67
|
+
client = handler.get_client(request.model, request.depth)
|
|
68
|
+
|
|
69
|
+
start_time = time.perf_counter()
|
|
70
|
+
|
|
71
|
+
async def run_all():
|
|
72
|
+
tasks = [client.acompletion(prompt) for prompt in request.prompts]
|
|
73
|
+
return await asyncio.gather(*tasks)
|
|
74
|
+
|
|
75
|
+
results = asyncio.run(run_all())
|
|
76
|
+
end_time = time.perf_counter()
|
|
77
|
+
|
|
78
|
+
total_time = end_time - start_time
|
|
79
|
+
usage_summary = client.get_last_usage()
|
|
80
|
+
|
|
81
|
+
chat_completions = [
|
|
82
|
+
RLMChatCompletion(
|
|
83
|
+
root_model=request.model or client.model_name,
|
|
84
|
+
prompt=prompt,
|
|
85
|
+
response=content,
|
|
86
|
+
usage_summary=usage_summary,
|
|
87
|
+
execution_time=total_time / len(request.prompts), # approximate per-prompt time
|
|
88
|
+
)
|
|
89
|
+
for prompt, content in zip(request.prompts, results, strict=True)
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
return LMResponse.batched_success_response(chat_completions=chat_completions)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ThreadingLMServer(ThreadingTCPServer):
|
|
96
|
+
"""Multi-threaded TCP server for LM requests."""
|
|
97
|
+
|
|
98
|
+
daemon_threads = True
|
|
99
|
+
allow_reuse_address = True
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class LMHandler:
|
|
103
|
+
"""
|
|
104
|
+
Handles all LM calls from the RLM main process and environment subprocesses.
|
|
105
|
+
|
|
106
|
+
Uses a multi-threaded socket server for concurrent requests.
|
|
107
|
+
Protocol: 4-byte big-endian length prefix + JSON payload.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
client: BaseLM,
|
|
113
|
+
host: str = "127.0.0.1",
|
|
114
|
+
port: int = 0, # auto-assign available port
|
|
115
|
+
other_backend_client: BaseLM | None = None,
|
|
116
|
+
):
|
|
117
|
+
self.default_client = client
|
|
118
|
+
self.other_backend_client = other_backend_client
|
|
119
|
+
self.clients: dict[str, BaseLM] = {}
|
|
120
|
+
self.host = host
|
|
121
|
+
self._server: ThreadingLMServer | None = None
|
|
122
|
+
self._thread: Thread | None = None
|
|
123
|
+
self._port = port
|
|
124
|
+
|
|
125
|
+
self.register_client(client.model_name, client)
|
|
126
|
+
|
|
127
|
+
def register_client(self, model_name: str, client: BaseLM) -> None:
|
|
128
|
+
"""Register a client for a specific model name."""
|
|
129
|
+
self.clients[model_name] = client
|
|
130
|
+
|
|
131
|
+
def get_client(self, model: str | None = None, depth: int = 0) -> BaseLM:
|
|
132
|
+
"""Get client by model name or depth, or return default.
|
|
133
|
+
|
|
134
|
+
Routing logic:
|
|
135
|
+
- depth=0: use default_client (main backend)
|
|
136
|
+
- depth=1: use other_backend_client if it exists, otherwise default_client
|
|
137
|
+
- If model is specified and exists in clients, use that (overrides depth routing)
|
|
138
|
+
"""
|
|
139
|
+
if model and model in self.clients:
|
|
140
|
+
return self.clients[model]
|
|
141
|
+
|
|
142
|
+
# Route based on depth
|
|
143
|
+
if depth == 1 and self.other_backend_client is not None:
|
|
144
|
+
return self.other_backend_client
|
|
145
|
+
|
|
146
|
+
return self.default_client
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def port(self) -> int:
|
|
150
|
+
"""Get the actual port (useful when auto-assigned)."""
|
|
151
|
+
if self._server:
|
|
152
|
+
return self._server.server_address[1]
|
|
153
|
+
return self._port
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def address(self) -> tuple[str, int]:
|
|
157
|
+
"""Get (host, port) tuple for connecting."""
|
|
158
|
+
return (self.host, self.port)
|
|
159
|
+
|
|
160
|
+
def start(self) -> tuple[str, int]:
|
|
161
|
+
"""Start the socket server in a background thread. Returns (host, port)."""
|
|
162
|
+
if self._server is not None:
|
|
163
|
+
return self.address
|
|
164
|
+
|
|
165
|
+
self._server = ThreadingLMServer((self.host, self._port), LMRequestHandler)
|
|
166
|
+
self._server.lm_handler = self # type: ignore
|
|
167
|
+
|
|
168
|
+
self._thread = Thread(target=self._server.serve_forever, daemon=True)
|
|
169
|
+
self._thread.start()
|
|
170
|
+
|
|
171
|
+
return self.address
|
|
172
|
+
|
|
173
|
+
def stop(self):
|
|
174
|
+
"""Stop the socket server."""
|
|
175
|
+
if self._server:
|
|
176
|
+
self._server.shutdown()
|
|
177
|
+
self._server = None
|
|
178
|
+
self._thread = None
|
|
179
|
+
|
|
180
|
+
def completion(self, prompt: str, model: str | None = None) -> str:
|
|
181
|
+
"""Direct completion call (for main process use)."""
|
|
182
|
+
return self.get_client(model).completion(prompt)
|
|
183
|
+
|
|
184
|
+
def __enter__(self):
|
|
185
|
+
self.start()
|
|
186
|
+
return self
|
|
187
|
+
|
|
188
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
189
|
+
self.stop()
|
|
190
|
+
return False
|
|
191
|
+
|
|
192
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
193
|
+
"""Get the usage summary for all clients, merged into a single dict."""
|
|
194
|
+
merged = {}
|
|
195
|
+
# Include default client
|
|
196
|
+
default_summary = self.default_client.get_usage_summary()
|
|
197
|
+
merged.update(default_summary.model_usage_summaries)
|
|
198
|
+
# Include other backend client if it exists
|
|
199
|
+
if self.other_backend_client is not None:
|
|
200
|
+
other_summary = self.other_backend_client.get_usage_summary()
|
|
201
|
+
merged.update(other_summary.model_usage_summaries)
|
|
202
|
+
# Include all registered clients
|
|
203
|
+
for client in self.clients.values():
|
|
204
|
+
client_summary = client.get_usage_summary()
|
|
205
|
+
merged.update(client_summary.model_usage_summaries)
|
|
206
|
+
return UsageSummary(model_usage_summaries=merged)
|
groknroll/core/rlm.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from groknroll.clients import BaseLM, get_client
|
|
6
|
+
from groknroll.core.exceptions import (
|
|
7
|
+
CompletionTimeoutError,
|
|
8
|
+
CostLimitExceededError,
|
|
9
|
+
FinalAnswerNotFoundError,
|
|
10
|
+
IterationLimitExceededError,
|
|
11
|
+
)
|
|
12
|
+
from groknroll.core.lm_handler import LMHandler
|
|
13
|
+
from groknroll.core.types import (
|
|
14
|
+
ClientBackend,
|
|
15
|
+
CodeBlock,
|
|
16
|
+
EnvironmentType,
|
|
17
|
+
REPLResult,
|
|
18
|
+
RLMChatCompletion,
|
|
19
|
+
RLMIteration,
|
|
20
|
+
RLMMetadata,
|
|
21
|
+
)
|
|
22
|
+
from groknroll.environments import BaseEnv, SupportsPersistence, get_environment
|
|
23
|
+
from groknroll.logger import RLMLogger, VerbosePrinter
|
|
24
|
+
from groknroll.utils.parsing import (
|
|
25
|
+
find_code_blocks,
|
|
26
|
+
find_final_answer,
|
|
27
|
+
format_iteration,
|
|
28
|
+
)
|
|
29
|
+
from groknroll.utils.prompts import (
|
|
30
|
+
RLM_SYSTEM_PROMPT,
|
|
31
|
+
QueryMetadata,
|
|
32
|
+
build_rlm_system_prompt,
|
|
33
|
+
build_user_prompt,
|
|
34
|
+
)
|
|
35
|
+
from groknroll.utils.rlm_utils import filter_sensitive_keys
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RLM:
|
|
39
|
+
"""
|
|
40
|
+
Recursive Language Model class that the user instantiates and runs on their tasks.
|
|
41
|
+
|
|
42
|
+
Each completion() call spawns its own environment and LM handler, which are
|
|
43
|
+
cleaned up when the call completes.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
backend: ClientBackend = "openai",
|
|
49
|
+
backend_kwargs: dict[str, Any] | None = None,
|
|
50
|
+
environment: EnvironmentType = "local",
|
|
51
|
+
environment_kwargs: dict[str, Any] | None = None,
|
|
52
|
+
depth: int = 0,
|
|
53
|
+
max_depth: int = 1,
|
|
54
|
+
max_iterations: int = 30,
|
|
55
|
+
custom_system_prompt: str | None = None,
|
|
56
|
+
other_backends: list[ClientBackend] | None = None,
|
|
57
|
+
other_backend_kwargs: list[dict[str, Any]] | None = None,
|
|
58
|
+
logger: RLMLogger | None = None,
|
|
59
|
+
verbose: bool = False,
|
|
60
|
+
persistent: bool = False,
|
|
61
|
+
max_cost: float | None = None,
|
|
62
|
+
timeout_seconds: float | None = None,
|
|
63
|
+
iteration_timeout_seconds: float | None = None,
|
|
64
|
+
):
|
|
65
|
+
"""
|
|
66
|
+
Args:
|
|
67
|
+
backend: The backend to use for the RLM.
|
|
68
|
+
backend_kwargs: The kwargs to pass to the backend.
|
|
69
|
+
environment: The environment to use for the RLM.
|
|
70
|
+
environment_kwargs: The kwargs to pass to the environment.
|
|
71
|
+
depth: The current depth of the RLM (0-indexed).
|
|
72
|
+
max_depth: The maximum depth of the RLM. Currently, only depth 1 is supported.
|
|
73
|
+
max_iterations: The maximum number of iterations of the RLM.
|
|
74
|
+
custom_system_prompt: The custom system prompt to use for the RLM.
|
|
75
|
+
other_backends: A list of other client backends that the environments can use to make sub-calls.
|
|
76
|
+
other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends).
|
|
77
|
+
logger: The logger to use for the RLM.
|
|
78
|
+
verbose: Whether to print verbose output in rich to console.
|
|
79
|
+
persistent: If True, reuse the environment across completion() calls for multi-turn conversations.
|
|
80
|
+
max_cost: Maximum allowed cost (in USD) for a single completion. If exceeded, raises CostLimitExceededError.
|
|
81
|
+
timeout_seconds: Maximum total time (in seconds) for a single completion. If exceeded, raises CompletionTimeoutError.
|
|
82
|
+
iteration_timeout_seconds: Maximum time (in seconds) for a single iteration. If exceeded, raises CompletionTimeoutError.
|
|
83
|
+
"""
|
|
84
|
+
# Store config for spawning per-completion
|
|
85
|
+
self.backend = backend
|
|
86
|
+
self.backend_kwargs = backend_kwargs
|
|
87
|
+
self.environment_type = environment
|
|
88
|
+
self.environment_kwargs = (
|
|
89
|
+
environment_kwargs.copy() if environment_kwargs is not None else {}
|
|
90
|
+
)
|
|
91
|
+
# Validate other_backends: currently only support one additional backend
|
|
92
|
+
if other_backends is not None:
|
|
93
|
+
if len(other_backends) != 1:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"We currently only support one additional backend for the recursive sub-calls! "
|
|
96
|
+
"This model will be the model used for recursive sub-calls, but this will change in the future"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self.other_backends = other_backends
|
|
100
|
+
self.other_backend_kwargs = other_backend_kwargs
|
|
101
|
+
|
|
102
|
+
self.depth = depth
|
|
103
|
+
self.max_depth = max_depth
|
|
104
|
+
self.max_iterations = max_iterations
|
|
105
|
+
self.system_prompt = custom_system_prompt if custom_system_prompt else RLM_SYSTEM_PROMPT
|
|
106
|
+
self.logger = logger
|
|
107
|
+
self.verbose = VerbosePrinter(enabled=verbose)
|
|
108
|
+
|
|
109
|
+
# Cost and timeout limits
|
|
110
|
+
self.max_cost = max_cost
|
|
111
|
+
self.timeout_seconds = timeout_seconds
|
|
112
|
+
self.iteration_timeout_seconds = iteration_timeout_seconds
|
|
113
|
+
|
|
114
|
+
# Persistence support
|
|
115
|
+
self.persistent = persistent
|
|
116
|
+
self._persistent_env: SupportsPersistence | None = None
|
|
117
|
+
|
|
118
|
+
# Validate persistence support at initialization
|
|
119
|
+
if self.persistent:
|
|
120
|
+
self._validate_persistent_environment_support()
|
|
121
|
+
|
|
122
|
+
# Log metadata if logger is provided
|
|
123
|
+
if self.logger or verbose:
|
|
124
|
+
metadata = RLMMetadata(
|
|
125
|
+
root_model=backend_kwargs.get("model_name", "unknown")
|
|
126
|
+
if backend_kwargs
|
|
127
|
+
else "unknown",
|
|
128
|
+
max_depth=max_depth,
|
|
129
|
+
max_iterations=max_iterations,
|
|
130
|
+
backend=backend,
|
|
131
|
+
backend_kwargs=filter_sensitive_keys(backend_kwargs) if backend_kwargs else {},
|
|
132
|
+
environment_type=environment,
|
|
133
|
+
environment_kwargs=filter_sensitive_keys(environment_kwargs)
|
|
134
|
+
if environment_kwargs
|
|
135
|
+
else {},
|
|
136
|
+
other_backends=other_backends,
|
|
137
|
+
)
|
|
138
|
+
if self.logger:
|
|
139
|
+
self.logger.log_metadata(metadata)
|
|
140
|
+
self.verbose.print_metadata(metadata)
|
|
141
|
+
|
|
142
|
+
@contextmanager
|
|
143
|
+
def _spawn_completion_context(self, prompt: str | dict[str, Any]):
|
|
144
|
+
"""
|
|
145
|
+
Spawn an LM handler and environment for a single completion call.
|
|
146
|
+
|
|
147
|
+
When persistent=True, the environment is reused across calls.
|
|
148
|
+
When persistent=False (default), creates fresh environment each call.
|
|
149
|
+
"""
|
|
150
|
+
# Create client and wrap in handler
|
|
151
|
+
client: BaseLM = get_client(self.backend, self.backend_kwargs)
|
|
152
|
+
|
|
153
|
+
# Create other_backend_client if provided (for depth=1 routing)
|
|
154
|
+
other_backend_client: BaseLM | None = None
|
|
155
|
+
if self.other_backends and self.other_backend_kwargs:
|
|
156
|
+
other_backend_client = get_client(self.other_backends[0], self.other_backend_kwargs[0])
|
|
157
|
+
|
|
158
|
+
lm_handler = LMHandler(client, other_backend_client=other_backend_client)
|
|
159
|
+
|
|
160
|
+
# Register other clients to be available as sub-call options (by model name)
|
|
161
|
+
if self.other_backends and self.other_backend_kwargs:
|
|
162
|
+
for backend, kwargs in zip(self.other_backends, self.other_backend_kwargs, strict=True):
|
|
163
|
+
other_client: BaseLM = get_client(backend, kwargs)
|
|
164
|
+
lm_handler.register_client(other_client.model_name, other_client)
|
|
165
|
+
|
|
166
|
+
lm_handler.start()
|
|
167
|
+
|
|
168
|
+
# Environment: reuse if persistent, otherwise create fresh
|
|
169
|
+
if self.persistent and self._persistent_env is not None:
|
|
170
|
+
environment = self._persistent_env
|
|
171
|
+
# Defensive check: ensure environment supports persistence methods
|
|
172
|
+
if not self._env_supports_persistence(environment):
|
|
173
|
+
raise RuntimeError(
|
|
174
|
+
f"Persistent environment of type '{type(environment).__name__}' does not "
|
|
175
|
+
f"implement required methods (update_handler_address, add_context, get_context_count). "
|
|
176
|
+
f"This should have been caught at initialization."
|
|
177
|
+
)
|
|
178
|
+
environment.update_handler_address((lm_handler.host, lm_handler.port))
|
|
179
|
+
environment.add_context(prompt)
|
|
180
|
+
else:
|
|
181
|
+
env_kwargs = self.environment_kwargs.copy()
|
|
182
|
+
env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port)
|
|
183
|
+
env_kwargs["context_payload"] = prompt
|
|
184
|
+
env_kwargs["depth"] = self.depth + 1 # Environment depth is RLM depth + 1
|
|
185
|
+
environment: BaseEnv = get_environment(self.environment_type, env_kwargs)
|
|
186
|
+
|
|
187
|
+
if self.persistent:
|
|
188
|
+
self._persistent_env = environment
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
yield lm_handler, environment
|
|
192
|
+
finally:
|
|
193
|
+
lm_handler.stop()
|
|
194
|
+
if not self.persistent and hasattr(environment, "cleanup"):
|
|
195
|
+
environment.cleanup()
|
|
196
|
+
|
|
197
|
+
def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
|
|
198
|
+
"""
|
|
199
|
+
Setup the system prompt for the RLM. Also include metadata about the prompt and build
|
|
200
|
+
up the initial message history.
|
|
201
|
+
"""
|
|
202
|
+
metadata = QueryMetadata(prompt)
|
|
203
|
+
message_history = build_rlm_system_prompt(
|
|
204
|
+
system_prompt=self.system_prompt, query_metadata=metadata
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return message_history
|
|
208
|
+
|
|
209
|
+
def completion(
|
|
210
|
+
self, prompt: str | dict[str, Any], root_prompt: str | None = None
|
|
211
|
+
) -> RLMChatCompletion:
|
|
212
|
+
"""
|
|
213
|
+
Recursive Language Model completion call. This is the main entry point for querying an RLM, and
|
|
214
|
+
can replace a regular LM completion call.
|
|
215
|
+
|
|
216
|
+
Spawns its own environment and LM handler for the duration of this call.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
prompt: A single string or dictionary of messages to pass as context to the model.
|
|
220
|
+
root_prompt: We allow the RLM's root LM to see a (small) prompt that the user specifies. A common example of this
|
|
221
|
+
is if the user is asking the RLM to answer a question, we can pass the question as the root prompt.
|
|
222
|
+
Returns:
|
|
223
|
+
A final answer as a string.
|
|
224
|
+
"""
|
|
225
|
+
time_start = time.perf_counter()
|
|
226
|
+
|
|
227
|
+
# If we're at max depth, the RLM is an LM, so we fallback to the regular LM.
|
|
228
|
+
if self.depth >= self.max_depth:
|
|
229
|
+
return self._fallback_answer(prompt)
|
|
230
|
+
|
|
231
|
+
with self._spawn_completion_context(prompt) as (lm_handler, environment):
|
|
232
|
+
message_history = self._setup_prompt(prompt)
|
|
233
|
+
|
|
234
|
+
for i in range(self.max_iterations):
|
|
235
|
+
# Check timeout
|
|
236
|
+
if self.timeout_seconds is not None:
|
|
237
|
+
elapsed = time.perf_counter() - time_start
|
|
238
|
+
if elapsed > self.timeout_seconds:
|
|
239
|
+
raise CompletionTimeoutError(
|
|
240
|
+
f"Completion exceeded timeout of {self.timeout_seconds}s",
|
|
241
|
+
timeout_seconds=self.timeout_seconds,
|
|
242
|
+
elapsed_seconds=elapsed,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Check cost limit
|
|
246
|
+
if self.max_cost is not None:
|
|
247
|
+
current_cost = lm_handler.get_usage_summary().total_cost
|
|
248
|
+
if current_cost > self.max_cost:
|
|
249
|
+
raise CostLimitExceededError(
|
|
250
|
+
f"Completion cost ${current_cost:.4f} exceeded limit ${self.max_cost:.4f}",
|
|
251
|
+
current_cost=current_cost,
|
|
252
|
+
cost_limit=self.max_cost,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Current prompt = message history + additional prompt suffix
|
|
256
|
+
context_count = (
|
|
257
|
+
environment.get_context_count()
|
|
258
|
+
if isinstance(environment, SupportsPersistence)
|
|
259
|
+
else 1
|
|
260
|
+
)
|
|
261
|
+
history_count = (
|
|
262
|
+
environment.get_history_count()
|
|
263
|
+
if isinstance(environment, SupportsPersistence)
|
|
264
|
+
else 0
|
|
265
|
+
)
|
|
266
|
+
current_prompt = message_history + [
|
|
267
|
+
build_user_prompt(root_prompt, i, context_count, history_count)
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
iteration: RLMIteration = self._completion_turn(
|
|
271
|
+
prompt=current_prompt,
|
|
272
|
+
lm_handler=lm_handler,
|
|
273
|
+
environment=environment,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Check if RLM is done and has a final answer.
|
|
277
|
+
final_answer = find_final_answer(iteration.response, environment=environment)
|
|
278
|
+
iteration.final_answer = final_answer
|
|
279
|
+
|
|
280
|
+
# If logger is used, log the iteration.
|
|
281
|
+
if self.logger:
|
|
282
|
+
self.logger.log(iteration)
|
|
283
|
+
|
|
284
|
+
# Verbose output for this iteration
|
|
285
|
+
self.verbose.print_iteration(iteration, i + 1)
|
|
286
|
+
|
|
287
|
+
if final_answer is not None:
|
|
288
|
+
time_end = time.perf_counter()
|
|
289
|
+
usage = lm_handler.get_usage_summary()
|
|
290
|
+
self.verbose.print_final_answer(final_answer)
|
|
291
|
+
self.verbose.print_summary(i + 1, time_end - time_start, usage.to_dict())
|
|
292
|
+
|
|
293
|
+
# Store message history in persistent environment
|
|
294
|
+
if self.persistent and isinstance(environment, SupportsPersistence):
|
|
295
|
+
environment.add_history(message_history)
|
|
296
|
+
|
|
297
|
+
return RLMChatCompletion(
|
|
298
|
+
root_model=self.backend_kwargs.get("model_name", "unknown")
|
|
299
|
+
if self.backend_kwargs
|
|
300
|
+
else "unknown",
|
|
301
|
+
prompt=prompt,
|
|
302
|
+
response=final_answer,
|
|
303
|
+
usage_summary=usage,
|
|
304
|
+
execution_time=time_end - time_start,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Format the iteration for the next prompt.
|
|
308
|
+
new_messages = format_iteration(iteration)
|
|
309
|
+
|
|
310
|
+
# Update message history with the new messages.
|
|
311
|
+
message_history.extend(new_messages)
|
|
312
|
+
|
|
313
|
+
# Default behavior: we run out of iterations, provide one final answer
|
|
314
|
+
time_end = time.perf_counter()
|
|
315
|
+
final_answer = self._default_answer(message_history, lm_handler)
|
|
316
|
+
usage = lm_handler.get_usage_summary()
|
|
317
|
+
self.verbose.print_final_answer(final_answer)
|
|
318
|
+
self.verbose.print_summary(self.max_iterations, time_end - time_start, usage.to_dict())
|
|
319
|
+
|
|
320
|
+
# Store message history in persistent environment
|
|
321
|
+
if self.persistent and isinstance(environment, SupportsPersistence):
|
|
322
|
+
environment.add_history(message_history)
|
|
323
|
+
|
|
324
|
+
return RLMChatCompletion(
|
|
325
|
+
root_model=self.backend_kwargs.get("model_name", "unknown")
|
|
326
|
+
if self.backend_kwargs
|
|
327
|
+
else "unknown",
|
|
328
|
+
prompt=prompt,
|
|
329
|
+
response=final_answer,
|
|
330
|
+
usage_summary=usage,
|
|
331
|
+
execution_time=time_end - time_start,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def _completion_turn(
|
|
335
|
+
self,
|
|
336
|
+
prompt: str | dict[str, Any],
|
|
337
|
+
lm_handler: LMHandler,
|
|
338
|
+
environment: BaseEnv,
|
|
339
|
+
) -> RLMIteration:
|
|
340
|
+
"""
|
|
341
|
+
Perform a single iteration of the RLM, including prompting the model
|
|
342
|
+
and code execution + tool execution.
|
|
343
|
+
"""
|
|
344
|
+
iter_start = time.perf_counter()
|
|
345
|
+
response = lm_handler.completion(prompt)
|
|
346
|
+
code_block_strs = find_code_blocks(response)
|
|
347
|
+
code_blocks = []
|
|
348
|
+
|
|
349
|
+
for code_block_str in code_block_strs:
|
|
350
|
+
# Check iteration timeout
|
|
351
|
+
if self.iteration_timeout_seconds is not None:
|
|
352
|
+
elapsed = time.perf_counter() - iter_start
|
|
353
|
+
if elapsed > self.iteration_timeout_seconds:
|
|
354
|
+
raise CompletionTimeoutError(
|
|
355
|
+
f"Iteration exceeded timeout of {self.iteration_timeout_seconds}s",
|
|
356
|
+
timeout_seconds=self.iteration_timeout_seconds,
|
|
357
|
+
elapsed_seconds=elapsed,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
code_result: REPLResult = environment.execute_code(code_block_str)
|
|
361
|
+
code_blocks.append(CodeBlock(code=code_block_str, result=code_result))
|
|
362
|
+
|
|
363
|
+
iteration_time = time.perf_counter() - iter_start
|
|
364
|
+
return RLMIteration(
|
|
365
|
+
prompt=prompt,
|
|
366
|
+
response=response,
|
|
367
|
+
code_blocks=code_blocks,
|
|
368
|
+
iteration_time=iteration_time,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
def _default_answer(self, message_history: list[dict[str, Any]], lm_handler: LMHandler) -> str:
|
|
372
|
+
"""
|
|
373
|
+
Default behavior if the RLM runs out of iterations and does not find a final answer.
|
|
374
|
+
It will take the message history, and try to generate a final answer from it.
|
|
375
|
+
"""
|
|
376
|
+
current_prompt = message_history + [
|
|
377
|
+
{
|
|
378
|
+
"role": "assistant",
|
|
379
|
+
"content": "Please provide a final answer to the user's question based on the information provided.",
|
|
380
|
+
}
|
|
381
|
+
]
|
|
382
|
+
response = lm_handler.completion(current_prompt)
|
|
383
|
+
|
|
384
|
+
if self.logger:
|
|
385
|
+
self.logger.log(
|
|
386
|
+
RLMIteration(
|
|
387
|
+
prompt=current_prompt,
|
|
388
|
+
response=response,
|
|
389
|
+
final_answer=response,
|
|
390
|
+
code_blocks=[],
|
|
391
|
+
)
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
return response
|
|
395
|
+
|
|
396
|
+
def _fallback_answer(self, message: str | dict[str, Any]) -> str:
|
|
397
|
+
"""
|
|
398
|
+
Fallback behavior if the RLM is actually at max depth, and should be treated as an LM.
|
|
399
|
+
"""
|
|
400
|
+
client: BaseLM = get_client(self.backend, self.backend_kwargs)
|
|
401
|
+
response = client.completion(message)
|
|
402
|
+
return response
|
|
403
|
+
|
|
404
|
+
def _validate_persistent_environment_support(self) -> None:
|
|
405
|
+
"""
|
|
406
|
+
Validate that the configured environment type supports persistent mode.
|
|
407
|
+
|
|
408
|
+
Persistent mode requires environments to implement:
|
|
409
|
+
- update_handler_address(address): Update LM handler address between calls
|
|
410
|
+
- add_context(payload, index): Add new context for multi-turn conversations
|
|
411
|
+
- get_context_count(): Return the number of loaded contexts
|
|
412
|
+
|
|
413
|
+
Currently only 'local' (LocalREPL) supports these methods.
|
|
414
|
+
|
|
415
|
+
Raises:
|
|
416
|
+
ValueError: If the environment type does not support persistent mode.
|
|
417
|
+
"""
|
|
418
|
+
# Known environments that support persistence
|
|
419
|
+
persistent_supported_environments = {"local"}
|
|
420
|
+
|
|
421
|
+
if self.environment_type not in persistent_supported_environments:
|
|
422
|
+
raise ValueError(
|
|
423
|
+
f"persistent=True is not supported for environment type '{self.environment_type}'. "
|
|
424
|
+
f"Persistent mode requires environments that implement update_handler_address(), "
|
|
425
|
+
f"add_context(), and get_context_count(). "
|
|
426
|
+
f"Supported environments: {sorted(persistent_supported_environments)}"
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
@staticmethod
|
|
430
|
+
def _env_supports_persistence(env: BaseEnv) -> bool:
|
|
431
|
+
"""Check if an environment instance supports persistent mode methods."""
|
|
432
|
+
return isinstance(env, SupportsPersistence)
|
|
433
|
+
|
|
434
|
+
def close(self) -> None:
|
|
435
|
+
"""Clean up persistent environment. Call when done with multi-turn conversations."""
|
|
436
|
+
if self._persistent_env is not None:
|
|
437
|
+
if hasattr(self._persistent_env, "cleanup"):
|
|
438
|
+
self._persistent_env.cleanup()
|
|
439
|
+
self._persistent_env = None
|
|
440
|
+
|
|
441
|
+
def __enter__(self) -> "RLM":
|
|
442
|
+
return self
|
|
443
|
+
|
|
444
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
|
|
445
|
+
self.close()
|
|
446
|
+
return False
|