groknroll 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- groknroll/__init__.py +36 -0
- groknroll/__main__.py +9 -0
- groknroll/agents/__init__.py +18 -0
- groknroll/agents/agent_manager.py +187 -0
- groknroll/agents/base_agent.py +118 -0
- groknroll/agents/build_agent.py +231 -0
- groknroll/agents/plan_agent.py +215 -0
- groknroll/cli/__init__.py +7 -0
- groknroll/cli/enhanced_cli.py +372 -0
- groknroll/cli/large_codebase_cli.py +413 -0
- groknroll/cli/main.py +331 -0
- groknroll/cli/rlm_commands.py +258 -0
- groknroll/clients/__init__.py +63 -0
- groknroll/clients/anthropic.py +112 -0
- groknroll/clients/azure_openai.py +142 -0
- groknroll/clients/base_lm.py +33 -0
- groknroll/clients/gemini.py +162 -0
- groknroll/clients/litellm.py +105 -0
- groknroll/clients/openai.py +129 -0
- groknroll/clients/portkey.py +94 -0
- groknroll/core/__init__.py +9 -0
- groknroll/core/agent.py +339 -0
- groknroll/core/comms_utils.py +264 -0
- groknroll/core/context.py +251 -0
- groknroll/core/exceptions.py +181 -0
- groknroll/core/large_codebase.py +564 -0
- groknroll/core/lm_handler.py +206 -0
- groknroll/core/rlm.py +446 -0
- groknroll/core/rlm_codebase.py +448 -0
- groknroll/core/rlm_integration.py +256 -0
- groknroll/core/types.py +276 -0
- groknroll/environments/__init__.py +34 -0
- groknroll/environments/base_env.py +182 -0
- groknroll/environments/constants.py +32 -0
- groknroll/environments/docker_repl.py +336 -0
- groknroll/environments/local_repl.py +388 -0
- groknroll/environments/modal_repl.py +502 -0
- groknroll/environments/prime_repl.py +588 -0
- groknroll/logger/__init__.py +4 -0
- groknroll/logger/rlm_logger.py +63 -0
- groknroll/logger/verbose.py +393 -0
- groknroll/operations/__init__.py +15 -0
- groknroll/operations/bash_ops.py +447 -0
- groknroll/operations/file_ops.py +473 -0
- groknroll/operations/git_ops.py +620 -0
- groknroll/oracle/__init__.py +11 -0
- groknroll/oracle/codebase_indexer.py +238 -0
- groknroll/oracle/oracle_agent.py +278 -0
- groknroll/setup.py +34 -0
- groknroll/storage/__init__.py +14 -0
- groknroll/storage/database.py +272 -0
- groknroll/storage/models.py +128 -0
- groknroll/utils/__init__.py +0 -0
- groknroll/utils/parsing.py +168 -0
- groknroll/utils/prompts.py +146 -0
- groknroll/utils/rlm_utils.py +19 -0
- groknroll-2.0.0.dist-info/METADATA +246 -0
- groknroll-2.0.0.dist-info/RECORD +62 -0
- groknroll-2.0.0.dist-info/WHEEL +5 -0
- groknroll-2.0.0.dist-info/entry_points.txt +3 -0
- groknroll-2.0.0.dist-info/licenses/LICENSE +21 -0
- groknroll-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from portkey_ai import AsyncPortkey, Portkey
|
|
5
|
+
from portkey_ai.api_resources.types.chat_complete_type import ChatCompletions
|
|
6
|
+
|
|
7
|
+
from groknroll.clients.base_lm import BaseLM
|
|
8
|
+
from groknroll.core.types import ModelUsageSummary, UsageSummary
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PortkeyClient(BaseLM):
|
|
12
|
+
"""
|
|
13
|
+
LM Client for running models with the Portkey API.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
api_key: str,
|
|
19
|
+
model_name: str | None = None,
|
|
20
|
+
base_url: str | None = "https://api.portkey.ai/v1",
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(model_name=model_name, **kwargs)
|
|
24
|
+
self.client = Portkey(api_key=api_key, base_url=base_url)
|
|
25
|
+
self.async_client = AsyncPortkey(api_key=api_key, base_url=base_url)
|
|
26
|
+
self.model_name = model_name
|
|
27
|
+
|
|
28
|
+
# Per-model usage tracking
|
|
29
|
+
self.model_call_counts: dict[str, int] = defaultdict(int)
|
|
30
|
+
self.model_input_tokens: dict[str, int] = defaultdict(int)
|
|
31
|
+
self.model_output_tokens: dict[str, int] = defaultdict(int)
|
|
32
|
+
self.model_total_tokens: dict[str, int] = defaultdict(int)
|
|
33
|
+
|
|
34
|
+
def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
|
|
35
|
+
if isinstance(prompt, str):
|
|
36
|
+
messages = [{"role": "user", "content": prompt}]
|
|
37
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
38
|
+
messages = prompt
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
41
|
+
|
|
42
|
+
model = model or self.model_name
|
|
43
|
+
if not model:
|
|
44
|
+
raise ValueError("Model name is required for Portkey client.")
|
|
45
|
+
|
|
46
|
+
response = self.client.chat.completions.create(
|
|
47
|
+
model=model,
|
|
48
|
+
messages=messages,
|
|
49
|
+
)
|
|
50
|
+
self._track_cost(response, model)
|
|
51
|
+
return response.choices[0].message.content
|
|
52
|
+
|
|
53
|
+
async def acompletion(self, prompt: str | dict[str, Any], model: str | None = None) -> str:
|
|
54
|
+
if isinstance(prompt, str):
|
|
55
|
+
messages = [{"role": "user", "content": prompt}]
|
|
56
|
+
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
|
|
57
|
+
messages = prompt
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Invalid prompt type: {type(prompt)}")
|
|
60
|
+
|
|
61
|
+
model = model or self.model_name
|
|
62
|
+
if not model:
|
|
63
|
+
raise ValueError("Model name is required for Portkey client.")
|
|
64
|
+
|
|
65
|
+
response = await self.async_client.chat.completions.create(model=model, messages=messages)
|
|
66
|
+
self._track_cost(response, model)
|
|
67
|
+
return response.choices[0].message.content
|
|
68
|
+
|
|
69
|
+
def _track_cost(self, response: ChatCompletions, model: str):
|
|
70
|
+
self.model_call_counts[model] += 1
|
|
71
|
+
self.model_input_tokens[model] += response.usage.prompt_tokens
|
|
72
|
+
self.model_output_tokens[model] += response.usage.completion_tokens
|
|
73
|
+
self.model_total_tokens[model] += response.usage.total_tokens
|
|
74
|
+
|
|
75
|
+
# Track last call for handler to read
|
|
76
|
+
self.last_prompt_tokens = response.usage.prompt_tokens
|
|
77
|
+
self.last_completion_tokens = response.usage.completion_tokens
|
|
78
|
+
|
|
79
|
+
def get_usage_summary(self) -> UsageSummary:
|
|
80
|
+
model_summaries = {}
|
|
81
|
+
for model in self.model_call_counts:
|
|
82
|
+
model_summaries[model] = ModelUsageSummary(
|
|
83
|
+
total_calls=self.model_call_counts[model],
|
|
84
|
+
total_input_tokens=self.model_input_tokens[model],
|
|
85
|
+
total_output_tokens=self.model_output_tokens[model],
|
|
86
|
+
)
|
|
87
|
+
return UsageSummary(model_usage_summaries=model_summaries)
|
|
88
|
+
|
|
89
|
+
def get_last_usage(self) -> ModelUsageSummary:
|
|
90
|
+
return ModelUsageSummary(
|
|
91
|
+
total_calls=1,
|
|
92
|
+
total_input_tokens=self.last_prompt_tokens,
|
|
93
|
+
total_output_tokens=self.last_completion_tokens,
|
|
94
|
+
)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
groknroll Core - Agent logic with integrated RLM
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from groknroll.core.agent import GroknrollAgent
|
|
6
|
+
from groknroll.core.rlm_integration import RLMIntegration
|
|
7
|
+
from groknroll.core.context import ProjectContext
|
|
8
|
+
|
|
9
|
+
__all__ = ["GroknrollAgent", "RLMIntegration", "ProjectContext"]
|
groknroll/core/agent.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GroknrollAgent - The Ultimate CLI Coding Agent
|
|
3
|
+
|
|
4
|
+
Main orchestrator combining RLM, project context, and autonomous workflows.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Dict, List, Optional, Any
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
from groknroll.core.rlm_integration import RLMIntegration, RLMConfig, RLMResult
|
|
12
|
+
from groknroll.core.context import ProjectContext
|
|
13
|
+
from groknroll.storage.database import Database
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class AgentConfig:
|
|
18
|
+
"""Configuration for GroknrollAgent"""
|
|
19
|
+
model: str = "gpt-4o-mini"
|
|
20
|
+
max_cost: float = 5.0
|
|
21
|
+
timeout: int = 300
|
|
22
|
+
auto_index: bool = True
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GroknrollAgent:
|
|
26
|
+
"""
|
|
27
|
+
The Ultimate CLI Coding Agent
|
|
28
|
+
|
|
29
|
+
Features:
|
|
30
|
+
- Unlimited context via RLM
|
|
31
|
+
- Project-aware with persistent state
|
|
32
|
+
- Autonomous multi-step workflows
|
|
33
|
+
- Self-correcting and verifying
|
|
34
|
+
- Local and private
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
project_path: Optional[Path] = None,
|
|
40
|
+
config: Optional[AgentConfig] = None
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize groknroll agent
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
project_path: Path to project (uses current dir if None)
|
|
47
|
+
config: Agent configuration
|
|
48
|
+
"""
|
|
49
|
+
self.config = config or AgentConfig()
|
|
50
|
+
|
|
51
|
+
# Set project path
|
|
52
|
+
self.project_path = project_path or Path.cwd()
|
|
53
|
+
if not self.project_path.exists():
|
|
54
|
+
raise ValueError(f"Project path does not exist: {self.project_path}")
|
|
55
|
+
|
|
56
|
+
# Initialize components
|
|
57
|
+
self.db = Database()
|
|
58
|
+
self.context = ProjectContext(self.project_path, self.db)
|
|
59
|
+
|
|
60
|
+
# Initialize RLM
|
|
61
|
+
rlm_config = RLMConfig(
|
|
62
|
+
model=self.config.model,
|
|
63
|
+
max_cost=self.config.max_cost,
|
|
64
|
+
timeout_seconds=self.config.timeout
|
|
65
|
+
)
|
|
66
|
+
self.rlm = RLMIntegration(rlm_config)
|
|
67
|
+
|
|
68
|
+
# Auto-index if configured
|
|
69
|
+
if self.config.auto_index:
|
|
70
|
+
self._ensure_indexed()
|
|
71
|
+
|
|
72
|
+
def _ensure_indexed(self) -> None:
|
|
73
|
+
"""Ensure project is indexed"""
|
|
74
|
+
overview = self.context.get_project_overview()
|
|
75
|
+
|
|
76
|
+
# Index if never indexed or very old
|
|
77
|
+
if overview["total_files"] == 0:
|
|
78
|
+
print("Indexing project for first time...")
|
|
79
|
+
stats = self.context.index_project()
|
|
80
|
+
print(f"✓ Indexed {stats['indexed']} files ({stats['total_lines']:,} lines)")
|
|
81
|
+
|
|
82
|
+
# =========================================================================
|
|
83
|
+
# Core Chat & Analysis
|
|
84
|
+
# =========================================================================
|
|
85
|
+
|
|
86
|
+
def chat(self, message: str, context: Optional[Dict[str, Any]] = None) -> str:
|
|
87
|
+
"""
|
|
88
|
+
Chat with agent about project
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
message: User message
|
|
92
|
+
context: Additional context
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Agent response
|
|
96
|
+
"""
|
|
97
|
+
# Add project context
|
|
98
|
+
project_context = {
|
|
99
|
+
"project_name": self.context.project.name,
|
|
100
|
+
"project_path": str(self.project_path),
|
|
101
|
+
"total_files": self.context.project.total_files,
|
|
102
|
+
**(context or {})
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
# Execute with RLM
|
|
106
|
+
result = self.rlm.complete(
|
|
107
|
+
task=message,
|
|
108
|
+
context=project_context
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Log execution
|
|
112
|
+
if result.success:
|
|
113
|
+
self.context.log_execution(
|
|
114
|
+
task=message,
|
|
115
|
+
response=result.response,
|
|
116
|
+
total_cost=result.total_cost,
|
|
117
|
+
total_time=result.total_time,
|
|
118
|
+
iterations=result.iterations,
|
|
119
|
+
status="success"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return result.response if result.success else f"Error: {result.error}"
|
|
123
|
+
|
|
124
|
+
def analyze_code(
|
|
125
|
+
self,
|
|
126
|
+
file_path: Optional[Path] = None,
|
|
127
|
+
code: Optional[str] = None,
|
|
128
|
+
analysis_type: str = "review"
|
|
129
|
+
) -> Dict[str, Any]:
|
|
130
|
+
"""
|
|
131
|
+
Analyze code
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
file_path: Path to file (or provide code directly)
|
|
135
|
+
code: Code string (or provide file_path)
|
|
136
|
+
analysis_type: Type of analysis (review, security, complexity, etc)
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Analysis results
|
|
140
|
+
"""
|
|
141
|
+
if file_path and not code:
|
|
142
|
+
with open(file_path, 'r') as f:
|
|
143
|
+
code = f.read()
|
|
144
|
+
|
|
145
|
+
if not code:
|
|
146
|
+
raise ValueError("Must provide either file_path or code")
|
|
147
|
+
|
|
148
|
+
# Determine language
|
|
149
|
+
language = "python" # Default
|
|
150
|
+
if file_path:
|
|
151
|
+
language = self.context._detect_language(file_path.suffix)
|
|
152
|
+
|
|
153
|
+
# Analyze with RLM
|
|
154
|
+
result = self.rlm.analyze_code(
|
|
155
|
+
code=code,
|
|
156
|
+
analysis_type=analysis_type,
|
|
157
|
+
language=language
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if result.success:
|
|
161
|
+
# Parse response into structured format
|
|
162
|
+
analysis_results = {
|
|
163
|
+
"analysis": result.response,
|
|
164
|
+
"file": str(file_path) if file_path else None,
|
|
165
|
+
"language": language,
|
|
166
|
+
"type": analysis_type
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
# Save analysis
|
|
170
|
+
self.context.save_analysis(
|
|
171
|
+
analysis_type=analysis_type,
|
|
172
|
+
results=analysis_results,
|
|
173
|
+
target_path=str(file_path) if file_path else None,
|
|
174
|
+
execution_time=result.total_time
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return analysis_results
|
|
178
|
+
|
|
179
|
+
return {"error": result.error}
|
|
180
|
+
|
|
181
|
+
def analyze_project(self, detailed: bool = False) -> Dict[str, Any]:
|
|
182
|
+
"""
|
|
183
|
+
Analyze entire project
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
detailed: Include detailed analysis
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Project analysis
|
|
190
|
+
"""
|
|
191
|
+
overview = self.context.get_project_overview()
|
|
192
|
+
|
|
193
|
+
analysis_prompt = f"""Analyze this project:
|
|
194
|
+
|
|
195
|
+
Name: {overview['name']}
|
|
196
|
+
Files: {overview['total_files']}
|
|
197
|
+
Lines of Code: {overview['total_lines']:,}
|
|
198
|
+
|
|
199
|
+
Languages:
|
|
200
|
+
{self._format_language_stats(overview['languages'])}
|
|
201
|
+
|
|
202
|
+
Provide:
|
|
203
|
+
1. Project structure assessment
|
|
204
|
+
2. Code quality overview
|
|
205
|
+
3. Potential issues
|
|
206
|
+
4. Recommendations
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
if detailed:
|
|
210
|
+
analysis_prompt += "\n5. Detailed analysis of key files\n6. Architecture suggestions"
|
|
211
|
+
|
|
212
|
+
result = self.rlm.complete(
|
|
213
|
+
task=analysis_prompt,
|
|
214
|
+
context=overview
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
if result.success:
|
|
218
|
+
analysis_results = {
|
|
219
|
+
"overview": overview,
|
|
220
|
+
"analysis": result.response,
|
|
221
|
+
"metrics": {
|
|
222
|
+
"cost": result.total_cost,
|
|
223
|
+
"time": result.total_time,
|
|
224
|
+
"iterations": result.iterations
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
self.context.save_analysis(
|
|
229
|
+
analysis_type="project_overview",
|
|
230
|
+
results=analysis_results,
|
|
231
|
+
execution_time=result.total_time
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
return analysis_results
|
|
235
|
+
|
|
236
|
+
return {"error": result.error}
|
|
237
|
+
|
|
238
|
+
# =========================================================================
|
|
239
|
+
# Code Search & Understanding
|
|
240
|
+
# =========================================================================
|
|
241
|
+
|
|
242
|
+
def search_code(self, query: str, language: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
243
|
+
"""
|
|
244
|
+
Search codebase
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
query: Search query
|
|
248
|
+
language: Filter by language
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
List of matching files
|
|
252
|
+
"""
|
|
253
|
+
return self.context.search_files(query, language)
|
|
254
|
+
|
|
255
|
+
def explain_code(self, file_path: Path) -> str:
|
|
256
|
+
"""Explain what code does"""
|
|
257
|
+
with open(file_path, 'r') as f:
|
|
258
|
+
code = f.read()
|
|
259
|
+
|
|
260
|
+
language = self.context._detect_language(file_path.suffix)
|
|
261
|
+
|
|
262
|
+
result = self.rlm.analyze_code(
|
|
263
|
+
code=code,
|
|
264
|
+
analysis_type="explain",
|
|
265
|
+
language=language
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return result.response if result.success else f"Error: {result.error}"
|
|
269
|
+
|
|
270
|
+
# =========================================================================
|
|
271
|
+
# Statistics & History
|
|
272
|
+
# =========================================================================
|
|
273
|
+
|
|
274
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
275
|
+
"""Get usage statistics"""
|
|
276
|
+
exec_stats = self.context.get_execution_stats()
|
|
277
|
+
overview = self.context.get_project_overview()
|
|
278
|
+
|
|
279
|
+
return {
|
|
280
|
+
"project": {
|
|
281
|
+
"name": overview["name"],
|
|
282
|
+
"files": overview["total_files"],
|
|
283
|
+
"lines": overview["total_lines"],
|
|
284
|
+
"languages": overview["languages"]
|
|
285
|
+
},
|
|
286
|
+
"executions": exec_stats,
|
|
287
|
+
"rlm": self.rlm.get_stats()
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
def get_history(self, limit: int = 10) -> List[Dict[str, Any]]:
|
|
291
|
+
"""Get recent execution history"""
|
|
292
|
+
executions = self.db.get_recent_executions(
|
|
293
|
+
project_id=self.context.project.id,
|
|
294
|
+
limit=limit
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return [
|
|
298
|
+
{
|
|
299
|
+
"task": e.task[:100] + "..." if len(e.task) > 100 else e.task,
|
|
300
|
+
"status": e.status,
|
|
301
|
+
"cost": e.total_cost,
|
|
302
|
+
"time": e.total_time,
|
|
303
|
+
"timestamp": e.started_at
|
|
304
|
+
}
|
|
305
|
+
for e in executions
|
|
306
|
+
]
|
|
307
|
+
|
|
308
|
+
# =========================================================================
|
|
309
|
+
# Project Management
|
|
310
|
+
# =========================================================================
|
|
311
|
+
|
|
312
|
+
def reindex_project(self, force: bool = True) -> Dict[str, Any]:
|
|
313
|
+
"""Re-index project"""
|
|
314
|
+
print("Re-indexing project...")
|
|
315
|
+
stats = self.context.index_project(force=force)
|
|
316
|
+
print(f"✓ Indexed {stats['indexed']} files")
|
|
317
|
+
return stats
|
|
318
|
+
|
|
319
|
+
def get_project_info(self) -> Dict[str, Any]:
|
|
320
|
+
"""Get project information"""
|
|
321
|
+
return self.context.get_project_overview()
|
|
322
|
+
|
|
323
|
+
# =========================================================================
|
|
324
|
+
# Utilities
|
|
325
|
+
# =========================================================================
|
|
326
|
+
|
|
327
|
+
def _format_language_stats(self, language_stats: Dict[str, Dict[str, int]]) -> str:
|
|
328
|
+
"""Format language statistics"""
|
|
329
|
+
lines = []
|
|
330
|
+
for lang, stats in sorted(language_stats.items(), key=lambda x: x[1]['lines'], reverse=True):
|
|
331
|
+
lines.append(f" {lang}: {stats['files']} files, {stats['lines']:,} lines")
|
|
332
|
+
return "\n".join(lines)
|
|
333
|
+
|
|
334
|
+
def reset_rlm(self) -> None:
|
|
335
|
+
"""Reset RLM environment"""
|
|
336
|
+
self.rlm.reset()
|
|
337
|
+
|
|
338
|
+
def __repr__(self) -> str:
|
|
339
|
+
return f"GroknrollAgent(project={self.project_path.name})"
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Communication utilities for RLM socket protocol.
|
|
3
|
+
|
|
4
|
+
Protocol: 4-byte big-endian length prefix + JSON payload.
|
|
5
|
+
Used for communication between LMHandler and environment subprocesses.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import socket
|
|
10
|
+
import struct
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from groknroll.core.types import RLMChatCompletion
|
|
15
|
+
|
|
16
|
+
# =============================================================================
|
|
17
|
+
# Message Dataclasses
|
|
18
|
+
# =============================================================================
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class LMRequest:
|
|
23
|
+
"""Request message sent to the LM Handler.
|
|
24
|
+
|
|
25
|
+
Supports both single prompt (prompt field) and batched prompts (prompts field).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
prompt: str | dict[str, Any] | None = None
|
|
29
|
+
prompts: list[str | dict[str, Any]] | None = None
|
|
30
|
+
model: str | None = None
|
|
31
|
+
depth: int = 0
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def is_batched(self) -> bool:
|
|
35
|
+
"""Check if this is a batched request."""
|
|
36
|
+
return self.prompts is not None and len(self.prompts) > 0
|
|
37
|
+
|
|
38
|
+
def to_dict(self) -> dict:
|
|
39
|
+
"""Convert to dict, excluding None values."""
|
|
40
|
+
d = {}
|
|
41
|
+
if self.prompt is not None:
|
|
42
|
+
d["prompt"] = self.prompt
|
|
43
|
+
if self.prompts is not None:
|
|
44
|
+
d["prompts"] = self.prompts
|
|
45
|
+
if self.model is not None:
|
|
46
|
+
d["model"] = self.model
|
|
47
|
+
d["depth"] = self.depth
|
|
48
|
+
return d
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def from_dict(cls, data: dict) -> "LMRequest":
|
|
52
|
+
"""Create from dict."""
|
|
53
|
+
return cls(
|
|
54
|
+
prompt=data.get("prompt"),
|
|
55
|
+
prompts=data.get("prompts"),
|
|
56
|
+
model=data.get("model"),
|
|
57
|
+
depth=data.get("depth", -1), # TODO: Default should throw an error
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class LMResponse:
|
|
63
|
+
"""Response message from the LM Handler.
|
|
64
|
+
|
|
65
|
+
Supports both single response (chat_completion) and batched responses (chat_completions).
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
error: str | None = None
|
|
69
|
+
chat_completion: RLMChatCompletion | None = None
|
|
70
|
+
chat_completions: list[RLMChatCompletion] | None = None
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def success(self) -> bool:
|
|
74
|
+
"""Check if response was successful."""
|
|
75
|
+
return self.error is None
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def is_batched(self) -> bool:
|
|
79
|
+
"""Check if this is a batched response."""
|
|
80
|
+
return self.chat_completions is not None
|
|
81
|
+
|
|
82
|
+
def to_dict(self) -> dict:
|
|
83
|
+
"""Convert to dict, excluding None values."""
|
|
84
|
+
if self.error is not None:
|
|
85
|
+
return {
|
|
86
|
+
"error": self.error,
|
|
87
|
+
"chat_completion": None,
|
|
88
|
+
"chat_completions": None,
|
|
89
|
+
}
|
|
90
|
+
if self.chat_completions is not None:
|
|
91
|
+
return {
|
|
92
|
+
"chat_completions": [c.to_dict() for c in self.chat_completions],
|
|
93
|
+
"chat_completion": None,
|
|
94
|
+
"error": None,
|
|
95
|
+
}
|
|
96
|
+
if self.chat_completion is not None:
|
|
97
|
+
return {
|
|
98
|
+
"chat_completion": self.chat_completion.to_dict(),
|
|
99
|
+
"chat_completions": None,
|
|
100
|
+
"error": None,
|
|
101
|
+
}
|
|
102
|
+
return {
|
|
103
|
+
"error": "No chat completion or error provided.",
|
|
104
|
+
"chat_completion": None,
|
|
105
|
+
"chat_completions": None,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def from_dict(cls, data: dict) -> "LMResponse":
|
|
110
|
+
"""Create from dict."""
|
|
111
|
+
chat_completions = None
|
|
112
|
+
if data.get("chat_completions"):
|
|
113
|
+
chat_completions = [RLMChatCompletion.from_dict(c) for c in data["chat_completions"]]
|
|
114
|
+
|
|
115
|
+
chat_completion = None
|
|
116
|
+
if data.get("chat_completion"):
|
|
117
|
+
chat_completion = RLMChatCompletion.from_dict(data["chat_completion"])
|
|
118
|
+
|
|
119
|
+
return cls(
|
|
120
|
+
error=data.get("error"),
|
|
121
|
+
chat_completion=chat_completion,
|
|
122
|
+
chat_completions=chat_completions,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def success_response(cls, chat_completion: RLMChatCompletion) -> "LMResponse":
|
|
127
|
+
"""Create a successful single response."""
|
|
128
|
+
return cls(chat_completion=chat_completion)
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def batched_success_response(cls, chat_completions: list[RLMChatCompletion]) -> "LMResponse":
|
|
132
|
+
"""Create a successful batched response."""
|
|
133
|
+
return cls(chat_completions=chat_completions)
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def error_response(cls, error: str) -> "LMResponse":
|
|
137
|
+
"""Create an error response."""
|
|
138
|
+
return cls(error=error)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# =============================================================================
|
|
142
|
+
# Socket Protocol Helpers
|
|
143
|
+
# =============================================================================
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def socket_send(sock: socket.socket, data: dict) -> None:
|
|
147
|
+
"""Send a length-prefixed JSON message over socket.
|
|
148
|
+
|
|
149
|
+
Protocol: 4-byte big-endian length prefix + UTF-8 JSON payload.
|
|
150
|
+
"""
|
|
151
|
+
payload = json.dumps(data).encode("utf-8")
|
|
152
|
+
sock.sendall(struct.pack(">I", len(payload)) + payload)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def socket_recv(sock: socket.socket) -> dict:
|
|
156
|
+
"""Receive a length-prefixed JSON message from socket.
|
|
157
|
+
|
|
158
|
+
Protocol: 4-byte big-endian length prefix + UTF-8 JSON payload.
|
|
159
|
+
Returns empty dict if connection closed before length received.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
ConnectionError: If connection closes mid-message.
|
|
163
|
+
"""
|
|
164
|
+
raw_len = sock.recv(4)
|
|
165
|
+
if not raw_len:
|
|
166
|
+
return {}
|
|
167
|
+
|
|
168
|
+
length = struct.unpack(">I", raw_len)[0]
|
|
169
|
+
payload = b""
|
|
170
|
+
while len(payload) < length:
|
|
171
|
+
chunk = sock.recv(length - len(payload))
|
|
172
|
+
if not chunk:
|
|
173
|
+
raise ConnectionError("Connection closed before message complete")
|
|
174
|
+
payload += chunk
|
|
175
|
+
|
|
176
|
+
return json.loads(payload.decode("utf-8"))
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def socket_request(address: tuple[str, int], data: dict, timeout: int = 300) -> dict:
|
|
180
|
+
"""Send a request and receive a response over a new socket connection.
|
|
181
|
+
|
|
182
|
+
Opens a new TCP connection, sends the request, waits for response, then closes.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
address: (host, port) tuple to connect to.
|
|
186
|
+
data: Dictionary to send as JSON.
|
|
187
|
+
timeout: Socket timeout in seconds (default 300).
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Response dictionary.
|
|
191
|
+
"""
|
|
192
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
193
|
+
sock.settimeout(timeout)
|
|
194
|
+
sock.connect(address)
|
|
195
|
+
socket_send(sock, data)
|
|
196
|
+
return socket_recv(sock)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# =============================================================================
|
|
200
|
+
# Typed Request Helpers
|
|
201
|
+
# =============================================================================
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def send_lm_request(
|
|
205
|
+
address: tuple[str, int], request: LMRequest, timeout: int = 300, depth: int | None = None
|
|
206
|
+
) -> LMResponse:
|
|
207
|
+
"""Send an LM request and return typed response.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
address: (host, port) tuple of LM Handler server.
|
|
211
|
+
request: LMRequest to send.
|
|
212
|
+
timeout: Socket timeout in seconds.
|
|
213
|
+
depth: Optional depth to override request depth.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
LMResponse with content or error.
|
|
217
|
+
"""
|
|
218
|
+
try:
|
|
219
|
+
if depth is not None:
|
|
220
|
+
request.depth = depth
|
|
221
|
+
response_data = socket_request(address, request.to_dict(), timeout)
|
|
222
|
+
return LMResponse.from_dict(response_data)
|
|
223
|
+
except Exception as e:
|
|
224
|
+
return LMResponse.error_response(f"Request failed: {e}")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def send_lm_request_batched(
|
|
228
|
+
address: tuple[str, int],
|
|
229
|
+
prompts: list[str | dict[str, Any]],
|
|
230
|
+
model: str | None = None,
|
|
231
|
+
timeout: int = 300,
|
|
232
|
+
depth: int = 0,
|
|
233
|
+
) -> list[LMResponse]:
|
|
234
|
+
"""Send a batched LM request and return a list of typed responses.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
address: (host, port) tuple of LM Handler server.
|
|
238
|
+
prompts: List of prompts to send.
|
|
239
|
+
model: Optional model name to use.
|
|
240
|
+
timeout: Socket timeout in seconds.
|
|
241
|
+
depth: Depth for routing (default 0).
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
List of LMResponse objects, one per prompt, in the same order.
|
|
245
|
+
"""
|
|
246
|
+
try:
|
|
247
|
+
request = LMRequest(prompts=prompts, model=model, depth=depth)
|
|
248
|
+
response_data = socket_request(address, request.to_dict(), timeout)
|
|
249
|
+
response = LMResponse.from_dict(response_data)
|
|
250
|
+
|
|
251
|
+
if not response.success:
|
|
252
|
+
# Return error responses for all prompts
|
|
253
|
+
return [LMResponse.error_response(response.error)] * len(prompts)
|
|
254
|
+
|
|
255
|
+
if response.chat_completions is None:
|
|
256
|
+
return [LMResponse.error_response("No completions returned")] * len(prompts)
|
|
257
|
+
|
|
258
|
+
# Convert batched response to list of individual responses
|
|
259
|
+
return [
|
|
260
|
+
LMResponse.success_response(chat_completion)
|
|
261
|
+
for chat_completion in response.chat_completions
|
|
262
|
+
]
|
|
263
|
+
except Exception as e:
|
|
264
|
+
return [LMResponse.error_response(f"Request failed: {e}")] * len(prompts)
|