switchforge 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- forge_core/__init__.py +3 -0
- forge_core/__main__.py +6 -0
- forge_core/ai/__init__.py +0 -0
- forge_core/ai/prompts.py +87 -0
- forge_core/ai/provider.py +121 -0
- forge_core/ai/structured.py +68 -0
- forge_core/auth.py +72 -0
- forge_core/cli.py +209 -0
- forge_core/config.py +118 -0
- forge_core/core/__init__.py +0 -0
- forge_core/core/agent_manager.py +111 -0
- forge_core/core/coverage.py +241 -0
- forge_core/core/file_manager.py +104 -0
- forge_core/models/__init__.py +0 -0
- forge_core/models/config.py +115 -0
- forge_core/models/dto.py +58 -0
- forge_core/models/project.py +72 -0
- forge_core/models/test_result.py +93 -0
- forge_core/orchestrator.py +233 -0
- forge_core/phases/__init__.py +0 -0
- forge_core/phases/analyze_project.py +149 -0
- forge_core/phases/audit_tests.py +35 -0
- forge_core/phases/compile_fix.py +79 -0
- forge_core/phases/coverage_report.py +19 -0
- forge_core/phases/detect_stack.py +66 -0
- forge_core/phases/exclusion_scan.py +74 -0
- forge_core/phases/fix_broken.py +83 -0
- forge_core/phases/generate_tests.py +218 -0
- forge_core/phases/journey_mapping.py +120 -0
- forge_core/phases/self_learn.py +93 -0
- forge_core/utils/__init__.py +0 -0
- forge_core/utils/logger.py +89 -0
- forge_core/utils/reporter.py +70 -0
- forge_core/utils/shell.py +93 -0
- forge_core/utils/tokens.py +67 -0
- switchforge-1.0.0.dist-info/METADATA +65 -0
- switchforge-1.0.0.dist-info/RECORD +39 -0
- switchforge-1.0.0.dist-info/WHEEL +4 -0
- switchforge-1.0.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Agent manager — heartbeat monitoring, stuck detection, scope splitting."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
from forge_core.utils import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AgentState:
|
|
13
|
+
"""Tracks the state of a generation agent."""
|
|
14
|
+
|
|
15
|
+
scope_id: str
|
|
16
|
+
target_files: list[str]
|
|
17
|
+
started_at: float = field(default_factory=time.time)
|
|
18
|
+
tool_calls: int = 0
|
|
19
|
+
last_progress_at: float = field(default_factory=time.time)
|
|
20
|
+
errors: list[str] = field(default_factory=list)
|
|
21
|
+
consecutive_same_error: int = 0
|
|
22
|
+
last_error: str = ""
|
|
23
|
+
completed: bool = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AgentManager:
|
|
27
|
+
"""Manages agent lifecycle with heartbeat and stuck detection.
|
|
28
|
+
|
|
29
|
+
Rules from copilot-instructions.md:
|
|
30
|
+
- Heartbeat every 10 tool calls
|
|
31
|
+
- Fruitless detection: same error 3x
|
|
32
|
+
- Auto-termination after 20 fruitless calls
|
|
33
|
+
- Scope splitting for stuck agents
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
MAX_FRUITLESS_CALLS = 20
|
|
37
|
+
SAME_ERROR_THRESHOLD = 3
|
|
38
|
+
HEARTBEAT_INTERVAL = 10
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
self._agents: dict[str, AgentState] = {}
|
|
42
|
+
|
|
43
|
+
def register(self, scope_id: str, target_files: list[str]) -> AgentState:
|
|
44
|
+
"""Register a new agent scope."""
|
|
45
|
+
state = AgentState(scope_id=scope_id, target_files=target_files)
|
|
46
|
+
self._agents[scope_id] = state
|
|
47
|
+
return state
|
|
48
|
+
|
|
49
|
+
def heartbeat(self, scope_id: str) -> str:
|
|
50
|
+
"""Record a heartbeat. Returns action: 'continue', 'split', or 'terminate'."""
|
|
51
|
+
state = self._agents.get(scope_id)
|
|
52
|
+
if not state:
|
|
53
|
+
return "terminate"
|
|
54
|
+
|
|
55
|
+
state.tool_calls += 1
|
|
56
|
+
|
|
57
|
+
# Check fruitless threshold
|
|
58
|
+
if state.consecutive_same_error >= self.SAME_ERROR_THRESHOLD:
|
|
59
|
+
logger.warn(f"Agent {scope_id}: same error {state.consecutive_same_error}x — splitting scope")
|
|
60
|
+
return "split"
|
|
61
|
+
|
|
62
|
+
if state.tool_calls >= self.MAX_FRUITLESS_CALLS and not state.last_progress_at > state.started_at:
|
|
63
|
+
logger.warn(f"Agent {scope_id}: {state.tool_calls} calls with no progress — terminating")
|
|
64
|
+
return "terminate"
|
|
65
|
+
|
|
66
|
+
# Periodic heartbeat log
|
|
67
|
+
if state.tool_calls % self.HEARTBEAT_INTERVAL == 0:
|
|
68
|
+
elapsed = time.time() - state.started_at
|
|
69
|
+
logger.info(f"Agent {scope_id}: {state.tool_calls} calls, {elapsed:.0f}s elapsed")
|
|
70
|
+
|
|
71
|
+
return "continue"
|
|
72
|
+
|
|
73
|
+
def record_error(self, scope_id: str, error: str) -> None:
|
|
74
|
+
"""Record an error from an agent."""
|
|
75
|
+
state = self._agents.get(scope_id)
|
|
76
|
+
if not state:
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
state.errors.append(error)
|
|
80
|
+
if error == state.last_error:
|
|
81
|
+
state.consecutive_same_error += 1
|
|
82
|
+
else:
|
|
83
|
+
state.consecutive_same_error = 1
|
|
84
|
+
state.last_error = error
|
|
85
|
+
|
|
86
|
+
def record_progress(self, scope_id: str) -> None:
|
|
87
|
+
"""Record that an agent made progress (wrote a file, coverage increased)."""
|
|
88
|
+
state = self._agents.get(scope_id)
|
|
89
|
+
if state:
|
|
90
|
+
state.last_progress_at = time.time()
|
|
91
|
+
state.consecutive_same_error = 0
|
|
92
|
+
|
|
93
|
+
def split_scope(self, scope_id: str) -> list[list[str]]:
|
|
94
|
+
"""Split a stuck agent's scope into smaller pieces."""
|
|
95
|
+
state = self._agents.get(scope_id)
|
|
96
|
+
if not state or len(state.target_files) <= 1:
|
|
97
|
+
return []
|
|
98
|
+
|
|
99
|
+
files = state.target_files
|
|
100
|
+
mid = len(files) // 2
|
|
101
|
+
return [files[:mid], files[mid:]]
|
|
102
|
+
|
|
103
|
+
def complete(self, scope_id: str) -> None:
|
|
104
|
+
"""Mark an agent as completed."""
|
|
105
|
+
state = self._agents.get(scope_id)
|
|
106
|
+
if state:
|
|
107
|
+
state.completed = True
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def active_count(self) -> int:
|
|
111
|
+
return sum(1 for s in self._agents.values() if not s.completed)
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""Coverage runner — execute test commands and parse coverage reports."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from forge_core.models.test_result import CoverageEntry, CoverageReport
|
|
9
|
+
from forge_core.utils import logger
|
|
10
|
+
from forge_core.utils.shell import run
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def run_tests(project_path: Path, test_command: str) -> tuple[bool, str]:
|
|
14
|
+
"""Run the project's test command. Returns (success, output)."""
|
|
15
|
+
result = run(test_command, cwd=project_path, timeout=600)
|
|
16
|
+
return result.success, result.output
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def run_coverage(project_path: Path, coverage_command: str) -> CoverageReport:
|
|
20
|
+
"""Run the project's coverage command and parse the report."""
|
|
21
|
+
result = run(coverage_command, cwd=project_path, timeout=600)
|
|
22
|
+
|
|
23
|
+
report = CoverageReport()
|
|
24
|
+
|
|
25
|
+
# Try parsing various coverage report formats
|
|
26
|
+
report = _try_parse_jacoco(project_path, report)
|
|
27
|
+
if report.total_lines == 0:
|
|
28
|
+
report = _try_parse_cobertura(project_path, report)
|
|
29
|
+
if report.total_lines == 0:
|
|
30
|
+
report = _try_parse_lcov(project_path, report)
|
|
31
|
+
if report.total_lines == 0:
|
|
32
|
+
report = _try_parse_pytest_cov(result.output, report)
|
|
33
|
+
if report.total_lines == 0:
|
|
34
|
+
report = _try_parse_go_cover(result.output, report)
|
|
35
|
+
if report.total_lines == 0:
|
|
36
|
+
# Fallback: try to extract percentage from output
|
|
37
|
+
report = _try_parse_generic(result.output, report)
|
|
38
|
+
|
|
39
|
+
# Count tests from output
|
|
40
|
+
test_counts = _parse_test_counts(result.output)
|
|
41
|
+
report.total_tests = test_counts.get("total", 0)
|
|
42
|
+
report.tests_passed = test_counts.get("passed", 0)
|
|
43
|
+
report.tests_failed = test_counts.get("failed", 0)
|
|
44
|
+
|
|
45
|
+
return report
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _try_parse_jacoco(project_path: Path, report: CoverageReport) -> CoverageReport:
|
|
49
|
+
"""Parse JaCoCo CSV or XML report (Kotlin/Java)."""
|
|
50
|
+
# Look for CSV report first
|
|
51
|
+
csv_patterns = [
|
|
52
|
+
"build/reports/jacoco/test/jacocoTestReport.csv",
|
|
53
|
+
"build/reports/jacoco/jacocoTestReport.csv",
|
|
54
|
+
"target/site/jacoco/jacoco.csv",
|
|
55
|
+
]
|
|
56
|
+
for pattern in csv_patterns:
|
|
57
|
+
csv_path = project_path / pattern
|
|
58
|
+
if csv_path.exists():
|
|
59
|
+
return _parse_jacoco_csv(csv_path, report)
|
|
60
|
+
|
|
61
|
+
# Try XML
|
|
62
|
+
xml_patterns = [
|
|
63
|
+
"build/reports/jacoco/test/jacocoTestReport.xml",
|
|
64
|
+
"target/site/jacoco/jacoco.xml",
|
|
65
|
+
]
|
|
66
|
+
for pattern in xml_patterns:
|
|
67
|
+
xml_path = project_path / pattern
|
|
68
|
+
if xml_path.exists():
|
|
69
|
+
return _parse_jacoco_xml(xml_path, report)
|
|
70
|
+
|
|
71
|
+
return report
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _parse_jacoco_csv(csv_path: Path, report: CoverageReport) -> CoverageReport:
|
|
75
|
+
"""Parse JaCoCo CSV report."""
|
|
76
|
+
import csv
|
|
77
|
+
|
|
78
|
+
total_missed = 0
|
|
79
|
+
total_covered = 0
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
with open(csv_path, encoding="utf-8") as f:
|
|
83
|
+
reader = csv.DictReader(f)
|
|
84
|
+
for row in reader:
|
|
85
|
+
missed = int(row.get("LINE_MISSED", 0))
|
|
86
|
+
covered = int(row.get("LINE_COVERED", 0))
|
|
87
|
+
total_missed += missed
|
|
88
|
+
total_covered += covered
|
|
89
|
+
|
|
90
|
+
pkg = row.get("PACKAGE", "")
|
|
91
|
+
cls = row.get("CLASS", "")
|
|
92
|
+
total = missed + covered
|
|
93
|
+
if total > 0:
|
|
94
|
+
report.files.append(
|
|
95
|
+
CoverageEntry(
|
|
96
|
+
file_path=f"{pkg}/{cls}",
|
|
97
|
+
line_coverage=round(covered / total * 100, 1),
|
|
98
|
+
lines_covered=covered,
|
|
99
|
+
lines_total=total,
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
total = total_missed + total_covered
|
|
104
|
+
report.total_lines = total
|
|
105
|
+
report.total_lines_covered = total_covered
|
|
106
|
+
if total > 0:
|
|
107
|
+
report.line_coverage = round(total_covered / total * 100, 1)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.warn(f"Failed to parse JaCoCo CSV: {e}")
|
|
110
|
+
|
|
111
|
+
return report
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _parse_jacoco_xml(xml_path: Path, report: CoverageReport) -> CoverageReport:
|
|
115
|
+
"""Parse JaCoCo XML report — extract LINE counter from root."""
|
|
116
|
+
import xml.etree.ElementTree as ET
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
tree = ET.parse(xml_path)
|
|
120
|
+
root = tree.getroot()
|
|
121
|
+
|
|
122
|
+
for counter in root.findall("counter"):
|
|
123
|
+
if counter.get("type") == "LINE":
|
|
124
|
+
missed = int(counter.get("missed", 0))
|
|
125
|
+
covered = int(counter.get("covered", 0))
|
|
126
|
+
total = missed + covered
|
|
127
|
+
report.total_lines = total
|
|
128
|
+
report.total_lines_covered = covered
|
|
129
|
+
if total > 0:
|
|
130
|
+
report.line_coverage = round(covered / total * 100, 1)
|
|
131
|
+
break
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.warn(f"Failed to parse JaCoCo XML: {e}")
|
|
134
|
+
|
|
135
|
+
return report
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _try_parse_cobertura(project_path: Path, report: CoverageReport) -> CoverageReport:
|
|
139
|
+
"""Parse Cobertura XML report."""
|
|
140
|
+
cobertura_paths = [
|
|
141
|
+
"coverage.xml",
|
|
142
|
+
"build/reports/cobertura/coverage.xml",
|
|
143
|
+
"target/site/cobertura/coverage.xml",
|
|
144
|
+
]
|
|
145
|
+
for pattern in cobertura_paths:
|
|
146
|
+
xml_path = project_path / pattern
|
|
147
|
+
if xml_path.exists():
|
|
148
|
+
try:
|
|
149
|
+
import xml.etree.ElementTree as ET
|
|
150
|
+
|
|
151
|
+
tree = ET.parse(xml_path)
|
|
152
|
+
root = tree.getroot()
|
|
153
|
+
rate = float(root.get("line-rate", 0))
|
|
154
|
+
report.line_coverage = round(rate * 100, 1)
|
|
155
|
+
return report
|
|
156
|
+
except Exception:
|
|
157
|
+
continue
|
|
158
|
+
return report
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _try_parse_lcov(project_path: Path, report: CoverageReport) -> CoverageReport:
|
|
162
|
+
"""Parse LCOV info file."""
|
|
163
|
+
lcov_paths = ["coverage/lcov.info", "lcov.info"]
|
|
164
|
+
for pattern in lcov_paths:
|
|
165
|
+
lcov_path = project_path / pattern
|
|
166
|
+
if lcov_path.exists():
|
|
167
|
+
try:
|
|
168
|
+
content = lcov_path.read_text(encoding="utf-8")
|
|
169
|
+
lf = sum(int(m.group(1)) for m in re.finditer(r"LF:(\d+)", content))
|
|
170
|
+
lh = sum(int(m.group(1)) for m in re.finditer(r"LH:(\d+)", content))
|
|
171
|
+
if lf > 0:
|
|
172
|
+
report.total_lines = lf
|
|
173
|
+
report.total_lines_covered = lh
|
|
174
|
+
report.line_coverage = round(lh / lf * 100, 1)
|
|
175
|
+
return report
|
|
176
|
+
except Exception:
|
|
177
|
+
continue
|
|
178
|
+
return report
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _try_parse_pytest_cov(output: str, report: CoverageReport) -> CoverageReport:
|
|
182
|
+
"""Parse pytest-cov output for TOTAL line."""
|
|
183
|
+
match = re.search(r"TOTAL\s+\d+\s+\d+\s+(\d+)%", output)
|
|
184
|
+
if match:
|
|
185
|
+
report.line_coverage = float(match.group(1))
|
|
186
|
+
return report
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _try_parse_go_cover(output: str, report: CoverageReport) -> CoverageReport:
|
|
190
|
+
"""Parse go test -cover output."""
|
|
191
|
+
match = re.search(r"coverage:\s+([\d.]+)%", output)
|
|
192
|
+
if match:
|
|
193
|
+
report.line_coverage = float(match.group(1))
|
|
194
|
+
return report
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _try_parse_generic(output: str, report: CoverageReport) -> CoverageReport:
|
|
198
|
+
"""Try to extract any coverage percentage from output."""
|
|
199
|
+
patterns = [
|
|
200
|
+
r"(?:line|statement|code)\s*coverage[:\s]+(\d+(?:\.\d+)?)\s*%",
|
|
201
|
+
r"(\d+(?:\.\d+)?)\s*%\s*(?:line|statement|code)\s*coverage",
|
|
202
|
+
r"Total[:\s]+(\d+(?:\.\d+)?)\s*%",
|
|
203
|
+
]
|
|
204
|
+
for pattern in patterns:
|
|
205
|
+
match = re.search(pattern, output, re.IGNORECASE)
|
|
206
|
+
if match:
|
|
207
|
+
report.line_coverage = float(match.group(1))
|
|
208
|
+
return report
|
|
209
|
+
return report
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _parse_test_counts(output: str) -> dict[str, int]:
|
|
213
|
+
"""Extract test pass/fail counts from build output."""
|
|
214
|
+
counts: dict[str, int] = {"total": 0, "passed": 0, "failed": 0}
|
|
215
|
+
|
|
216
|
+
# Gradle/JUnit style
|
|
217
|
+
m = re.search(r"(\d+)\s+tests?\s+completed,\s+(\d+)\s+failed", output)
|
|
218
|
+
if m:
|
|
219
|
+
counts["total"] = int(m.group(1))
|
|
220
|
+
counts["failed"] = int(m.group(2))
|
|
221
|
+
counts["passed"] = counts["total"] - counts["failed"]
|
|
222
|
+
return counts
|
|
223
|
+
|
|
224
|
+
# pytest style
|
|
225
|
+
m = re.search(r"(\d+)\s+passed", output)
|
|
226
|
+
if m:
|
|
227
|
+
counts["passed"] = int(m.group(1))
|
|
228
|
+
m = re.search(r"(\d+)\s+failed", output)
|
|
229
|
+
if m:
|
|
230
|
+
counts["failed"] = int(m.group(1))
|
|
231
|
+
counts["total"] = counts["passed"] + counts["failed"]
|
|
232
|
+
|
|
233
|
+
# Go style
|
|
234
|
+
m = re.search(r"ok\s+.+\s+([\d.]+)s", output)
|
|
235
|
+
if m and counts["total"] == 0:
|
|
236
|
+
# go test doesn't show counts easily, estimate from "--- PASS:" lines
|
|
237
|
+
counts["passed"] = len(re.findall(r"--- PASS:", output))
|
|
238
|
+
counts["failed"] = len(re.findall(r"--- FAIL:", output))
|
|
239
|
+
counts["total"] = counts["passed"] + counts["failed"]
|
|
240
|
+
|
|
241
|
+
return counts
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Safe file read/write with rollback protection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import shutil
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from forge_core.utils import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FileManager:
|
|
12
|
+
"""Manages file operations with rollback support.
|
|
13
|
+
|
|
14
|
+
Tracks all written files so they can be reverted if coverage drops.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, project_path: Path):
|
|
18
|
+
self.project_path = project_path
|
|
19
|
+
self._written_files: dict[str, str | None] = {} # path → original content (None if new)
|
|
20
|
+
self._backup_dir: Path | None = None
|
|
21
|
+
|
|
22
|
+
def read_file(self, relative_path: str) -> str:
|
|
23
|
+
"""Read a file from the project."""
|
|
24
|
+
full_path = self.project_path / relative_path
|
|
25
|
+
if not full_path.exists():
|
|
26
|
+
return ""
|
|
27
|
+
return full_path.read_text(encoding="utf-8")
|
|
28
|
+
|
|
29
|
+
def read_files(self, glob_pattern: str) -> dict[str, str]:
|
|
30
|
+
"""Read multiple files matching a glob pattern. Returns {relative_path: content}."""
|
|
31
|
+
results: dict[str, str] = {}
|
|
32
|
+
for path in sorted(self.project_path.glob(glob_pattern)):
|
|
33
|
+
if path.is_file():
|
|
34
|
+
rel = str(path.relative_to(self.project_path))
|
|
35
|
+
try:
|
|
36
|
+
results[rel] = path.read_text(encoding="utf-8")
|
|
37
|
+
except (UnicodeDecodeError, PermissionError):
|
|
38
|
+
continue
|
|
39
|
+
return results
|
|
40
|
+
|
|
41
|
+
def write_file(self, relative_path: str, content: str) -> None:
|
|
42
|
+
"""Write a file, tracking it for potential rollback."""
|
|
43
|
+
full_path = self.project_path / relative_path
|
|
44
|
+
|
|
45
|
+
# Store original state for rollback
|
|
46
|
+
if relative_path not in self._written_files:
|
|
47
|
+
if full_path.exists():
|
|
48
|
+
self._written_files[relative_path] = full_path.read_text(encoding="utf-8")
|
|
49
|
+
else:
|
|
50
|
+
self._written_files[relative_path] = None
|
|
51
|
+
|
|
52
|
+
# Ensure parent dirs exist
|
|
53
|
+
full_path.parent.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
full_path.write_text(content, encoding="utf-8")
|
|
55
|
+
|
|
56
|
+
def rollback(self) -> int:
|
|
57
|
+
"""Rollback all files written since last checkpoint. Returns count of files rolled back."""
|
|
58
|
+
count = 0
|
|
59
|
+
for rel_path, original in self._written_files.items():
|
|
60
|
+
full_path = self.project_path / rel_path
|
|
61
|
+
if original is None:
|
|
62
|
+
# File was newly created — delete it
|
|
63
|
+
if full_path.exists():
|
|
64
|
+
full_path.unlink()
|
|
65
|
+
count += 1
|
|
66
|
+
else:
|
|
67
|
+
# File was modified — restore original
|
|
68
|
+
full_path.write_text(original, encoding="utf-8")
|
|
69
|
+
count += 1
|
|
70
|
+
|
|
71
|
+
if count:
|
|
72
|
+
logger.warn(f"Rolled back {count} files")
|
|
73
|
+
self._written_files.clear()
|
|
74
|
+
return count
|
|
75
|
+
|
|
76
|
+
def checkpoint(self) -> None:
|
|
77
|
+
"""Accept current state — clear rollback history."""
|
|
78
|
+
self._written_files.clear()
|
|
79
|
+
|
|
80
|
+
def list_source_files(self, source_root: str = "src") -> list[str]:
|
|
81
|
+
"""List all source files under the source root."""
|
|
82
|
+
root = self.project_path / source_root
|
|
83
|
+
if not root.exists():
|
|
84
|
+
return []
|
|
85
|
+
return [
|
|
86
|
+
str(p.relative_to(self.project_path))
|
|
87
|
+
for p in sorted(root.rglob("*"))
|
|
88
|
+
if p.is_file() and not p.name.startswith(".")
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
def list_test_files(self, test_root: str = "src/test") -> list[str]:
|
|
92
|
+
"""List all test files under the test root."""
|
|
93
|
+
root = self.project_path / test_root
|
|
94
|
+
if not root.exists():
|
|
95
|
+
return []
|
|
96
|
+
return [
|
|
97
|
+
str(p.relative_to(self.project_path))
|
|
98
|
+
for p in sorted(root.rglob("*"))
|
|
99
|
+
if p.is_file() and not p.name.startswith(".")
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def pending_rollback_count(self) -> int:
|
|
104
|
+
return len(self._written_files)
|
|
File without changes
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Pydantic models for configuration and tenant/plan data."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Plan(str, Enum):
|
|
13
|
+
FREE = "free"
|
|
14
|
+
PRO = "pro"
|
|
15
|
+
ENTERPRISE = "enterprise"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AIProvider(str, Enum):
|
|
19
|
+
OPENAI = "openai"
|
|
20
|
+
ANTHROPIC = "anthropic"
|
|
21
|
+
AZURE = "azure"
|
|
22
|
+
BEDROCK = "bedrock"
|
|
23
|
+
OLLAMA = "ollama"
|
|
24
|
+
AUTO = "auto"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RunMode(str, Enum):
|
|
28
|
+
FULL = "full"
|
|
29
|
+
TARGETED = "targeted"
|
|
30
|
+
ANALYZE_ONLY = "analyze_only"
|
|
31
|
+
ANALYZE_REVIEW = "analyze_review"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TenantInfo(BaseModel):
|
|
35
|
+
"""Multi-tenant identification — populated by SaaS or CLI."""
|
|
36
|
+
|
|
37
|
+
org_id: str = ""
|
|
38
|
+
org_name: str = ""
|
|
39
|
+
user_id: str = ""
|
|
40
|
+
project_id: str = ""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PlanLimits(BaseModel):
|
|
44
|
+
"""Usage limits per plan tier."""
|
|
45
|
+
|
|
46
|
+
plan: Plan = Plan.FREE
|
|
47
|
+
max_tests_per_month: int = 500
|
|
48
|
+
max_repos: int = 1
|
|
49
|
+
ci_cd_enabled: bool = False
|
|
50
|
+
cross_project_learning: bool = False
|
|
51
|
+
max_parallel_agents: int = 2
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def for_plan(cls, plan: Plan) -> PlanLimits:
|
|
55
|
+
if plan == Plan.PRO:
|
|
56
|
+
return cls(
|
|
57
|
+
plan=plan,
|
|
58
|
+
max_tests_per_month=-1, # unlimited
|
|
59
|
+
max_repos=-1,
|
|
60
|
+
ci_cd_enabled=True,
|
|
61
|
+
cross_project_learning=True,
|
|
62
|
+
max_parallel_agents=4,
|
|
63
|
+
)
|
|
64
|
+
elif plan == Plan.ENTERPRISE:
|
|
65
|
+
return cls(
|
|
66
|
+
plan=plan,
|
|
67
|
+
max_tests_per_month=-1,
|
|
68
|
+
max_repos=-1,
|
|
69
|
+
ci_cd_enabled=True,
|
|
70
|
+
cross_project_learning=True,
|
|
71
|
+
max_parallel_agents=8,
|
|
72
|
+
)
|
|
73
|
+
return cls()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AIConfig(BaseModel):
|
|
77
|
+
"""AI provider configuration."""
|
|
78
|
+
|
|
79
|
+
provider: AIProvider = AIProvider.AUTO
|
|
80
|
+
model: str = "gpt-4o"
|
|
81
|
+
api_key: str = ""
|
|
82
|
+
base_url: str = ""
|
|
83
|
+
temperature: float = 0.1
|
|
84
|
+
max_tokens: int = 4096
|
|
85
|
+
use_saas_proxy: bool = False # True = use our API key via SaaS
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ForgeConfig(BaseModel):
|
|
89
|
+
"""Top-level engine configuration."""
|
|
90
|
+
|
|
91
|
+
# Project
|
|
92
|
+
project_path: Path = Field(default_factory=lambda: Path("."))
|
|
93
|
+
target_coverage: float = 90.0
|
|
94
|
+
max_iterations: int = 10
|
|
95
|
+
mode: RunMode = RunMode.FULL
|
|
96
|
+
target_files: list[str] = Field(default_factory=list)
|
|
97
|
+
|
|
98
|
+
# AI
|
|
99
|
+
ai: AIConfig = Field(default_factory=AIConfig)
|
|
100
|
+
|
|
101
|
+
# Tenant
|
|
102
|
+
tenant: TenantInfo = Field(default_factory=TenantInfo)
|
|
103
|
+
|
|
104
|
+
# Plan
|
|
105
|
+
limits: PlanLimits = Field(default_factory=PlanLimits)
|
|
106
|
+
|
|
107
|
+
# Engine
|
|
108
|
+
central_agent_path: str = ""
|
|
109
|
+
knowledge_packs_dir: str = "knowledge-packs"
|
|
110
|
+
cache_dir: str = ".forge-cache"
|
|
111
|
+
prompts_dir: str = ".github/prompts"
|
|
112
|
+
|
|
113
|
+
# SaaS
|
|
114
|
+
saas_api_url: str = "https://api.theswitchcompany.online"
|
|
115
|
+
auth_token: str = ""
|
forge_core/models/dto.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Pydantic models for DTO registry."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DTOParam(BaseModel):
|
|
11
|
+
"""A single constructor/field parameter of a DTO."""
|
|
12
|
+
|
|
13
|
+
name: str
|
|
14
|
+
type: str
|
|
15
|
+
default: str = ""
|
|
16
|
+
nullable: bool = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DTOEntry(BaseModel):
|
|
20
|
+
"""A single DTO class registered in the global registry."""
|
|
21
|
+
|
|
22
|
+
class_name: str
|
|
23
|
+
package: str
|
|
24
|
+
file_path: str
|
|
25
|
+
params: list[DTOParam] = Field(default_factory=list)
|
|
26
|
+
has_builder: bool = False
|
|
27
|
+
has_factory: bool = False
|
|
28
|
+
validation_annotations: list[str] = Field(default_factory=list)
|
|
29
|
+
nested_dtos: list[str] = Field(default_factory=list) # class names of nested DTOs
|
|
30
|
+
used_in_journeys: list[str] = Field(default_factory=list)
|
|
31
|
+
used_in_layers: list[str] = Field(default_factory=list)
|
|
32
|
+
|
|
33
|
+
def mock_snippet(self) -> str:
|
|
34
|
+
"""Generate a minimal mock/test instance snippet."""
|
|
35
|
+
params_str = ", ".join(
|
|
36
|
+
f'{p.name}={"null" if p.nullable else repr(p.default) if p.default else "TODO"}'
|
|
37
|
+
for p in self.params
|
|
38
|
+
)
|
|
39
|
+
return f"{self.class_name}({params_str})"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class DTORegistry(BaseModel):
|
|
43
|
+
"""Global registry of all DTOs in the project — read once, shared everywhere."""
|
|
44
|
+
|
|
45
|
+
entries: dict[str, DTOEntry] = Field(default_factory=dict) # class_name → DTOEntry
|
|
46
|
+
|
|
47
|
+
def register(self, entry: DTOEntry) -> None:
|
|
48
|
+
self.entries[entry.class_name] = entry
|
|
49
|
+
|
|
50
|
+
def get(self, class_name: str) -> "Optional[DTOEntry]":
|
|
51
|
+
return self.entries.get(class_name)
|
|
52
|
+
|
|
53
|
+
def for_journey(self, journey_name: str) -> list[DTOEntry]:
|
|
54
|
+
return [e for e in self.entries.values() if journey_name in e.used_in_journeys]
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def count(self) -> int:
|
|
58
|
+
return len(self.entries)
|