groknroll 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- groknroll/__init__.py +36 -0
- groknroll/__main__.py +9 -0
- groknroll/agents/__init__.py +18 -0
- groknroll/agents/agent_manager.py +187 -0
- groknroll/agents/base_agent.py +118 -0
- groknroll/agents/build_agent.py +231 -0
- groknroll/agents/plan_agent.py +215 -0
- groknroll/cli/__init__.py +7 -0
- groknroll/cli/enhanced_cli.py +372 -0
- groknroll/cli/large_codebase_cli.py +413 -0
- groknroll/cli/main.py +331 -0
- groknroll/cli/rlm_commands.py +258 -0
- groknroll/clients/__init__.py +63 -0
- groknroll/clients/anthropic.py +112 -0
- groknroll/clients/azure_openai.py +142 -0
- groknroll/clients/base_lm.py +33 -0
- groknroll/clients/gemini.py +162 -0
- groknroll/clients/litellm.py +105 -0
- groknroll/clients/openai.py +129 -0
- groknroll/clients/portkey.py +94 -0
- groknroll/core/__init__.py +9 -0
- groknroll/core/agent.py +339 -0
- groknroll/core/comms_utils.py +264 -0
- groknroll/core/context.py +251 -0
- groknroll/core/exceptions.py +181 -0
- groknroll/core/large_codebase.py +564 -0
- groknroll/core/lm_handler.py +206 -0
- groknroll/core/rlm.py +446 -0
- groknroll/core/rlm_codebase.py +448 -0
- groknroll/core/rlm_integration.py +256 -0
- groknroll/core/types.py +276 -0
- groknroll/environments/__init__.py +34 -0
- groknroll/environments/base_env.py +182 -0
- groknroll/environments/constants.py +32 -0
- groknroll/environments/docker_repl.py +336 -0
- groknroll/environments/local_repl.py +388 -0
- groknroll/environments/modal_repl.py +502 -0
- groknroll/environments/prime_repl.py +588 -0
- groknroll/logger/__init__.py +4 -0
- groknroll/logger/rlm_logger.py +63 -0
- groknroll/logger/verbose.py +393 -0
- groknroll/operations/__init__.py +15 -0
- groknroll/operations/bash_ops.py +447 -0
- groknroll/operations/file_ops.py +473 -0
- groknroll/operations/git_ops.py +620 -0
- groknroll/oracle/__init__.py +11 -0
- groknroll/oracle/codebase_indexer.py +238 -0
- groknroll/oracle/oracle_agent.py +278 -0
- groknroll/setup.py +34 -0
- groknroll/storage/__init__.py +14 -0
- groknroll/storage/database.py +272 -0
- groknroll/storage/models.py +128 -0
- groknroll/utils/__init__.py +0 -0
- groknroll/utils/parsing.py +168 -0
- groknroll/utils/prompts.py +146 -0
- groknroll/utils/rlm_utils.py +19 -0
- groknroll-2.0.0.dist-info/METADATA +246 -0
- groknroll-2.0.0.dist-info/RECORD +62 -0
- groknroll-2.0.0.dist-info/WHEEL +5 -0
- groknroll-2.0.0.dist-info/entry_points.txt +3 -0
- groknroll-2.0.0.dist-info/licenses/LICENSE +21 -0
- groknroll-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database management for groknroll
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, List, Dict, Any
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
|
|
9
|
+
from sqlalchemy import create_engine, desc
|
|
10
|
+
from sqlalchemy.orm import sessionmaker, Session as DBSession
|
|
11
|
+
|
|
12
|
+
from groknroll.storage.models import Base, Project, FileIndex, Execution, Session, Analysis
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Database:
|
|
16
|
+
"""
|
|
17
|
+
SQLite database for groknroll project state
|
|
18
|
+
|
|
19
|
+
Stores:
|
|
20
|
+
- Project metadata
|
|
21
|
+
- File index with AST data
|
|
22
|
+
- RLM execution history
|
|
23
|
+
- Session history
|
|
24
|
+
- Analysis results
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, db_path: Optional[Path] = None):
|
|
28
|
+
"""
|
|
29
|
+
Initialize database
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
db_path: Path to SQLite database file.
|
|
33
|
+
If None, uses ~/.groknroll/groknroll.db
|
|
34
|
+
"""
|
|
35
|
+
if db_path is None:
|
|
36
|
+
db_path = Path.home() / ".groknroll" / "groknroll.db"
|
|
37
|
+
|
|
38
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
|
|
40
|
+
self.db_path = db_path
|
|
41
|
+
self.engine = create_engine(f"sqlite:///{db_path}", echo=False)
|
|
42
|
+
self.SessionLocal = sessionmaker(bind=self.engine)
|
|
43
|
+
|
|
44
|
+
# Create tables
|
|
45
|
+
Base.metadata.create_all(self.engine)
|
|
46
|
+
|
|
47
|
+
def get_session(self) -> DBSession:
|
|
48
|
+
"""Get database session"""
|
|
49
|
+
return self.SessionLocal()
|
|
50
|
+
|
|
51
|
+
# =========================================================================
|
|
52
|
+
# Project Operations
|
|
53
|
+
# =========================================================================
|
|
54
|
+
|
|
55
|
+
def get_or_create_project(self, project_path: Path) -> Project:
|
|
56
|
+
"""Get existing project or create new one"""
|
|
57
|
+
with self.get_session() as session:
|
|
58
|
+
project = session.query(Project).filter_by(path=str(project_path)).first()
|
|
59
|
+
|
|
60
|
+
if project is None:
|
|
61
|
+
project = Project(
|
|
62
|
+
path=str(project_path),
|
|
63
|
+
name=project_path.name
|
|
64
|
+
)
|
|
65
|
+
session.add(project)
|
|
66
|
+
session.commit()
|
|
67
|
+
session.refresh(project)
|
|
68
|
+
|
|
69
|
+
return project
|
|
70
|
+
|
|
71
|
+
def update_project_stats(
|
|
72
|
+
self,
|
|
73
|
+
project_id: int,
|
|
74
|
+
total_files: int,
|
|
75
|
+
total_lines: int
|
|
76
|
+
) -> None:
|
|
77
|
+
"""Update project statistics"""
|
|
78
|
+
with self.get_session() as session:
|
|
79
|
+
project = session.query(Project).filter_by(id=project_id).first()
|
|
80
|
+
if project:
|
|
81
|
+
project.total_files = total_files
|
|
82
|
+
project.total_lines = total_lines
|
|
83
|
+
project.last_indexed = datetime.utcnow()
|
|
84
|
+
session.commit()
|
|
85
|
+
|
|
86
|
+
# =========================================================================
|
|
87
|
+
# File Index Operations
|
|
88
|
+
# =========================================================================
|
|
89
|
+
|
|
90
|
+
def index_file(
|
|
91
|
+
self,
|
|
92
|
+
project_id: int,
|
|
93
|
+
file_path: Path,
|
|
94
|
+
relative_path: str,
|
|
95
|
+
**metadata
|
|
96
|
+
) -> FileIndex:
|
|
97
|
+
"""Index a file"""
|
|
98
|
+
with self.get_session() as session:
|
|
99
|
+
file_index = session.query(FileIndex).filter_by(
|
|
100
|
+
project_id=project_id,
|
|
101
|
+
path=str(file_path)
|
|
102
|
+
).first()
|
|
103
|
+
|
|
104
|
+
if file_index is None:
|
|
105
|
+
file_index = FileIndex(
|
|
106
|
+
project_id=project_id,
|
|
107
|
+
path=str(file_path),
|
|
108
|
+
relative_path=relative_path
|
|
109
|
+
)
|
|
110
|
+
session.add(file_index)
|
|
111
|
+
|
|
112
|
+
# Update metadata
|
|
113
|
+
for key, value in metadata.items():
|
|
114
|
+
if hasattr(file_index, key):
|
|
115
|
+
setattr(file_index, key, value)
|
|
116
|
+
|
|
117
|
+
file_index.updated_at = datetime.utcnow()
|
|
118
|
+
session.commit()
|
|
119
|
+
session.refresh(file_index)
|
|
120
|
+
|
|
121
|
+
return file_index
|
|
122
|
+
|
|
123
|
+
def get_project_files(self, project_id: int) -> List[FileIndex]:
|
|
124
|
+
"""Get all indexed files for project"""
|
|
125
|
+
with self.get_session() as session:
|
|
126
|
+
return session.query(FileIndex).filter_by(project_id=project_id).all()
|
|
127
|
+
|
|
128
|
+
def search_files(
|
|
129
|
+
self,
|
|
130
|
+
project_id: int,
|
|
131
|
+
query: str,
|
|
132
|
+
language: Optional[str] = None
|
|
133
|
+
) -> List[FileIndex]:
|
|
134
|
+
"""Search indexed files"""
|
|
135
|
+
with self.get_session() as session:
|
|
136
|
+
q = session.query(FileIndex).filter_by(project_id=project_id)
|
|
137
|
+
|
|
138
|
+
if language:
|
|
139
|
+
q = q.filter_by(language=language)
|
|
140
|
+
|
|
141
|
+
q = q.filter(FileIndex.relative_path.like(f"%{query}%"))
|
|
142
|
+
|
|
143
|
+
return q.all()
|
|
144
|
+
|
|
145
|
+
# =========================================================================
|
|
146
|
+
# Execution History
|
|
147
|
+
# =========================================================================
|
|
148
|
+
|
|
149
|
+
def log_execution(
|
|
150
|
+
self,
|
|
151
|
+
project_id: int,
|
|
152
|
+
task: str,
|
|
153
|
+
response: str,
|
|
154
|
+
**metrics
|
|
155
|
+
) -> Execution:
|
|
156
|
+
"""Log RLM execution"""
|
|
157
|
+
with self.get_session() as session:
|
|
158
|
+
execution = Execution(
|
|
159
|
+
project_id=project_id,
|
|
160
|
+
task=task,
|
|
161
|
+
response=response,
|
|
162
|
+
status=metrics.get("status", "success"),
|
|
163
|
+
**metrics
|
|
164
|
+
)
|
|
165
|
+
session.add(execution)
|
|
166
|
+
session.commit()
|
|
167
|
+
session.refresh(execution)
|
|
168
|
+
|
|
169
|
+
return execution
|
|
170
|
+
|
|
171
|
+
def get_recent_executions(
|
|
172
|
+
self,
|
|
173
|
+
project_id: int,
|
|
174
|
+
limit: int = 10
|
|
175
|
+
) -> List[Execution]:
|
|
176
|
+
"""Get recent executions"""
|
|
177
|
+
with self.get_session() as session:
|
|
178
|
+
return (
|
|
179
|
+
session.query(Execution)
|
|
180
|
+
.filter_by(project_id=project_id)
|
|
181
|
+
.order_by(desc(Execution.started_at))
|
|
182
|
+
.limit(limit)
|
|
183
|
+
.all()
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def get_execution_stats(self, project_id: int) -> Dict[str, Any]:
|
|
187
|
+
"""Get execution statistics"""
|
|
188
|
+
with self.get_session() as session:
|
|
189
|
+
executions = session.query(Execution).filter_by(project_id=project_id).all()
|
|
190
|
+
|
|
191
|
+
total_cost = sum(e.total_cost or 0.0 for e in executions)
|
|
192
|
+
total_time = sum(e.total_time or 0.0 for e in executions)
|
|
193
|
+
total_count = len(executions)
|
|
194
|
+
success_count = sum(1 for e in executions if e.status == "success")
|
|
195
|
+
|
|
196
|
+
return {
|
|
197
|
+
"total_executions": total_count,
|
|
198
|
+
"successful": success_count,
|
|
199
|
+
"failed": total_count - success_count,
|
|
200
|
+
"total_cost": total_cost,
|
|
201
|
+
"total_time": total_time,
|
|
202
|
+
"avg_cost": total_cost / total_count if total_count > 0 else 0.0,
|
|
203
|
+
"avg_time": total_time / total_count if total_count > 0 else 0.0,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
# =========================================================================
|
|
207
|
+
# Session Management
|
|
208
|
+
# =========================================================================
|
|
209
|
+
|
|
210
|
+
def create_session(
|
|
211
|
+
self,
|
|
212
|
+
project_id: int,
|
|
213
|
+
session_type: str
|
|
214
|
+
) -> Session:
|
|
215
|
+
"""Create new session"""
|
|
216
|
+
with self.get_session() as session:
|
|
217
|
+
new_session = Session(
|
|
218
|
+
project_id=project_id,
|
|
219
|
+
session_type=session_type
|
|
220
|
+
)
|
|
221
|
+
session.add(new_session)
|
|
222
|
+
session.commit()
|
|
223
|
+
session.refresh(new_session)
|
|
224
|
+
|
|
225
|
+
return new_session
|
|
226
|
+
|
|
227
|
+
def end_session(self, session_id: int) -> None:
|
|
228
|
+
"""End session"""
|
|
229
|
+
with self.get_session() as session:
|
|
230
|
+
sess = session.query(Session).filter_by(id=session_id).first()
|
|
231
|
+
if sess:
|
|
232
|
+
sess.ended_at = datetime.utcnow()
|
|
233
|
+
session.commit()
|
|
234
|
+
|
|
235
|
+
# =========================================================================
|
|
236
|
+
# Analysis Results
|
|
237
|
+
# =========================================================================
|
|
238
|
+
|
|
239
|
+
def save_analysis(
|
|
240
|
+
self,
|
|
241
|
+
project_id: int,
|
|
242
|
+
analysis_type: str,
|
|
243
|
+
results: Dict[str, Any],
|
|
244
|
+
**metadata
|
|
245
|
+
) -> Analysis:
|
|
246
|
+
"""Save analysis results"""
|
|
247
|
+
with self.get_session() as session:
|
|
248
|
+
analysis = Analysis(
|
|
249
|
+
project_id=project_id,
|
|
250
|
+
analysis_type=analysis_type,
|
|
251
|
+
results=results,
|
|
252
|
+
**metadata
|
|
253
|
+
)
|
|
254
|
+
session.add(analysis)
|
|
255
|
+
session.commit()
|
|
256
|
+
session.refresh(analysis)
|
|
257
|
+
|
|
258
|
+
return analysis
|
|
259
|
+
|
|
260
|
+
def get_latest_analysis(
|
|
261
|
+
self,
|
|
262
|
+
project_id: int,
|
|
263
|
+
analysis_type: str
|
|
264
|
+
) -> Optional[Analysis]:
|
|
265
|
+
"""Get latest analysis of given type"""
|
|
266
|
+
with self.get_session() as session:
|
|
267
|
+
return (
|
|
268
|
+
session.query(Analysis)
|
|
269
|
+
.filter_by(project_id=project_id, analysis_type=analysis_type)
|
|
270
|
+
.order_by(desc(Analysis.created_at))
|
|
271
|
+
.first()
|
|
272
|
+
)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database models for groknroll
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Optional
|
|
7
|
+
from sqlalchemy import Column, Integer, String, Float, DateTime, Text, ForeignKey, Boolean, JSON
|
|
8
|
+
from sqlalchemy.ext.declarative import declarative_base
|
|
9
|
+
from sqlalchemy.orm import relationship
|
|
10
|
+
|
|
11
|
+
Base = declarative_base()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Project(Base):
|
|
15
|
+
"""Project metadata"""
|
|
16
|
+
__tablename__ = "projects"
|
|
17
|
+
|
|
18
|
+
id = Column(Integer, primary_key=True)
|
|
19
|
+
path = Column(String, unique=True, nullable=False, index=True)
|
|
20
|
+
name = Column(String, nullable=False)
|
|
21
|
+
language = Column(String) # Primary language
|
|
22
|
+
total_files = Column(Integer, default=0)
|
|
23
|
+
total_lines = Column(Integer, default=0)
|
|
24
|
+
last_indexed = Column(DateTime)
|
|
25
|
+
created_at = Column(DateTime, default=datetime.utcnow)
|
|
26
|
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
27
|
+
|
|
28
|
+
# Relationships
|
|
29
|
+
files = relationship("FileIndex", back_populates="project", cascade="all, delete-orphan")
|
|
30
|
+
executions = relationship("Execution", back_populates="project", cascade="all, delete-orphan")
|
|
31
|
+
sessions = relationship("Session", back_populates="project", cascade="all, delete-orphan")
|
|
32
|
+
analyses = relationship("Analysis", back_populates="project", cascade="all, delete-orphan")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FileIndex(Base):
|
|
36
|
+
"""Indexed file metadata"""
|
|
37
|
+
__tablename__ = "file_index"
|
|
38
|
+
|
|
39
|
+
id = Column(Integer, primary_key=True)
|
|
40
|
+
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
|
|
41
|
+
path = Column(String, nullable=False, index=True)
|
|
42
|
+
relative_path = Column(String, nullable=False)
|
|
43
|
+
language = Column(String)
|
|
44
|
+
size_bytes = Column(Integer)
|
|
45
|
+
lines_of_code = Column(Integer)
|
|
46
|
+
complexity = Column(Float) # Cyclomatic complexity
|
|
47
|
+
last_modified = Column(DateTime)
|
|
48
|
+
ast_data = Column(JSON) # Parsed AST metadata
|
|
49
|
+
imports = Column(JSON) # List of imports
|
|
50
|
+
exports = Column(JSON) # List of exports
|
|
51
|
+
created_at = Column(DateTime, default=datetime.utcnow)
|
|
52
|
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
53
|
+
|
|
54
|
+
# Relationships
|
|
55
|
+
project = relationship("Project", back_populates="files")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Execution(Base):
|
|
59
|
+
"""RLM execution history"""
|
|
60
|
+
__tablename__ = "executions"
|
|
61
|
+
|
|
62
|
+
id = Column(Integer, primary_key=True)
|
|
63
|
+
project_id = Column(Integer, ForeignKey("projects.id"), index=True)
|
|
64
|
+
session_id = Column(Integer, ForeignKey("sessions.id"), index=True)
|
|
65
|
+
|
|
66
|
+
# Request
|
|
67
|
+
task = Column(Text, nullable=False)
|
|
68
|
+
context = Column(JSON)
|
|
69
|
+
model = Column(String)
|
|
70
|
+
|
|
71
|
+
# Response
|
|
72
|
+
response = Column(Text)
|
|
73
|
+
trace_log = Column(Text)
|
|
74
|
+
|
|
75
|
+
# Metrics
|
|
76
|
+
total_cost = Column(Float)
|
|
77
|
+
total_time = Column(Float)
|
|
78
|
+
iterations = Column(Integer)
|
|
79
|
+
status = Column(String) # success, failed, timeout
|
|
80
|
+
error_message = Column(Text)
|
|
81
|
+
|
|
82
|
+
# Timestamps
|
|
83
|
+
started_at = Column(DateTime, default=datetime.utcnow, index=True)
|
|
84
|
+
completed_at = Column(DateTime)
|
|
85
|
+
created_at = Column(DateTime, default=datetime.utcnow)
|
|
86
|
+
|
|
87
|
+
# Relationships
|
|
88
|
+
project = relationship("Project", back_populates="executions")
|
|
89
|
+
session = relationship("Session", back_populates="executions")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class Session(Base):
|
|
93
|
+
"""Interactive session history"""
|
|
94
|
+
__tablename__ = "sessions"
|
|
95
|
+
|
|
96
|
+
id = Column(Integer, primary_key=True)
|
|
97
|
+
project_id = Column(Integer, ForeignKey("projects.id"), index=True)
|
|
98
|
+
|
|
99
|
+
session_type = Column(String) # chat, repl, dashboard
|
|
100
|
+
started_at = Column(DateTime, default=datetime.utcnow, index=True)
|
|
101
|
+
ended_at = Column(DateTime)
|
|
102
|
+
message_count = Column(Integer, default=0)
|
|
103
|
+
total_cost = Column(Float, default=0.0)
|
|
104
|
+
|
|
105
|
+
# Relationships
|
|
106
|
+
project = relationship("Project", back_populates="sessions")
|
|
107
|
+
executions = relationship("Execution", back_populates="session", cascade="all, delete-orphan")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class Analysis(Base):
|
|
111
|
+
"""Code analysis results"""
|
|
112
|
+
__tablename__ = "analyses"
|
|
113
|
+
|
|
114
|
+
id = Column(Integer, primary_key=True)
|
|
115
|
+
project_id = Column(Integer, ForeignKey("projects.id"), index=True)
|
|
116
|
+
|
|
117
|
+
analysis_type = Column(String, nullable=False) # security, complexity, review, etc
|
|
118
|
+
target_path = Column(String) # File or directory analyzed
|
|
119
|
+
results = Column(JSON) # Analysis results
|
|
120
|
+
recommendations = Column(JSON) # Recommendations
|
|
121
|
+
issues = Column(JSON) # Issues found
|
|
122
|
+
metrics = Column(JSON) # Metrics
|
|
123
|
+
|
|
124
|
+
execution_time = Column(Float)
|
|
125
|
+
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
|
126
|
+
|
|
127
|
+
# Relationships
|
|
128
|
+
project = relationship("Project", back_populates="analyses")
|
|
File without changes
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Parsing utilities for RLM trajectories.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from groknroll.core.types import REPLResult, RLMIteration
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from groknroll.environments.base_env import BaseEnv
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def find_code_blocks(text: str) -> list[str]:
|
|
15
|
+
"""
|
|
16
|
+
Find REPL code blocks in text wrapped in triple backticks and return List of content(s).
|
|
17
|
+
Returns None if no code blocks are found.
|
|
18
|
+
"""
|
|
19
|
+
pattern = r"```repl\s*\n(.*?)\n```"
|
|
20
|
+
results = []
|
|
21
|
+
|
|
22
|
+
for match in re.finditer(pattern, text, re.DOTALL):
|
|
23
|
+
code_content = match.group(1).strip()
|
|
24
|
+
results.append(code_content)
|
|
25
|
+
|
|
26
|
+
return results
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def find_final_answer(text: str, environment: "BaseEnv | None" = None) -> str | None:
|
|
30
|
+
"""
|
|
31
|
+
Find FINAL(...) or FINAL_VAR(...) statement in response and return the final answer string.
|
|
32
|
+
|
|
33
|
+
If FINAL_VAR is found and an environment is provided, executes code to retrieve the variable value.
|
|
34
|
+
Returns None if neither pattern is found.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
text: The response text to parse
|
|
38
|
+
environment: Optional environment to execute code for FINAL_VAR retrieval
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The final answer string, or None if no final answer pattern is found
|
|
42
|
+
"""
|
|
43
|
+
# Check for FINAL_VAR pattern first - must be at start of line
|
|
44
|
+
final_var_pattern = r"^\s*FINAL_VAR\((.*?)\)"
|
|
45
|
+
match = re.search(final_var_pattern, text, re.MULTILINE | re.DOTALL)
|
|
46
|
+
if match:
|
|
47
|
+
variable_name = match.group(1).strip().strip('"').strip("'")
|
|
48
|
+
if environment is not None:
|
|
49
|
+
result = environment.execute_code(f"print(FINAL_VAR({variable_name!r}))")
|
|
50
|
+
final_answer = result.stdout.strip()
|
|
51
|
+
if final_answer == "":
|
|
52
|
+
final_answer = result.stderr.strip() or ""
|
|
53
|
+
return final_answer
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
# Check for FINAL pattern - must be at start of line
|
|
57
|
+
final_pattern = r"^\s*FINAL\((.*?)\)"
|
|
58
|
+
match = re.search(final_pattern, text, re.MULTILINE | re.DOTALL)
|
|
59
|
+
if match:
|
|
60
|
+
return match.group(1).strip()
|
|
61
|
+
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def format_iteration(
|
|
66
|
+
iteration: RLMIteration, max_character_length: int = 20000
|
|
67
|
+
) -> list[dict[str, str]]:
|
|
68
|
+
"""
|
|
69
|
+
Format an RLM iteration (including all code blocks) to append to the message history for
|
|
70
|
+
the prompt of the LM in the next iteration. We also truncate code execution results
|
|
71
|
+
that exceed the max_character_length.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
iteration: The iteration to format
|
|
75
|
+
max_character_length: The maximum character length of the result
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
A list of messages to add to the next prompt
|
|
79
|
+
"""
|
|
80
|
+
messages = [{"role": "assistant", "content": iteration.response}]
|
|
81
|
+
|
|
82
|
+
for code_block in iteration.code_blocks:
|
|
83
|
+
code = code_block.code
|
|
84
|
+
result = code_block.result
|
|
85
|
+
result = format_execution_result(result)
|
|
86
|
+
if len(result) > max_character_length:
|
|
87
|
+
result = (
|
|
88
|
+
result[:max_character_length]
|
|
89
|
+
+ f"... + [{len(result) - max_character_length} chars...]"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
execution_message = {
|
|
93
|
+
"role": "user",
|
|
94
|
+
"content": f"Code executed:\n```python\n{code}\n```\n\nREPL output:\n{result}",
|
|
95
|
+
}
|
|
96
|
+
messages.append(execution_message)
|
|
97
|
+
return messages
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
################
|
|
101
|
+
# TODO: Remove and refactor these soon
|
|
102
|
+
################
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def format_execution_result(result: REPLResult) -> str:
|
|
106
|
+
"""
|
|
107
|
+
Format the execution result as a string for display.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
result: The REPLResult object to format.
|
|
111
|
+
"""
|
|
112
|
+
result_parts = []
|
|
113
|
+
|
|
114
|
+
if result.stdout:
|
|
115
|
+
result_parts.append(f"\n{result.stdout}")
|
|
116
|
+
|
|
117
|
+
if result.stderr:
|
|
118
|
+
result_parts.append(f"\n{result.stderr}")
|
|
119
|
+
|
|
120
|
+
# Show some key variables (excluding internal ones)
|
|
121
|
+
important_vars = {}
|
|
122
|
+
for key, value in result.locals.items():
|
|
123
|
+
if not key.startswith("_") and key not in [
|
|
124
|
+
"__builtins__",
|
|
125
|
+
"__name__",
|
|
126
|
+
"__doc__",
|
|
127
|
+
]:
|
|
128
|
+
# Only show simple types or short representations
|
|
129
|
+
if isinstance(value, (str, int, float, bool, list, dict, tuple)):
|
|
130
|
+
important_vars[key] = ""
|
|
131
|
+
|
|
132
|
+
if important_vars:
|
|
133
|
+
result_parts.append(f"REPL variables: {list(important_vars.keys())}\n")
|
|
134
|
+
|
|
135
|
+
return "\n\n".join(result_parts) if result_parts else "No output"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def check_for_final_answer(response: str, repl_env, logger) -> str | None:
|
|
139
|
+
"""Check if response contains a final answer."""
|
|
140
|
+
# Use the new find_final_answer function which handles both FINAL and FINAL_VAR
|
|
141
|
+
return find_final_answer(response, environment=repl_env)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def convert_context_for_repl(context):
|
|
145
|
+
"""
|
|
146
|
+
Convert REPL context to either some
|
|
147
|
+
"""
|
|
148
|
+
if isinstance(context, dict):
|
|
149
|
+
context_data = context
|
|
150
|
+
context_str = None
|
|
151
|
+
elif isinstance(context, str):
|
|
152
|
+
context_data = None
|
|
153
|
+
context_str = context
|
|
154
|
+
elif isinstance(context, list):
|
|
155
|
+
if len(context) > 0 and isinstance(context[0], dict):
|
|
156
|
+
if "content" in context[0]:
|
|
157
|
+
context_data = [msg.get("content", "") for msg in context]
|
|
158
|
+
else:
|
|
159
|
+
context_data = context
|
|
160
|
+
context_str = None
|
|
161
|
+
else:
|
|
162
|
+
context_data = context
|
|
163
|
+
context_str = None
|
|
164
|
+
else:
|
|
165
|
+
context_data = context
|
|
166
|
+
context_str = None
|
|
167
|
+
|
|
168
|
+
return context_data, context_str
|