switchforge 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,111 @@
1
+ """Agent manager — heartbeat monitoring, stuck detection, scope splitting."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from dataclasses import dataclass, field
7
+
8
+ from forge_core.utils import logger
9
+
10
+
11
+ @dataclass
12
+ class AgentState:
13
+ """Tracks the state of a generation agent."""
14
+
15
+ scope_id: str
16
+ target_files: list[str]
17
+ started_at: float = field(default_factory=time.time)
18
+ tool_calls: int = 0
19
+ last_progress_at: float = field(default_factory=time.time)
20
+ errors: list[str] = field(default_factory=list)
21
+ consecutive_same_error: int = 0
22
+ last_error: str = ""
23
+ completed: bool = False
24
+
25
+
26
+ class AgentManager:
27
+ """Manages agent lifecycle with heartbeat and stuck detection.
28
+
29
+ Rules from copilot-instructions.md:
30
+ - Heartbeat every 10 tool calls
31
+ - Fruitless detection: same error 3x
32
+ - Auto-termination after 20 fruitless calls
33
+ - Scope splitting for stuck agents
34
+ """
35
+
36
+ MAX_FRUITLESS_CALLS = 20
37
+ SAME_ERROR_THRESHOLD = 3
38
+ HEARTBEAT_INTERVAL = 10
39
+
40
+ def __init__(self):
41
+ self._agents: dict[str, AgentState] = {}
42
+
43
+ def register(self, scope_id: str, target_files: list[str]) -> AgentState:
44
+ """Register a new agent scope."""
45
+ state = AgentState(scope_id=scope_id, target_files=target_files)
46
+ self._agents[scope_id] = state
47
+ return state
48
+
49
+ def heartbeat(self, scope_id: str) -> str:
50
+ """Record a heartbeat. Returns action: 'continue', 'split', or 'terminate'."""
51
+ state = self._agents.get(scope_id)
52
+ if not state:
53
+ return "terminate"
54
+
55
+ state.tool_calls += 1
56
+
57
+ # Check fruitless threshold
58
+ if state.consecutive_same_error >= self.SAME_ERROR_THRESHOLD:
59
+ logger.warn(f"Agent {scope_id}: same error {state.consecutive_same_error}x — splitting scope")
60
+ return "split"
61
+
62
+ if state.tool_calls >= self.MAX_FRUITLESS_CALLS and not state.last_progress_at > state.started_at:
63
+ logger.warn(f"Agent {scope_id}: {state.tool_calls} calls with no progress — terminating")
64
+ return "terminate"
65
+
66
+ # Periodic heartbeat log
67
+ if state.tool_calls % self.HEARTBEAT_INTERVAL == 0:
68
+ elapsed = time.time() - state.started_at
69
+ logger.info(f"Agent {scope_id}: {state.tool_calls} calls, {elapsed:.0f}s elapsed")
70
+
71
+ return "continue"
72
+
73
+ def record_error(self, scope_id: str, error: str) -> None:
74
+ """Record an error from an agent."""
75
+ state = self._agents.get(scope_id)
76
+ if not state:
77
+ return
78
+
79
+ state.errors.append(error)
80
+ if error == state.last_error:
81
+ state.consecutive_same_error += 1
82
+ else:
83
+ state.consecutive_same_error = 1
84
+ state.last_error = error
85
+
86
+ def record_progress(self, scope_id: str) -> None:
87
+ """Record that an agent made progress (wrote a file, coverage increased)."""
88
+ state = self._agents.get(scope_id)
89
+ if state:
90
+ state.last_progress_at = time.time()
91
+ state.consecutive_same_error = 0
92
+
93
+ def split_scope(self, scope_id: str) -> list[list[str]]:
94
+ """Split a stuck agent's scope into smaller pieces."""
95
+ state = self._agents.get(scope_id)
96
+ if not state or len(state.target_files) <= 1:
97
+ return []
98
+
99
+ files = state.target_files
100
+ mid = len(files) // 2
101
+ return [files[:mid], files[mid:]]
102
+
103
+ def complete(self, scope_id: str) -> None:
104
+ """Mark an agent as completed."""
105
+ state = self._agents.get(scope_id)
106
+ if state:
107
+ state.completed = True
108
+
109
+ @property
110
+ def active_count(self) -> int:
111
+ return sum(1 for s in self._agents.values() if not s.completed)
@@ -0,0 +1,241 @@
1
+ """Coverage runner — execute test commands and parse coverage reports."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from forge_core.models.test_result import CoverageEntry, CoverageReport
9
+ from forge_core.utils import logger
10
+ from forge_core.utils.shell import run
11
+
12
+
13
+ def run_tests(project_path: Path, test_command: str) -> tuple[bool, str]:
14
+ """Run the project's test command. Returns (success, output)."""
15
+ result = run(test_command, cwd=project_path, timeout=600)
16
+ return result.success, result.output
17
+
18
+
19
+ def run_coverage(project_path: Path, coverage_command: str) -> CoverageReport:
20
+ """Run the project's coverage command and parse the report."""
21
+ result = run(coverage_command, cwd=project_path, timeout=600)
22
+
23
+ report = CoverageReport()
24
+
25
+ # Try parsing various coverage report formats
26
+ report = _try_parse_jacoco(project_path, report)
27
+ if report.total_lines == 0:
28
+ report = _try_parse_cobertura(project_path, report)
29
+ if report.total_lines == 0:
30
+ report = _try_parse_lcov(project_path, report)
31
+ if report.total_lines == 0:
32
+ report = _try_parse_pytest_cov(result.output, report)
33
+ if report.total_lines == 0:
34
+ report = _try_parse_go_cover(result.output, report)
35
+ if report.total_lines == 0:
36
+ # Fallback: try to extract percentage from output
37
+ report = _try_parse_generic(result.output, report)
38
+
39
+ # Count tests from output
40
+ test_counts = _parse_test_counts(result.output)
41
+ report.total_tests = test_counts.get("total", 0)
42
+ report.tests_passed = test_counts.get("passed", 0)
43
+ report.tests_failed = test_counts.get("failed", 0)
44
+
45
+ return report
46
+
47
+
48
+ def _try_parse_jacoco(project_path: Path, report: CoverageReport) -> CoverageReport:
49
+ """Parse JaCoCo CSV or XML report (Kotlin/Java)."""
50
+ # Look for CSV report first
51
+ csv_patterns = [
52
+ "build/reports/jacoco/test/jacocoTestReport.csv",
53
+ "build/reports/jacoco/jacocoTestReport.csv",
54
+ "target/site/jacoco/jacoco.csv",
55
+ ]
56
+ for pattern in csv_patterns:
57
+ csv_path = project_path / pattern
58
+ if csv_path.exists():
59
+ return _parse_jacoco_csv(csv_path, report)
60
+
61
+ # Try XML
62
+ xml_patterns = [
63
+ "build/reports/jacoco/test/jacocoTestReport.xml",
64
+ "target/site/jacoco/jacoco.xml",
65
+ ]
66
+ for pattern in xml_patterns:
67
+ xml_path = project_path / pattern
68
+ if xml_path.exists():
69
+ return _parse_jacoco_xml(xml_path, report)
70
+
71
+ return report
72
+
73
+
74
+ def _parse_jacoco_csv(csv_path: Path, report: CoverageReport) -> CoverageReport:
75
+ """Parse JaCoCo CSV report."""
76
+ import csv
77
+
78
+ total_missed = 0
79
+ total_covered = 0
80
+
81
+ try:
82
+ with open(csv_path, encoding="utf-8") as f:
83
+ reader = csv.DictReader(f)
84
+ for row in reader:
85
+ missed = int(row.get("LINE_MISSED", 0))
86
+ covered = int(row.get("LINE_COVERED", 0))
87
+ total_missed += missed
88
+ total_covered += covered
89
+
90
+ pkg = row.get("PACKAGE", "")
91
+ cls = row.get("CLASS", "")
92
+ total = missed + covered
93
+ if total > 0:
94
+ report.files.append(
95
+ CoverageEntry(
96
+ file_path=f"{pkg}/{cls}",
97
+ line_coverage=round(covered / total * 100, 1),
98
+ lines_covered=covered,
99
+ lines_total=total,
100
+ )
101
+ )
102
+
103
+ total = total_missed + total_covered
104
+ report.total_lines = total
105
+ report.total_lines_covered = total_covered
106
+ if total > 0:
107
+ report.line_coverage = round(total_covered / total * 100, 1)
108
+ except Exception as e:
109
+ logger.warn(f"Failed to parse JaCoCo CSV: {e}")
110
+
111
+ return report
112
+
113
+
114
+ def _parse_jacoco_xml(xml_path: Path, report: CoverageReport) -> CoverageReport:
115
+ """Parse JaCoCo XML report — extract LINE counter from root."""
116
+ import xml.etree.ElementTree as ET
117
+
118
+ try:
119
+ tree = ET.parse(xml_path)
120
+ root = tree.getroot()
121
+
122
+ for counter in root.findall("counter"):
123
+ if counter.get("type") == "LINE":
124
+ missed = int(counter.get("missed", 0))
125
+ covered = int(counter.get("covered", 0))
126
+ total = missed + covered
127
+ report.total_lines = total
128
+ report.total_lines_covered = covered
129
+ if total > 0:
130
+ report.line_coverage = round(covered / total * 100, 1)
131
+ break
132
+ except Exception as e:
133
+ logger.warn(f"Failed to parse JaCoCo XML: {e}")
134
+
135
+ return report
136
+
137
+
138
+ def _try_parse_cobertura(project_path: Path, report: CoverageReport) -> CoverageReport:
139
+ """Parse Cobertura XML report."""
140
+ cobertura_paths = [
141
+ "coverage.xml",
142
+ "build/reports/cobertura/coverage.xml",
143
+ "target/site/cobertura/coverage.xml",
144
+ ]
145
+ for pattern in cobertura_paths:
146
+ xml_path = project_path / pattern
147
+ if xml_path.exists():
148
+ try:
149
+ import xml.etree.ElementTree as ET
150
+
151
+ tree = ET.parse(xml_path)
152
+ root = tree.getroot()
153
+ rate = float(root.get("line-rate", 0))
154
+ report.line_coverage = round(rate * 100, 1)
155
+ return report
156
+ except Exception:
157
+ continue
158
+ return report
159
+
160
+
161
+ def _try_parse_lcov(project_path: Path, report: CoverageReport) -> CoverageReport:
162
+ """Parse LCOV info file."""
163
+ lcov_paths = ["coverage/lcov.info", "lcov.info"]
164
+ for pattern in lcov_paths:
165
+ lcov_path = project_path / pattern
166
+ if lcov_path.exists():
167
+ try:
168
+ content = lcov_path.read_text(encoding="utf-8")
169
+ lf = sum(int(m.group(1)) for m in re.finditer(r"LF:(\d+)", content))
170
+ lh = sum(int(m.group(1)) for m in re.finditer(r"LH:(\d+)", content))
171
+ if lf > 0:
172
+ report.total_lines = lf
173
+ report.total_lines_covered = lh
174
+ report.line_coverage = round(lh / lf * 100, 1)
175
+ return report
176
+ except Exception:
177
+ continue
178
+ return report
179
+
180
+
181
+ def _try_parse_pytest_cov(output: str, report: CoverageReport) -> CoverageReport:
182
+ """Parse pytest-cov output for TOTAL line."""
183
+ match = re.search(r"TOTAL\s+\d+\s+\d+\s+(\d+)%", output)
184
+ if match:
185
+ report.line_coverage = float(match.group(1))
186
+ return report
187
+
188
+
189
+ def _try_parse_go_cover(output: str, report: CoverageReport) -> CoverageReport:
190
+ """Parse go test -cover output."""
191
+ match = re.search(r"coverage:\s+([\d.]+)%", output)
192
+ if match:
193
+ report.line_coverage = float(match.group(1))
194
+ return report
195
+
196
+
197
+ def _try_parse_generic(output: str, report: CoverageReport) -> CoverageReport:
198
+ """Try to extract any coverage percentage from output."""
199
+ patterns = [
200
+ r"(?:line|statement|code)\s*coverage[:\s]+(\d+(?:\.\d+)?)\s*%",
201
+ r"(\d+(?:\.\d+)?)\s*%\s*(?:line|statement|code)\s*coverage",
202
+ r"Total[:\s]+(\d+(?:\.\d+)?)\s*%",
203
+ ]
204
+ for pattern in patterns:
205
+ match = re.search(pattern, output, re.IGNORECASE)
206
+ if match:
207
+ report.line_coverage = float(match.group(1))
208
+ return report
209
+ return report
210
+
211
+
212
+ def _parse_test_counts(output: str) -> dict[str, int]:
213
+ """Extract test pass/fail counts from build output."""
214
+ counts: dict[str, int] = {"total": 0, "passed": 0, "failed": 0}
215
+
216
+ # Gradle/JUnit style
217
+ m = re.search(r"(\d+)\s+tests?\s+completed,\s+(\d+)\s+failed", output)
218
+ if m:
219
+ counts["total"] = int(m.group(1))
220
+ counts["failed"] = int(m.group(2))
221
+ counts["passed"] = counts["total"] - counts["failed"]
222
+ return counts
223
+
224
+ # pytest style
225
+ m = re.search(r"(\d+)\s+passed", output)
226
+ if m:
227
+ counts["passed"] = int(m.group(1))
228
+ m = re.search(r"(\d+)\s+failed", output)
229
+ if m:
230
+ counts["failed"] = int(m.group(1))
231
+ counts["total"] = counts["passed"] + counts["failed"]
232
+
233
+ # Go style
234
+ m = re.search(r"ok\s+.+\s+([\d.]+)s", output)
235
+ if m and counts["total"] == 0:
236
+ # go test doesn't show counts easily, estimate from "--- PASS:" lines
237
+ counts["passed"] = len(re.findall(r"--- PASS:", output))
238
+ counts["failed"] = len(re.findall(r"--- FAIL:", output))
239
+ counts["total"] = counts["passed"] + counts["failed"]
240
+
241
+ return counts
@@ -0,0 +1,104 @@
1
+ """Safe file read/write with rollback protection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import shutil
6
+ from pathlib import Path
7
+
8
+ from forge_core.utils import logger
9
+
10
+
11
+ class FileManager:
12
+ """Manages file operations with rollback support.
13
+
14
+ Tracks all written files so they can be reverted if coverage drops.
15
+ """
16
+
17
+ def __init__(self, project_path: Path):
18
+ self.project_path = project_path
19
+ self._written_files: dict[str, str | None] = {} # path → original content (None if new)
20
+ self._backup_dir: Path | None = None
21
+
22
+ def read_file(self, relative_path: str) -> str:
23
+ """Read a file from the project."""
24
+ full_path = self.project_path / relative_path
25
+ if not full_path.exists():
26
+ return ""
27
+ return full_path.read_text(encoding="utf-8")
28
+
29
+ def read_files(self, glob_pattern: str) -> dict[str, str]:
30
+ """Read multiple files matching a glob pattern. Returns {relative_path: content}."""
31
+ results: dict[str, str] = {}
32
+ for path in sorted(self.project_path.glob(glob_pattern)):
33
+ if path.is_file():
34
+ rel = str(path.relative_to(self.project_path))
35
+ try:
36
+ results[rel] = path.read_text(encoding="utf-8")
37
+ except (UnicodeDecodeError, PermissionError):
38
+ continue
39
+ return results
40
+
41
+ def write_file(self, relative_path: str, content: str) -> None:
42
+ """Write a file, tracking it for potential rollback."""
43
+ full_path = self.project_path / relative_path
44
+
45
+ # Store original state for rollback
46
+ if relative_path not in self._written_files:
47
+ if full_path.exists():
48
+ self._written_files[relative_path] = full_path.read_text(encoding="utf-8")
49
+ else:
50
+ self._written_files[relative_path] = None
51
+
52
+ # Ensure parent dirs exist
53
+ full_path.parent.mkdir(parents=True, exist_ok=True)
54
+ full_path.write_text(content, encoding="utf-8")
55
+
56
+ def rollback(self) -> int:
57
+ """Rollback all files written since last checkpoint. Returns count of files rolled back."""
58
+ count = 0
59
+ for rel_path, original in self._written_files.items():
60
+ full_path = self.project_path / rel_path
61
+ if original is None:
62
+ # File was newly created — delete it
63
+ if full_path.exists():
64
+ full_path.unlink()
65
+ count += 1
66
+ else:
67
+ # File was modified — restore original
68
+ full_path.write_text(original, encoding="utf-8")
69
+ count += 1
70
+
71
+ if count:
72
+ logger.warn(f"Rolled back {count} files")
73
+ self._written_files.clear()
74
+ return count
75
+
76
+ def checkpoint(self) -> None:
77
+ """Accept current state — clear rollback history."""
78
+ self._written_files.clear()
79
+
80
+ def list_source_files(self, source_root: str = "src") -> list[str]:
81
+ """List all source files under the source root."""
82
+ root = self.project_path / source_root
83
+ if not root.exists():
84
+ return []
85
+ return [
86
+ str(p.relative_to(self.project_path))
87
+ for p in sorted(root.rglob("*"))
88
+ if p.is_file() and not p.name.startswith(".")
89
+ ]
90
+
91
+ def list_test_files(self, test_root: str = "src/test") -> list[str]:
92
+ """List all test files under the test root."""
93
+ root = self.project_path / test_root
94
+ if not root.exists():
95
+ return []
96
+ return [
97
+ str(p.relative_to(self.project_path))
98
+ for p in sorted(root.rglob("*"))
99
+ if p.is_file() and not p.name.startswith(".")
100
+ ]
101
+
102
+ @property
103
+ def pending_rollback_count(self) -> int:
104
+ return len(self._written_files)
File without changes
@@ -0,0 +1,115 @@
1
+ """Pydantic models for configuration and tenant/plan data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class Plan(str, Enum):
13
+ FREE = "free"
14
+ PRO = "pro"
15
+ ENTERPRISE = "enterprise"
16
+
17
+
18
+ class AIProvider(str, Enum):
19
+ OPENAI = "openai"
20
+ ANTHROPIC = "anthropic"
21
+ AZURE = "azure"
22
+ BEDROCK = "bedrock"
23
+ OLLAMA = "ollama"
24
+ AUTO = "auto"
25
+
26
+
27
+ class RunMode(str, Enum):
28
+ FULL = "full"
29
+ TARGETED = "targeted"
30
+ ANALYZE_ONLY = "analyze_only"
31
+ ANALYZE_REVIEW = "analyze_review"
32
+
33
+
34
+ class TenantInfo(BaseModel):
35
+ """Multi-tenant identification — populated by SaaS or CLI."""
36
+
37
+ org_id: str = ""
38
+ org_name: str = ""
39
+ user_id: str = ""
40
+ project_id: str = ""
41
+
42
+
43
+ class PlanLimits(BaseModel):
44
+ """Usage limits per plan tier."""
45
+
46
+ plan: Plan = Plan.FREE
47
+ max_tests_per_month: int = 500
48
+ max_repos: int = 1
49
+ ci_cd_enabled: bool = False
50
+ cross_project_learning: bool = False
51
+ max_parallel_agents: int = 2
52
+
53
+ @classmethod
54
+ def for_plan(cls, plan: Plan) -> PlanLimits:
55
+ if plan == Plan.PRO:
56
+ return cls(
57
+ plan=plan,
58
+ max_tests_per_month=-1, # unlimited
59
+ max_repos=-1,
60
+ ci_cd_enabled=True,
61
+ cross_project_learning=True,
62
+ max_parallel_agents=4,
63
+ )
64
+ elif plan == Plan.ENTERPRISE:
65
+ return cls(
66
+ plan=plan,
67
+ max_tests_per_month=-1,
68
+ max_repos=-1,
69
+ ci_cd_enabled=True,
70
+ cross_project_learning=True,
71
+ max_parallel_agents=8,
72
+ )
73
+ return cls()
74
+
75
+
76
+ class AIConfig(BaseModel):
77
+ """AI provider configuration."""
78
+
79
+ provider: AIProvider = AIProvider.AUTO
80
+ model: str = "gpt-4o"
81
+ api_key: str = ""
82
+ base_url: str = ""
83
+ temperature: float = 0.1
84
+ max_tokens: int = 4096
85
+ use_saas_proxy: bool = False # True = use our API key via SaaS
86
+
87
+
88
+ class ForgeConfig(BaseModel):
89
+ """Top-level engine configuration."""
90
+
91
+ # Project
92
+ project_path: Path = Field(default_factory=lambda: Path("."))
93
+ target_coverage: float = 90.0
94
+ max_iterations: int = 10
95
+ mode: RunMode = RunMode.FULL
96
+ target_files: list[str] = Field(default_factory=list)
97
+
98
+ # AI
99
+ ai: AIConfig = Field(default_factory=AIConfig)
100
+
101
+ # Tenant
102
+ tenant: TenantInfo = Field(default_factory=TenantInfo)
103
+
104
+ # Plan
105
+ limits: PlanLimits = Field(default_factory=PlanLimits)
106
+
107
+ # Engine
108
+ central_agent_path: str = ""
109
+ knowledge_packs_dir: str = "knowledge-packs"
110
+ cache_dir: str = ".forge-cache"
111
+ prompts_dir: str = ".github/prompts"
112
+
113
+ # SaaS
114
+ saas_api_url: str = "https://api.theswitchcompany.online"
115
+ auth_token: str = ""
@@ -0,0 +1,58 @@
1
+ """Pydantic models for DTO registry."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class DTOParam(BaseModel):
11
+ """A single constructor/field parameter of a DTO."""
12
+
13
+ name: str
14
+ type: str
15
+ default: str = ""
16
+ nullable: bool = False
17
+
18
+
19
+ class DTOEntry(BaseModel):
20
+ """A single DTO class registered in the global registry."""
21
+
22
+ class_name: str
23
+ package: str
24
+ file_path: str
25
+ params: list[DTOParam] = Field(default_factory=list)
26
+ has_builder: bool = False
27
+ has_factory: bool = False
28
+ validation_annotations: list[str] = Field(default_factory=list)
29
+ nested_dtos: list[str] = Field(default_factory=list) # class names of nested DTOs
30
+ used_in_journeys: list[str] = Field(default_factory=list)
31
+ used_in_layers: list[str] = Field(default_factory=list)
32
+
33
+ def mock_snippet(self) -> str:
34
+ """Generate a minimal mock/test instance snippet."""
35
+ params_str = ", ".join(
36
+ f'{p.name}={"null" if p.nullable else repr(p.default) if p.default else "TODO"}'
37
+ for p in self.params
38
+ )
39
+ return f"{self.class_name}({params_str})"
40
+
41
+
42
+ class DTORegistry(BaseModel):
43
+ """Global registry of all DTOs in the project — read once, shared everywhere."""
44
+
45
+ entries: dict[str, DTOEntry] = Field(default_factory=dict) # class_name → DTOEntry
46
+
47
+ def register(self, entry: DTOEntry) -> None:
48
+ self.entries[entry.class_name] = entry
49
+
50
+ def get(self, class_name: str) -> "Optional[DTOEntry]":
51
+ return self.entries.get(class_name)
52
+
53
+ def for_journey(self, journey_name: str) -> list[DTOEntry]:
54
+ return [e for e in self.entries.values() if journey_name in e.used_in_journeys]
55
+
56
+ @property
57
+ def count(self) -> int:
58
+ return len(self.entries)