gdmcode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gdmcode-0.1.0.dist-info/METADATA +240 -0
- gdmcode-0.1.0.dist-info/RECORD +131 -0
- gdmcode-0.1.0.dist-info/WHEEL +4 -0
- gdmcode-0.1.0.dist-info/entry_points.txt +2 -0
- src/__init__.py +1 -0
- src/_internal/__init__.py +0 -0
- src/_internal/constants.py +244 -0
- src/_internal/domain_skills.py +339 -0
- src/agent/__init__.py +0 -0
- src/agent/commit_classifier.py +91 -0
- src/agent/context_budget.py +391 -0
- src/agent/daemon.py +681 -0
- src/agent/dag_validator.py +153 -0
- src/agent/debug_loop.py +473 -0
- src/agent/impact_analyzer.py +149 -0
- src/agent/impact_graph.py +117 -0
- src/agent/loop.py +1410 -0
- src/agent/orchestrator.py +141 -0
- src/agent/regression_guard.py +251 -0
- src/agent/review_gate.py +648 -0
- src/agent/risk_scorer.py +169 -0
- src/agent/self_healing.py +145 -0
- src/agent/smart_test_selector.py +89 -0
- src/agent/system_prompt.py +226 -0
- src/agent/task_tracker.py +320 -0
- src/agent/test_validator.py +210 -0
- src/agent/tool_orchestrator.py +402 -0
- src/agent/transcript.py +230 -0
- src/agent/verification_loop.py +133 -0
- src/agent/work_director.py +136 -0
- src/agent/worktree_manager.py +53 -0
- src/artifacts/__init__.py +16 -0
- src/artifacts/artifact_store.py +456 -0
- src/artifacts/verification_graph.py +75 -0
- src/auth.py +411 -0
- src/cli.py +1290 -0
- src/commands.py +1398 -0
- src/config.py +762 -0
- src/cost_tracker.py +348 -0
- src/db/__init__.py +4 -0
- src/db/migrations.py +337 -0
- src/enterprise/__init__.py +3 -0
- src/enterprise/audit_log.py +182 -0
- src/enterprise/identity.py +90 -0
- src/enterprise/rbac.py +100 -0
- src/enterprise/team_config.py +125 -0
- src/enterprise/usage_analytics.py +261 -0
- src/exceptions.py +207 -0
- src/git_workflow.py +651 -0
- src/integrations/__init__.py +6 -0
- src/integrations/github_actions.py +106 -0
- src/integrations/mcp_server.py +333 -0
- src/integrations/sentry_integration.py +100 -0
- src/integrations/sentry_server.py +82 -0
- src/integrations/webhook_security.py +19 -0
- src/main.py +27 -0
- src/memory/__init__.py +0 -0
- src/memory/code_index.py +376 -0
- src/memory/compressor.py +378 -0
- src/memory/context_memory.py +135 -0
- src/memory/continuous_memory.py +234 -0
- src/memory/conventions.py +495 -0
- src/memory/db.py +1119 -0
- src/memory/document_index.py +205 -0
- src/memory/file_cache.py +128 -0
- src/memory/project_scanner.py +178 -0
- src/memory/session_store.py +201 -0
- src/models/__init__.py +0 -0
- src/models/client.py +715 -0
- src/models/definitions.py +459 -0
- src/models/router.py +418 -0
- src/models/schemas.py +389 -0
- src/permissions.py +294 -0
- src/remote/__init__.py +5 -0
- src/remote/command_filter.py +33 -0
- src/remote/models.py +31 -0
- src/remote/permission_handler.py +79 -0
- src/remote/phone_ui.py +48 -0
- src/remote/protocol.py +59 -0
- src/remote/qr.py +65 -0
- src/remote/server.py +586 -0
- src/remote/token_manager.py +61 -0
- src/remote/tunnel.py +212 -0
- src/repl.py +475 -0
- src/runtime/__init__.py +1 -0
- src/runtime/branch_farm.py +372 -0
- src/runtime/replay.py +351 -0
- src/sandbox/__init__.py +2 -0
- src/sandbox/hermetic.py +214 -0
- src/sandbox/policy.py +44 -0
- src/sdk/__init__.py +3 -0
- src/sdk/plugin_base.py +39 -0
- src/sdk/plugin_host.py +100 -0
- src/sdk/plugin_loader.py +101 -0
- src/security.py +409 -0
- src/server/__init__.py +7 -0
- src/server/bridge.py +427 -0
- src/server/bridge_cli.py +103 -0
- src/server/bridge_client.py +170 -0
- src/server/protocol_version.py +103 -0
- src/session/__init__.py +10 -0
- src/session/event_fanout.py +46 -0
- src/session/input_broker.py +38 -0
- src/session/permission_bridge.py +100 -0
- src/tools/__init__.py +160 -0
- src/tools/_atomic.py +72 -0
- src/tools/agent_tools.py +423 -0
- src/tools/ask_user_tool.py +83 -0
- src/tools/bash_tool.py +384 -0
- src/tools/browser_tool.py +352 -0
- src/tools/browser_tools.py +179 -0
- src/tools/dep_tools.py +210 -0
- src/tools/document_reader.py +167 -0
- src/tools/document_tool.py +240 -0
- src/tools/document_writer.py +171 -0
- src/tools/impact_tools.py +240 -0
- src/tools/playwright_tool.py +172 -0
- src/tools/quality_tools.py +366 -0
- src/tools/read_tools.py +318 -0
- src/tools/result_cache.py +157 -0
- src/tools/search_tools.py +310 -0
- src/tools/shell_tools.py +311 -0
- src/tools/write_tools.py +337 -0
- src/voice/__init__.py +25 -0
- src/voice/audio_capture.py +92 -0
- src/voice/audio_playback.py +68 -0
- src/voice/errors.py +14 -0
- src/voice/models.py +35 -0
- src/voice/providers.py +143 -0
- src/voice/vad.py +55 -0
- src/voice/voice_loop.py +156 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""TaskTracker -- manages task decomposition and live progress tracking.
|
|
2
|
+
|
|
3
|
+
Tasks are stored in the ``tasks`` SQLite table (gdm.db) and survive process
|
|
4
|
+
restarts. The ``/tasks`` slash command reads from this tracker.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import sqlite3
|
|
11
|
+
import uuid
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Final
|
|
14
|
+
|
|
15
|
+
from rich.console import Console
|
|
16
|
+
|
|
17
|
+
from src.exceptions import DatabaseError
|
|
18
|
+
from src.memory.db import GdmDatabase
|
|
19
|
+
|
|
20
|
+
__all__ = ["Subtask", "Task", "TaskTracker"]
|
|
21
|
+
|
|
22
|
+
log = logging.getLogger(__name__)
|
|
23
|
+
console = Console()
|
|
24
|
+
|
|
25
|
+
# ---------------------------------------------------------------------------
|
|
26
|
+
# Constants
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
|
|
29
|
+
_STATUS_PENDING: Final[str] = "pending"
|
|
30
|
+
_STATUS_IN_PROGRESS: Final[str] = "in_progress"
|
|
31
|
+
_STATUS_DONE: Final[str] = "done"
|
|
32
|
+
_STATUS_BLOCKED: Final[str] = "blocked"
|
|
33
|
+
_STATUS_FAILED: Final[str] = "failed"
|
|
34
|
+
|
|
35
|
+
_VALID_SUBTASK_STATUSES: Final[frozenset[str]] = frozenset(
|
|
36
|
+
{_STATUS_PENDING, _STATUS_IN_PROGRESS, _STATUS_DONE, _STATUS_BLOCKED, _STATUS_FAILED}
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
_ICON_DONE: Final[str] = "\u2713" # checkmark
|
|
40
|
+
_ICON_IN_PROGRESS: Final[str] = "\u27f3" # clockwise circle arrow
|
|
41
|
+
_ICON_FAILED: Final[str] = "\u2717" # ballot x (also for blocked)
|
|
42
|
+
_ICON_PENDING: Final[str] = "\u25cb" # white circle
|
|
43
|
+
|
|
44
|
+
_RESUME_PREFIX: Final[str] = "Resume previous task:"
|
|
45
|
+
_BLOCKED_TAG: Final[str] = "BLOCKED"
|
|
46
|
+
|
|
47
|
+
_SELECT_COLS: Final[str] = "task_id, session_id, title, status, subtasks"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ---------------------------------------------------------------------------
|
|
51
|
+
# Dataclasses
|
|
52
|
+
# ---------------------------------------------------------------------------
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class Subtask:
|
|
57
|
+
"""A subtask within a Task.
|
|
58
|
+
|
|
59
|
+
Statuses: ``pending`` | ``in_progress`` | ``done`` | ``blocked`` | ``failed``
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
id: str
|
|
63
|
+
title: str
|
|
64
|
+
status: str = _STATUS_PENDING
|
|
65
|
+
depends_on: list[str] = field(default_factory=list)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class Task:
|
|
70
|
+
"""A top-level task managed by :class:`TaskTracker`.
|
|
71
|
+
|
|
72
|
+
Statuses: ``pending`` | ``in_progress`` | ``done`` | ``blocked``
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
task_id: str
|
|
76
|
+
session_id: str
|
|
77
|
+
title: str
|
|
78
|
+
status: str
|
|
79
|
+
subtasks: list[Subtask]
|
|
80
|
+
|
|
81
|
+
def progress_str(self) -> str:
|
|
82
|
+
"""Return ``[done/total]`` progress string, e.g. ``[2/6]``."""
|
|
83
|
+
done = sum(1 for s in self.subtasks if s.status == _STATUS_DONE)
|
|
84
|
+
return f"[{done}/{len(self.subtasks)}]"
|
|
85
|
+
|
|
86
|
+
def to_rich_panel(self) -> str:
|
|
87
|
+
"""Return Rich markup string showing task and subtask statuses.
|
|
88
|
+
|
|
89
|
+
Example::
|
|
90
|
+
|
|
91
|
+
[bold]Add OAuth2 to auth module[/bold] [2/6]
|
|
92
|
+
checkmark Read current auth.ts structure
|
|
93
|
+
arrow Implementing OAuth2 callback handler...
|
|
94
|
+
circle Write unit tests
|
|
95
|
+
"""
|
|
96
|
+
progress = self.progress_str() if self.subtasks else ""
|
|
97
|
+
header = f"[bold]{self.title}[/bold]"
|
|
98
|
+
if progress:
|
|
99
|
+
header = f"{header} {progress}"
|
|
100
|
+
lines = [header]
|
|
101
|
+
for sub in self.subtasks:
|
|
102
|
+
lines.append(f"{_subtask_icon(sub.status)} {sub.title}")
|
|
103
|
+
return "\n".join(lines)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# ---------------------------------------------------------------------------
|
|
107
|
+
# Module-level helpers
|
|
108
|
+
# ---------------------------------------------------------------------------
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _subtask_icon(status: str) -> str:
|
|
112
|
+
"""Map subtask status to its single-character display icon."""
|
|
113
|
+
if status == _STATUS_DONE:
|
|
114
|
+
return _ICON_DONE
|
|
115
|
+
if status == _STATUS_IN_PROGRESS:
|
|
116
|
+
return _ICON_IN_PROGRESS
|
|
117
|
+
if status in (_STATUS_FAILED, _STATUS_BLOCKED):
|
|
118
|
+
return _ICON_FAILED
|
|
119
|
+
return _ICON_PENDING
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _row_to_task(row: sqlite3.Row) -> Task:
|
|
123
|
+
"""Convert a :class:`sqlite3.Row` from the tasks table to a :class:`Task`."""
|
|
124
|
+
raw: list[dict[str, str]] = json.loads(row["subtasks"] or "[]")
|
|
125
|
+
subtasks = [
|
|
126
|
+
Subtask(
|
|
127
|
+
id=s["id"],
|
|
128
|
+
title=s["title"],
|
|
129
|
+
status=s.get("status", _STATUS_PENDING),
|
|
130
|
+
depends_on=s.get("depends_on", []),
|
|
131
|
+
)
|
|
132
|
+
for s in raw
|
|
133
|
+
]
|
|
134
|
+
return Task(
|
|
135
|
+
task_id=row["task_id"],
|
|
136
|
+
session_id=row["session_id"],
|
|
137
|
+
title=row["title"],
|
|
138
|
+
status=row["status"],
|
|
139
|
+
subtasks=subtasks,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _subtasks_to_json(subtasks: list[Subtask]) -> str:
|
|
144
|
+
"""Serialize a list of :class:`Subtask` objects to a JSON string."""
|
|
145
|
+
return json.dumps(
|
|
146
|
+
[{"id": s.id, "title": s.title, "status": s.status} for s in subtasks]
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# ---------------------------------------------------------------------------
|
|
151
|
+
# TaskTracker
|
|
152
|
+
# ---------------------------------------------------------------------------
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class TaskTracker:
|
|
156
|
+
"""Manages task decomposition and live progress for the current session.
|
|
157
|
+
|
|
158
|
+
All state is persisted to gdm.db tasks table -- tasks survive restarts.
|
|
159
|
+
|
|
160
|
+
Usage::
|
|
161
|
+
|
|
162
|
+
tracker = TaskTracker(db, session_id)
|
|
163
|
+
task_id = tracker.create_task("Add OAuth2 to auth module")
|
|
164
|
+
subtask_id = tracker.add_subtask(task_id, "Read current auth.ts structure")
|
|
165
|
+
tracker.update_subtask(task_id, subtask_id, "in_progress")
|
|
166
|
+
tracker.update_subtask(task_id, subtask_id, "done")
|
|
167
|
+
active = tracker.get_active() # list of in_progress + pending tasks
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(self, db: GdmDatabase, session_id: str) -> None:
|
|
171
|
+
self._db = db
|
|
172
|
+
self._session_id = session_id
|
|
173
|
+
|
|
174
|
+
# ------------------------------------------------------------------
|
|
175
|
+
# Task lifecycle
|
|
176
|
+
# ------------------------------------------------------------------
|
|
177
|
+
|
|
178
|
+
def create_task(self, title: str) -> str:
|
|
179
|
+
"""Create a new task. Returns task_id (UUID). Status starts as ``pending``."""
|
|
180
|
+
task_id = str(uuid.uuid4())
|
|
181
|
+
self._db.execute(
|
|
182
|
+
f"INSERT INTO tasks ({_SELECT_COLS}) VALUES (?, ?, ?, ?, ?)",
|
|
183
|
+
(task_id, self._session_id, title, _STATUS_PENDING, "[]"),
|
|
184
|
+
)
|
|
185
|
+
log.debug("Created task %s: %s", task_id, title)
|
|
186
|
+
return task_id
|
|
187
|
+
|
|
188
|
+
def start_task(self, task_id: str) -> None:
|
|
189
|
+
"""Mark task as ``in_progress``."""
|
|
190
|
+
self._set_status(task_id, _STATUS_IN_PROGRESS)
|
|
191
|
+
|
|
192
|
+
def complete_task(self, task_id: str) -> None:
|
|
193
|
+
"""Mark task as ``done``."""
|
|
194
|
+
self._set_status(task_id, _STATUS_DONE)
|
|
195
|
+
|
|
196
|
+
def block_task(self, task_id: str, reason: str) -> None:
|
|
197
|
+
"""Mark task as ``blocked``. Appends *reason* to the task title."""
|
|
198
|
+
task = self._require_task(task_id)
|
|
199
|
+
new_title = f"{task.title} [{_BLOCKED_TAG}: {reason}]"
|
|
200
|
+
self._db.execute(
|
|
201
|
+
"UPDATE tasks SET status = ?, title = ? WHERE task_id = ?",
|
|
202
|
+
(_STATUS_BLOCKED, new_title, task_id),
|
|
203
|
+
)
|
|
204
|
+
log.debug("Blocked task %s: %s", task_id, reason)
|
|
205
|
+
|
|
206
|
+
# ------------------------------------------------------------------
|
|
207
|
+
# Subtask operations
|
|
208
|
+
# ------------------------------------------------------------------
|
|
209
|
+
|
|
210
|
+
def add_subtask(self, task_id: str, title: str) -> str:
|
|
211
|
+
"""Add a subtask to an existing task. Returns the new subtask_id."""
|
|
212
|
+
task = self._require_task(task_id)
|
|
213
|
+
subtask_id = str(uuid.uuid4())
|
|
214
|
+
task.subtasks.append(Subtask(id=subtask_id, title=title))
|
|
215
|
+
self._db.execute(
|
|
216
|
+
"UPDATE tasks SET subtasks = ? WHERE task_id = ?",
|
|
217
|
+
(_subtasks_to_json(task.subtasks), task_id),
|
|
218
|
+
)
|
|
219
|
+
log.debug("Added subtask %s to task %s: %s", subtask_id, task_id, title)
|
|
220
|
+
return subtask_id
|
|
221
|
+
|
|
222
|
+
def update_subtask(self, task_id: str, subtask_id: str, status: str) -> None:
|
|
223
|
+
"""Update a subtask status. Auto-completes parent when all subtasks done."""
|
|
224
|
+
if status not in _VALID_SUBTASK_STATUSES:
|
|
225
|
+
raise DatabaseError(
|
|
226
|
+
f"Invalid subtask status {status!r}. "
|
|
227
|
+
f"Must be one of: {sorted(_VALID_SUBTASK_STATUSES)}"
|
|
228
|
+
)
|
|
229
|
+
task = self._require_task(task_id)
|
|
230
|
+
found = False
|
|
231
|
+
for sub in task.subtasks:
|
|
232
|
+
if sub.id == subtask_id:
|
|
233
|
+
sub.status = status
|
|
234
|
+
found = True
|
|
235
|
+
break
|
|
236
|
+
if not found:
|
|
237
|
+
raise DatabaseError(
|
|
238
|
+
f"Subtask {subtask_id!r} not found in task {task_id!r}"
|
|
239
|
+
)
|
|
240
|
+
self._db.execute(
|
|
241
|
+
"UPDATE tasks SET subtasks = ? WHERE task_id = ?",
|
|
242
|
+
(_subtasks_to_json(task.subtasks), task_id),
|
|
243
|
+
)
|
|
244
|
+
self._auto_complete_parent(task_id, task.subtasks)
|
|
245
|
+
|
|
246
|
+
# ------------------------------------------------------------------
|
|
247
|
+
# Query methods
|
|
248
|
+
# ------------------------------------------------------------------
|
|
249
|
+
|
|
250
|
+
def get_active(self) -> list[Task]:
|
|
251
|
+
"""Return ``in_progress`` and ``pending`` tasks for this session."""
|
|
252
|
+
rows = self._db.execute_all(
|
|
253
|
+
f"SELECT {_SELECT_COLS} FROM tasks "
|
|
254
|
+
"WHERE session_id = ? AND status IN (?, ?) ORDER BY rowid",
|
|
255
|
+
(self._session_id, _STATUS_IN_PROGRESS, _STATUS_PENDING),
|
|
256
|
+
)
|
|
257
|
+
return [_row_to_task(r) for r in rows]
|
|
258
|
+
|
|
259
|
+
def get_all(self) -> list[Task]:
|
|
260
|
+
"""Return all tasks for this session, regardless of status."""
|
|
261
|
+
rows = self._db.execute_all(
|
|
262
|
+
f"SELECT {_SELECT_COLS} FROM tasks WHERE session_id = ? ORDER BY rowid",
|
|
263
|
+
(self._session_id,),
|
|
264
|
+
)
|
|
265
|
+
return [_row_to_task(r) for r in rows]
|
|
266
|
+
|
|
267
|
+
def get_task(self, task_id: str) -> Task | None:
|
|
268
|
+
"""Get a specific task by ID. Returns ``None`` if not found."""
|
|
269
|
+
row = self._db.execute_one(
|
|
270
|
+
f"SELECT {_SELECT_COLS} FROM tasks WHERE task_id = ?",
|
|
271
|
+
(task_id,),
|
|
272
|
+
)
|
|
273
|
+
return _row_to_task(row) if row is not None else None
|
|
274
|
+
|
|
275
|
+
def resume_prompt(self) -> str | None:
|
|
276
|
+
"""Return a resume prompt string if incomplete tasks exist, else ``None``.
|
|
277
|
+
|
|
278
|
+
Returns a string like::
|
|
279
|
+
|
|
280
|
+
'Resume previous task: [2/6] Add OAuth2...'
|
|
281
|
+
"""
|
|
282
|
+
active = self.get_active()
|
|
283
|
+
if not active:
|
|
284
|
+
return None
|
|
285
|
+
in_progress = [t for t in active if t.status == _STATUS_IN_PROGRESS]
|
|
286
|
+
task = in_progress[0] if in_progress else active[0]
|
|
287
|
+
parts: list[str] = [_RESUME_PREFIX]
|
|
288
|
+
if task.subtasks:
|
|
289
|
+
parts.append(task.progress_str())
|
|
290
|
+
parts.append(task.title)
|
|
291
|
+
return " ".join(parts)
|
|
292
|
+
|
|
293
|
+
# ------------------------------------------------------------------
|
|
294
|
+
# Private helpers
|
|
295
|
+
# ------------------------------------------------------------------
|
|
296
|
+
|
|
297
|
+
def _set_status(self, task_id: str, status: str) -> None:
|
|
298
|
+
"""Execute a status UPDATE for a task row."""
|
|
299
|
+
self._db.execute(
|
|
300
|
+
"UPDATE tasks SET status = ? WHERE task_id = ? AND session_id = ?",
|
|
301
|
+
(status, task_id, self._session_id),
|
|
302
|
+
)
|
|
303
|
+
log.debug("Task %s -> %s", task_id, status)
|
|
304
|
+
|
|
305
|
+
def _require_task(self, task_id: str) -> Task:
|
|
306
|
+
"""Return the task or raise :class:`DatabaseError` if not found."""
|
|
307
|
+
task = self.get_task(task_id)
|
|
308
|
+
if task is None:
|
|
309
|
+
raise DatabaseError(
|
|
310
|
+
f"Task {task_id!r} not found in session {self._session_id!r}"
|
|
311
|
+
)
|
|
312
|
+
return task
|
|
313
|
+
|
|
314
|
+
def _auto_complete_parent(self, task_id: str, subtasks: list[Subtask]) -> None:
|
|
315
|
+
"""Mark parent task ``done`` when every subtask is ``done``."""
|
|
316
|
+
if not subtasks:
|
|
317
|
+
return
|
|
318
|
+
if all(s.status == _STATUS_DONE for s in subtasks):
|
|
319
|
+
self._set_status(task_id, _STATUS_DONE)
|
|
320
|
+
log.debug("Auto-completed task %s (all subtasks done)", task_id)
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""Test validator — checks quality of LLM-generated tests before committing.
|
|
2
|
+
|
|
3
|
+
Validates:
|
|
4
|
+
1. Coverage threshold — new tests add net-positive coverage
|
|
5
|
+
2. Assertion density — tests must have ≥1 assertion per test function
|
|
6
|
+
3. No tautological asserts — e.g., assert True, assert x == x
|
|
7
|
+
4. Mock hygiene — mocks must target real paths (no mock of non-existent)
|
|
8
|
+
5. Test isolation — no global state mutation between tests
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import ast
|
|
13
|
+
import logging
|
|
14
|
+
import re
|
|
15
|
+
import subprocess
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
__all__ = ["TestValidator", "TestValidationResult"]
|
|
21
|
+
|
|
22
|
+
log = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
_MIN_ASSERTIONS_PER_TEST: int = 1
|
|
25
|
+
_COVERAGE_THRESHOLD_PCT: float = 0.0 # net-positive (≥0 change)
|
|
26
|
+
|
|
27
|
+
_TAUTOLOGICAL_PATTERNS: list[re.Pattern[str]] = [
|
|
28
|
+
re.compile(r"assert\s+True\s*$", re.MULTILINE),
|
|
29
|
+
re.compile(r"assert\s+1\s*==\s*1", re.MULTILINE),
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ---------------------------------------------------------------------------
|
|
34
|
+
# Result dataclass
|
|
35
|
+
# ---------------------------------------------------------------------------
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class TestValidationResult:
|
|
39
|
+
"""Result of validating a test file."""
|
|
40
|
+
|
|
41
|
+
file: str
|
|
42
|
+
passed: bool = True
|
|
43
|
+
issues: list[str] = field(default_factory=list)
|
|
44
|
+
test_count: int = 0
|
|
45
|
+
assertion_count: int = 0
|
|
46
|
+
tautological_count: int = 0
|
|
47
|
+
|
|
48
|
+
def add_issue(self, msg: str) -> None:
|
|
49
|
+
self.issues.append(msg)
|
|
50
|
+
self.passed = False
|
|
51
|
+
|
|
52
|
+
def summary(self) -> str:
|
|
53
|
+
status = "PASS" if self.passed else "FAIL"
|
|
54
|
+
parts = [f"[{status}] {self.file} — {self.test_count} tests, {self.assertion_count} assertions"]
|
|
55
|
+
if self.issues:
|
|
56
|
+
parts += [f" • {i}" for i in self.issues]
|
|
57
|
+
return "\n".join(parts)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# ---------------------------------------------------------------------------
|
|
61
|
+
# AST visitor
|
|
62
|
+
# ---------------------------------------------------------------------------
|
|
63
|
+
|
|
64
|
+
class _TestVisitor(ast.NodeVisitor):
|
|
65
|
+
"""Collect test function metadata via AST walk."""
|
|
66
|
+
|
|
67
|
+
def __init__(self) -> None:
|
|
68
|
+
self.test_functions: list[ast.FunctionDef] = []
|
|
69
|
+
self.assertions: int = 0
|
|
70
|
+
self.tautological: int = 0
|
|
71
|
+
|
|
72
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # noqa: N802
|
|
73
|
+
if node.name.startswith("test_"):
|
|
74
|
+
self.test_functions.append(node)
|
|
75
|
+
self._count_assertions(node)
|
|
76
|
+
self.generic_visit(node)
|
|
77
|
+
|
|
78
|
+
visit_AsyncFunctionDef = visit_FunctionDef # noqa: N815
|
|
79
|
+
|
|
80
|
+
def _count_assertions(self, node: ast.FunctionDef) -> None:
|
|
81
|
+
for child in ast.walk(node):
|
|
82
|
+
if isinstance(child, ast.Assert):
|
|
83
|
+
self.assertions += 1
|
|
84
|
+
if self._is_tautological(child):
|
|
85
|
+
self.tautological += 1
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def _is_tautological(node: ast.Assert) -> bool:
|
|
89
|
+
test = node.test
|
|
90
|
+
# assert True
|
|
91
|
+
if isinstance(test, ast.Constant) and test.value is True:
|
|
92
|
+
return True
|
|
93
|
+
# assert x == x (Compare where left == right)
|
|
94
|
+
if isinstance(test, ast.Compare) and len(test.comparators) == 1:
|
|
95
|
+
if isinstance(test.ops[0], ast.Eq):
|
|
96
|
+
if ast.dump(test.left) == ast.dump(test.comparators[0]):
|
|
97
|
+
return True
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ---------------------------------------------------------------------------
|
|
102
|
+
# TestValidator
|
|
103
|
+
# ---------------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
class TestValidator:
|
|
106
|
+
"""Validate quality of test files before they are committed.
|
|
107
|
+
|
|
108
|
+
Usage::
|
|
109
|
+
|
|
110
|
+
validator = TestValidator()
|
|
111
|
+
result = validator.validate(Path("tests/test_auth.py"))
|
|
112
|
+
if not result.passed:
|
|
113
|
+
print(result.summary())
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def validate(self, test_path: Path) -> TestValidationResult:
|
|
117
|
+
"""Run all checks on *test_path*. Returns a TestValidationResult."""
|
|
118
|
+
result = TestValidationResult(file=str(test_path))
|
|
119
|
+
|
|
120
|
+
if not test_path.exists():
|
|
121
|
+
result.add_issue(f"File not found: {test_path}")
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
source = test_path.read_text(encoding="utf-8")
|
|
126
|
+
except OSError as exc:
|
|
127
|
+
result.add_issue(f"Cannot read file: {exc}")
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
tree = ast.parse(source, filename=str(test_path))
|
|
132
|
+
except SyntaxError as exc:
|
|
133
|
+
result.add_issue(f"Syntax error: {exc}")
|
|
134
|
+
return result
|
|
135
|
+
|
|
136
|
+
visitor = _TestVisitor()
|
|
137
|
+
visitor.visit(tree)
|
|
138
|
+
|
|
139
|
+
result.test_count = len(visitor.test_functions)
|
|
140
|
+
result.assertion_count = visitor.assertions
|
|
141
|
+
result.tautological_count = visitor.tautological
|
|
142
|
+
|
|
143
|
+
self._check_test_count(result)
|
|
144
|
+
self._check_assertions(result, visitor)
|
|
145
|
+
self._check_tautological(result, source)
|
|
146
|
+
self._check_mock_hygiene(result, tree, test_path)
|
|
147
|
+
|
|
148
|
+
return result
|
|
149
|
+
|
|
150
|
+
def validate_many(self, test_dir: Path) -> list[TestValidationResult]:
|
|
151
|
+
"""Validate all test_*.py files in *test_dir*."""
|
|
152
|
+
results: list[TestValidationResult] = []
|
|
153
|
+
for path in sorted(test_dir.rglob("test_*.py")):
|
|
154
|
+
results.append(self.validate(path))
|
|
155
|
+
return results
|
|
156
|
+
|
|
157
|
+
# ------------------------------------------------------------------
|
|
158
|
+
# Checks
|
|
159
|
+
# ------------------------------------------------------------------
|
|
160
|
+
|
|
161
|
+
def _check_test_count(self, result: TestValidationResult) -> None:
|
|
162
|
+
if result.test_count == 0:
|
|
163
|
+
result.add_issue("No test functions found (must start with test_)")
|
|
164
|
+
|
|
165
|
+
def _check_assertions(
|
|
166
|
+
self, result: TestValidationResult, visitor: _TestVisitor
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Each test function must have ≥1 assertion."""
|
|
169
|
+
for fn in visitor.test_functions:
|
|
170
|
+
fn_assertions = sum(
|
|
171
|
+
1 for node in ast.walk(fn) if isinstance(node, ast.Assert)
|
|
172
|
+
)
|
|
173
|
+
if fn_assertions < _MIN_ASSERTIONS_PER_TEST:
|
|
174
|
+
result.add_issue(
|
|
175
|
+
f"test '{fn.name}' has 0 assertions (line {fn.lineno})"
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def _check_tautological(self, result: TestValidationResult, source: str) -> None:
|
|
179
|
+
for pat in _TAUTOLOGICAL_PATTERNS:
|
|
180
|
+
if pat.search(source):
|
|
181
|
+
result.tautological_count += 1
|
|
182
|
+
result.add_issue(f"Tautological assert detected: {pat.pattern!r}")
|
|
183
|
+
|
|
184
|
+
def _check_mock_hygiene(
|
|
185
|
+
self, result: TestValidationResult, tree: ast.AST, test_path: Path
|
|
186
|
+
) -> None:
|
|
187
|
+
"""Warn if mock.patch targets a path that does not exist in sys.modules context."""
|
|
188
|
+
for node in ast.walk(tree):
|
|
189
|
+
if not isinstance(node, (ast.Call,)):
|
|
190
|
+
continue
|
|
191
|
+
func = node.func
|
|
192
|
+
# look for @patch("some.path") or mocker.patch("some.path")
|
|
193
|
+
if not isinstance(func, ast.Attribute) or func.attr != "patch":
|
|
194
|
+
continue
|
|
195
|
+
for arg in node.args:
|
|
196
|
+
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
|
|
197
|
+
target = arg.value
|
|
198
|
+
if not self._mock_target_plausible(target):
|
|
199
|
+
result.add_issue(
|
|
200
|
+
f"Mock target may not exist: {target!r} (line {node.lineno})"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def _mock_target_plausible(target: str) -> bool:
|
|
205
|
+
"""Return True if mock target looks like a valid dotted path."""
|
|
206
|
+
parts = target.split(".")
|
|
207
|
+
if len(parts) < 2: # noqa: PLR2004
|
|
208
|
+
return False
|
|
209
|
+
# basic sanity: all parts must be valid identifiers
|
|
210
|
+
return all(p.isidentifier() for p in parts)
|