lean-lsp-mcp 0.14.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,43 +3,52 @@ from typing import Dict, List, Optional, Tuple
3
3
  from leanclient import LeanLSPClient
4
4
  from leanclient.utils import DocumentContentChange
5
5
 
6
+ from lean_lsp_mcp.models import FileOutline, OutlineEntry
7
+
6
8
 
7
9
  METHOD_KIND = {6, "method"}
8
10
  KIND_TAGS = {"namespace": "Ns"}
9
11
 
10
12
 
11
- def _get_info_trees(client: LeanLSPClient, path: str, symbols: List[Dict]) -> Dict[str, str]:
13
+ def _get_info_trees(
14
+ client: LeanLSPClient, path: str, symbols: List[Dict]
15
+ ) -> Dict[str, str]:
12
16
  """Insert #info_trees commands, collect diagnostics, then revert changes."""
13
17
  if not symbols:
14
18
  return {}
15
19
 
16
20
  symbol_by_line = {}
17
21
  changes = []
18
- for i, sym in enumerate(sorted(symbols, key=lambda s: s['range']['start']['line'])):
19
- line = sym['range']['start']['line'] + i
20
- symbol_by_line[line] = sym['name']
22
+ for i, sym in enumerate(sorted(symbols, key=lambda s: s["range"]["start"]["line"])):
23
+ line = sym["range"]["start"]["line"] + i
24
+ symbol_by_line[line] = sym["name"]
21
25
  changes.append(DocumentContentChange("#info_trees in\n", [line, 0], [line, 0]))
22
26
 
23
27
  client.update_file(path, changes)
24
28
  diagnostics = client.get_diagnostics(path)
25
29
 
26
30
  info_trees = {
27
- symbol_by_line[diag['range']['start']['line']]: diag['message']
31
+ symbol_by_line[diag["range"]["start"]["line"]]: diag["message"]
28
32
  for diag in diagnostics
29
- if diag['severity'] == 3 and diag['range']['start']['line'] in symbol_by_line
33
+ if diag["severity"] == 3 and diag["range"]["start"]["line"] in symbol_by_line
30
34
  }
31
35
 
32
36
  # Revert in reverse order
33
- client.update_file(path, [
34
- DocumentContentChange("", [line, 0], [line + 1, 0])
35
- for line in sorted(symbol_by_line.keys(), reverse=True)
36
- ])
37
+ client.update_file(
38
+ path,
39
+ [
40
+ DocumentContentChange("", [line, 0], [line + 1, 0])
41
+ for line in sorted(symbol_by_line.keys(), reverse=True)
42
+ ],
43
+ )
37
44
  return info_trees
38
45
 
39
46
 
40
47
  def _extract_type(info: str, name: str) -> Optional[str]:
41
48
  """Extract type signature from info tree message."""
42
- if m := re.search(rf' • \[Term\] {re.escape(name)} \(isBinder := true\) : ([^@]+) @', info):
49
+ if m := re.search(
50
+ rf" • \[Term\] {re.escape(name)} \(isBinder := true\) : ([^@]+) @", info
51
+ ):
43
52
  return m.group(1).strip()
44
53
  return None
45
54
 
@@ -47,14 +56,16 @@ def _extract_type(info: str, name: str) -> Optional[str]:
47
56
  def _extract_fields(info: str, name: str) -> List[Tuple[str, str]]:
48
57
  """Extract structure/class fields from info tree message."""
49
58
  fields = []
50
- for pattern in [rf'{re.escape(name)}\.(\w+)', rf'@{re.escape(name)}\.(\w+)']:
51
- for m in re.finditer(rf' • \[Term\] {pattern} \(isBinder := true\) : (.+?) @', info):
59
+ for pattern in [rf"{re.escape(name)}\.(\w+)", rf"@{re.escape(name)}\.(\w+)"]:
60
+ for m in re.finditer(
61
+ rf" • \[Term\] {pattern} \(isBinder := true\) : (.+?) @", info
62
+ ):
52
63
  field_name, full_type = m.groups()
53
64
  # Clean up the type signature
54
- if ']' in full_type:
55
- field_type = full_type[full_type.rfind(']')+1:].lstrip('').strip()
56
- elif '' in full_type:
57
- field_type = full_type.split('')[-1].strip()
65
+ if "]" in full_type:
66
+ field_type = full_type[full_type.rfind("]") + 1 :].lstrip("").strip()
67
+ elif "" in full_type:
68
+ field_type = full_type.split("")[-1].strip()
58
69
  else:
59
70
  field_type = full_type.strip()
60
71
  fields.append((field_name, field_type))
@@ -68,51 +79,61 @@ def _extract_declarations(content: str, start: int, end: int) -> List[Dict]:
68
79
 
69
80
  while i < min(end, len(lines)):
70
81
  line = lines[i].strip()
71
- for keyword in ['theorem', 'lemma', 'def']:
82
+ for keyword in ["theorem", "lemma", "def"]:
72
83
  if line.startswith(f"{keyword} "):
73
- name = line[len(keyword):].strip().split()[0]
74
- if name and not name.startswith('_'):
84
+ name = line[len(keyword) :].strip().split()[0]
85
+ if name and not name.startswith("_"):
75
86
  # Collect until :=
76
87
  decl_lines = [line]
77
88
  j = i + 1
78
- while j < min(end, len(lines)) and ':=' not in ' '.join(decl_lines):
79
- if (next_line := lines[j].strip()) and not next_line.startswith('--'):
89
+ while j < min(end, len(lines)) and ":=" not in " ".join(decl_lines):
90
+ if (next_line := lines[j].strip()) and not next_line.startswith(
91
+ "--"
92
+ ):
80
93
  decl_lines.append(next_line)
81
94
  j += 1
82
95
 
83
96
  # Extract signature (everything before :=, minus keyword and name)
84
- full_decl = ' '.join(decl_lines)
97
+ full_decl = " ".join(decl_lines)
85
98
  type_sig = None
86
- if ':=' in full_decl:
87
- sig_part = full_decl.split(':=', 1)[0].strip()[len(keyword):].strip()
99
+ if ":=" in full_decl:
100
+ sig_part = (
101
+ full_decl.split(":=", 1)[0].strip()[len(keyword) :].strip()
102
+ )
88
103
  if sig_part.startswith(name):
89
- type_sig = sig_part[len(name):].strip()
90
-
91
- decls.append({
92
- 'name': name,
93
- 'kind': 'method',
94
- 'range': {'start': {'line': i, 'character': 0},
95
- 'end': {'line': i, 'character': len(lines[i])}},
96
- '_keyword': keyword,
97
- '_type': type_sig
98
- })
104
+ type_sig = sig_part[len(name) :].strip()
105
+
106
+ decls.append(
107
+ {
108
+ "name": name,
109
+ "kind": "method",
110
+ "range": {
111
+ "start": {"line": i, "character": 0},
112
+ "end": {"line": i, "character": len(lines[i])},
113
+ },
114
+ "_keyword": keyword,
115
+ "_type": type_sig,
116
+ }
117
+ )
99
118
  break
100
119
  i += 1
101
120
  return decls
102
121
 
103
122
 
104
- def _flatten_symbols(symbols: List[Dict], indent: int = 0, content: str = "") -> List[Tuple[Dict, int]]:
123
+ def _flatten_symbols(
124
+ symbols: List[Dict], indent: int = 0, content: str = ""
125
+ ) -> List[Tuple[Dict, int]]:
105
126
  """Recursively flatten symbol hierarchy, extracting declarations from namespaces."""
106
127
  result = []
107
128
  for sym in symbols:
108
129
  result.append((sym, indent))
109
- children = sym.get('children', [])
130
+ children = sym.get("children", [])
110
131
 
111
132
  # Extract theorem/lemma/def from namespace bodies
112
- if content and sym.get('kind') == 'namespace':
113
- ns_range = sym['range']
114
- ns_start = ns_range['start']['line']
115
- ns_end = ns_range['end']['line']
133
+ if content and sym.get("kind") == "namespace":
134
+ ns_range = sym["range"]
135
+ ns_start = ns_range["start"]["line"]
136
+ ns_end = ns_range["end"]["line"]
116
137
  children = children + _extract_declarations(content, ns_start, ns_end)
117
138
 
118
139
  if children:
@@ -120,32 +141,36 @@ def _flatten_symbols(symbols: List[Dict], indent: int = 0, content: str = "") ->
120
141
  return result
121
142
 
122
143
 
123
- def _detect_tag(name: str, kind: str, type_sig: str, has_fields: bool, keyword: Optional[str]) -> str:
144
+ def _detect_tag(
145
+ name: str, kind: str, type_sig: str, has_fields: bool, keyword: Optional[str]
146
+ ) -> str:
124
147
  """Determine the appropriate tag for a symbol."""
125
148
  if has_fields:
126
- return "Class" if '' in type_sig else "Struct"
149
+ return "Class" if "" in type_sig else "Struct"
127
150
  if name == "example":
128
151
  return "Ex"
129
- if keyword in {'theorem', 'lemma'}:
152
+ if keyword in {"theorem", "lemma"}:
130
153
  return "Thm"
131
- if type_sig and any(marker in type_sig for marker in ['', '=']):
154
+ if type_sig and any(marker in type_sig for marker in ["", "="]):
132
155
  return "Thm"
133
- if type_sig and '' in type_sig.replace('', '', 1): # More than one arrow
156
+ if type_sig and "" in type_sig.replace("", "", 1): # More than one arrow
134
157
  return "Thm"
135
158
  return KIND_TAGS.get(kind, "Def")
136
159
 
137
160
 
138
161
  def _format_symbol(sym: Dict, type_sigs: Dict, fields_map: Dict, indent: int) -> str:
139
162
  """Format a single symbol with its type signature and fields."""
140
- name = sym['name']
141
- type_sig = sym.get('_type') or type_sigs.get(name, "")
163
+ name = sym["name"]
164
+ type_sig = sym.get("_type") or type_sigs.get(name, "")
142
165
  fields = fields_map.get(name, [])
143
166
 
144
- tag = _detect_tag(name, sym.get('kind', ''), type_sig, bool(fields), sym.get('_keyword'))
167
+ tag = _detect_tag(
168
+ name, sym.get("kind", ""), type_sig, bool(fields), sym.get("_keyword")
169
+ )
145
170
  prefix = "\t" * indent
146
171
 
147
- start = sym['range']['start']['line'] + 1
148
- end = sym['range']['end']['line'] + 1
172
+ start = sym["range"]["start"]["line"] + 1
173
+ end = sym["range"]["end"]["line"] + 1
149
174
  line_info = f"L{start}" if start == end else f"L{start}-{end}"
150
175
 
151
176
  result = f"{prefix}[{tag}: {line_info}] {name}"
@@ -158,14 +183,108 @@ def _format_symbol(sym: Dict, type_sigs: Dict, fields_map: Dict, indent: int) ->
158
183
  return result + "\n"
159
184
 
160
185
 
186
+ def _build_outline_entry(
187
+ sym: Dict, type_sigs: Dict, fields_map: Dict, indent: int
188
+ ) -> Optional[OutlineEntry]:
189
+ """Build a structured outline entry for a symbol."""
190
+ name = sym["name"]
191
+ type_sig = sym.get("_type") or type_sigs.get(name, "")
192
+ fields = fields_map.get(name, [])
193
+
194
+ tag = _detect_tag(
195
+ name, sym.get("kind", ""), type_sig, bool(fields), sym.get("_keyword")
196
+ )
197
+ start = sym["range"]["start"]["line"] + 1
198
+ end = sym["range"]["end"]["line"] + 1
199
+
200
+ # Add fields as children for structs/classes
201
+ children = [
202
+ OutlineEntry(
203
+ name=fname,
204
+ kind="field",
205
+ start_line=start,
206
+ end_line=start,
207
+ type_signature=ftype,
208
+ children=[],
209
+ )
210
+ for fname, ftype in fields
211
+ ]
212
+
213
+ return OutlineEntry(
214
+ name=name,
215
+ kind=tag,
216
+ start_line=start,
217
+ end_line=end,
218
+ type_signature=type_sig if type_sig else None,
219
+ children=children,
220
+ )
221
+
222
+
223
+ def generate_outline_data(client: LeanLSPClient, path: str) -> FileOutline:
224
+ """Generate structured outline data for a Lean file."""
225
+ client.open_file(path)
226
+ content = client.get_file_content(path)
227
+
228
+ # Extract imports
229
+ imports = [
230
+ line.strip()[7:]
231
+ for line in content.splitlines()
232
+ if line.strip().startswith("import ")
233
+ ]
234
+
235
+ symbols = client.get_document_symbols(path)
236
+ if not symbols and not imports:
237
+ return FileOutline(imports=[], declarations=[])
238
+
239
+ # Flatten symbol tree and extract namespace declarations
240
+ all_symbols = _flatten_symbols(symbols, content=content)
241
+
242
+ # Get info trees only for LSP symbols (not extracted declarations)
243
+ lsp_methods = [
244
+ s
245
+ for s, _ in all_symbols
246
+ if s.get("kind") in METHOD_KIND and "_keyword" not in s
247
+ ]
248
+ info_trees = _get_info_trees(client, path, lsp_methods)
249
+
250
+ # Extract type signatures and fields from info trees
251
+ type_sigs = {
252
+ name: sig
253
+ for name, info in info_trees.items()
254
+ if (sig := _extract_type(info, name))
255
+ }
256
+ fields_map = {
257
+ name: fields
258
+ for name, info in info_trees.items()
259
+ if (fields := _extract_fields(info, name))
260
+ }
261
+
262
+ # Build declarations list
263
+ declarations = []
264
+ for sym, indent in all_symbols:
265
+ if (
266
+ sym.get("kind") in METHOD_KIND
267
+ or sym.get("_keyword")
268
+ or sym.get("kind") == "namespace"
269
+ ):
270
+ entry = _build_outline_entry(sym, type_sigs, fields_map, indent)
271
+ if entry:
272
+ declarations.append(entry)
273
+
274
+ return FileOutline(imports=imports, declarations=declarations)
275
+
276
+
161
277
  def generate_outline(client: LeanLSPClient, path: str) -> str:
162
278
  """Generate a concise outline of a Lean file showing structure and signatures."""
163
279
  client.open_file(path)
164
280
  content = client.get_file_content(path)
165
281
 
166
282
  # Extract imports
167
- imports = [line.strip()[7:] for line in content.splitlines()
168
- if line.strip().startswith("import ")]
283
+ imports = [
284
+ line.strip()[7:]
285
+ for line in content.splitlines()
286
+ if line.strip().startswith("import ")
287
+ ]
169
288
 
170
289
  symbols = client.get_document_symbols(path)
171
290
  if not symbols and not imports:
@@ -175,14 +294,24 @@ def generate_outline(client: LeanLSPClient, path: str) -> str:
175
294
  all_symbols = _flatten_symbols(symbols, content=content)
176
295
 
177
296
  # Get info trees only for LSP symbols (not extracted declarations)
178
- lsp_methods = [s for s, _ in all_symbols if s.get('kind') in METHOD_KIND and '_keyword' not in s]
297
+ lsp_methods = [
298
+ s
299
+ for s, _ in all_symbols
300
+ if s.get("kind") in METHOD_KIND and "_keyword" not in s
301
+ ]
179
302
  info_trees = _get_info_trees(client, path, lsp_methods)
180
303
 
181
304
  # Extract type signatures and fields from info trees
182
- type_sigs = {name: sig for name, info in info_trees.items()
183
- if (sig := _extract_type(info, name))}
184
- fields_map = {name: fields for name, info in info_trees.items()
185
- if (fields := _extract_fields(info, name))}
305
+ type_sigs = {
306
+ name: sig
307
+ for name, info in info_trees.items()
308
+ if (sig := _extract_type(info, name))
309
+ }
310
+ fields_map = {
311
+ name: fields
312
+ for name, info in info_trees.items()
313
+ if (fields := _extract_fields(info, name))
314
+ }
186
315
 
187
316
  # Build output
188
317
  parts = []
@@ -193,7 +322,9 @@ def generate_outline(client: LeanLSPClient, path: str) -> str:
193
322
  declarations = [
194
323
  _format_symbol(sym, type_sigs, fields_map, indent)
195
324
  for sym, indent in all_symbols
196
- if sym.get('kind') in METHOD_KIND or sym.get('_keyword') or sym.get('kind') == 'namespace'
325
+ if sym.get("kind") in METHOD_KIND
326
+ or sym.get("_keyword")
327
+ or sym.get("kind") == "namespace"
197
328
  ]
198
329
  parts.append("## Declarations\n" + "".join(declarations).rstrip())
199
330