lean-lsp-mcp 0.1.7__py3-none-any.whl → 0.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lean_lsp_mcp/utils.py CHANGED
@@ -1,37 +1,74 @@
1
1
  import os
2
+ import secrets
2
3
  import sys
3
- from typing import List, Dict
4
+ import tempfile
5
+ from typing import List, Dict, Optional
4
6
 
7
+ from mcp.server.auth.provider import AccessToken, TokenVerifier
5
8
 
6
- class StdoutToStderr:
7
- """Redirects stdout to stderr at the file descriptor level bc lake build logging"""
9
+
10
+ class OutputCapture:
11
+ """Capture any output to stdout and stderr at the file descriptor level."""
8
12
 
9
13
  def __init__(self):
10
14
  self.original_stdout_fd = None
15
+ self.original_stderr_fd = None
16
+ self.temp_file = None
17
+ self.captured_output = ""
11
18
 
12
19
  def __enter__(self):
20
+ self.temp_file = tempfile.NamedTemporaryFile(
21
+ mode="w+", delete=False, encoding="utf-8"
22
+ )
13
23
  self.original_stdout_fd = os.dup(sys.stdout.fileno())
14
- stderr_fd = sys.stderr.fileno()
15
- os.dup2(stderr_fd, sys.stdout.fileno())
24
+ self.original_stderr_fd = os.dup(sys.stderr.fileno())
25
+ os.dup2(self.temp_file.fileno(), sys.stdout.fileno())
26
+ os.dup2(self.temp_file.fileno(), sys.stderr.fileno())
16
27
  return self
17
28
 
18
29
  def __exit__(self, exc_type, exc_val, exc_tb):
19
- if self.original_stdout_fd is not None:
20
- os.dup2(self.original_stdout_fd, sys.stdout.fileno())
21
- os.close(self.original_stdout_fd)
22
- self.original_stdout_fd = None
30
+ os.dup2(self.original_stdout_fd, sys.stdout.fileno())
31
+ os.dup2(self.original_stderr_fd, sys.stderr.fileno())
32
+ os.close(self.original_stdout_fd)
33
+ os.close(self.original_stderr_fd)
34
+
35
+ self.temp_file.flush()
36
+ self.temp_file.seek(0)
37
+ self.captured_output = self.temp_file.read()
38
+ self.temp_file.close()
39
+ os.unlink(self.temp_file.name)
40
+
41
+ def get_output(self):
42
+ return self.captured_output
43
+
44
+
45
+ class OptionalTokenVerifier(TokenVerifier):
46
+ """Minimal verifier that accepts a single pre-shared token."""
47
+
48
+ def __init__(self, expected_token: str):
49
+ self._expected_token = expected_token
50
+
51
+ async def verify_token(self, token: str | None) -> AccessToken | None:
52
+ if token is None or not secrets.compare_digest(token, self._expected_token):
53
+ return None
54
+ # AccessToken requires both client_id and scopes parameters to be provided.
55
+ return AccessToken(token=token, client_id="lean-lsp-mcp-optional", scopes=[])
23
56
 
24
57
 
25
- def format_diagnostics(diagnostics: List[Dict]) -> List[str]:
58
+ def format_diagnostics(diagnostics: List[Dict], select_line: int = -1) -> List[str]:
26
59
  """Format the diagnostics messages.
27
60
 
28
61
  Args:
29
62
  diagnostics (List[Dict]): List of diagnostics.
63
+ select_line (int): If -1, format all diagnostics. If >= 0, only format diagnostics for this line.
30
64
 
31
65
  Returns:
32
66
  List[str]: Formatted diagnostics messages.
33
67
  """
34
68
  msgs = []
69
+ if select_line != -1:
70
+ diagnostics = filter_diagnostics_by_position(diagnostics, select_line, None)
71
+
35
72
  # Format more compact
36
73
  for diag in diagnostics:
37
74
  r = diag.get("fullRange", diag.get("range", None))
@@ -41,3 +78,184 @@ def format_diagnostics(diagnostics: List[Dict]) -> List[str]:
41
78
  r_text = f"l{r['start']['line'] + 1}c{r['start']['character'] + 1}-l{r['end']['line'] + 1}c{r['end']['character'] + 1}"
42
79
  msgs.append(f"{r_text}, severity: {diag['severity']}\n{diag['message']}")
43
80
  return msgs
81
+
82
+
83
+ def format_goal(goal, default_msg):
84
+ if goal is None:
85
+ return default_msg
86
+ rendered = goal.get("rendered")
87
+ return rendered.replace("```lean\n", "").replace("\n```", "") if rendered else None
88
+
89
+
90
+ def _utf16_index_to_py_index(text: str, utf16_index: int) -> int | None:
91
+ """Convert an LSP UTF-16 column index into a Python string index."""
92
+ if utf16_index < 0:
93
+ return None
94
+
95
+ units = 0
96
+ for idx, ch in enumerate(text):
97
+ code_point = ord(ch)
98
+ next_units = units + (2 if code_point > 0xFFFF else 1)
99
+
100
+ if utf16_index < next_units:
101
+ return idx
102
+ if utf16_index == next_units:
103
+ return idx + 1
104
+
105
+ units = next_units
106
+ if units >= utf16_index:
107
+ return len(text)
108
+ return None
109
+
110
+
111
+ def extract_range(content: str, range: dict) -> str:
112
+ """Extract the text from the content based on the range.
113
+
114
+ Args:
115
+ content (str): The content to extract from.
116
+ range (dict): The range to extract.
117
+
118
+ Returns:
119
+ str: The extracted range text.
120
+ """
121
+ start_line = range["start"]["line"]
122
+ start_char = range["start"]["character"]
123
+ end_line = range["end"]["line"]
124
+ end_char = range["end"]["character"]
125
+
126
+ lines = content.splitlines(keepends=True)
127
+ if not lines:
128
+ lines = [""]
129
+
130
+ line_offsets: List[int] = []
131
+ offset = 0
132
+ for line in lines:
133
+ line_offsets.append(offset)
134
+ offset += len(line)
135
+ total_length = len(content)
136
+
137
+ def position_to_offset(line: int, character: int) -> int | None:
138
+ if line == len(lines) and character == 0:
139
+ return total_length
140
+ if line < 0 or line >= len(lines):
141
+ return None
142
+ py_index = _utf16_index_to_py_index(lines[line], character)
143
+ if py_index is None:
144
+ return None
145
+ if py_index > len(lines[line]):
146
+ return None
147
+ return line_offsets[line] + py_index
148
+
149
+ start_offset = position_to_offset(start_line, start_char)
150
+ end_offset = position_to_offset(end_line, end_char)
151
+
152
+ if start_offset is None or end_offset is None or start_offset > end_offset:
153
+ return "Range out of bounds"
154
+
155
+ return content[start_offset:end_offset]
156
+
157
+
158
+ def find_start_position(content: str, query: str) -> dict | None:
159
+ """Find the position of the query in the content.
160
+
161
+ Args:
162
+ content (str): The content to search in.
163
+ query (str): The query to find.
164
+
165
+ Returns:
166
+ dict | None: The position of the query in the content. {"line": int, "column": int}
167
+ """
168
+ lines = content.splitlines()
169
+ for line_number, line in enumerate(lines):
170
+ char_index = line.find(query)
171
+ if char_index != -1:
172
+ return {"line": line_number, "column": char_index}
173
+ return None
174
+
175
+
176
+ def format_line(
177
+ file_content: str,
178
+ line_number: int,
179
+ column: Optional[int] = None,
180
+ cursor_tag: Optional[str] = "<cursor>",
181
+ ) -> str:
182
+ """Show a line and cursor position in a file.
183
+
184
+ Args:
185
+ file_content (str): The content of the file.
186
+ line_number (int): The line number (1-indexed).
187
+ column (Optional[int]): The column number (1-indexed). If None, no cursor position is shown.
188
+ cursor_tag (Optional[str]): The tag to use for the cursor position. Defaults to "<cursor>".
189
+ Returns:
190
+ str: The formatted position.
191
+ """
192
+ lines = file_content.splitlines()
193
+ line_number -= 1
194
+ if line_number < 0 or line_number >= len(lines):
195
+ return "Line number out of range"
196
+ line = lines[line_number]
197
+ if column is None:
198
+ return line
199
+ column -= 1
200
+ # Allow placing the cursor at end-of-line (column == len(line))
201
+ if column < 0 or column > len(line):
202
+ return "Invalid column number"
203
+ return f"{line[:column]}{cursor_tag}{line[column:]}"
204
+
205
+
206
+ def filter_diagnostics_by_position(
207
+ diagnostics: List[Dict], line: Optional[int], column: Optional[int]
208
+ ) -> List[Dict]:
209
+ """Return diagnostics that intersect the requested (0-indexed) position."""
210
+
211
+ if line is None:
212
+ return list(diagnostics)
213
+
214
+ matches: List[Dict] = []
215
+ for diagnostic in diagnostics:
216
+ diagnostic_range = diagnostic.get("range") or diagnostic.get("fullRange")
217
+ if not diagnostic_range:
218
+ continue
219
+
220
+ start = diagnostic_range.get("start", {})
221
+ end = diagnostic_range.get("end", {})
222
+ start_line = start.get("line")
223
+ end_line = end.get("line")
224
+
225
+ if start_line is None or end_line is None:
226
+ continue
227
+ if line < start_line or line > end_line:
228
+ continue
229
+
230
+ start_char = start.get("character")
231
+ end_char = end.get("character")
232
+
233
+ if column is None:
234
+ if (
235
+ line == end_line
236
+ and line != start_line
237
+ and end_char is not None
238
+ and end_char == 0
239
+ ):
240
+ continue
241
+ matches.append(diagnostic)
242
+ continue
243
+
244
+ if start_char is None:
245
+ start_char = 0
246
+ if end_char is None:
247
+ end_char = column + 1
248
+
249
+ if start_line == end_line and start_char == end_char:
250
+ if column == start_char:
251
+ matches.append(diagnostic)
252
+ continue
253
+
254
+ if line == start_line and column < start_char:
255
+ continue
256
+ if line == end_line and column >= end_char:
257
+ continue
258
+
259
+ matches.append(diagnostic)
260
+
261
+ return matches