groknroll 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. groknroll/__init__.py +36 -0
  2. groknroll/__main__.py +9 -0
  3. groknroll/agents/__init__.py +18 -0
  4. groknroll/agents/agent_manager.py +187 -0
  5. groknroll/agents/base_agent.py +118 -0
  6. groknroll/agents/build_agent.py +231 -0
  7. groknroll/agents/plan_agent.py +215 -0
  8. groknroll/cli/__init__.py +7 -0
  9. groknroll/cli/enhanced_cli.py +372 -0
  10. groknroll/cli/large_codebase_cli.py +413 -0
  11. groknroll/cli/main.py +331 -0
  12. groknroll/cli/rlm_commands.py +258 -0
  13. groknroll/clients/__init__.py +63 -0
  14. groknroll/clients/anthropic.py +112 -0
  15. groknroll/clients/azure_openai.py +142 -0
  16. groknroll/clients/base_lm.py +33 -0
  17. groknroll/clients/gemini.py +162 -0
  18. groknroll/clients/litellm.py +105 -0
  19. groknroll/clients/openai.py +129 -0
  20. groknroll/clients/portkey.py +94 -0
  21. groknroll/core/__init__.py +9 -0
  22. groknroll/core/agent.py +339 -0
  23. groknroll/core/comms_utils.py +264 -0
  24. groknroll/core/context.py +251 -0
  25. groknroll/core/exceptions.py +181 -0
  26. groknroll/core/large_codebase.py +564 -0
  27. groknroll/core/lm_handler.py +206 -0
  28. groknroll/core/rlm.py +446 -0
  29. groknroll/core/rlm_codebase.py +448 -0
  30. groknroll/core/rlm_integration.py +256 -0
  31. groknroll/core/types.py +276 -0
  32. groknroll/environments/__init__.py +34 -0
  33. groknroll/environments/base_env.py +182 -0
  34. groknroll/environments/constants.py +32 -0
  35. groknroll/environments/docker_repl.py +336 -0
  36. groknroll/environments/local_repl.py +388 -0
  37. groknroll/environments/modal_repl.py +502 -0
  38. groknroll/environments/prime_repl.py +588 -0
  39. groknroll/logger/__init__.py +4 -0
  40. groknroll/logger/rlm_logger.py +63 -0
  41. groknroll/logger/verbose.py +393 -0
  42. groknroll/operations/__init__.py +15 -0
  43. groknroll/operations/bash_ops.py +447 -0
  44. groknroll/operations/file_ops.py +473 -0
  45. groknroll/operations/git_ops.py +620 -0
  46. groknroll/oracle/__init__.py +11 -0
  47. groknroll/oracle/codebase_indexer.py +238 -0
  48. groknroll/oracle/oracle_agent.py +278 -0
  49. groknroll/setup.py +34 -0
  50. groknroll/storage/__init__.py +14 -0
  51. groknroll/storage/database.py +272 -0
  52. groknroll/storage/models.py +128 -0
  53. groknroll/utils/__init__.py +0 -0
  54. groknroll/utils/parsing.py +168 -0
  55. groknroll/utils/prompts.py +146 -0
  56. groknroll/utils/rlm_utils.py +19 -0
  57. groknroll-2.0.0.dist-info/METADATA +246 -0
  58. groknroll-2.0.0.dist-info/RECORD +62 -0
  59. groknroll-2.0.0.dist-info/WHEEL +5 -0
  60. groknroll-2.0.0.dist-info/entry_points.txt +3 -0
  61. groknroll-2.0.0.dist-info/licenses/LICENSE +21 -0
  62. groknroll-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,206 @@
1
+ """
2
+ LMHandler - Routes LLM requests from the RLM process and environment subprocesses.
3
+
4
+ Uses a multi-threaded socket server. Protocol: 4-byte length prefix + JSON payload.
5
+ """
6
+
7
+ import asyncio
8
+ import time
9
+ from socketserver import StreamRequestHandler, ThreadingTCPServer
10
+ from threading import Thread
11
+
12
+ from groknroll.clients.base_lm import BaseLM
13
+ from groknroll.core.comms_utils import LMRequest, LMResponse, socket_recv, socket_send
14
+ from groknroll.core.types import RLMChatCompletion, UsageSummary
15
+
16
+
17
+ class LMRequestHandler(StreamRequestHandler):
18
+ """Socket handler for LLM completion requests."""
19
+
20
+ def handle(self):
21
+ try:
22
+ request_data = socket_recv(self.connection)
23
+ if not isinstance(request_data, dict):
24
+ response = LMResponse.error_response("Request must be a JSON object")
25
+ socket_send(self.connection, response.to_dict())
26
+ return
27
+
28
+ request = LMRequest.from_dict(request_data)
29
+ handler: LMHandler = self.server.lm_handler # type: ignore
30
+
31
+ if request.is_batched:
32
+ # Batched request: process multiple prompts concurrently
33
+ response = self._handle_batched(request, handler)
34
+ elif request.prompt:
35
+ # Single request: process one prompt
36
+ response = self._handle_single(request, handler)
37
+ else:
38
+ response = LMResponse.error_response("Missing 'prompt' or 'prompts' in request.")
39
+
40
+ socket_send(self.connection, response.to_dict())
41
+
42
+ except Exception as e:
43
+ response = LMResponse.error_response(str(e))
44
+ socket_send(self.connection, response.to_dict())
45
+
46
+ def _handle_single(self, request: LMRequest, handler: "LMHandler") -> LMResponse:
47
+ """Handle a single prompt request."""
48
+ client = handler.get_client(request.model, request.depth)
49
+
50
+ start_time = time.perf_counter()
51
+ content = client.completion(request.prompt)
52
+ end_time = time.perf_counter()
53
+
54
+ usage_summary = client.get_last_usage()
55
+ return LMResponse.success_response(
56
+ chat_completion=RLMChatCompletion(
57
+ root_model=request.model or client.model_name,
58
+ prompt=request.prompt,
59
+ response=content,
60
+ usage_summary=usage_summary,
61
+ execution_time=end_time - start_time,
62
+ )
63
+ )
64
+
65
+ def _handle_batched(self, request: LMRequest, handler: "LMHandler") -> LMResponse:
66
+ """Handle a batched prompts request using async for concurrency."""
67
+ client = handler.get_client(request.model, request.depth)
68
+
69
+ start_time = time.perf_counter()
70
+
71
+ async def run_all():
72
+ tasks = [client.acompletion(prompt) for prompt in request.prompts]
73
+ return await asyncio.gather(*tasks)
74
+
75
+ results = asyncio.run(run_all())
76
+ end_time = time.perf_counter()
77
+
78
+ total_time = end_time - start_time
79
+ usage_summary = client.get_last_usage()
80
+
81
+ chat_completions = [
82
+ RLMChatCompletion(
83
+ root_model=request.model or client.model_name,
84
+ prompt=prompt,
85
+ response=content,
86
+ usage_summary=usage_summary,
87
+ execution_time=total_time / len(request.prompts), # approximate per-prompt time
88
+ )
89
+ for prompt, content in zip(request.prompts, results, strict=True)
90
+ ]
91
+
92
+ return LMResponse.batched_success_response(chat_completions=chat_completions)
93
+
94
+
95
+ class ThreadingLMServer(ThreadingTCPServer):
96
+ """Multi-threaded TCP server for LM requests."""
97
+
98
+ daemon_threads = True
99
+ allow_reuse_address = True
100
+
101
+
102
+ class LMHandler:
103
+ """
104
+ Handles all LM calls from the RLM main process and environment subprocesses.
105
+
106
+ Uses a multi-threaded socket server for concurrent requests.
107
+ Protocol: 4-byte big-endian length prefix + JSON payload.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ client: BaseLM,
113
+ host: str = "127.0.0.1",
114
+ port: int = 0, # auto-assign available port
115
+ other_backend_client: BaseLM | None = None,
116
+ ):
117
+ self.default_client = client
118
+ self.other_backend_client = other_backend_client
119
+ self.clients: dict[str, BaseLM] = {}
120
+ self.host = host
121
+ self._server: ThreadingLMServer | None = None
122
+ self._thread: Thread | None = None
123
+ self._port = port
124
+
125
+ self.register_client(client.model_name, client)
126
+
127
+ def register_client(self, model_name: str, client: BaseLM) -> None:
128
+ """Register a client for a specific model name."""
129
+ self.clients[model_name] = client
130
+
131
+ def get_client(self, model: str | None = None, depth: int = 0) -> BaseLM:
132
+ """Get client by model name or depth, or return default.
133
+
134
+ Routing logic:
135
+ - depth=0: use default_client (main backend)
136
+ - depth=1: use other_backend_client if it exists, otherwise default_client
137
+ - If model is specified and exists in clients, use that (overrides depth routing)
138
+ """
139
+ if model and model in self.clients:
140
+ return self.clients[model]
141
+
142
+ # Route based on depth
143
+ if depth == 1 and self.other_backend_client is not None:
144
+ return self.other_backend_client
145
+
146
+ return self.default_client
147
+
148
+ @property
149
+ def port(self) -> int:
150
+ """Get the actual port (useful when auto-assigned)."""
151
+ if self._server:
152
+ return self._server.server_address[1]
153
+ return self._port
154
+
155
+ @property
156
+ def address(self) -> tuple[str, int]:
157
+ """Get (host, port) tuple for connecting."""
158
+ return (self.host, self.port)
159
+
160
+ def start(self) -> tuple[str, int]:
161
+ """Start the socket server in a background thread. Returns (host, port)."""
162
+ if self._server is not None:
163
+ return self.address
164
+
165
+ self._server = ThreadingLMServer((self.host, self._port), LMRequestHandler)
166
+ self._server.lm_handler = self # type: ignore
167
+
168
+ self._thread = Thread(target=self._server.serve_forever, daemon=True)
169
+ self._thread.start()
170
+
171
+ return self.address
172
+
173
+ def stop(self):
174
+ """Stop the socket server."""
175
+ if self._server:
176
+ self._server.shutdown()
177
+ self._server = None
178
+ self._thread = None
179
+
180
+ def completion(self, prompt: str, model: str | None = None) -> str:
181
+ """Direct completion call (for main process use)."""
182
+ return self.get_client(model).completion(prompt)
183
+
184
+ def __enter__(self):
185
+ self.start()
186
+ return self
187
+
188
+ def __exit__(self, exc_type, exc_val, exc_tb):
189
+ self.stop()
190
+ return False
191
+
192
+ def get_usage_summary(self) -> UsageSummary:
193
+ """Get the usage summary for all clients, merged into a single dict."""
194
+ merged = {}
195
+ # Include default client
196
+ default_summary = self.default_client.get_usage_summary()
197
+ merged.update(default_summary.model_usage_summaries)
198
+ # Include other backend client if it exists
199
+ if self.other_backend_client is not None:
200
+ other_summary = self.other_backend_client.get_usage_summary()
201
+ merged.update(other_summary.model_usage_summaries)
202
+ # Include all registered clients
203
+ for client in self.clients.values():
204
+ client_summary = client.get_usage_summary()
205
+ merged.update(client_summary.model_usage_summaries)
206
+ return UsageSummary(model_usage_summaries=merged)
groknroll/core/rlm.py ADDED
@@ -0,0 +1,446 @@
1
+ import time
2
+ from contextlib import contextmanager
3
+ from typing import Any
4
+
5
+ from groknroll.clients import BaseLM, get_client
6
+ from groknroll.core.exceptions import (
7
+ CompletionTimeoutError,
8
+ CostLimitExceededError,
9
+ FinalAnswerNotFoundError,
10
+ IterationLimitExceededError,
11
+ )
12
+ from groknroll.core.lm_handler import LMHandler
13
+ from groknroll.core.types import (
14
+ ClientBackend,
15
+ CodeBlock,
16
+ EnvironmentType,
17
+ REPLResult,
18
+ RLMChatCompletion,
19
+ RLMIteration,
20
+ RLMMetadata,
21
+ )
22
+ from groknroll.environments import BaseEnv, SupportsPersistence, get_environment
23
+ from groknroll.logger import RLMLogger, VerbosePrinter
24
+ from groknroll.utils.parsing import (
25
+ find_code_blocks,
26
+ find_final_answer,
27
+ format_iteration,
28
+ )
29
+ from groknroll.utils.prompts import (
30
+ RLM_SYSTEM_PROMPT,
31
+ QueryMetadata,
32
+ build_rlm_system_prompt,
33
+ build_user_prompt,
34
+ )
35
+ from groknroll.utils.rlm_utils import filter_sensitive_keys
36
+
37
+
38
+ class RLM:
39
+ """
40
+ Recursive Language Model class that the user instantiates and runs on their tasks.
41
+
42
+ Each completion() call spawns its own environment and LM handler, which are
43
+ cleaned up when the call completes.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ backend: ClientBackend = "openai",
49
+ backend_kwargs: dict[str, Any] | None = None,
50
+ environment: EnvironmentType = "local",
51
+ environment_kwargs: dict[str, Any] | None = None,
52
+ depth: int = 0,
53
+ max_depth: int = 1,
54
+ max_iterations: int = 30,
55
+ custom_system_prompt: str | None = None,
56
+ other_backends: list[ClientBackend] | None = None,
57
+ other_backend_kwargs: list[dict[str, Any]] | None = None,
58
+ logger: RLMLogger | None = None,
59
+ verbose: bool = False,
60
+ persistent: bool = False,
61
+ max_cost: float | None = None,
62
+ timeout_seconds: float | None = None,
63
+ iteration_timeout_seconds: float | None = None,
64
+ ):
65
+ """
66
+ Args:
67
+ backend: The backend to use for the RLM.
68
+ backend_kwargs: The kwargs to pass to the backend.
69
+ environment: The environment to use for the RLM.
70
+ environment_kwargs: The kwargs to pass to the environment.
71
+ depth: The current depth of the RLM (0-indexed).
72
+ max_depth: The maximum depth of the RLM. Currently, only depth 1 is supported.
73
+ max_iterations: The maximum number of iterations of the RLM.
74
+ custom_system_prompt: The custom system prompt to use for the RLM.
75
+ other_backends: A list of other client backends that the environments can use to make sub-calls.
76
+ other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends).
77
+ logger: The logger to use for the RLM.
78
+ verbose: Whether to print verbose output in rich to console.
79
+ persistent: If True, reuse the environment across completion() calls for multi-turn conversations.
80
+ max_cost: Maximum allowed cost (in USD) for a single completion. If exceeded, raises CostLimitExceededError.
81
+ timeout_seconds: Maximum total time (in seconds) for a single completion. If exceeded, raises CompletionTimeoutError.
82
+ iteration_timeout_seconds: Maximum time (in seconds) for a single iteration. If exceeded, raises CompletionTimeoutError.
83
+ """
84
+ # Store config for spawning per-completion
85
+ self.backend = backend
86
+ self.backend_kwargs = backend_kwargs
87
+ self.environment_type = environment
88
+ self.environment_kwargs = (
89
+ environment_kwargs.copy() if environment_kwargs is not None else {}
90
+ )
91
+ # Validate other_backends: currently only support one additional backend
92
+ if other_backends is not None:
93
+ if len(other_backends) != 1:
94
+ raise ValueError(
95
+ "We currently only support one additional backend for the recursive sub-calls! "
96
+ "This model will be the model used for recursive sub-calls, but this will change in the future"
97
+ )
98
+
99
+ self.other_backends = other_backends
100
+ self.other_backend_kwargs = other_backend_kwargs
101
+
102
+ self.depth = depth
103
+ self.max_depth = max_depth
104
+ self.max_iterations = max_iterations
105
+ self.system_prompt = custom_system_prompt if custom_system_prompt else RLM_SYSTEM_PROMPT
106
+ self.logger = logger
107
+ self.verbose = VerbosePrinter(enabled=verbose)
108
+
109
+ # Cost and timeout limits
110
+ self.max_cost = max_cost
111
+ self.timeout_seconds = timeout_seconds
112
+ self.iteration_timeout_seconds = iteration_timeout_seconds
113
+
114
+ # Persistence support
115
+ self.persistent = persistent
116
+ self._persistent_env: SupportsPersistence | None = None
117
+
118
+ # Validate persistence support at initialization
119
+ if self.persistent:
120
+ self._validate_persistent_environment_support()
121
+
122
+ # Log metadata if logger is provided
123
+ if self.logger or verbose:
124
+ metadata = RLMMetadata(
125
+ root_model=backend_kwargs.get("model_name", "unknown")
126
+ if backend_kwargs
127
+ else "unknown",
128
+ max_depth=max_depth,
129
+ max_iterations=max_iterations,
130
+ backend=backend,
131
+ backend_kwargs=filter_sensitive_keys(backend_kwargs) if backend_kwargs else {},
132
+ environment_type=environment,
133
+ environment_kwargs=filter_sensitive_keys(environment_kwargs)
134
+ if environment_kwargs
135
+ else {},
136
+ other_backends=other_backends,
137
+ )
138
+ if self.logger:
139
+ self.logger.log_metadata(metadata)
140
+ self.verbose.print_metadata(metadata)
141
+
142
+ @contextmanager
143
+ def _spawn_completion_context(self, prompt: str | dict[str, Any]):
144
+ """
145
+ Spawn an LM handler and environment for a single completion call.
146
+
147
+ When persistent=True, the environment is reused across calls.
148
+ When persistent=False (default), creates fresh environment each call.
149
+ """
150
+ # Create client and wrap in handler
151
+ client: BaseLM = get_client(self.backend, self.backend_kwargs)
152
+
153
+ # Create other_backend_client if provided (for depth=1 routing)
154
+ other_backend_client: BaseLM | None = None
155
+ if self.other_backends and self.other_backend_kwargs:
156
+ other_backend_client = get_client(self.other_backends[0], self.other_backend_kwargs[0])
157
+
158
+ lm_handler = LMHandler(client, other_backend_client=other_backend_client)
159
+
160
+ # Register other clients to be available as sub-call options (by model name)
161
+ if self.other_backends and self.other_backend_kwargs:
162
+ for backend, kwargs in zip(self.other_backends, self.other_backend_kwargs, strict=True):
163
+ other_client: BaseLM = get_client(backend, kwargs)
164
+ lm_handler.register_client(other_client.model_name, other_client)
165
+
166
+ lm_handler.start()
167
+
168
+ # Environment: reuse if persistent, otherwise create fresh
169
+ if self.persistent and self._persistent_env is not None:
170
+ environment = self._persistent_env
171
+ # Defensive check: ensure environment supports persistence methods
172
+ if not self._env_supports_persistence(environment):
173
+ raise RuntimeError(
174
+ f"Persistent environment of type '{type(environment).__name__}' does not "
175
+ f"implement required methods (update_handler_address, add_context, get_context_count). "
176
+ f"This should have been caught at initialization."
177
+ )
178
+ environment.update_handler_address((lm_handler.host, lm_handler.port))
179
+ environment.add_context(prompt)
180
+ else:
181
+ env_kwargs = self.environment_kwargs.copy()
182
+ env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port)
183
+ env_kwargs["context_payload"] = prompt
184
+ env_kwargs["depth"] = self.depth + 1 # Environment depth is RLM depth + 1
185
+ environment: BaseEnv = get_environment(self.environment_type, env_kwargs)
186
+
187
+ if self.persistent:
188
+ self._persistent_env = environment
189
+
190
+ try:
191
+ yield lm_handler, environment
192
+ finally:
193
+ lm_handler.stop()
194
+ if not self.persistent and hasattr(environment, "cleanup"):
195
+ environment.cleanup()
196
+
197
+ def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
198
+ """
199
+ Setup the system prompt for the RLM. Also include metadata about the prompt and build
200
+ up the initial message history.
201
+ """
202
+ metadata = QueryMetadata(prompt)
203
+ message_history = build_rlm_system_prompt(
204
+ system_prompt=self.system_prompt, query_metadata=metadata
205
+ )
206
+
207
+ return message_history
208
+
209
+ def completion(
210
+ self, prompt: str | dict[str, Any], root_prompt: str | None = None
211
+ ) -> RLMChatCompletion:
212
+ """
213
+ Recursive Language Model completion call. This is the main entry point for querying an RLM, and
214
+ can replace a regular LM completion call.
215
+
216
+ Spawns its own environment and LM handler for the duration of this call.
217
+
218
+ Args:
219
+ prompt: A single string or dictionary of messages to pass as context to the model.
220
+ root_prompt: We allow the RLM's root LM to see a (small) prompt that the user specifies. A common example of this
221
+ is if the user is asking the RLM to answer a question, we can pass the question as the root prompt.
222
+ Returns:
223
+ A final answer as a string.
224
+ """
225
+ time_start = time.perf_counter()
226
+
227
+ # If we're at max depth, the RLM is an LM, so we fallback to the regular LM.
228
+ if self.depth >= self.max_depth:
229
+ return self._fallback_answer(prompt)
230
+
231
+ with self._spawn_completion_context(prompt) as (lm_handler, environment):
232
+ message_history = self._setup_prompt(prompt)
233
+
234
+ for i in range(self.max_iterations):
235
+ # Check timeout
236
+ if self.timeout_seconds is not None:
237
+ elapsed = time.perf_counter() - time_start
238
+ if elapsed > self.timeout_seconds:
239
+ raise CompletionTimeoutError(
240
+ f"Completion exceeded timeout of {self.timeout_seconds}s",
241
+ timeout_seconds=self.timeout_seconds,
242
+ elapsed_seconds=elapsed,
243
+ )
244
+
245
+ # Check cost limit
246
+ if self.max_cost is not None:
247
+ current_cost = lm_handler.get_usage_summary().total_cost
248
+ if current_cost > self.max_cost:
249
+ raise CostLimitExceededError(
250
+ f"Completion cost ${current_cost:.4f} exceeded limit ${self.max_cost:.4f}",
251
+ current_cost=current_cost,
252
+ cost_limit=self.max_cost,
253
+ )
254
+
255
+ # Current prompt = message history + additional prompt suffix
256
+ context_count = (
257
+ environment.get_context_count()
258
+ if isinstance(environment, SupportsPersistence)
259
+ else 1
260
+ )
261
+ history_count = (
262
+ environment.get_history_count()
263
+ if isinstance(environment, SupportsPersistence)
264
+ else 0
265
+ )
266
+ current_prompt = message_history + [
267
+ build_user_prompt(root_prompt, i, context_count, history_count)
268
+ ]
269
+
270
+ iteration: RLMIteration = self._completion_turn(
271
+ prompt=current_prompt,
272
+ lm_handler=lm_handler,
273
+ environment=environment,
274
+ )
275
+
276
+ # Check if RLM is done and has a final answer.
277
+ final_answer = find_final_answer(iteration.response, environment=environment)
278
+ iteration.final_answer = final_answer
279
+
280
+ # If logger is used, log the iteration.
281
+ if self.logger:
282
+ self.logger.log(iteration)
283
+
284
+ # Verbose output for this iteration
285
+ self.verbose.print_iteration(iteration, i + 1)
286
+
287
+ if final_answer is not None:
288
+ time_end = time.perf_counter()
289
+ usage = lm_handler.get_usage_summary()
290
+ self.verbose.print_final_answer(final_answer)
291
+ self.verbose.print_summary(i + 1, time_end - time_start, usage.to_dict())
292
+
293
+ # Store message history in persistent environment
294
+ if self.persistent and isinstance(environment, SupportsPersistence):
295
+ environment.add_history(message_history)
296
+
297
+ return RLMChatCompletion(
298
+ root_model=self.backend_kwargs.get("model_name", "unknown")
299
+ if self.backend_kwargs
300
+ else "unknown",
301
+ prompt=prompt,
302
+ response=final_answer,
303
+ usage_summary=usage,
304
+ execution_time=time_end - time_start,
305
+ )
306
+
307
+ # Format the iteration for the next prompt.
308
+ new_messages = format_iteration(iteration)
309
+
310
+ # Update message history with the new messages.
311
+ message_history.extend(new_messages)
312
+
313
+ # Default behavior: we run out of iterations, provide one final answer
314
+ time_end = time.perf_counter()
315
+ final_answer = self._default_answer(message_history, lm_handler)
316
+ usage = lm_handler.get_usage_summary()
317
+ self.verbose.print_final_answer(final_answer)
318
+ self.verbose.print_summary(self.max_iterations, time_end - time_start, usage.to_dict())
319
+
320
+ # Store message history in persistent environment
321
+ if self.persistent and isinstance(environment, SupportsPersistence):
322
+ environment.add_history(message_history)
323
+
324
+ return RLMChatCompletion(
325
+ root_model=self.backend_kwargs.get("model_name", "unknown")
326
+ if self.backend_kwargs
327
+ else "unknown",
328
+ prompt=prompt,
329
+ response=final_answer,
330
+ usage_summary=usage,
331
+ execution_time=time_end - time_start,
332
+ )
333
+
334
+ def _completion_turn(
335
+ self,
336
+ prompt: str | dict[str, Any],
337
+ lm_handler: LMHandler,
338
+ environment: BaseEnv,
339
+ ) -> RLMIteration:
340
+ """
341
+ Perform a single iteration of the RLM, including prompting the model
342
+ and code execution + tool execution.
343
+ """
344
+ iter_start = time.perf_counter()
345
+ response = lm_handler.completion(prompt)
346
+ code_block_strs = find_code_blocks(response)
347
+ code_blocks = []
348
+
349
+ for code_block_str in code_block_strs:
350
+ # Check iteration timeout
351
+ if self.iteration_timeout_seconds is not None:
352
+ elapsed = time.perf_counter() - iter_start
353
+ if elapsed > self.iteration_timeout_seconds:
354
+ raise CompletionTimeoutError(
355
+ f"Iteration exceeded timeout of {self.iteration_timeout_seconds}s",
356
+ timeout_seconds=self.iteration_timeout_seconds,
357
+ elapsed_seconds=elapsed,
358
+ )
359
+
360
+ code_result: REPLResult = environment.execute_code(code_block_str)
361
+ code_blocks.append(CodeBlock(code=code_block_str, result=code_result))
362
+
363
+ iteration_time = time.perf_counter() - iter_start
364
+ return RLMIteration(
365
+ prompt=prompt,
366
+ response=response,
367
+ code_blocks=code_blocks,
368
+ iteration_time=iteration_time,
369
+ )
370
+
371
+ def _default_answer(self, message_history: list[dict[str, Any]], lm_handler: LMHandler) -> str:
372
+ """
373
+ Default behavior if the RLM runs out of iterations and does not find a final answer.
374
+ It will take the message history, and try to generate a final answer from it.
375
+ """
376
+ current_prompt = message_history + [
377
+ {
378
+ "role": "assistant",
379
+ "content": "Please provide a final answer to the user's question based on the information provided.",
380
+ }
381
+ ]
382
+ response = lm_handler.completion(current_prompt)
383
+
384
+ if self.logger:
385
+ self.logger.log(
386
+ RLMIteration(
387
+ prompt=current_prompt,
388
+ response=response,
389
+ final_answer=response,
390
+ code_blocks=[],
391
+ )
392
+ )
393
+
394
+ return response
395
+
396
+ def _fallback_answer(self, message: str | dict[str, Any]) -> str:
397
+ """
398
+ Fallback behavior if the RLM is actually at max depth, and should be treated as an LM.
399
+ """
400
+ client: BaseLM = get_client(self.backend, self.backend_kwargs)
401
+ response = client.completion(message)
402
+ return response
403
+
404
+ def _validate_persistent_environment_support(self) -> None:
405
+ """
406
+ Validate that the configured environment type supports persistent mode.
407
+
408
+ Persistent mode requires environments to implement:
409
+ - update_handler_address(address): Update LM handler address between calls
410
+ - add_context(payload, index): Add new context for multi-turn conversations
411
+ - get_context_count(): Return the number of loaded contexts
412
+
413
+ Currently only 'local' (LocalREPL) supports these methods.
414
+
415
+ Raises:
416
+ ValueError: If the environment type does not support persistent mode.
417
+ """
418
+ # Known environments that support persistence
419
+ persistent_supported_environments = {"local"}
420
+
421
+ if self.environment_type not in persistent_supported_environments:
422
+ raise ValueError(
423
+ f"persistent=True is not supported for environment type '{self.environment_type}'. "
424
+ f"Persistent mode requires environments that implement update_handler_address(), "
425
+ f"add_context(), and get_context_count(). "
426
+ f"Supported environments: {sorted(persistent_supported_environments)}"
427
+ )
428
+
429
+ @staticmethod
430
+ def _env_supports_persistence(env: BaseEnv) -> bool:
431
+ """Check if an environment instance supports persistent mode methods."""
432
+ return isinstance(env, SupportsPersistence)
433
+
434
+ def close(self) -> None:
435
+ """Clean up persistent environment. Call when done with multi-turn conversations."""
436
+ if self._persistent_env is not None:
437
+ if hasattr(self._persistent_env, "cleanup"):
438
+ self._persistent_env.cleanup()
439
+ self._persistent_env = None
440
+
441
+ def __enter__(self) -> "RLM":
442
+ return self
443
+
444
+ def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
445
+ self.close()
446
+ return False