switchforge 1.1.0__tar.gz → 2.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {switchforge-1.1.0 → switchforge-2.0.0}/PKG-INFO +1 -3
  2. switchforge-2.0.0/forge_core/ai/provider.py +142 -0
  3. switchforge-2.0.0/forge_core/ai/structured.py +108 -0
  4. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/models/config.py +17 -14
  5. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/models/dto.py +21 -19
  6. switchforge-2.0.0/forge_core/models/project.py +78 -0
  7. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/models/test_result.py +23 -19
  8. {switchforge-1.1.0 → switchforge-2.0.0}/pyproject.toml +1 -3
  9. switchforge-1.1.0/forge_core/ai/provider.py +0 -120
  10. switchforge-1.1.0/forge_core/ai/structured.py +0 -62
  11. switchforge-1.1.0/forge_core/models/project.py +0 -72
  12. {switchforge-1.1.0 → switchforge-2.0.0}/.gitignore +0 -0
  13. {switchforge-1.1.0 → switchforge-2.0.0}/README.md +0 -0
  14. {switchforge-1.1.0 → switchforge-2.0.0}/forge-core.spec +0 -0
  15. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/__init__.py +0 -0
  16. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/__main__.py +0 -0
  17. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/ai/__init__.py +0 -0
  18. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/ai/prompts.py +0 -0
  19. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/auth.py +0 -0
  20. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/cli.py +0 -0
  21. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/config.py +0 -0
  22. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/core/__init__.py +0 -0
  23. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/core/agent_manager.py +0 -0
  24. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/core/coverage.py +0 -0
  25. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/core/file_manager.py +0 -0
  26. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/models/__init__.py +0 -0
  27. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/orchestrator.py +0 -0
  28. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/__init__.py +0 -0
  29. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/analyze_project.py +0 -0
  30. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/audit_tests.py +0 -0
  31. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/compile_fix.py +0 -0
  32. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/coverage_report.py +0 -0
  33. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/detect_stack.py +0 -0
  34. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/exclusion_scan.py +0 -0
  35. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/fix_broken.py +0 -0
  36. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/generate_tests.py +0 -0
  37. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/journey_mapping.py +0 -0
  38. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/phases/self_learn.py +0 -0
  39. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/utils/__init__.py +0 -0
  40. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/utils/logger.py +0 -0
  41. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/utils/reporter.py +0 -0
  42. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/utils/shell.py +0 -0
  43. {switchforge-1.1.0 → switchforge-2.0.0}/forge_core/utils/tokens.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: switchforge
3
- Version: 1.1.0
3
+ Version: 2.0.0
4
4
  Summary: AI-powered backend test generation engine
5
5
  Project-URL: Homepage, https://theswitchcompany.online/products/forge/core
6
6
  Project-URL: Repository, https://github.com/switchcompany/forge-core
@@ -18,8 +18,6 @@ Classifier: Programming Language :: Python :: 3.12
18
18
  Classifier: Topic :: Software Development :: Testing
19
19
  Requires-Python: >=3.10
20
20
  Requires-Dist: httpx>=0.27.0
21
- Requires-Dist: openai>=1.30.0
22
- Requires-Dist: pydantic>=2.7.0
23
21
  Requires-Dist: pyyaml>=6.0
24
22
  Requires-Dist: rich>=13.7.0
25
23
  Requires-Dist: typer>=0.12.0
@@ -0,0 +1,142 @@
1
+ """AI provider — raw httpx calls, zero native dependencies."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from typing import Any
8
+
9
+ import httpx
10
+
11
+ from forge_core.models.config import AIConfig, AIProvider
12
+ from forge_core.utils import logger
13
+ from forge_core.utils.tokens import count_tokens
14
+
15
+ _TIMEOUT = httpx.Timeout(120.0, connect=10.0)
16
+
17
+
18
+ def _get_api_url(config: AIConfig) -> str:
19
+ """Resolve the API base URL."""
20
+ if config.base_url:
21
+ return config.base_url.rstrip("/")
22
+ if config.provider == AIProvider.ANTHROPIC:
23
+ return "https://api.anthropic.com/v1"
24
+ if config.provider == AIProvider.OLLAMA:
25
+ return "http://localhost:11434/v1"
26
+ if config.provider == AIProvider.AZURE:
27
+ return (os.environ.get("AZURE_OPENAI_ENDPOINT", "")).rstrip("/")
28
+ return "https://api.openai.com/v1"
29
+
30
+
31
+ def _get_api_key(config: AIConfig) -> str:
32
+ """Resolve the API key."""
33
+ if config.api_key:
34
+ return config.api_key
35
+ if config.provider == AIProvider.ANTHROPIC:
36
+ return os.environ.get("ANTHROPIC_API_KEY", "")
37
+ if config.provider == AIProvider.AZURE:
38
+ return os.environ.get("AZURE_OPENAI_API_KEY", "")
39
+ if config.provider == AIProvider.OLLAMA:
40
+ return "ollama"
41
+ return os.environ.get("OPENAI_API_KEY", "")
42
+
43
+
44
+ def _resolve_model(config: AIConfig) -> str:
45
+ """Resolve the model name."""
46
+ return config.model
47
+
48
+
49
+ def _call_chat_api(
50
+ config: AIConfig,
51
+ model: str,
52
+ messages: list[dict[str, str]],
53
+ temperature: float,
54
+ max_tokens: int,
55
+ json_mode: bool = False,
56
+ ) -> str:
57
+ """Make a raw HTTP POST to the chat completions endpoint."""
58
+ base_url = _get_api_url(config)
59
+ api_key = _get_api_key(config)
60
+
61
+ body: dict[str, Any] = {
62
+ "model": model,
63
+ "messages": messages,
64
+ "temperature": temperature,
65
+ "max_tokens": max_tokens,
66
+ }
67
+ if json_mode:
68
+ body["response_format"] = {"type": "json_object"}
69
+
70
+ headers = {
71
+ "Content-Type": "application/json",
72
+ "Authorization": f"Bearer {api_key}",
73
+ }
74
+
75
+ resp = httpx.post(
76
+ f"{base_url}/chat/completions",
77
+ json=body,
78
+ headers=headers,
79
+ timeout=_TIMEOUT,
80
+ )
81
+ resp.raise_for_status()
82
+ data = resp.json()
83
+ return data["choices"][0]["message"]["content"] or ""
84
+
85
+
86
+ def complete(
87
+ config: AIConfig,
88
+ system_prompt: str,
89
+ user_prompt: str,
90
+ json_mode: bool = False,
91
+ max_tokens: int | None = None,
92
+ ) -> str:
93
+ """Send a completion request to the configured AI provider."""
94
+ model = _resolve_model(config)
95
+ messages = [
96
+ {"role": "system", "content": system_prompt},
97
+ {"role": "user", "content": user_prompt},
98
+ ]
99
+
100
+ input_tokens = count_tokens(system_prompt + user_prompt, config.model)
101
+ logger.info(f"AI call → {model} ({input_tokens} input tokens)")
102
+
103
+ try:
104
+ content = _call_chat_api(
105
+ config, model, messages, config.temperature,
106
+ max_tokens or config.max_tokens, json_mode,
107
+ )
108
+ output_tokens = count_tokens(content, config.model)
109
+ logger.info(f"AI response ← {output_tokens} output tokens")
110
+ return content
111
+ except Exception as e:
112
+ logger.error(f"AI call failed: {e}")
113
+ raise
114
+
115
+
116
+ def complete_with_fallback(
117
+ config: AIConfig,
118
+ system_prompt: str,
119
+ user_prompt: str,
120
+ fallback_models: list[str] | None = None,
121
+ json_mode: bool = False,
122
+ ) -> str:
123
+ """Try primary model, fall back to alternatives on failure."""
124
+ models = [_resolve_model(config)] + (fallback_models or [])
125
+ messages = [
126
+ {"role": "system", "content": system_prompt},
127
+ {"role": "user", "content": user_prompt},
128
+ ]
129
+
130
+ last_error = None
131
+ for model in models:
132
+ try:
133
+ return _call_chat_api(
134
+ config, model, messages, config.temperature,
135
+ config.max_tokens, json_mode,
136
+ )
137
+ except Exception as e:
138
+ last_error = e
139
+ logger.warn(f"Model {model} failed, trying next: {e}")
140
+ continue
141
+
142
+ raise RuntimeError(f"All models failed. Last error: {last_error}")
@@ -0,0 +1,108 @@
1
+ """Structured AI outputs using JSON mode + dataclass parsing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
6
+ import json
7
+ from typing import Any, TypeVar, get_type_hints
8
+
9
+ from forge_core.models.config import AIConfig
10
+ from forge_core.ai.provider import complete
11
+ from forge_core.utils import logger
12
+
13
+ T = TypeVar("T")
14
+
15
+
16
+ def _schema_from_dataclass(cls: type) -> dict[str, Any]:
17
+ """Generate a minimal JSON schema from a dataclass."""
18
+ hints = get_type_hints(cls)
19
+ properties: dict[str, Any] = {}
20
+ for f in dataclasses.fields(cls):
21
+ hint = hints.get(f.name, str)
22
+ hint_str = str(hint)
23
+ if hint_str.startswith("list") or "list[" in hint_str.lower():
24
+ properties[f.name] = {"type": "array"}
25
+ elif hint_str.startswith("dict") or "dict[" in hint_str.lower():
26
+ properties[f.name] = {"type": "object"}
27
+ elif hint in (int,) or "int" in hint_str:
28
+ properties[f.name] = {"type": "integer"}
29
+ elif hint in (float,):
30
+ properties[f.name] = {"type": "number"}
31
+ elif hint in (bool,):
32
+ properties[f.name] = {"type": "boolean"}
33
+ else:
34
+ properties[f.name] = {"type": "string"}
35
+ return {"type": "object", "properties": properties}
36
+
37
+
38
+ def _dict_to_dataclass(cls: type[T], data: dict[str, Any]) -> T:
39
+ """Recursively convert a dict to a dataclass, handling nested dataclasses."""
40
+ hints = get_type_hints(cls)
41
+ kwargs: dict[str, Any] = {}
42
+ for f in dataclasses.fields(cls):
43
+ if f.name not in data:
44
+ continue
45
+ val = data[f.name]
46
+ hint = hints.get(f.name)
47
+ # Check if the hint is itself a dataclass
48
+ if dataclasses.is_dataclass(hint) and isinstance(val, dict):
49
+ kwargs[f.name] = _dict_to_dataclass(hint, val)
50
+ elif isinstance(val, list) and val:
51
+ # Try to detect nested dataclass in list
52
+ inner = getattr(hint, "__args__", [None])[0] if hasattr(hint, "__args__") else None
53
+ if inner and dataclasses.is_dataclass(inner):
54
+ kwargs[f.name] = [_dict_to_dataclass(inner, item) if isinstance(item, dict) else item for item in val]
55
+ else:
56
+ kwargs[f.name] = val
57
+ else:
58
+ kwargs[f.name] = val
59
+ return cls(**kwargs)
60
+
61
+
62
+ def extract(
63
+ config: AIConfig,
64
+ system_prompt: str,
65
+ user_prompt: str,
66
+ response_model: type[T],
67
+ max_retries: int = 2,
68
+ ) -> T:
69
+ """Extract structured data from AI using JSON mode + dataclass.
70
+
71
+ Forces the AI to return data matching a dataclass schema.
72
+ Retries on validation failure.
73
+ """
74
+ schema = _schema_from_dataclass(response_model)
75
+ structured_prompt = (
76
+ f"{system_prompt}\n\n"
77
+ f"You MUST respond with valid JSON matching this schema:\n"
78
+ f"```json\n{json.dumps(schema, indent=2)}\n```\n"
79
+ f"Respond ONLY with the JSON object, no other text."
80
+ )
81
+
82
+ logger.info(f"Structured extraction → {config.model} → {response_model.__name__}")
83
+
84
+ last_error = None
85
+ for attempt in range(max_retries + 1):
86
+ try:
87
+ raw = complete(config, structured_prompt, user_prompt, json_mode=True)
88
+ # Strip markdown code fences if present
89
+ cleaned = raw.strip()
90
+ if cleaned.startswith("```"):
91
+ cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
92
+ if cleaned.endswith("```"):
93
+ cleaned = cleaned[:-3]
94
+ cleaned = cleaned.strip()
95
+
96
+ data = json.loads(cleaned)
97
+ result = _dict_to_dataclass(response_model, data)
98
+ logger.success(f"Extracted {response_model.__name__} successfully")
99
+ return result
100
+ except Exception as e:
101
+ last_error = e
102
+ if attempt < max_retries:
103
+ logger.warn(f"Extraction attempt {attempt + 1} failed: {e}, retrying...")
104
+ continue
105
+
106
+ raise RuntimeError(
107
+ f"Structured extraction failed after {max_retries + 1} attempts: {last_error}"
108
+ )
@@ -1,13 +1,12 @@
1
- """Pydantic models for configuration and tenant/plan data."""
1
+ """Data models for configuration and tenant/plan data."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from dataclasses import dataclass, field
5
6
  from enum import Enum
6
7
  from pathlib import Path
7
8
  from typing import Optional
8
9
 
9
- from pydantic import BaseModel, Field
10
-
11
10
 
12
11
  class Plan(str, Enum):
13
12
  FREE = "free"
@@ -31,7 +30,8 @@ class RunMode(str, Enum):
31
30
  ANALYZE_REVIEW = "analyze_review"
32
31
 
33
32
 
34
- class TenantInfo(BaseModel):
33
+ @dataclass
34
+ class TenantInfo:
35
35
  """Multi-tenant identification — populated by SaaS or CLI."""
36
36
 
37
37
  org_id: str = ""
@@ -40,7 +40,8 @@ class TenantInfo(BaseModel):
40
40
  project_id: str = ""
41
41
 
42
42
 
43
- class PlanLimits(BaseModel):
43
+ @dataclass
44
+ class PlanLimits:
44
45
  """Usage limits per plan tier."""
45
46
 
46
47
  plan: Plan = Plan.FREE
@@ -55,7 +56,7 @@ class PlanLimits(BaseModel):
55
56
  if plan == Plan.PRO:
56
57
  return cls(
57
58
  plan=plan,
58
- max_tests_per_month=-1, # unlimited
59
+ max_tests_per_month=-1,
59
60
  max_repos=-1,
60
61
  ci_cd_enabled=True,
61
62
  cross_project_learning=True,
@@ -73,7 +74,8 @@ class PlanLimits(BaseModel):
73
74
  return cls()
74
75
 
75
76
 
76
- class AIConfig(BaseModel):
77
+ @dataclass
78
+ class AIConfig:
77
79
  """AI provider configuration."""
78
80
 
79
81
  provider: AIProvider = AIProvider.AUTO
@@ -82,27 +84,28 @@ class AIConfig(BaseModel):
82
84
  base_url: str = ""
83
85
  temperature: float = 0.1
84
86
  max_tokens: int = 4096
85
- use_saas_proxy: bool = False # True = use our API key via SaaS
87
+ use_saas_proxy: bool = False
86
88
 
87
89
 
88
- class ForgeConfig(BaseModel):
90
+ @dataclass
91
+ class ForgeConfig:
89
92
  """Top-level engine configuration."""
90
93
 
91
94
  # Project
92
- project_path: Path = Field(default_factory=lambda: Path("."))
95
+ project_path: Path = field(default_factory=lambda: Path("."))
93
96
  target_coverage: float = 90.0
94
97
  max_iterations: int = 10
95
98
  mode: RunMode = RunMode.FULL
96
- target_files: list[str] = Field(default_factory=list)
99
+ target_files: list[str] = field(default_factory=list)
97
100
 
98
101
  # AI
99
- ai: AIConfig = Field(default_factory=AIConfig)
102
+ ai: AIConfig = field(default_factory=AIConfig)
100
103
 
101
104
  # Tenant
102
- tenant: TenantInfo = Field(default_factory=TenantInfo)
105
+ tenant: TenantInfo = field(default_factory=TenantInfo)
103
106
 
104
107
  # Plan
105
- limits: PlanLimits = Field(default_factory=PlanLimits)
108
+ limits: PlanLimits = field(default_factory=PlanLimits)
106
109
 
107
110
  # Engine
108
111
  central_agent_path: str = ""
@@ -1,34 +1,35 @@
1
- """Pydantic models for DTO registry."""
1
+ """Data models for DTO registry."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from dataclasses import dataclass, field
5
6
  from typing import Optional
6
7
 
7
- from pydantic import BaseModel, Field
8
8
 
9
-
10
- class DTOParam(BaseModel):
9
+ @dataclass
10
+ class DTOParam:
11
11
  """A single constructor/field parameter of a DTO."""
12
12
 
13
- name: str
14
- type: str
13
+ name: str = ""
14
+ type: str = ""
15
15
  default: str = ""
16
16
  nullable: bool = False
17
17
 
18
18
 
19
- class DTOEntry(BaseModel):
19
+ @dataclass
20
+ class DTOEntry:
20
21
  """A single DTO class registered in the global registry."""
21
22
 
22
- class_name: str
23
- package: str
24
- file_path: str
25
- params: list[DTOParam] = Field(default_factory=list)
23
+ class_name: str = ""
24
+ package: str = ""
25
+ file_path: str = ""
26
+ params: list[DTOParam] = field(default_factory=list)
26
27
  has_builder: bool = False
27
28
  has_factory: bool = False
28
- validation_annotations: list[str] = Field(default_factory=list)
29
- nested_dtos: list[str] = Field(default_factory=list) # class names of nested DTOs
30
- used_in_journeys: list[str] = Field(default_factory=list)
31
- used_in_layers: list[str] = Field(default_factory=list)
29
+ validation_annotations: list[str] = field(default_factory=list)
30
+ nested_dtos: list[str] = field(default_factory=list)
31
+ used_in_journeys: list[str] = field(default_factory=list)
32
+ used_in_layers: list[str] = field(default_factory=list)
32
33
 
33
34
  def mock_snippet(self) -> str:
34
35
  """Generate a minimal mock/test instance snippet."""
@@ -39,15 +40,16 @@ class DTOEntry(BaseModel):
39
40
  return f"{self.class_name}({params_str})"
40
41
 
41
42
 
42
- class DTORegistry(BaseModel):
43
- """Global registry of all DTOs in the project — read once, shared everywhere."""
43
+ @dataclass
44
+ class DTORegistry:
45
+ """Global registry of all DTOs in the project."""
44
46
 
45
- entries: dict[str, DTOEntry] = Field(default_factory=dict) # class_name → DTOEntry
47
+ entries: dict[str, DTOEntry] = field(default_factory=dict)
46
48
 
47
49
  def register(self, entry: DTOEntry) -> None:
48
50
  self.entries[entry.class_name] = entry
49
51
 
50
- def get(self, class_name: str) -> "Optional[DTOEntry]":
52
+ def get(self, class_name: str) -> Optional[DTOEntry]:
51
53
  return self.entries.get(class_name)
52
54
 
53
55
  def for_journey(self, journey_name: str) -> list[DTOEntry]:
@@ -0,0 +1,78 @@
1
+ """Data models for project structure — 4-Level DAG."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+
8
+ @dataclass
9
+ class Component:
10
+ """A single class/file within a layer."""
11
+
12
+ name: str = ""
13
+ file_path: str = ""
14
+ layer: str = ""
15
+ dependencies: list[str] = field(default_factory=list)
16
+ is_tested: bool = False
17
+ existing_test_file: str = ""
18
+ coverage_pct: float = 0.0
19
+
20
+
21
+ @dataclass
22
+ class Journey:
23
+ """A traced user journey across layers."""
24
+
25
+ name: str = ""
26
+ entry_point: str = ""
27
+ entry_type: str = ""
28
+ components: list[str] = field(default_factory=list)
29
+ priority: int = 1
30
+ description: str = ""
31
+
32
+
33
+ @dataclass
34
+ class Layer:
35
+ """A functional layer within a module."""
36
+
37
+ name: str = ""
38
+ components: list[Component] = field(default_factory=list)
39
+
40
+
41
+ @dataclass
42
+ class Module:
43
+ """A module/package within the project."""
44
+
45
+ name: str = ""
46
+ path: str = ""
47
+ layers: list[Layer] = field(default_factory=list)
48
+ journeys: list[Journey] = field(default_factory=list)
49
+
50
+
51
+ @dataclass
52
+ class TechStack:
53
+ """Detected technology stack."""
54
+
55
+ language: str = ""
56
+ framework: str = ""
57
+ build_tool: str = ""
58
+ test_framework: str = ""
59
+ mock_library: str = ""
60
+ coverage_tool: str = ""
61
+ source_root: str = ""
62
+ test_root: str = ""
63
+ test_command: str = ""
64
+ coverage_command: str = ""
65
+ is_monorepo: bool = False
66
+ modules: list[str] = field(default_factory=list)
67
+
68
+
69
+ @dataclass
70
+ class ProjectGraph:
71
+ """4-Level DAG: Project → Modules → Layers → Components."""
72
+
73
+ name: str = ""
74
+ root_path: str = ""
75
+ tech_stack: TechStack = field(default_factory=TechStack)
76
+ modules: list[Module] = field(default_factory=list)
77
+ total_source_files: int = 0
78
+ total_test_files: int = 0
@@ -1,40 +1,42 @@
1
- """Pydantic models for test results and coverage reports."""
1
+ """Data models for test results and coverage reports."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from dataclasses import dataclass, field
5
6
  from datetime import datetime
6
7
  from typing import Optional
7
8
 
8
- from pydantic import BaseModel, Field
9
9
 
10
-
11
- class TestFileResult(BaseModel):
10
+ @dataclass
11
+ class TestFileResult:
12
12
  """Result of a single generated test file."""
13
13
 
14
- file_path: str
14
+ file_path: str = ""
15
15
  test_count: int = 0
16
16
  passed: int = 0
17
17
  failed: int = 0
18
- errors: list[str] = Field(default_factory=list)
19
- compile_errors: list[str] = Field(default_factory=list)
18
+ errors: list[str] = field(default_factory=list)
19
+ compile_errors: list[str] = field(default_factory=list)
20
20
 
21
21
 
22
- class CoverageEntry(BaseModel):
22
+ @dataclass
23
+ class CoverageEntry:
23
24
  """Coverage data for a single source file."""
24
25
 
25
- file_path: str
26
+ file_path: str = ""
26
27
  line_coverage: float = 0.0
27
28
  branch_coverage: float = 0.0
28
29
  lines_covered: int = 0
29
30
  lines_total: int = 0
30
31
 
31
32
 
32
- class CoverageReport(BaseModel):
33
+ @dataclass
34
+ class CoverageReport:
33
35
  """Aggregated coverage report."""
34
36
 
35
37
  line_coverage: float = 0.0
36
38
  branch_coverage: float = 0.0
37
- files: list[CoverageEntry] = Field(default_factory=list)
39
+ files: list[CoverageEntry] = field(default_factory=list)
38
40
  total_lines_covered: int = 0
39
41
  total_lines: int = 0
40
42
  total_tests: int = 0
@@ -42,11 +44,12 @@ class CoverageReport(BaseModel):
42
44
  tests_failed: int = 0
43
45
 
44
46
 
45
- class IterationResult(BaseModel):
47
+ @dataclass
48
+ class IterationResult:
46
49
  """Result of a single generation iteration."""
47
50
 
48
- iteration: int
49
- tests_generated: list[TestFileResult] = Field(default_factory=list)
51
+ iteration: int = 0
52
+ tests_generated: list[TestFileResult] = field(default_factory=list)
50
53
  coverage_before: float = 0.0
51
54
  coverage_after: float = 0.0
52
55
  coverage_delta: float = 0.0
@@ -54,14 +57,15 @@ class IterationResult(BaseModel):
54
57
  duration_seconds: float = 0.0
55
58
 
56
59
 
57
- class RunReport(BaseModel):
60
+ @dataclass
61
+ class RunReport:
58
62
  """Final report for a complete Forge Core run."""
59
63
 
60
64
  project_name: str = ""
61
65
  project_path: str = ""
62
66
  language: str = ""
63
67
  framework: str = ""
64
- started_at: datetime = Field(default_factory=datetime.now)
68
+ started_at: datetime = field(default_factory=datetime.now)
65
69
  completed_at: Optional[datetime] = None
66
70
  duration_seconds: float = 0.0
67
71
 
@@ -75,10 +79,10 @@ class RunReport(BaseModel):
75
79
  total_tests_after: int = 0
76
80
  tests_generated: int = 0
77
81
  tests_fixed: int = 0
78
- test_files_created: list[str] = Field(default_factory=list)
82
+ test_files_created: list[str] = field(default_factory=list)
79
83
 
80
84
  # Iterations
81
- iterations: list[IterationResult] = Field(default_factory=list)
85
+ iterations: list[IterationResult] = field(default_factory=list)
82
86
  total_iterations: int = 0
83
87
  rollbacks: int = 0
84
88
 
@@ -90,4 +94,4 @@ class RunReport(BaseModel):
90
94
  # Metadata
91
95
  mode: str = "full"
92
96
  target_coverage: float = 90.0
93
- production_files_changed: int = 0 # must always be 0
97
+ production_files_changed: int = 0
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "switchforge"
7
- version = "1.1.0"
7
+ version = "2.0.0"
8
8
  description = "AI-powered backend test generation engine"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -25,12 +25,10 @@ classifiers = [
25
25
  ]
26
26
 
27
27
  dependencies = [
28
- "openai>=1.30.0",
29
28
  "typer>=0.12.0",
30
29
  "rich>=13.7.0",
31
30
  "pyyaml>=6.0",
32
31
  "httpx>=0.27.0",
33
- "pydantic>=2.7.0",
34
32
  ]
35
33
 
36
34
  [project.optional-dependencies]
@@ -1,120 +0,0 @@
1
- """AI provider — direct OpenAI SDK, zero native dependencies."""
2
-
3
- from __future__ import annotations
4
-
5
- import os
6
- from typing import Any
7
-
8
- from openai import OpenAI
9
-
10
- from forge_core.models.config import AIConfig, AIProvider
11
- from forge_core.utils import logger
12
- from forge_core.utils.tokens import count_tokens
13
-
14
-
15
- def _get_client(config: AIConfig) -> OpenAI:
16
- """Create an OpenAI-compatible client for any provider."""
17
- api_key = config.api_key or os.environ.get("OPENAI_API_KEY", "")
18
-
19
- if config.provider == AIProvider.ANTHROPIC:
20
- base_url = config.base_url or "https://api.anthropic.com/v1"
21
- api_key = config.api_key or os.environ.get("ANTHROPIC_API_KEY", "")
22
- elif config.provider == AIProvider.OLLAMA:
23
- base_url = config.base_url or "http://localhost:11434/v1"
24
- api_key = "ollama"
25
- elif config.provider == AIProvider.AZURE:
26
- base_url = config.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT", "")
27
- api_key = config.api_key or os.environ.get("AZURE_OPENAI_API_KEY", "")
28
- elif config.base_url:
29
- base_url = config.base_url
30
- else:
31
- base_url = None # default OpenAI
32
-
33
- kwargs: dict[str, Any] = {"api_key": api_key}
34
- if base_url:
35
- kwargs["base_url"] = base_url
36
-
37
- return OpenAI(**kwargs)
38
-
39
-
40
- def _resolve_model(config: AIConfig) -> str:
41
- """Resolve the model name."""
42
- return config.model
43
-
44
-
45
- def complete(
46
- config: AIConfig,
47
- system_prompt: str,
48
- user_prompt: str,
49
- json_mode: bool = False,
50
- max_tokens: int | None = None,
51
- ) -> str:
52
- """Send a completion request to the configured AI provider."""
53
- model = _resolve_model(config)
54
- client = _get_client(config)
55
-
56
- messages = [
57
- {"role": "system", "content": system_prompt},
58
- {"role": "user", "content": user_prompt},
59
- ]
60
-
61
- kwargs: dict[str, Any] = {
62
- "model": model,
63
- "messages": messages,
64
- "temperature": config.temperature,
65
- "max_tokens": max_tokens or config.max_tokens,
66
- }
67
-
68
- if json_mode:
69
- kwargs["response_format"] = {"type": "json_object"}
70
-
71
- input_tokens = count_tokens(system_prompt + user_prompt, config.model)
72
- logger.info(f"AI call → {model} ({input_tokens} input tokens)")
73
-
74
- try:
75
- response = client.chat.completions.create(**kwargs)
76
- content = response.choices[0].message.content or ""
77
-
78
- output_tokens = count_tokens(content, config.model)
79
- logger.info(f"AI response ← {output_tokens} output tokens")
80
-
81
- return content
82
- except Exception as e:
83
- logger.error(f"AI call failed: {e}")
84
- raise
85
-
86
-
87
- def complete_with_fallback(
88
- config: AIConfig,
89
- system_prompt: str,
90
- user_prompt: str,
91
- fallback_models: list[str] | None = None,
92
- json_mode: bool = False,
93
- ) -> str:
94
- """Try primary model, fall back to alternatives on failure."""
95
- models = [_resolve_model(config)] + (fallback_models or [])
96
- client = _get_client(config)
97
-
98
- last_error = None
99
- for model in models:
100
- try:
101
- kwargs: dict[str, Any] = {
102
- "model": model,
103
- "messages": [
104
- {"role": "system", "content": system_prompt},
105
- {"role": "user", "content": user_prompt},
106
- ],
107
- "temperature": config.temperature,
108
- "max_tokens": config.max_tokens,
109
- }
110
- if json_mode:
111
- kwargs["response_format"] = {"type": "json_object"}
112
-
113
- response = client.chat.completions.create(**kwargs)
114
- return response.choices[0].message.content or ""
115
- except Exception as e:
116
- last_error = e
117
- logger.warn(f"Model {model} failed, trying next: {e}")
118
- continue
119
-
120
- raise RuntimeError(f"All models failed. Last error: {last_error}")
@@ -1,62 +0,0 @@
1
- """Structured AI outputs using OpenAI JSON mode + pydantic parsing."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- from typing import TypeVar
7
-
8
- from pydantic import BaseModel
9
-
10
- from forge_core.models.config import AIConfig
11
- from forge_core.ai.provider import complete
12
- from forge_core.utils import logger
13
-
14
- T = TypeVar("T", bound=BaseModel)
15
-
16
-
17
- def extract(
18
- config: AIConfig,
19
- system_prompt: str,
20
- user_prompt: str,
21
- response_model: type[T],
22
- max_retries: int = 2,
23
- ) -> T:
24
- """Extract structured data from AI using JSON mode + pydantic.
25
-
26
- Forces the AI to return data matching a pydantic model schema.
27
- Retries on validation failure.
28
- """
29
- schema = response_model.model_json_schema()
30
- structured_prompt = (
31
- f"{system_prompt}\n\n"
32
- f"You MUST respond with valid JSON matching this schema:\n"
33
- f"```json\n{json.dumps(schema, indent=2)}\n```\n"
34
- f"Respond ONLY with the JSON object, no other text."
35
- )
36
-
37
- logger.info(f"Structured extraction → {config.model} → {response_model.__name__}")
38
-
39
- last_error = None
40
- for attempt in range(max_retries + 1):
41
- try:
42
- raw = complete(config, structured_prompt, user_prompt, json_mode=True)
43
- # Strip markdown code fences if present
44
- cleaned = raw.strip()
45
- if cleaned.startswith("```"):
46
- cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
47
- if cleaned.endswith("```"):
48
- cleaned = cleaned[:-3]
49
- cleaned = cleaned.strip()
50
-
51
- result = response_model.model_validate_json(cleaned)
52
- logger.success(f"Extracted {response_model.__name__} successfully")
53
- return result
54
- except Exception as e:
55
- last_error = e
56
- if attempt < max_retries:
57
- logger.warn(f"Extraction attempt {attempt + 1} failed: {e}, retrying...")
58
- continue
59
-
60
- raise RuntimeError(
61
- f"Structured extraction failed after {max_retries + 1} attempts: {last_error}"
62
- )
@@ -1,72 +0,0 @@
1
- """Pydantic models for project structure — 4-Level DAG."""
2
-
3
- from __future__ import annotations
4
-
5
- from pydantic import BaseModel, Field
6
-
7
-
8
- class Component(BaseModel):
9
- """A single class/file within a layer."""
10
-
11
- name: str
12
- file_path: str
13
- layer: str = ""
14
- dependencies: list[str] = Field(default_factory=list)
15
- is_tested: bool = False
16
- existing_test_file: str = ""
17
- coverage_pct: float = 0.0
18
-
19
-
20
- class Journey(BaseModel):
21
- """A traced user journey across layers."""
22
-
23
- name: str
24
- entry_point: str
25
- entry_type: str = "" # route, consumer, job, grpc, cli
26
- components: list[str] = Field(default_factory=list) # ordered list of component names
27
- priority: int = 1 # 1 = critical, 5 = low
28
- description: str = ""
29
-
30
-
31
- class Layer(BaseModel):
32
- """A functional layer within a module (e.g., controllers, services, adapters)."""
33
-
34
- name: str
35
- components: list[Component] = Field(default_factory=list)
36
-
37
-
38
- class Module(BaseModel):
39
- """A module/package within the project."""
40
-
41
- name: str
42
- path: str
43
- layers: list[Layer] = Field(default_factory=list)
44
- journeys: list[Journey] = Field(default_factory=list)
45
-
46
-
47
- class TechStack(BaseModel):
48
- """Detected technology stack."""
49
-
50
- language: str = ""
51
- framework: str = ""
52
- build_tool: str = ""
53
- test_framework: str = ""
54
- mock_library: str = ""
55
- coverage_tool: str = ""
56
- source_root: str = ""
57
- test_root: str = ""
58
- test_command: str = ""
59
- coverage_command: str = ""
60
- is_monorepo: bool = False
61
- modules: list[str] = Field(default_factory=list)
62
-
63
-
64
- class ProjectGraph(BaseModel):
65
- """4-Level DAG: Project → Modules → Layers → Components."""
66
-
67
- name: str = ""
68
- root_path: str = ""
69
- tech_stack: TechStack = Field(default_factory=TechStack)
70
- modules: list[Module] = Field(default_factory=list)
71
- total_source_files: int = 0
72
- total_test_files: int = 0
File without changes
File without changes
File without changes