ctrlcode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctrlcode/__init__.py +8 -0
- ctrlcode/agents/__init__.py +29 -0
- ctrlcode/agents/cleanup.py +388 -0
- ctrlcode/agents/communication.py +439 -0
- ctrlcode/agents/observability.py +421 -0
- ctrlcode/agents/react_loop.py +297 -0
- ctrlcode/agents/registry.py +211 -0
- ctrlcode/agents/result_parser.py +242 -0
- ctrlcode/agents/workflow.py +723 -0
- ctrlcode/analysis/__init__.py +28 -0
- ctrlcode/analysis/ast_diff.py +163 -0
- ctrlcode/analysis/bug_detector.py +149 -0
- ctrlcode/analysis/code_graphs.py +329 -0
- ctrlcode/analysis/semantic.py +205 -0
- ctrlcode/analysis/static.py +183 -0
- ctrlcode/analysis/synthesizer.py +281 -0
- ctrlcode/analysis/tests.py +189 -0
- ctrlcode/cleanup/__init__.py +16 -0
- ctrlcode/cleanup/auto_merge.py +350 -0
- ctrlcode/cleanup/doc_gardening.py +388 -0
- ctrlcode/cleanup/pr_automation.py +330 -0
- ctrlcode/cleanup/scheduler.py +356 -0
- ctrlcode/config.py +380 -0
- ctrlcode/embeddings/__init__.py +6 -0
- ctrlcode/embeddings/embedder.py +192 -0
- ctrlcode/embeddings/vector_store.py +213 -0
- ctrlcode/fuzzing/__init__.py +24 -0
- ctrlcode/fuzzing/analyzer.py +280 -0
- ctrlcode/fuzzing/budget.py +112 -0
- ctrlcode/fuzzing/context.py +665 -0
- ctrlcode/fuzzing/context_fuzzer.py +506 -0
- ctrlcode/fuzzing/derived_orchestrator.py +732 -0
- ctrlcode/fuzzing/oracle_adapter.py +135 -0
- ctrlcode/linters/__init__.py +11 -0
- ctrlcode/linters/hand_rolled_utils.py +221 -0
- ctrlcode/linters/yolo_parsing.py +217 -0
- ctrlcode/metrics/__init__.py +6 -0
- ctrlcode/metrics/dashboard.py +283 -0
- ctrlcode/metrics/tech_debt.py +663 -0
- ctrlcode/paths.py +68 -0
- ctrlcode/permissions.py +179 -0
- ctrlcode/providers/__init__.py +15 -0
- ctrlcode/providers/anthropic.py +138 -0
- ctrlcode/providers/base.py +77 -0
- ctrlcode/providers/openai.py +197 -0
- ctrlcode/providers/parallel.py +104 -0
- ctrlcode/server.py +871 -0
- ctrlcode/session/__init__.py +6 -0
- ctrlcode/session/baseline.py +57 -0
- ctrlcode/session/manager.py +967 -0
- ctrlcode/skills/__init__.py +10 -0
- ctrlcode/skills/builtin/commit.toml +29 -0
- ctrlcode/skills/builtin/docs.toml +25 -0
- ctrlcode/skills/builtin/refactor.toml +33 -0
- ctrlcode/skills/builtin/review.toml +28 -0
- ctrlcode/skills/builtin/test.toml +28 -0
- ctrlcode/skills/loader.py +111 -0
- ctrlcode/skills/registry.py +139 -0
- ctrlcode/storage/__init__.py +19 -0
- ctrlcode/storage/history_db.py +708 -0
- ctrlcode/tools/__init__.py +220 -0
- ctrlcode/tools/bash.py +112 -0
- ctrlcode/tools/browser.py +352 -0
- ctrlcode/tools/executor.py +153 -0
- ctrlcode/tools/explore.py +486 -0
- ctrlcode/tools/mcp.py +108 -0
- ctrlcode/tools/observability.py +561 -0
- ctrlcode/tools/registry.py +193 -0
- ctrlcode/tools/todo.py +291 -0
- ctrlcode/tools/update.py +266 -0
- ctrlcode/tools/webfetch.py +147 -0
- ctrlcode-0.1.0.dist-info/METADATA +93 -0
- ctrlcode-0.1.0.dist-info/RECORD +75 -0
- ctrlcode-0.1.0.dist-info/WHEEL +4 -0
- ctrlcode-0.1.0.dist-info/entry_points.txt +3 -0
ctrlcode/config.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
"""Configuration management for ctrl-code."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import secrets
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from ctrlcode.paths import get_cache_dir, get_config_dir
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import tomllib
|
|
14
|
+
except ImportError:
|
|
15
|
+
import tomli as tomllib # type: ignore # Python < 3.11
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ServerConfig:
|
|
20
|
+
"""Server configuration."""
|
|
21
|
+
|
|
22
|
+
host: str = "127.0.0.1"
|
|
23
|
+
port: int = 8765
|
|
24
|
+
api_key: str | None = None # API key for authentication (generated if not set)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ProviderConfig:
|
|
29
|
+
"""Provider configuration."""
|
|
30
|
+
|
|
31
|
+
api_key: str
|
|
32
|
+
model: str
|
|
33
|
+
base_url: str | None = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class EmbeddingsConfig:
|
|
38
|
+
"""Embeddings configuration."""
|
|
39
|
+
|
|
40
|
+
api_key: str
|
|
41
|
+
base_url: str
|
|
42
|
+
model: str = "text-embedding-3-small"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ContextConfig:
|
|
47
|
+
"""Context window management configuration."""
|
|
48
|
+
|
|
49
|
+
prune_protect: int = 40000
|
|
50
|
+
default_limit: int = 200000
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class FuzzingConfig:
|
|
55
|
+
"""Differential fuzzing configuration (derived context architecture)."""
|
|
56
|
+
|
|
57
|
+
enabled: bool = True
|
|
58
|
+
max_iterations: int = 10
|
|
59
|
+
budget_tokens: int = 100000
|
|
60
|
+
budget_seconds: int = 30
|
|
61
|
+
|
|
62
|
+
# Test distribution
|
|
63
|
+
input_fuzz_ratio: float = 0.3
|
|
64
|
+
environment_fuzz_ratio: float = 0.4
|
|
65
|
+
combined_fuzz_ratio: float = 0.2
|
|
66
|
+
invariant_fuzz_ratio: float = 0.1
|
|
67
|
+
|
|
68
|
+
# Oracle settings
|
|
69
|
+
oracle_confidence_threshold: float = 0.8 # Min confidence to apply oracle corrections
|
|
70
|
+
context_re_derivation_on_mismatch: bool = True
|
|
71
|
+
|
|
72
|
+
# Skip fuzzing conditions
|
|
73
|
+
simple_edits: bool = True
|
|
74
|
+
under_lines: int = 10
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class SecurityConfig:
|
|
79
|
+
"""Security configuration."""
|
|
80
|
+
|
|
81
|
+
# Input validation
|
|
82
|
+
max_input_length: int = 200000 # 200K chars (roughly matches context window), -1 = unlimited
|
|
83
|
+
|
|
84
|
+
# Rate limiting
|
|
85
|
+
rate_limit_enabled: bool = True
|
|
86
|
+
rate_limit_window_seconds: float = 60.0 # 1 minute
|
|
87
|
+
rate_limit_max_requests: int = 30 # Requests per window
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class MCPServerConfig:
|
|
92
|
+
"""MCP server configuration."""
|
|
93
|
+
|
|
94
|
+
name: str
|
|
95
|
+
command: list[str]
|
|
96
|
+
env: dict[str, str] = field(default_factory=dict)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class Config:
|
|
101
|
+
"""Main configuration."""
|
|
102
|
+
|
|
103
|
+
server: ServerConfig = field(default_factory=ServerConfig)
|
|
104
|
+
anthropic: Optional[ProviderConfig] = None
|
|
105
|
+
openai: Optional[ProviderConfig] = None
|
|
106
|
+
embeddings: Optional[EmbeddingsConfig] = None
|
|
107
|
+
context: ContextConfig = field(default_factory=ContextConfig)
|
|
108
|
+
fuzzing: FuzzingConfig = field(default_factory=FuzzingConfig)
|
|
109
|
+
security: SecurityConfig = field(default_factory=SecurityConfig)
|
|
110
|
+
mcp_servers: list[MCPServerConfig] = field(default_factory=list)
|
|
111
|
+
skills_directory: Optional[Path] = None
|
|
112
|
+
storage_path: Optional[Path] = None
|
|
113
|
+
workspace_root: Optional[Path] = None
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def _validate_config_permissions(config_path: Path) -> None:
|
|
117
|
+
"""
|
|
118
|
+
Validate config file has secure permissions.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
config_path: Path to config file
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: If permissions are insecure
|
|
125
|
+
"""
|
|
126
|
+
import stat
|
|
127
|
+
|
|
128
|
+
# Check file permissions
|
|
129
|
+
file_stat = config_path.stat()
|
|
130
|
+
mode = file_stat.st_mode
|
|
131
|
+
|
|
132
|
+
# Check if world-writable or group-writable
|
|
133
|
+
if mode & stat.S_IWOTH:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f"Config file {config_path} is world-writable (insecure). "
|
|
136
|
+
f"Run: chmod 600 {config_path}"
|
|
137
|
+
)
|
|
138
|
+
if mode & stat.S_IWGRP:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
f"Config file {config_path} is group-writable (insecure). "
|
|
141
|
+
f"Run: chmod 600 {config_path}"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def load(cls, config_path: Optional[Path] = None) -> "Config":
|
|
146
|
+
"""
|
|
147
|
+
Load configuration from TOML file.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
config_path: Path to config file (default: ~/.config/ctrlcode/config.toml)
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Config instance
|
|
154
|
+
"""
|
|
155
|
+
if config_path is None:
|
|
156
|
+
config_path = get_config_dir() / "config.toml"
|
|
157
|
+
|
|
158
|
+
if not config_path.exists():
|
|
159
|
+
# Return default config
|
|
160
|
+
return cls()
|
|
161
|
+
|
|
162
|
+
# Security: validate file permissions before loading
|
|
163
|
+
cls._validate_config_permissions(config_path)
|
|
164
|
+
|
|
165
|
+
with open(config_path, "rb") as f:
|
|
166
|
+
data = tomllib.load(f)
|
|
167
|
+
|
|
168
|
+
# Parse server config
|
|
169
|
+
server_data = data.get("server", {})
|
|
170
|
+
|
|
171
|
+
# Generate API key if not present
|
|
172
|
+
if "api_key" not in server_data or not server_data["api_key"]:
|
|
173
|
+
server_data["api_key"] = secrets.token_urlsafe(32)
|
|
174
|
+
# Save generated key back to config
|
|
175
|
+
if config_path:
|
|
176
|
+
cls._save_api_key(config_path, server_data["api_key"])
|
|
177
|
+
|
|
178
|
+
server = ServerConfig(**server_data)
|
|
179
|
+
|
|
180
|
+
# Parse provider configs
|
|
181
|
+
anthropic = None
|
|
182
|
+
if "providers" in data and "anthropic" in data["providers"]:
|
|
183
|
+
anthropic_data = data["providers"]["anthropic"]
|
|
184
|
+
# Get API key from env or config
|
|
185
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY", anthropic_data.get("api_key"))
|
|
186
|
+
if api_key:
|
|
187
|
+
anthropic = ProviderConfig(
|
|
188
|
+
api_key=api_key,
|
|
189
|
+
model=anthropic_data.get("model", "claude-sonnet-4-5-20250929"),
|
|
190
|
+
base_url=anthropic_data.get("base_url")
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
openai = None
|
|
194
|
+
if "providers" in data and "openai" in data["providers"]:
|
|
195
|
+
openai_data = data["providers"]["openai"]
|
|
196
|
+
# Get API key from env or config
|
|
197
|
+
api_key = os.environ.get("OPENAI_API_KEY", openai_data.get("api_key"))
|
|
198
|
+
if api_key:
|
|
199
|
+
openai = ProviderConfig(
|
|
200
|
+
api_key=api_key,
|
|
201
|
+
model=openai_data.get("model", "gpt-4"),
|
|
202
|
+
base_url=openai_data.get("base_url")
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Parse embeddings config
|
|
206
|
+
embeddings = None
|
|
207
|
+
if "embeddings" in data:
|
|
208
|
+
embeddings_data = data["embeddings"]
|
|
209
|
+
# Get API key from env or config
|
|
210
|
+
api_key = os.environ.get("EMBEDDINGS_API_KEY", embeddings_data.get("api_key"))
|
|
211
|
+
base_url = embeddings_data.get("base_url")
|
|
212
|
+
if api_key and base_url:
|
|
213
|
+
embeddings = EmbeddingsConfig(
|
|
214
|
+
api_key=api_key,
|
|
215
|
+
base_url=base_url,
|
|
216
|
+
model=embeddings_data.get("model", "text-embedding-3-small")
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Parse context config
|
|
220
|
+
context = ContextConfig(**data.get("context", {}))
|
|
221
|
+
|
|
222
|
+
# Parse security config
|
|
223
|
+
security = SecurityConfig(**data.get("security", {}))
|
|
224
|
+
|
|
225
|
+
# Parse fuzzing config
|
|
226
|
+
fuzzing_data = data.get("fuzzing", {})
|
|
227
|
+
fuzzing = FuzzingConfig(
|
|
228
|
+
enabled=fuzzing_data.get("enabled", True),
|
|
229
|
+
max_iterations=fuzzing_data.get("max_iterations", 10),
|
|
230
|
+
budget_tokens=fuzzing_data.get("budget_tokens", 100000),
|
|
231
|
+
budget_seconds=fuzzing_data.get("budget_seconds", 30),
|
|
232
|
+
input_fuzz_ratio=fuzzing_data.get("distribution", {}).get("input_fuzz_ratio", 0.3),
|
|
233
|
+
environment_fuzz_ratio=fuzzing_data.get("distribution", {}).get("environment_fuzz_ratio", 0.4),
|
|
234
|
+
combined_fuzz_ratio=fuzzing_data.get("distribution", {}).get("combined_fuzz_ratio", 0.2),
|
|
235
|
+
invariant_fuzz_ratio=fuzzing_data.get("distribution", {}).get("invariant_fuzz_ratio", 0.1),
|
|
236
|
+
oracle_confidence_threshold=fuzzing_data.get("oracle", {}).get("confidence_threshold", 0.8),
|
|
237
|
+
context_re_derivation_on_mismatch=fuzzing_data.get("oracle", {}).get("re_derivation_on_mismatch", True),
|
|
238
|
+
simple_edits=fuzzing_data.get("skip", {}).get("simple_edits", True),
|
|
239
|
+
under_lines=fuzzing_data.get("skip", {}).get("under_lines", 10),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Parse MCP servers with security validation
|
|
243
|
+
mcp_servers = []
|
|
244
|
+
for server_data in data.get("mcp", {}).get("servers", []):
|
|
245
|
+
# Security: validate MCP server executable
|
|
246
|
+
command = server_data["command"]
|
|
247
|
+
if not command:
|
|
248
|
+
raise ValueError(f"MCP server '{server_data['name']}' has empty command")
|
|
249
|
+
|
|
250
|
+
# IMPORTANT: Expand all paths FIRST, then validate the expanded result
|
|
251
|
+
# This prevents TOCTOU (Time of Check Time of Use) vulnerabilities
|
|
252
|
+
expanded_command = [
|
|
253
|
+
str(Path(part).expanduser()) if part.startswith("~") else part
|
|
254
|
+
for part in command
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
executable = expanded_command[0]
|
|
258
|
+
|
|
259
|
+
# Resolve to absolute path (follows symlinks, removes ..)
|
|
260
|
+
exe_path = Path(executable).resolve()
|
|
261
|
+
|
|
262
|
+
# Security: Check for path traversal attempts
|
|
263
|
+
# After resolve(), if path still contains .., something is wrong
|
|
264
|
+
if ".." in exe_path.parts:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
f"MCP server '{server_data['name']}' executable contains path traversal: {executable}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Allow executables from trusted locations
|
|
270
|
+
trusted_prefixes = [
|
|
271
|
+
Path(sys.prefix).resolve() / "bin", # Virtual environment
|
|
272
|
+
Path(sys.prefix).resolve() / "Scripts", # Windows venv
|
|
273
|
+
(Path.home() / ".local" / "bin").resolve(), # User installs
|
|
274
|
+
Path("/usr/local/bin").resolve(), # System-wide
|
|
275
|
+
Path("/usr/bin").resolve(), # System
|
|
276
|
+
(Path(get_config_dir()) / "mcp-servers").resolve(), # User MCP servers directory
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
# Check if resolved path is within any trusted prefix
|
|
280
|
+
is_trusted = any(
|
|
281
|
+
exe_path.is_relative_to(prefix)
|
|
282
|
+
for prefix in trusted_prefixes
|
|
283
|
+
if prefix.exists()
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
if not is_trusted:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
f"MCP server '{server_data['name']}' executable '{exe_path}' "
|
|
289
|
+
f"is not in a trusted location. Expected in: "
|
|
290
|
+
f"{', '.join(str(p) for p in trusted_prefixes if p.exists())}"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Verify it's an executable file
|
|
294
|
+
if not exe_path.is_file():
|
|
295
|
+
raise ValueError(f"MCP server '{server_data['name']}' executable '{exe_path}' is not a file")
|
|
296
|
+
|
|
297
|
+
if not os.access(exe_path, os.X_OK):
|
|
298
|
+
raise ValueError(f"MCP server '{server_data['name']}' executable '{exe_path}' is not executable")
|
|
299
|
+
|
|
300
|
+
# Update command with resolved absolute path (prevents symlink attacks)
|
|
301
|
+
expanded_command[0] = str(exe_path)
|
|
302
|
+
|
|
303
|
+
mcp_servers.append(MCPServerConfig(
|
|
304
|
+
name=server_data["name"],
|
|
305
|
+
command=expanded_command,
|
|
306
|
+
env=server_data.get("env", {})
|
|
307
|
+
))
|
|
308
|
+
|
|
309
|
+
# Parse skills directory
|
|
310
|
+
skills_dir = None
|
|
311
|
+
if "skills" in data and "directory" in data["skills"]:
|
|
312
|
+
skills_dir = Path(data["skills"]["directory"]).expanduser()
|
|
313
|
+
|
|
314
|
+
# Storage path
|
|
315
|
+
storage_path = get_cache_dir() / "conversations"
|
|
316
|
+
|
|
317
|
+
# Workspace root (default to cwd, can be overridden in config)
|
|
318
|
+
workspace_root = Path.cwd()
|
|
319
|
+
if "workspace" in data and "root" in data["workspace"]:
|
|
320
|
+
workspace_root = Path(data["workspace"]["root"]).expanduser()
|
|
321
|
+
|
|
322
|
+
return cls(
|
|
323
|
+
server=server,
|
|
324
|
+
anthropic=anthropic,
|
|
325
|
+
openai=openai,
|
|
326
|
+
embeddings=embeddings,
|
|
327
|
+
context=context,
|
|
328
|
+
fuzzing=fuzzing,
|
|
329
|
+
security=security,
|
|
330
|
+
mcp_servers=mcp_servers,
|
|
331
|
+
skills_directory=skills_dir,
|
|
332
|
+
storage_path=storage_path,
|
|
333
|
+
workspace_root=workspace_root,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
@staticmethod
|
|
337
|
+
def _save_api_key(config_path: Path, api_key: str) -> None:
|
|
338
|
+
"""
|
|
339
|
+
Save generated API key to config file.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
config_path: Path to config file
|
|
343
|
+
api_key: Generated API key
|
|
344
|
+
"""
|
|
345
|
+
import tomllib
|
|
346
|
+
|
|
347
|
+
# Read existing config
|
|
348
|
+
if config_path.exists():
|
|
349
|
+
with open(config_path, "rb") as f:
|
|
350
|
+
data = tomllib.load(f)
|
|
351
|
+
else:
|
|
352
|
+
data = {}
|
|
353
|
+
|
|
354
|
+
# Add API key to server section
|
|
355
|
+
if "server" not in data:
|
|
356
|
+
data["server"] = {}
|
|
357
|
+
data["server"]["api_key"] = api_key
|
|
358
|
+
|
|
359
|
+
# Write back as TOML
|
|
360
|
+
with open(config_path, "w") as f:
|
|
361
|
+
f.write("# Ctrl+Code Configuration\n\n")
|
|
362
|
+
f.write("[server]\n")
|
|
363
|
+
for key, value in data.get("server", {}).items():
|
|
364
|
+
if isinstance(value, str):
|
|
365
|
+
f.write(f'{key} = "{value}"\n')
|
|
366
|
+
else:
|
|
367
|
+
f.write(f'{key} = {value}\n')
|
|
368
|
+
|
|
369
|
+
# Write other sections if they exist
|
|
370
|
+
for section, content in data.items():
|
|
371
|
+
if section != "server" and isinstance(content, dict):
|
|
372
|
+
f.write(f"\n[{section}]\n")
|
|
373
|
+
for key, value in content.items():
|
|
374
|
+
if isinstance(value, str):
|
|
375
|
+
f.write(f'{key} = "{value}"\n')
|
|
376
|
+
else:
|
|
377
|
+
f.write(f'{key} = {value}\n')
|
|
378
|
+
|
|
379
|
+
# Set secure file permissions (user read/write only)
|
|
380
|
+
config_path.chmod(0o600)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""Code embedding generation using provider's OpenAI-compatible embeddings endpoint."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CodeEmbedder:
|
|
12
|
+
"""Generate semantic embeddings for code, oracles, tests, and bug patterns.
|
|
13
|
+
|
|
14
|
+
Uses OpenAI-compatible embeddings endpoint (/v1/embeddings).
|
|
15
|
+
Works with any OpenAI-compatible server (vLLM, LM Studio, Ollama, etc.).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
api_key: str,
|
|
21
|
+
base_url: str,
|
|
22
|
+
model_name: str = "text-embedding-3-small",
|
|
23
|
+
normalize: bool = True,
|
|
24
|
+
):
|
|
25
|
+
"""Initialize embedder with embeddings endpoint configuration.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
api_key: API key for embeddings endpoint
|
|
29
|
+
base_url: Base URL for OpenAI-compatible embeddings API
|
|
30
|
+
model_name: Embedding model name (default: text-embedding-3-small)
|
|
31
|
+
normalize: L2-normalize embeddings for cosine similarity
|
|
32
|
+
"""
|
|
33
|
+
if not api_key or not base_url:
|
|
34
|
+
raise ValueError("api_key and base_url required for embeddings")
|
|
35
|
+
|
|
36
|
+
self.api_key = api_key
|
|
37
|
+
self.base_url = base_url
|
|
38
|
+
self.model_name = model_name
|
|
39
|
+
self.normalize = normalize
|
|
40
|
+
self._embedding_dim: Optional[int] = None
|
|
41
|
+
self.backend = "api"
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def embedding_dim(self) -> int:
|
|
45
|
+
"""Dimension of embedding vectors."""
|
|
46
|
+
if self._embedding_dim is None:
|
|
47
|
+
# text-embedding-3-small: 1536 dimensions
|
|
48
|
+
# text-embedding-ada-002: 1536 dimensions
|
|
49
|
+
# Most OpenAI-compatible endpoints use 1536
|
|
50
|
+
self._embedding_dim = 1536
|
|
51
|
+
return self._embedding_dim
|
|
52
|
+
|
|
53
|
+
def embed_code(self, code: str) -> np.ndarray:
|
|
54
|
+
"""Embed source code into semantic vector.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
code: Source code string
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Normalized embedding vector (numpy array)
|
|
61
|
+
"""
|
|
62
|
+
return self._embed_single(code)
|
|
63
|
+
|
|
64
|
+
def embed_oracle(self, oracle: str) -> np.ndarray:
|
|
65
|
+
"""Embed behavioral oracle/invariants.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
oracle: Oracle specification text
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Normalized embedding vector
|
|
72
|
+
"""
|
|
73
|
+
return self._embed_single(oracle)
|
|
74
|
+
|
|
75
|
+
def embed_test_case(self, test_case: str) -> np.ndarray:
|
|
76
|
+
"""Embed test case code.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
test_case: Test code string
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Normalized embedding vector
|
|
83
|
+
"""
|
|
84
|
+
return self._embed_single(test_case)
|
|
85
|
+
|
|
86
|
+
def embed_bug_pattern(self, bug_description: str) -> np.ndarray:
|
|
87
|
+
"""Embed bug pattern description.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
bug_description: Bug description and context
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Normalized embedding vector
|
|
94
|
+
"""
|
|
95
|
+
return self._embed_single(bug_description)
|
|
96
|
+
|
|
97
|
+
def embed_batch(self, texts: list[str]) -> np.ndarray:
|
|
98
|
+
"""Embed multiple texts efficiently.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
texts: List of text strings to embed
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Array of embeddings (shape: [len(texts), embedding_dim])
|
|
105
|
+
"""
|
|
106
|
+
if not texts:
|
|
107
|
+
return np.array([])
|
|
108
|
+
|
|
109
|
+
return self._embed_batch_api(texts)
|
|
110
|
+
|
|
111
|
+
def _embed_single(self, text: str) -> np.ndarray:
|
|
112
|
+
"""Internal method to embed single text."""
|
|
113
|
+
return self._embed_single_api(text)
|
|
114
|
+
|
|
115
|
+
def _embed_single_api(self, text: str) -> np.ndarray:
|
|
116
|
+
"""Embed using OpenAI-compatible API endpoint."""
|
|
117
|
+
import openai
|
|
118
|
+
|
|
119
|
+
client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
response = client.embeddings.create(
|
|
123
|
+
input=text,
|
|
124
|
+
model=self.model_name
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Check if response has data
|
|
128
|
+
if not hasattr(response, 'data') or not response.data:
|
|
129
|
+
logger.error(f"No embedding data received. Response: {response}")
|
|
130
|
+
raise ValueError("No embedding data received")
|
|
131
|
+
|
|
132
|
+
# Extract embedding
|
|
133
|
+
embedding = np.array(response.data[0].embedding, dtype=np.float32)
|
|
134
|
+
|
|
135
|
+
# Normalize if requested
|
|
136
|
+
if self.normalize:
|
|
137
|
+
norm = np.linalg.norm(embedding)
|
|
138
|
+
if norm > 0:
|
|
139
|
+
embedding = embedding / norm
|
|
140
|
+
|
|
141
|
+
return embedding
|
|
142
|
+
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Embedding API error: {e}")
|
|
145
|
+
# Return zero vector as fallback
|
|
146
|
+
return np.zeros(self.embedding_dim, dtype=np.float32)
|
|
147
|
+
|
|
148
|
+
def _embed_batch_api(self, texts: list[str]) -> np.ndarray:
|
|
149
|
+
"""Embed batch using OpenAI-compatible API endpoint."""
|
|
150
|
+
import openai
|
|
151
|
+
|
|
152
|
+
client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
response = client.embeddings.create(
|
|
156
|
+
input=texts,
|
|
157
|
+
model=self.model_name
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Check if response has data
|
|
161
|
+
if not hasattr(response, 'data') or not response.data:
|
|
162
|
+
logger.error(f"No embedding data received. Response: {response}")
|
|
163
|
+
raise ValueError("No embedding data received")
|
|
164
|
+
|
|
165
|
+
# Extract embeddings in order
|
|
166
|
+
embeddings = [np.array(item.embedding, dtype=np.float32) for item in response.data]
|
|
167
|
+
embeddings = np.array(embeddings)
|
|
168
|
+
|
|
169
|
+
# Normalize if requested
|
|
170
|
+
if self.normalize:
|
|
171
|
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
172
|
+
embeddings = np.where(norms > 0, embeddings / norms, embeddings)
|
|
173
|
+
|
|
174
|
+
return embeddings
|
|
175
|
+
|
|
176
|
+
except Exception as e:
|
|
177
|
+
logger.error(f"Batch embedding API error: {e}")
|
|
178
|
+
# Return zero vectors as fallback
|
|
179
|
+
return np.zeros((len(texts), self.embedding_dim), dtype=np.float32)
|
|
180
|
+
|
|
181
|
+
def cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float:
|
|
182
|
+
"""Calculate cosine similarity between two embeddings.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
emb1: First embedding (normalized)
|
|
186
|
+
emb2: Second embedding (normalized)
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Similarity score in [0, 1] (assumes normalized embeddings)
|
|
190
|
+
"""
|
|
191
|
+
# For normalized vectors, cosine similarity = dot product
|
|
192
|
+
return float(np.dot(emb1, emb2))
|