lcp 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lcp/__init__.py +135 -0
- lcp/ai/__init__.py +18 -0
- lcp/ai/agent.py +539 -0
- lcp/ai/connectors/__init__.py +6 -0
- lcp/ai/connectors/anthropic.py +110 -0
- lcp/ai/connectors/openai.py +140 -0
- lcp/ai/hierarchy.py +385 -0
- lcp/ai/models.py +78 -0
- lcp/ai/prompts.py +144 -0
- lcp/ai/provider.py +24 -0
- lcp/ai/writer.py +235 -0
- lcp/cli.py +789 -0
- lcp/coverage.py +271 -0
- lcp/diff.py +173 -0
- lcp/generator.py +189 -0
- lcp/mcp_server.py +1425 -0
- lcp/models.py +283 -0
- lcp/publish.py +501 -0
- lcp/scanner.py +542 -0
- lcp/schema.json +386 -0
- lcp/validator.py +113 -0
- lcp-1.0.0.dist-info/METADATA +327 -0
- lcp-1.0.0.dist-info/RECORD +26 -0
- lcp-1.0.0.dist-info/WHEEL +4 -0
- lcp-1.0.0.dist-info/entry_points.txt +2 -0
- lcp-1.0.0.dist-info/licenses/LICENSE +21 -0
lcp/__init__.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""LCP Python SDK - Generate Library Context Protocol files from Python packages."""
|
|
2
|
+
|
|
3
|
+
from .coverage import (
|
|
4
|
+
CoverageReport,
|
|
5
|
+
CoverageSummary,
|
|
6
|
+
UndocumentedSymbol,
|
|
7
|
+
generate_coverage,
|
|
8
|
+
generate_coverage_from_scanned,
|
|
9
|
+
)
|
|
10
|
+
from .diff import DiffResult, SymbolDiff, diff_documents, load_lcp_document, update_document
|
|
11
|
+
from .generator import generate_lcp
|
|
12
|
+
from .publish import PublishError, PublishResult, publish_manifest
|
|
13
|
+
from .mcp_server import (
|
|
14
|
+
LCPIndex,
|
|
15
|
+
MultiLibraryIndex,
|
|
16
|
+
create_server,
|
|
17
|
+
create_universal_server,
|
|
18
|
+
resolve_library_document,
|
|
19
|
+
run_server,
|
|
20
|
+
run_universal_server,
|
|
21
|
+
)
|
|
22
|
+
from .models import (
|
|
23
|
+
LCPDocument,
|
|
24
|
+
Library,
|
|
25
|
+
Manifest,
|
|
26
|
+
Param,
|
|
27
|
+
Semantics,
|
|
28
|
+
Signature,
|
|
29
|
+
Symbol,
|
|
30
|
+
SymbolKind,
|
|
31
|
+
)
|
|
32
|
+
from .scanner import scan_package
|
|
33
|
+
from .validator import (
|
|
34
|
+
LCPValidationError,
|
|
35
|
+
is_valid,
|
|
36
|
+
validate_document,
|
|
37
|
+
validate_file,
|
|
38
|
+
validate_or_raise,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from importlib.metadata import version as _version, PackageNotFoundError as _PackageNotFoundError
|
|
43
|
+
__version__ = _version("lcp")
|
|
44
|
+
except _PackageNotFoundError:
|
|
45
|
+
__version__ = "0.0.0+unknown"
|
|
46
|
+
__all__ = [
|
|
47
|
+
# Main functions
|
|
48
|
+
"scan",
|
|
49
|
+
"generate_coverage",
|
|
50
|
+
"generate_coverage_from_scanned",
|
|
51
|
+
# Coverage models
|
|
52
|
+
"CoverageReport",
|
|
53
|
+
"CoverageSummary",
|
|
54
|
+
"UndocumentedSymbol",
|
|
55
|
+
# LCP Models
|
|
56
|
+
"LCPDocument",
|
|
57
|
+
"Library",
|
|
58
|
+
"Manifest",
|
|
59
|
+
"Param",
|
|
60
|
+
"Semantics",
|
|
61
|
+
"Signature",
|
|
62
|
+
"Symbol",
|
|
63
|
+
"SymbolKind",
|
|
64
|
+
# Scanner
|
|
65
|
+
"scan_package",
|
|
66
|
+
# Generator
|
|
67
|
+
"generate_lcp",
|
|
68
|
+
# MCP Server
|
|
69
|
+
"LCPIndex",
|
|
70
|
+
"MultiLibraryIndex",
|
|
71
|
+
"create_server",
|
|
72
|
+
"create_universal_server",
|
|
73
|
+
"resolve_library_document",
|
|
74
|
+
"run_server",
|
|
75
|
+
"run_universal_server",
|
|
76
|
+
# Diff
|
|
77
|
+
"DiffResult",
|
|
78
|
+
"SymbolDiff",
|
|
79
|
+
"diff_documents",
|
|
80
|
+
"load_lcp_document",
|
|
81
|
+
"update_document",
|
|
82
|
+
# Publish
|
|
83
|
+
"PublishError",
|
|
84
|
+
"PublishResult",
|
|
85
|
+
"publish_manifest",
|
|
86
|
+
# Validator
|
|
87
|
+
"validate_document",
|
|
88
|
+
"validate_file",
|
|
89
|
+
"validate_or_raise",
|
|
90
|
+
"is_valid",
|
|
91
|
+
"LCPValidationError",
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def scan(
|
|
96
|
+
package_name: str,
|
|
97
|
+
*,
|
|
98
|
+
include_private: bool = False,
|
|
99
|
+
recursive: bool = True,
|
|
100
|
+
validate: bool = True,
|
|
101
|
+
) -> LCPDocument:
|
|
102
|
+
"""Scan a Python package and generate an LCP document.
|
|
103
|
+
|
|
104
|
+
This is the main entry point for the SDK.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
package_name: The name of an installed Python package to scan.
|
|
108
|
+
include_private: Include private symbols (starting with _).
|
|
109
|
+
recursive: Scan submodules recursively.
|
|
110
|
+
validate: Validate the output against the LCP schema.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
An LCPDocument containing the scanned library information.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ImportError: If the package cannot be imported.
|
|
117
|
+
LCPValidationError: If validation is enabled and the output is invalid.
|
|
118
|
+
|
|
119
|
+
Example:
|
|
120
|
+
>>> from lcp import scan
|
|
121
|
+
>>> doc = scan("json")
|
|
122
|
+
>>> doc.to_file("json.lcp.json")
|
|
123
|
+
"""
|
|
124
|
+
scanned = scan_package(
|
|
125
|
+
package_name,
|
|
126
|
+
include_private=include_private,
|
|
127
|
+
recursive=recursive,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
lcp_doc = generate_lcp(scanned)
|
|
131
|
+
|
|
132
|
+
if validate:
|
|
133
|
+
validate_or_raise(lcp_doc)
|
|
134
|
+
|
|
135
|
+
return lcp_doc
|
lcp/ai/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""AI documentation generation module for LCP."""
|
|
2
|
+
|
|
3
|
+
from .agent import DocGenAgent
|
|
4
|
+
from .connectors import AnthropicProvider, OpenAIProvider
|
|
5
|
+
from .models import DocGenConfig, DocGenResult, HierarchicalConfig, SymbolResult, TokenUsage
|
|
6
|
+
from .provider import LLMProvider
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"DocGenAgent",
|
|
10
|
+
"DocGenConfig",
|
|
11
|
+
"DocGenResult",
|
|
12
|
+
"HierarchicalConfig",
|
|
13
|
+
"SymbolResult",
|
|
14
|
+
"TokenUsage",
|
|
15
|
+
"LLMProvider",
|
|
16
|
+
"OpenAIProvider",
|
|
17
|
+
"AnthropicProvider",
|
|
18
|
+
]
|
lcp/ai/agent.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
1
|
+
"""DocGenAgent - orchestrator for AI documentation generation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from .hierarchy import ModuleTree, SymbolNode, build_context, build_hierarchy
|
|
12
|
+
from .models import DocGenConfig, DocGenResult, HierarchicalConfig, SymbolResult, TokenUsage
|
|
13
|
+
from .prompts import build_system_prompt, build_user_prompt, build_user_prompt_hierarchical
|
|
14
|
+
from .provider import LLMProvider
|
|
15
|
+
from .writer import inject_docstrings_batch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DocGenAgent:
|
|
19
|
+
"""Agent that generates docstrings for undocumented Python symbols.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
provider: LLM provider to use for generating docstrings.
|
|
23
|
+
config: Configuration for the generation run.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
provider: LLMProvider,
|
|
29
|
+
config: DocGenConfig | None = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
self._provider = provider
|
|
32
|
+
self._config = config or DocGenConfig()
|
|
33
|
+
|
|
34
|
+
def run(self, coverage_input: str | dict) -> DocGenResult:
|
|
35
|
+
"""Execute the documentation generation.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
coverage_input: Path to coverage JSON file or parsed dict.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
DocGenResult with statistics and per-symbol results.
|
|
42
|
+
"""
|
|
43
|
+
coverage_data = self._load_coverage(coverage_input)
|
|
44
|
+
undocumented = coverage_data.get("undocumented", [])
|
|
45
|
+
|
|
46
|
+
# Filter by kinds if configured
|
|
47
|
+
symbols = self._filter_symbols(undocumented)
|
|
48
|
+
|
|
49
|
+
if not symbols:
|
|
50
|
+
return DocGenResult(
|
|
51
|
+
symbols_processed=0,
|
|
52
|
+
symbols_updated=0,
|
|
53
|
+
symbols_skipped=0,
|
|
54
|
+
symbols_failed=0,
|
|
55
|
+
total_usage=TokenUsage(),
|
|
56
|
+
results=[],
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Group symbols by source_file
|
|
60
|
+
by_file: dict[str, list[tuple[int, dict]]] = defaultdict(list)
|
|
61
|
+
no_file: list[tuple[int, dict]] = []
|
|
62
|
+
|
|
63
|
+
for idx, sym in enumerate(symbols):
|
|
64
|
+
source_file = sym.get("source_file")
|
|
65
|
+
if source_file:
|
|
66
|
+
by_file[source_file].append((idx, sym))
|
|
67
|
+
else:
|
|
68
|
+
no_file.append((idx, sym))
|
|
69
|
+
|
|
70
|
+
# Process symbols and collect results
|
|
71
|
+
all_results: list[SymbolResult] = [None] * len(symbols) # type: ignore[list-item]
|
|
72
|
+
total_usage = TokenUsage()
|
|
73
|
+
|
|
74
|
+
# Process each file group
|
|
75
|
+
for source_file, file_symbols in by_file.items():
|
|
76
|
+
file_injections: list[tuple[int, str, str, str, SymbolResult]] = []
|
|
77
|
+
|
|
78
|
+
for idx, sym in file_symbols:
|
|
79
|
+
result = self._process_symbol(sym, source_file)
|
|
80
|
+
all_results[idx] = result
|
|
81
|
+
if result.usage:
|
|
82
|
+
total_usage = total_usage + result.usage
|
|
83
|
+
|
|
84
|
+
if result.status in ("updated", "dry_run") and result.docstring:
|
|
85
|
+
file_injections.append(
|
|
86
|
+
(idx, sym.get("kind", ""), sym.get("entity", ""), result.docstring, result)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Batch write docstrings for this file
|
|
90
|
+
if file_injections and not self._config.dry_run:
|
|
91
|
+
injections = [
|
|
92
|
+
(kind, entity, docstring)
|
|
93
|
+
for _, kind, entity, docstring, _ in file_injections
|
|
94
|
+
]
|
|
95
|
+
write_results = inject_docstrings_batch(source_file, injections)
|
|
96
|
+
|
|
97
|
+
for i, success in enumerate(write_results):
|
|
98
|
+
_, _, _, _, sym_result = file_injections[i]
|
|
99
|
+
if not success:
|
|
100
|
+
sym_result.status = "skipped"
|
|
101
|
+
|
|
102
|
+
# Process symbols without source files
|
|
103
|
+
for idx, sym in no_file:
|
|
104
|
+
result = SymbolResult(
|
|
105
|
+
symbol_id=f"{sym.get('module', '')}:{sym.get('entity', '')}",
|
|
106
|
+
kind=sym.get("kind", ""),
|
|
107
|
+
source_file=None,
|
|
108
|
+
status="skipped",
|
|
109
|
+
error="No source file available",
|
|
110
|
+
)
|
|
111
|
+
all_results[idx] = result
|
|
112
|
+
|
|
113
|
+
# Aggregate stats
|
|
114
|
+
updated = sum(1 for r in all_results if r.status == "updated")
|
|
115
|
+
skipped = sum(1 for r in all_results if r.status == "skipped")
|
|
116
|
+
failed = sum(1 for r in all_results if r.status == "failed")
|
|
117
|
+
dry_run_count = sum(1 for r in all_results if r.status == "dry_run")
|
|
118
|
+
|
|
119
|
+
return DocGenResult(
|
|
120
|
+
symbols_processed=len(symbols),
|
|
121
|
+
symbols_updated=updated + dry_run_count,
|
|
122
|
+
symbols_skipped=skipped,
|
|
123
|
+
symbols_failed=failed,
|
|
124
|
+
total_usage=total_usage,
|
|
125
|
+
results=all_results,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def _load_coverage(self, coverage_input: str | dict) -> dict:
|
|
129
|
+
"""Load coverage data from a file path or dict."""
|
|
130
|
+
if isinstance(coverage_input, dict):
|
|
131
|
+
return coverage_input
|
|
132
|
+
|
|
133
|
+
path = Path(coverage_input)
|
|
134
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
135
|
+
return json.load(f)
|
|
136
|
+
|
|
137
|
+
def _filter_symbols(self, undocumented: list[dict]) -> list[dict]:
|
|
138
|
+
"""Filter symbols based on config.kinds."""
|
|
139
|
+
if not self._config.kinds:
|
|
140
|
+
return undocumented
|
|
141
|
+
|
|
142
|
+
return [s for s in undocumented if s.get("kind") in self._config.kinds]
|
|
143
|
+
|
|
144
|
+
def _read_source_context(
|
|
145
|
+
self, source_file: str, kind: str, entity: str
|
|
146
|
+
) -> str:
|
|
147
|
+
"""Read the source code context for a symbol.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
source_file: Path to the source file.
|
|
151
|
+
kind: Symbol kind.
|
|
152
|
+
entity: Entity name.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Source code string for the symbol.
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
source = Path(source_file).read_text(encoding="utf-8")
|
|
159
|
+
except (OSError, UnicodeDecodeError):
|
|
160
|
+
return ""
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
tree = ast.parse(source)
|
|
164
|
+
except SyntaxError:
|
|
165
|
+
return ""
|
|
166
|
+
|
|
167
|
+
source_lines = source.splitlines()
|
|
168
|
+
|
|
169
|
+
if kind == "module":
|
|
170
|
+
# Return first ~30 lines for module context
|
|
171
|
+
return "\n".join(source_lines[:30])
|
|
172
|
+
|
|
173
|
+
# Find the node
|
|
174
|
+
node = self._find_source_node(tree, kind, entity)
|
|
175
|
+
if node is None:
|
|
176
|
+
return ""
|
|
177
|
+
|
|
178
|
+
start = node.lineno - 1
|
|
179
|
+
end = getattr(node, "end_lineno", start + 1)
|
|
180
|
+
# Limit to 50 lines of context
|
|
181
|
+
end = min(end, start + 50)
|
|
182
|
+
|
|
183
|
+
return "\n".join(source_lines[start:end])
|
|
184
|
+
|
|
185
|
+
def _find_source_node(
|
|
186
|
+
self, tree: ast.Module, kind: str, entity: str
|
|
187
|
+
) -> ast.AST | None:
|
|
188
|
+
"""Find an AST node by kind and entity name."""
|
|
189
|
+
if "#" in entity:
|
|
190
|
+
class_name, method_name = entity.split("#", 1)
|
|
191
|
+
for node in ast.walk(tree):
|
|
192
|
+
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
|
193
|
+
for item in node.body:
|
|
194
|
+
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
195
|
+
if item.name == method_name:
|
|
196
|
+
return item
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
for node in ast.iter_child_nodes(tree):
|
|
200
|
+
if kind == "class" and isinstance(node, ast.ClassDef) and node.name == entity:
|
|
201
|
+
return node
|
|
202
|
+
if kind == "function" and isinstance(
|
|
203
|
+
node, (ast.FunctionDef, ast.AsyncFunctionDef)
|
|
204
|
+
):
|
|
205
|
+
if node.name == entity:
|
|
206
|
+
return node
|
|
207
|
+
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
def _generate_docstring(
|
|
211
|
+
self, symbol: dict, source_context: str
|
|
212
|
+
) -> tuple[str, TokenUsage]:
|
|
213
|
+
"""Call the LLM to generate a docstring.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
symbol: Symbol dict from coverage data.
|
|
217
|
+
source_context: Source code context.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Tuple of (docstring_text, token_usage).
|
|
221
|
+
"""
|
|
222
|
+
system = build_system_prompt(
|
|
223
|
+
docstring_style=self._config.docstring_style,
|
|
224
|
+
description=self._config.description,
|
|
225
|
+
)
|
|
226
|
+
prompt = build_user_prompt(
|
|
227
|
+
kind=symbol.get("kind", ""),
|
|
228
|
+
module=symbol.get("module", ""),
|
|
229
|
+
entity=symbol.get("entity", ""),
|
|
230
|
+
source_context=source_context,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
response = self._provider.generate(system, prompt)
|
|
234
|
+
return response.content.strip(), response.usage
|
|
235
|
+
|
|
236
|
+
def _process_symbol(self, symbol: dict, source_file: str) -> SymbolResult:
|
|
237
|
+
"""Process a single symbol: read context, generate docstring.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
symbol: Symbol dict from coverage data.
|
|
241
|
+
source_file: Path to source file.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
SymbolResult for this symbol.
|
|
245
|
+
"""
|
|
246
|
+
symbol_id = f"{symbol.get('module', '')}:{symbol.get('entity', '')}"
|
|
247
|
+
kind = symbol.get("kind", "")
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
source_context = self._read_source_context(
|
|
251
|
+
source_file, kind, symbol.get("entity", "")
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if not source_context:
|
|
255
|
+
return SymbolResult(
|
|
256
|
+
symbol_id=symbol_id,
|
|
257
|
+
kind=kind,
|
|
258
|
+
source_file=source_file,
|
|
259
|
+
status="skipped",
|
|
260
|
+
error="Could not read source context",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
docstring, usage = self._generate_docstring(symbol, source_context)
|
|
264
|
+
|
|
265
|
+
if not docstring:
|
|
266
|
+
return SymbolResult(
|
|
267
|
+
symbol_id=symbol_id,
|
|
268
|
+
kind=kind,
|
|
269
|
+
source_file=source_file,
|
|
270
|
+
status="skipped",
|
|
271
|
+
usage=usage,
|
|
272
|
+
error="LLM returned empty docstring",
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
status = "dry_run" if self._config.dry_run else "updated"
|
|
276
|
+
|
|
277
|
+
return SymbolResult(
|
|
278
|
+
symbol_id=symbol_id,
|
|
279
|
+
kind=kind,
|
|
280
|
+
source_file=source_file,
|
|
281
|
+
status=status,
|
|
282
|
+
docstring=docstring,
|
|
283
|
+
usage=usage,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
except Exception as e:
|
|
287
|
+
return SymbolResult(
|
|
288
|
+
symbol_id=symbol_id,
|
|
289
|
+
kind=kind,
|
|
290
|
+
source_file=source_file,
|
|
291
|
+
status="failed",
|
|
292
|
+
error=str(e),
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# ------------------------------------------------------------------
|
|
296
|
+
# Hierarchical async engine
|
|
297
|
+
# ------------------------------------------------------------------
|
|
298
|
+
|
|
299
|
+
def run_sync(self, coverage_input: str | dict) -> DocGenResult:
|
|
300
|
+
"""Execute documentation generation synchronously.
|
|
301
|
+
|
|
302
|
+
Uses hierarchical async engine if config is HierarchicalConfig,
|
|
303
|
+
otherwise falls back to basic sequential processing.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
coverage_input: Path to coverage JSON file or parsed dict.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
DocGenResult with statistics and per-symbol results.
|
|
310
|
+
"""
|
|
311
|
+
if not isinstance(self._config, HierarchicalConfig):
|
|
312
|
+
return self.run(coverage_input)
|
|
313
|
+
return asyncio.run(self.run_async(coverage_input))
|
|
314
|
+
|
|
315
|
+
async def run_async(self, coverage_input: str | dict) -> DocGenResult:
|
|
316
|
+
"""Execute hierarchical bottom-up documentation generation.
|
|
317
|
+
|
|
318
|
+
Processes symbols level-by-level (leaves first, then classes,
|
|
319
|
+
then modules), using asyncio.gather with a semaphore for
|
|
320
|
+
concurrency control.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
coverage_input: Path to coverage JSON file or parsed dict.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
DocGenResult with statistics and per-symbol results.
|
|
327
|
+
"""
|
|
328
|
+
coverage_data = self._load_coverage(coverage_input)
|
|
329
|
+
undocumented = coverage_data.get("undocumented", [])
|
|
330
|
+
|
|
331
|
+
# Filter by kinds if configured
|
|
332
|
+
symbols = self._filter_symbols(undocumented)
|
|
333
|
+
|
|
334
|
+
if not symbols:
|
|
335
|
+
return DocGenResult(
|
|
336
|
+
symbols_processed=0,
|
|
337
|
+
symbols_updated=0,
|
|
338
|
+
symbols_skipped=0,
|
|
339
|
+
symbols_failed=0,
|
|
340
|
+
total_usage=TokenUsage(),
|
|
341
|
+
results=[],
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Build hierarchy
|
|
345
|
+
trees = build_hierarchy(symbols)
|
|
346
|
+
|
|
347
|
+
config = self._config
|
|
348
|
+
assert isinstance(config, HierarchicalConfig)
|
|
349
|
+
sem = asyncio.Semaphore(config.max_workers)
|
|
350
|
+
|
|
351
|
+
all_results: list[SymbolResult] = []
|
|
352
|
+
total_usage = TokenUsage()
|
|
353
|
+
|
|
354
|
+
# Process level by level: 0 (leaves) -> 1 (classes) -> 2 (modules)
|
|
355
|
+
for level in (0, 1, 2):
|
|
356
|
+
# Collect all pending nodes at this level from all trees
|
|
357
|
+
level_nodes: list[tuple[SymbolNode, ModuleTree]] = []
|
|
358
|
+
for tree in trees.values():
|
|
359
|
+
for node in tree.levels.get(level, []):
|
|
360
|
+
if node.status == "pending":
|
|
361
|
+
level_nodes.append((node, tree))
|
|
362
|
+
|
|
363
|
+
if not level_nodes:
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
# Process all nodes at this level concurrently
|
|
367
|
+
coros = [
|
|
368
|
+
self._process_node(node, tree, sem)
|
|
369
|
+
for node, tree in level_nodes
|
|
370
|
+
]
|
|
371
|
+
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
372
|
+
|
|
373
|
+
# Collect results and update node status/docstring
|
|
374
|
+
for i, result in enumerate(results):
|
|
375
|
+
node, _tree = level_nodes[i]
|
|
376
|
+
if isinstance(result, Exception):
|
|
377
|
+
sym_result = SymbolResult(
|
|
378
|
+
symbol_id=node.symbol_id,
|
|
379
|
+
kind=node.kind,
|
|
380
|
+
source_file=node.symbol.get("source_file"),
|
|
381
|
+
status="failed",
|
|
382
|
+
error=str(result),
|
|
383
|
+
)
|
|
384
|
+
node.status = "failed"
|
|
385
|
+
else:
|
|
386
|
+
sym_result = result
|
|
387
|
+
node.status = sym_result.status
|
|
388
|
+
if sym_result.docstring:
|
|
389
|
+
node.docstring = sym_result.docstring
|
|
390
|
+
|
|
391
|
+
all_results.append(sym_result)
|
|
392
|
+
if sym_result.usage:
|
|
393
|
+
total_usage = total_usage + sym_result.usage
|
|
394
|
+
|
|
395
|
+
# Propagate failures to next level
|
|
396
|
+
self._propagate_failures(trees, level, config.failure_threshold)
|
|
397
|
+
|
|
398
|
+
# Collect skipped nodes from propagation (nodes that were pending
|
|
399
|
+
# but got marked as skipped by _propagate_failures)
|
|
400
|
+
for tree in trees.values():
|
|
401
|
+
for level_nodes_list in tree.levels.values():
|
|
402
|
+
for node in level_nodes_list:
|
|
403
|
+
if node.status == "skipped":
|
|
404
|
+
# Check if already in results
|
|
405
|
+
existing_ids = {r.symbol_id for r in all_results}
|
|
406
|
+
if node.symbol_id not in existing_ids:
|
|
407
|
+
all_results.append(SymbolResult(
|
|
408
|
+
symbol_id=node.symbol_id,
|
|
409
|
+
kind=node.kind,
|
|
410
|
+
source_file=node.symbol.get("source_file"),
|
|
411
|
+
status="skipped",
|
|
412
|
+
error="Skipped due to child failure propagation",
|
|
413
|
+
))
|
|
414
|
+
|
|
415
|
+
# Batch write results if not dry_run
|
|
416
|
+
if not config.dry_run:
|
|
417
|
+
self._write_results(all_results)
|
|
418
|
+
|
|
419
|
+
# Aggregate stats
|
|
420
|
+
updated = sum(1 for r in all_results if r.status == "updated")
|
|
421
|
+
skipped = sum(1 for r in all_results if r.status == "skipped")
|
|
422
|
+
failed = sum(1 for r in all_results if r.status == "failed")
|
|
423
|
+
dry_run_count = sum(1 for r in all_results if r.status == "dry_run")
|
|
424
|
+
|
|
425
|
+
return DocGenResult(
|
|
426
|
+
symbols_processed=len(all_results),
|
|
427
|
+
symbols_updated=updated + dry_run_count,
|
|
428
|
+
symbols_skipped=skipped,
|
|
429
|
+
symbols_failed=failed,
|
|
430
|
+
total_usage=total_usage,
|
|
431
|
+
results=all_results,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
async def _process_node(
|
|
435
|
+
self,
|
|
436
|
+
node: SymbolNode,
|
|
437
|
+
tree: ModuleTree,
|
|
438
|
+
semaphore: asyncio.Semaphore,
|
|
439
|
+
) -> SymbolResult:
|
|
440
|
+
"""Process a single SymbolNode asynchronously.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
node: The symbol node to process.
|
|
444
|
+
tree: The module tree containing this node.
|
|
445
|
+
semaphore: Semaphore for concurrency control.
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
SymbolResult for this node.
|
|
449
|
+
"""
|
|
450
|
+
async with semaphore:
|
|
451
|
+
context = build_context(node, tree)
|
|
452
|
+
if not context:
|
|
453
|
+
return SymbolResult(
|
|
454
|
+
symbol_id=node.symbol_id,
|
|
455
|
+
kind=node.kind,
|
|
456
|
+
source_file=node.symbol.get("source_file"),
|
|
457
|
+
status="skipped",
|
|
458
|
+
error="Could not build context",
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
system = build_system_prompt(
|
|
462
|
+
docstring_style=self._config.docstring_style,
|
|
463
|
+
description=self._config.description,
|
|
464
|
+
)
|
|
465
|
+
prompt = build_user_prompt_hierarchical(node, context)
|
|
466
|
+
response = await self._provider.agenerate(system, prompt)
|
|
467
|
+
|
|
468
|
+
docstring = response.content.strip()
|
|
469
|
+
if not docstring:
|
|
470
|
+
return SymbolResult(
|
|
471
|
+
symbol_id=node.symbol_id,
|
|
472
|
+
kind=node.kind,
|
|
473
|
+
source_file=node.symbol.get("source_file"),
|
|
474
|
+
status="skipped",
|
|
475
|
+
usage=response.usage,
|
|
476
|
+
error="LLM returned empty docstring",
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
status = "dry_run" if self._config.dry_run else "updated"
|
|
480
|
+
return SymbolResult(
|
|
481
|
+
symbol_id=node.symbol_id,
|
|
482
|
+
kind=node.kind,
|
|
483
|
+
source_file=node.symbol.get("source_file"),
|
|
484
|
+
status=status,
|
|
485
|
+
docstring=docstring,
|
|
486
|
+
usage=response.usage,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
def _propagate_failures(
|
|
490
|
+
self,
|
|
491
|
+
trees: dict[str, ModuleTree],
|
|
492
|
+
completed_level: int,
|
|
493
|
+
threshold: float,
|
|
494
|
+
) -> None:
|
|
495
|
+
"""Mark parent nodes as skipped if too many children failed.
|
|
496
|
+
|
|
497
|
+
After processing a level, check the next level's nodes. If a
|
|
498
|
+
parent's children have a failure ratio >= threshold, skip it.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
trees: All module trees.
|
|
502
|
+
completed_level: The level that was just processed.
|
|
503
|
+
threshold: Failure ratio threshold (0.0 to 1.0).
|
|
504
|
+
"""
|
|
505
|
+
next_level = completed_level + 1
|
|
506
|
+
for tree in trees.values():
|
|
507
|
+
for parent in tree.levels.get(next_level, []):
|
|
508
|
+
if parent.status != "pending":
|
|
509
|
+
continue
|
|
510
|
+
children = parent.children
|
|
511
|
+
if not children:
|
|
512
|
+
continue
|
|
513
|
+
failed = sum(1 for c in children if c.status == "failed")
|
|
514
|
+
if failed / len(children) >= threshold:
|
|
515
|
+
parent.status = "skipped"
|
|
516
|
+
|
|
517
|
+
def _write_results(self, results: list[SymbolResult]) -> None:
|
|
518
|
+
"""Batch write all generated docstrings grouped by file.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
results: List of SymbolResult from the async run.
|
|
522
|
+
"""
|
|
523
|
+
by_file: dict[str, list[SymbolResult]] = defaultdict(list)
|
|
524
|
+
for r in results:
|
|
525
|
+
if r.status == "updated" and r.docstring and r.source_file:
|
|
526
|
+
by_file[r.source_file].append(r)
|
|
527
|
+
|
|
528
|
+
for source_file, file_results in by_file.items():
|
|
529
|
+
injections: list[tuple[str, str, str]] = []
|
|
530
|
+
for r in file_results:
|
|
531
|
+
# Extract entity from symbol_id (format: "module:entity")
|
|
532
|
+
entity = r.symbol_id.split(":", 1)[-1]
|
|
533
|
+
injections.append((r.kind, entity, r.docstring))
|
|
534
|
+
|
|
535
|
+
write_results = inject_docstrings_batch(source_file, injections)
|
|
536
|
+
|
|
537
|
+
for i, success in enumerate(write_results):
|
|
538
|
+
if not success:
|
|
539
|
+
file_results[i].status = "skipped"
|