testgap 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
testgap/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ from importlib.metadata import PackageNotFoundError, version
2
+
3
+ try:
4
+ __version__ = version("testgap")
5
+ except PackageNotFoundError:
6
+ __version__ = "0.0.0+unknown"
7
+
8
+ __all__ = ["__version__"]
testgap/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from testgap.cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
testgap/cli.py ADDED
@@ -0,0 +1,274 @@
1
+ from pathlib import Path
2
+
3
+ import typer
4
+ from rich.console import Console
5
+ from rich.markup import escape
6
+ from rich.prompt import Confirm, Prompt
7
+ from rich.table import Table
8
+
9
+ from testgap import __version__
10
+ from testgap.config.init_wizard import (
11
+ analyze,
12
+ build_config,
13
+ ensure_gitignore_entry,
14
+ provider_status,
15
+ suggest_model,
16
+ write_config,
17
+ )
18
+ from testgap.config.loader import CONFIG_FILENAME, ConfigError, load_config
19
+ from testgap.generator import LLMClient
20
+ from testgap.pipeline import DiffRunReport, FunctionSuggestion, run_diff
21
+
22
+ app = typer.Typer(
23
+ name="testgap",
24
+ help="AI-powered test generator that closes coverage gaps in your PRs.",
25
+ add_completion=False,
26
+ no_args_is_help=True,
27
+ )
28
+ console = Console()
29
+
30
+
31
+ def _version_callback(value: bool) -> None:
32
+ if value:
33
+ console.print(f"testgap {__version__}")
34
+ raise typer.Exit()
35
+
36
+
37
+ @app.callback()
38
+ def main(
39
+ version: bool = typer.Option(
40
+ None, "--version", callback=_version_callback, is_eager=True, help="Show version and exit."
41
+ ),
42
+ ) -> None:
43
+ pass
44
+
45
+
46
+ @app.command()
47
+ def init(
48
+ path: Path | None = typer.Option(
49
+ None, "--path", "-p", help="Project root to initialize.", file_okay=False
50
+ ),
51
+ yes: bool = typer.Option(
52
+ False, "--yes", "-y", help="Accept all detected defaults without prompts."
53
+ ),
54
+ ) -> None:
55
+ """Initialize TestGap in the current project (creates .testgap.yml)."""
56
+ root = (path or Path.cwd()).resolve()
57
+ if not root.is_dir():
58
+ console.print(f"[red]✗[/] {root} is not a directory")
59
+ raise typer.Exit(code=1)
60
+
61
+ console.print(f"[bold]Analyzing[/] {root}")
62
+ report = analyze(root)
63
+
64
+ if not report.pytest_signals:
65
+ console.print("[red]✗[/] No pytest project detected.")
66
+ console.print(" Install pytest first: [cyan]pip install pytest[/]")
67
+ raise typer.Exit(code=1)
68
+
69
+ console.print(f"[green]✓[/] pytest detected ({escape(report.pytest_signals[0])})")
70
+
71
+ if not report.has_git:
72
+ console.print(
73
+ "[yellow]![/] Not a git repository — `testgap diff` will not work until you `git init`."
74
+ )
75
+
76
+ existing = root / CONFIG_FILENAME
77
+ if existing.is_file() and not yes:
78
+ action = Prompt.ask(
79
+ f"[yellow]{CONFIG_FILENAME} already exists.[/] Action?",
80
+ choices=["overwrite", "backup", "cancel"],
81
+ default="cancel",
82
+ )
83
+ if action == "cancel":
84
+ console.print("Aborted.")
85
+ raise typer.Exit(code=0)
86
+ if action == "backup":
87
+ backup_path = existing.with_suffix(existing.suffix + ".bak")
88
+ backup_path.write_bytes(existing.read_bytes())
89
+ console.print(f" Backed up to {backup_path.name}")
90
+
91
+ source_paths = _choose_source_paths(report, yes=yes)
92
+ test_paths = report.test_paths or ["tests/"]
93
+ if not report.test_paths:
94
+ console.print(f"[yellow]![/] No tests/ directory found — defaulting to {test_paths[0]}")
95
+ else:
96
+ console.print(f"[green]✓[/] test directory: {test_paths[0]}")
97
+
98
+ model = _choose_model(yes=yes)
99
+
100
+ config = build_config(source_paths=source_paths, test_paths=test_paths, model=model)
101
+ config_path = write_config(config, root)
102
+ console.print(f"[green]✓[/] wrote {config_path.relative_to(root)}")
103
+
104
+ if ensure_gitignore_entry(root):
105
+ console.print("[green]✓[/] added .testgap/ to .gitignore")
106
+
107
+ console.print()
108
+ console.print("[bold]Next steps:[/]")
109
+ console.print(" [cyan]testgap diff --review[/] suggest tests for uncovered changes")
110
+
111
+
112
+ def _choose_source_paths(report, *, yes: bool) -> list[str]:
113
+ if report.source_paths and not report.layout_ambiguous:
114
+ console.print(f"[green]✓[/] source path: {report.source_paths[0]}")
115
+ return report.source_paths
116
+
117
+ if report.layout_ambiguous and not yes:
118
+ console.print("[yellow]?[/] multiple source candidates found:")
119
+ for i, p in enumerate(report.source_paths, 1):
120
+ console.print(f" [{i}] {p}")
121
+ choice = Prompt.ask(
122
+ " pick one",
123
+ choices=[str(i) for i in range(1, len(report.source_paths) + 1)],
124
+ default="1",
125
+ )
126
+ return [report.source_paths[int(choice) - 1]]
127
+
128
+ if not report.source_paths:
129
+ if yes:
130
+ console.print("[yellow]![/] no source layout detected — defaulting to src/")
131
+ return ["src/"]
132
+ custom = Prompt.ask(
133
+ "[yellow]?[/] no source layout detected. Source path?", default="src/"
134
+ )
135
+ return [custom]
136
+
137
+ return report.source_paths
138
+
139
+
140
+ @app.command()
141
+ def diff(
142
+ base: str | None = typer.Option(
143
+ None, "--base", "-b", help="Base git ref. Defaults to origin/HEAD then main/master."
144
+ ),
145
+ head: str = typer.Option("HEAD", "--head", help="Head ref. Defaults to HEAD."),
146
+ max_functions: int | None = typer.Option(
147
+ None, "--max-functions", "-n", help="Limit number of functions processed."
148
+ ),
149
+ path: Path | None = typer.Option(None, "--path", "-p", file_okay=False),
150
+ ) -> None:
151
+ """Analyze the diff and propose tests for uncovered changes (non-interactive)."""
152
+ root = (path or Path.cwd()).resolve()
153
+
154
+ try:
155
+ config = load_config()
156
+ except ConfigError as e:
157
+ console.print(f"[red]✗[/] {escape(str(e))}")
158
+ raise typer.Exit(code=1) from e
159
+
160
+ console.print(f"[bold]Analyzing diff[/] in {root}")
161
+
162
+ llm_client = LLMClient(model=config.llm.model, max_retries=config.llm.max_retries)
163
+
164
+ try:
165
+ report = run_diff(
166
+ project_root=root,
167
+ config=config,
168
+ llm_client=llm_client,
169
+ base_ref=base,
170
+ head_ref=head,
171
+ max_functions=max_functions,
172
+ )
173
+ except Exception as e: # surface user-facing errors from coverage/git layers
174
+ console.print(f"[red]✗[/] {escape(str(e))}")
175
+ raise typer.Exit(code=1) from e
176
+
177
+ _print_diff_report(report)
178
+
179
+ if report.suggestions and not all(s.succeeded for s in report.suggestions):
180
+ raise typer.Exit(code=1)
181
+
182
+
183
+ def _print_diff_report(report: DiffRunReport) -> None:
184
+ console.print(f"[dim]base[/] {report.base_ref} → [dim]head[/] {report.head_ref}")
185
+
186
+ if report.skipped_reason:
187
+ console.print(f"[green]✓[/] {report.skipped_reason}")
188
+ return
189
+
190
+ summary = (
191
+ f"changed lines: {report.changed_total} "
192
+ f"covered: {report.covered_total} "
193
+ f"diff coverage: {report.diff_coverage_pct}%"
194
+ )
195
+ console.print(summary)
196
+ console.print()
197
+
198
+ for i, suggestion in enumerate(report.suggestions, 1):
199
+ _print_suggestion(i, len(report.suggestions), suggestion)
200
+
201
+ console.print()
202
+ console.print(f"[dim]LLM cost this run:[/] ${report.cost_total:.4f}")
203
+
204
+
205
+ def _print_suggestion(idx: int, total: int, s: FunctionSuggestion) -> None:
206
+ file_label = escape(f"{s.function.file.name}::{s.function.qualname}")
207
+ header = f"[{idx}/{total}] {file_label}"
208
+ console.print(f"[bold]{header}[/]")
209
+ lines_str = ", ".join(str(n) for n in s.function.uncovered_lines[:8])
210
+ if len(s.function.uncovered_lines) > 8:
211
+ lines_str += ", …"
212
+ console.print(f" uncovered lines: {lines_str}")
213
+
214
+ if s.error:
215
+ console.print(f" [red]✗[/] {escape(s.error)}")
216
+ return
217
+
218
+ if s.validator_result is None or s.generated is None:
219
+ console.print(" [yellow]![/] no result captured")
220
+ return
221
+
222
+ if s.validator_result.environment_error:
223
+ console.print(f" [red]✗[/] {escape(s.validator_result.environment_error)}")
224
+ return
225
+
226
+ accepted_n = len(s.accepted_cases)
227
+ discarded_n = len(s.discarded_cases)
228
+ total_n = accepted_n + discarded_n
229
+ cost_label = f"${s.cost_usd:.4f}" if s.cost_usd > 0 else "$0 (cost unknown)"
230
+ retried_marker = " [retried]" if s.attempts == 2 else ""
231
+
232
+ if s.fully_passed:
233
+ console.print(
234
+ f" [green]✓[/] {accepted_n}/{total_n} tests passed {cost_label}{retried_marker}"
235
+ )
236
+ elif s.succeeded:
237
+ console.print(
238
+ f" [yellow]![/] {accepted_n} kept / {discarded_n} discarded "
239
+ f"{cost_label}{retried_marker}"
240
+ )
241
+ else:
242
+ console.print(
243
+ f" [yellow]![/] {accepted_n} pass / {discarded_n} fail of {total_n} "
244
+ f"{cost_label}{retried_marker}"
245
+ )
246
+
247
+ if s.retry_skipped_reason:
248
+ console.print(f" [yellow]![/] retry skipped: {escape(s.retry_skipped_reason)}")
249
+
250
+ for case in s.discarded_cases[:3]:
251
+ console.print(f" [red]·[/] {escape(case.name)}")
252
+
253
+
254
+ def _choose_model(*, yes: bool) -> str:
255
+ suggested = suggest_model()
256
+ if yes:
257
+ return suggested
258
+
259
+ table = Table(show_header=True, header_style="bold", title="Available LLM providers")
260
+ table.add_column("model")
261
+ table.add_column("status")
262
+ rows = provider_status()
263
+ options: list[str] = []
264
+ for model, status in rows:
265
+ marker = "→" if model == suggested else " "
266
+ table.add_row(f"{marker} {model}", status)
267
+ options.append(model)
268
+ console.print(table)
269
+
270
+ use_default = Confirm.ask(f"Use suggested model [cyan]{suggested}[/]?", default=True)
271
+ if use_default:
272
+ return suggested
273
+ choice = Prompt.ask("Enter model id", default=suggested)
274
+ return choice or suggested
@@ -0,0 +1,15 @@
1
+ from testgap.config.schema import (
2
+ CoverageConfig,
3
+ GenerationConfig,
4
+ LLMConfig,
5
+ ProjectConfig,
6
+ TestGapConfig,
7
+ )
8
+
9
+ __all__ = [
10
+ "TestGapConfig",
11
+ "ProjectConfig",
12
+ "CoverageConfig",
13
+ "LLMConfig",
14
+ "GenerationConfig",
15
+ ]
@@ -0,0 +1,113 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+
5
+ from testgap.config.loader import CONFIG_FILENAME, dump_config
6
+ from testgap.config.schema import (
7
+ CoverageConfig,
8
+ GenerationConfig,
9
+ LLMConfig,
10
+ ProjectConfig,
11
+ TestGapConfig,
12
+ )
13
+ from testgap.detect import detect_layout, detect_pytest, detect_source_paths, detect_test_dirs
14
+
15
+
16
+ @dataclass
17
+ class DetectionReport:
18
+ pytest_signals: list[str]
19
+ source_paths: list[str]
20
+ test_paths: list[str]
21
+ layout_kind: str
22
+ layout_ambiguous: bool
23
+ has_git: bool
24
+
25
+
26
+ def analyze(root: Path) -> DetectionReport:
27
+ pytest = detect_pytest(root)
28
+ layout = detect_layout(root)
29
+ test_dirs = detect_test_dirs(root)
30
+
31
+ source_paths = detect_source_paths(root)
32
+ ambiguous = layout.kind.value == "flat" and len(layout.candidates) > 1
33
+
34
+ test_paths = sorted({f"{p.relative_to(root).as_posix()}/" for p in test_dirs.paths})
35
+
36
+ return DetectionReport(
37
+ pytest_signals=pytest.signals,
38
+ source_paths=source_paths,
39
+ test_paths=test_paths,
40
+ layout_kind=layout.kind.value,
41
+ layout_ambiguous=ambiguous,
42
+ has_git=(root / ".git").exists(),
43
+ )
44
+
45
+
46
+ _KNOWN_PROVIDERS = (
47
+ ("anthropic/claude-sonnet-4-6", "ANTHROPIC_API_KEY"),
48
+ ("openai/gpt-4o", "OPENAI_API_KEY"),
49
+ ("gemini/gemini-2.0-flash", "GEMINI_API_KEY"),
50
+ ("ollama/qwen2.5-coder", None),
51
+ )
52
+
53
+
54
+ def suggest_model() -> str:
55
+ """Pick the first provider whose API key is set in env; fall back to Ollama."""
56
+ for model, env_var in _KNOWN_PROVIDERS:
57
+ if env_var is None:
58
+ continue
59
+ if os.environ.get(env_var):
60
+ return model
61
+ return "ollama/qwen2.5-coder"
62
+
63
+
64
+ def provider_status() -> list[tuple[str, str]]:
65
+ """Return (model, status) pairs for display in the wizard."""
66
+ rows: list[tuple[str, str]] = []
67
+ for model, env_var in _KNOWN_PROVIDERS:
68
+ if env_var is None:
69
+ rows.append((model, "local model"))
70
+ elif os.environ.get(env_var):
71
+ rows.append((model, f"{env_var} found"))
72
+ else:
73
+ rows.append((model, f"{env_var} not set"))
74
+ return rows
75
+
76
+
77
+ def build_config(
78
+ *,
79
+ source_paths: list[str],
80
+ test_paths: list[str],
81
+ model: str,
82
+ ) -> TestGapConfig:
83
+ return TestGapConfig(
84
+ project=ProjectConfig(
85
+ source_paths=source_paths or ["src/"],
86
+ test_paths=test_paths or ["tests/"],
87
+ ),
88
+ coverage=CoverageConfig(),
89
+ llm=LLMConfig(model=model),
90
+ generation=GenerationConfig(),
91
+ )
92
+
93
+
94
+ def write_config(config: TestGapConfig, root: Path) -> Path:
95
+ path = root / CONFIG_FILENAME
96
+ dump_config(config, path)
97
+ return path
98
+
99
+
100
+ def ensure_gitignore_entry(root: Path, entry: str = ".testgap/") -> bool:
101
+ """Append entry to .gitignore if missing. Returns True if file was modified."""
102
+ gitignore = root / ".gitignore"
103
+ if not gitignore.is_file():
104
+ gitignore.write_text(f"# TestGap\n{entry}\n", encoding="utf-8")
105
+ return True
106
+ content = gitignore.read_text(encoding="utf-8")
107
+ lines = {line.strip() for line in content.splitlines()}
108
+ if entry.strip() in lines or entry.rstrip("/") in lines:
109
+ return False
110
+ suffix = "" if content.endswith("\n") or not content else "\n"
111
+ with gitignore.open("a", encoding="utf-8") as f:
112
+ f.write(f"{suffix}\n# TestGap\n{entry}\n")
113
+ return True
@@ -0,0 +1,54 @@
1
+ from pathlib import Path
2
+
3
+ import yaml
4
+ from pydantic import ValidationError
5
+
6
+ from testgap.config.schema import TestGapConfig
7
+
8
+ CONFIG_FILENAME = ".testgap.yml"
9
+
10
+
11
+ class ConfigError(Exception):
12
+ pass
13
+
14
+
15
+ class ConfigNotFoundError(ConfigError):
16
+ pass
17
+
18
+
19
+ class ConfigInvalidError(ConfigError):
20
+ pass
21
+
22
+
23
+ def find_config(start: Path | None = None) -> Path:
24
+ cwd = (start or Path.cwd()).resolve()
25
+ for parent in [cwd, *cwd.parents]:
26
+ candidate = parent / CONFIG_FILENAME
27
+ if candidate.is_file():
28
+ return candidate
29
+ raise ConfigNotFoundError(
30
+ f"{CONFIG_FILENAME} not found in {cwd} or any parent directory. "
31
+ "Run `testgap init` to create one."
32
+ )
33
+
34
+
35
+ def load_config(path: Path | None = None) -> TestGapConfig:
36
+ config_path = path or find_config()
37
+ try:
38
+ raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}
39
+ except yaml.YAMLError as e:
40
+ raise ConfigInvalidError(f"Failed to parse {config_path}: {e}") from e
41
+
42
+ if not isinstance(raw, dict):
43
+ raise ConfigInvalidError(f"{config_path}: root must be a mapping, got {type(raw).__name__}")
44
+
45
+ try:
46
+ return TestGapConfig.model_validate(raw)
47
+ except ValidationError as e:
48
+ raise ConfigInvalidError(f"Invalid config at {config_path}:\n{e}") from e
49
+
50
+
51
+ def dump_config(config: TestGapConfig, path: Path) -> None:
52
+ data = config.model_dump(mode="json")
53
+ yaml_text = yaml.safe_dump(data, sort_keys=False, allow_unicode=True, default_flow_style=False)
54
+ path.write_text(yaml_text, encoding="utf-8")
@@ -0,0 +1,48 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel, Field, field_validator
4
+
5
+
6
+ class ProjectConfig(BaseModel):
7
+ language: Literal["python"] = "python"
8
+ test_framework: Literal["pytest"] = "pytest"
9
+ source_paths: list[str] = Field(default_factory=lambda: ["src/"])
10
+ test_paths: list[str] = Field(default_factory=lambda: ["tests/"])
11
+
12
+
13
+ class CoverageConfig(BaseModel):
14
+ threshold: int = Field(default=80, ge=0, le=100)
15
+ diff_threshold: int = Field(default=90, ge=0, le=100)
16
+ exclude: list[str] = Field(
17
+ default_factory=lambda: ["**/migrations/**", "**/__init__.py"]
18
+ )
19
+
20
+
21
+ class LLMConfig(BaseModel):
22
+ model: str = "anthropic/claude-sonnet-4-6"
23
+ max_cost_per_run: float = Field(default=2.0, gt=0)
24
+ max_retries: int = Field(default=2, ge=0, le=5)
25
+
26
+
27
+ class GenerationConfig(BaseModel):
28
+ style: Literal["match_existing", "minimal"] = "match_existing"
29
+ include_docstrings: bool = True
30
+ max_tests_per_function: int = Field(default=3, ge=1, le=10)
31
+ test_timeout_seconds: int = Field(default=30, ge=1, le=600)
32
+
33
+
34
+ class TestGapConfig(BaseModel):
35
+ __test__ = False # not a pytest test class despite the name
36
+
37
+ version: int = 1
38
+ project: ProjectConfig = Field(default_factory=ProjectConfig)
39
+ coverage: CoverageConfig = Field(default_factory=CoverageConfig)
40
+ llm: LLMConfig = Field(default_factory=LLMConfig)
41
+ generation: GenerationConfig = Field(default_factory=GenerationConfig)
42
+
43
+ @field_validator("version")
44
+ @classmethod
45
+ def _check_version(cls, v: int) -> int:
46
+ if v != 1:
47
+ raise ValueError(f"Unsupported config version: {v}. Only version 1 is supported.")
48
+ return v
@@ -0,0 +1,3 @@
1
+ from testgap.cost.tracker import BudgetExceeded, CostTracker
2
+
3
+ __all__ = ["CostTracker", "BudgetExceeded"]
@@ -0,0 +1,47 @@
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ class BudgetExceeded(Exception):
5
+ pass
6
+
7
+
8
+ @dataclass
9
+ class CostEntry:
10
+ label: str
11
+ cost_usd: float
12
+ input_tokens: int = 0
13
+ output_tokens: int = 0
14
+
15
+
16
+ @dataclass
17
+ class CostTracker:
18
+ max_cost_per_run: float
19
+ entries: list[CostEntry] = field(default_factory=list)
20
+
21
+ @property
22
+ def spent(self) -> float:
23
+ return round(sum(e.cost_usd for e in self.entries), 6)
24
+
25
+ @property
26
+ def remaining(self) -> float:
27
+ return max(0.0, round(self.max_cost_per_run - self.spent, 6))
28
+
29
+ def would_exceed(self, estimated: float) -> bool:
30
+ return (self.spent + estimated) > self.max_cost_per_run
31
+
32
+ def near_limit(self, ratio: float = 0.8) -> bool:
33
+ return self.spent >= self.max_cost_per_run * ratio
34
+
35
+ def record(
36
+ self, *, label: str, cost_usd: float, input_tokens: int = 0, output_tokens: int = 0
37
+ ) -> CostEntry:
38
+ if self.spent + cost_usd > self.max_cost_per_run:
39
+ raise BudgetExceeded(
40
+ f"recording ${cost_usd:.4f} would exceed budget "
41
+ f"${self.max_cost_per_run:.2f} (already spent ${self.spent:.4f})"
42
+ )
43
+ entry = CostEntry(
44
+ label=label, cost_usd=cost_usd, input_tokens=input_tokens, output_tokens=output_tokens
45
+ )
46
+ self.entries.append(entry)
47
+ return entry
@@ -0,0 +1,22 @@
1
+ from testgap.coverage.ast_grouping import UncoveredFunction, group_by_function
2
+ from testgap.coverage.diff_coverage import (
3
+ DiffCoverageReport,
4
+ UncoveredLine,
5
+ compute_diff_coverage,
6
+ )
7
+ from testgap.coverage.git_diff import GitDiffError, changed_lines, resolve_base_ref
8
+ from testgap.coverage.runner import CoverageError, CoverageRunResult, run_pytest_with_coverage
9
+
10
+ __all__ = [
11
+ "UncoveredFunction",
12
+ "group_by_function",
13
+ "DiffCoverageReport",
14
+ "UncoveredLine",
15
+ "compute_diff_coverage",
16
+ "GitDiffError",
17
+ "changed_lines",
18
+ "resolve_base_ref",
19
+ "CoverageError",
20
+ "CoverageRunResult",
21
+ "run_pytest_with_coverage",
22
+ ]