wafer-lsp 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/__init__.py +1 -0
- wafer_lsp/__main__.py +9 -0
- wafer_lsp/analyzers/__init__.py +0 -0
- wafer_lsp/analyzers/compiler_integration.py +16 -0
- wafer_lsp/analyzers/docs_index.py +36 -0
- wafer_lsp/handlers/__init__.py +30 -0
- wafer_lsp/handlers/code_action.py +48 -0
- wafer_lsp/handlers/code_lens.py +48 -0
- wafer_lsp/handlers/completion.py +6 -0
- wafer_lsp/handlers/diagnostics.py +41 -0
- wafer_lsp/handlers/document_symbol.py +176 -0
- wafer_lsp/handlers/hip_diagnostics.py +303 -0
- wafer_lsp/handlers/hover.py +251 -0
- wafer_lsp/handlers/inlay_hint.py +245 -0
- wafer_lsp/handlers/semantic_tokens.py +224 -0
- wafer_lsp/handlers/workspace_symbol.py +87 -0
- wafer_lsp/languages/README.md +195 -0
- wafer_lsp/languages/__init__.py +17 -0
- wafer_lsp/languages/converter.py +88 -0
- wafer_lsp/languages/detector.py +107 -0
- wafer_lsp/languages/parser_manager.py +33 -0
- wafer_lsp/languages/registry.py +120 -0
- wafer_lsp/languages/types.py +37 -0
- wafer_lsp/parsers/__init__.py +36 -0
- wafer_lsp/parsers/base_parser.py +9 -0
- wafer_lsp/parsers/cuda_parser.py +95 -0
- wafer_lsp/parsers/cutedsl_parser.py +114 -0
- wafer_lsp/parsers/hip_parser.py +688 -0
- wafer_lsp/server.py +58 -0
- wafer_lsp/services/__init__.py +38 -0
- wafer_lsp/services/analysis_service.py +22 -0
- wafer_lsp/services/docs_service.py +40 -0
- wafer_lsp/services/document_service.py +20 -0
- wafer_lsp/services/hip_docs.py +806 -0
- wafer_lsp/services/hip_hover_service.py +412 -0
- wafer_lsp/services/hover_service.py +237 -0
- wafer_lsp/services/language_registry_service.py +26 -0
- wafer_lsp/services/position_service.py +77 -0
- wafer_lsp/utils/__init__.py +0 -0
- wafer_lsp/utils/lsp_helpers.py +79 -0
- wafer_lsp-0.1.13.dist-info/METADATA +60 -0
- wafer_lsp-0.1.13.dist-info/RECORD +44 -0
- wafer_lsp-0.1.13.dist-info/WHEEL +4 -0
- wafer_lsp-0.1.13.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HIP Hover Service.
|
|
3
|
+
|
|
4
|
+
Provides rich hover documentation for HIP code including:
|
|
5
|
+
- HIP API functions (hipMalloc, hipMemcpy, etc.)
|
|
6
|
+
- Memory qualifiers (__device__, __shared__, __constant__)
|
|
7
|
+
- Wavefront intrinsics (__shfl, __ballot, etc.)
|
|
8
|
+
- Thread indexing (threadIdx, blockIdx, etc.)
|
|
9
|
+
- Kernel function information
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from lsprotocol.types import Hover, MarkupContent, MarkupKind, Position
|
|
13
|
+
|
|
14
|
+
from ..parsers.hip_parser import (
|
|
15
|
+
HIPKernel,
|
|
16
|
+
HIPDeviceFunction,
|
|
17
|
+
SharedMemoryAllocation,
|
|
18
|
+
KernelLaunchSite,
|
|
19
|
+
HIPParser,
|
|
20
|
+
)
|
|
21
|
+
from .hip_docs import (
|
|
22
|
+
HIPDocsService,
|
|
23
|
+
HIPAPIDoc,
|
|
24
|
+
MemoryQualifierDoc,
|
|
25
|
+
IntrinsicDoc,
|
|
26
|
+
ThreadIndexDoc,
|
|
27
|
+
create_hip_docs_service,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class HIPHoverService:
|
|
32
|
+
"""Provides hover documentation for HIP code."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, docs_service: HIPDocsService | None = None):
|
|
35
|
+
self._docs = docs_service or create_hip_docs_service()
|
|
36
|
+
self._parser = HIPParser()
|
|
37
|
+
|
|
38
|
+
def get_hover(self, content: str, position: Position, uri: str) -> Hover | None:
|
|
39
|
+
"""Get hover information for a position in HIP code.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
content: The document content
|
|
43
|
+
position: The cursor position
|
|
44
|
+
uri: The document URI
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Hover information or None
|
|
48
|
+
"""
|
|
49
|
+
word = self._get_word_at_position(content, position)
|
|
50
|
+
if not word:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
# Try different types of hover in order of specificity
|
|
54
|
+
|
|
55
|
+
# 1. Check for HIP API functions
|
|
56
|
+
api_doc = self._docs.get_api_doc(word)
|
|
57
|
+
if api_doc:
|
|
58
|
+
return self._format_api_hover(api_doc)
|
|
59
|
+
|
|
60
|
+
# 2. Check for memory qualifiers (including in context)
|
|
61
|
+
qualifier_doc = self._docs.get_memory_qualifier_doc(word)
|
|
62
|
+
if qualifier_doc:
|
|
63
|
+
return self._format_qualifier_hover(qualifier_doc)
|
|
64
|
+
|
|
65
|
+
# 3. Check for intrinsics
|
|
66
|
+
intrinsic_doc = self._docs.get_intrinsic_doc(word)
|
|
67
|
+
if intrinsic_doc:
|
|
68
|
+
return self._format_intrinsic_hover(intrinsic_doc)
|
|
69
|
+
|
|
70
|
+
# 4. Check for thread indexing variables
|
|
71
|
+
thread_doc = self._docs.get_thread_index_doc(word)
|
|
72
|
+
if thread_doc:
|
|
73
|
+
return self._format_thread_index_hover(thread_doc)
|
|
74
|
+
|
|
75
|
+
# 5. Check for kernels/device functions in the file
|
|
76
|
+
parsed = self._parser.parse_file(content)
|
|
77
|
+
|
|
78
|
+
for kernel in parsed.get("kernels", []):
|
|
79
|
+
if kernel.name == word:
|
|
80
|
+
return self._format_kernel_hover(kernel)
|
|
81
|
+
|
|
82
|
+
for device_func in parsed.get("device_functions", []):
|
|
83
|
+
if device_func.name == word:
|
|
84
|
+
return self._format_device_function_hover(device_func)
|
|
85
|
+
|
|
86
|
+
# 6. Check for shared memory variables
|
|
87
|
+
for shared_var in parsed.get("shared_memory", []):
|
|
88
|
+
if shared_var.name == word:
|
|
89
|
+
return self._format_shared_memory_hover(shared_var)
|
|
90
|
+
|
|
91
|
+
# 7. Check for kernel launch (when hovering on kernel name at launch site)
|
|
92
|
+
for launch in parsed.get("launch_sites", []):
|
|
93
|
+
if launch.kernel_name == word and position.line == launch.line:
|
|
94
|
+
return self._format_launch_site_hover(launch)
|
|
95
|
+
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
def _get_word_at_position(self, content: str, position: Position) -> str:
|
|
99
|
+
"""Extract the word at the given position."""
|
|
100
|
+
lines = content.split("\n")
|
|
101
|
+
if position.line >= len(lines):
|
|
102
|
+
return ""
|
|
103
|
+
|
|
104
|
+
line = lines[position.line]
|
|
105
|
+
if position.character >= len(line):
|
|
106
|
+
return ""
|
|
107
|
+
|
|
108
|
+
# Find word boundaries (include underscores for __global__ etc.)
|
|
109
|
+
start = position.character
|
|
110
|
+
while start > 0 and (line[start - 1].isalnum() or line[start - 1] == "_"):
|
|
111
|
+
start -= 1
|
|
112
|
+
|
|
113
|
+
end = position.character
|
|
114
|
+
while end < len(line) and (line[end].isalnum() or line[end] == "_"):
|
|
115
|
+
end += 1
|
|
116
|
+
|
|
117
|
+
return line[start:end]
|
|
118
|
+
|
|
119
|
+
def _format_api_hover(self, doc: HIPAPIDoc) -> Hover:
|
|
120
|
+
"""Format hover content for a HIP API function."""
|
|
121
|
+
lines = [
|
|
122
|
+
f"### `{doc.name}`",
|
|
123
|
+
"",
|
|
124
|
+
f"```cpp",
|
|
125
|
+
doc.signature,
|
|
126
|
+
"```",
|
|
127
|
+
"",
|
|
128
|
+
doc.description,
|
|
129
|
+
"",
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
if doc.parameters:
|
|
133
|
+
lines.append("**Parameters:**")
|
|
134
|
+
for param_name, param_desc in doc.parameters:
|
|
135
|
+
lines.append(f"- `{param_name}`: {param_desc}")
|
|
136
|
+
lines.append("")
|
|
137
|
+
|
|
138
|
+
lines.append(f"**Returns:** {doc.return_value}")
|
|
139
|
+
lines.append("")
|
|
140
|
+
|
|
141
|
+
if doc.amd_notes:
|
|
142
|
+
lines.append("**AMD Notes:**")
|
|
143
|
+
lines.append(f"> {doc.amd_notes}")
|
|
144
|
+
lines.append("")
|
|
145
|
+
|
|
146
|
+
if doc.example:
|
|
147
|
+
lines.append("**Example:**")
|
|
148
|
+
lines.append("```cpp")
|
|
149
|
+
lines.extend(doc.example.replace("\\n", "\n").split("\n"))
|
|
150
|
+
lines.append("```")
|
|
151
|
+
lines.append("")
|
|
152
|
+
|
|
153
|
+
if doc.related:
|
|
154
|
+
lines.append(f"**Related:** {', '.join(f'`{r}`' for r in doc.related)}")
|
|
155
|
+
|
|
156
|
+
if doc.doc_url:
|
|
157
|
+
lines.append("")
|
|
158
|
+
lines.append(f"[📖 Documentation]({doc.doc_url})")
|
|
159
|
+
|
|
160
|
+
return Hover(contents=MarkupContent(
|
|
161
|
+
kind=MarkupKind.Markdown,
|
|
162
|
+
value="\n".join(lines)
|
|
163
|
+
))
|
|
164
|
+
|
|
165
|
+
def _format_qualifier_hover(self, doc: MemoryQualifierDoc) -> Hover:
|
|
166
|
+
"""Format hover content for a memory qualifier."""
|
|
167
|
+
lines = [
|
|
168
|
+
f"### `{doc.name}`",
|
|
169
|
+
"",
|
|
170
|
+
doc.description,
|
|
171
|
+
"",
|
|
172
|
+
"**AMD Architecture:**",
|
|
173
|
+
doc.amd_details,
|
|
174
|
+
"",
|
|
175
|
+
"**Performance Tips:**",
|
|
176
|
+
doc.performance_tips,
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
if doc.example:
|
|
180
|
+
lines.append("")
|
|
181
|
+
lines.append("**Example:**")
|
|
182
|
+
lines.append("```cpp")
|
|
183
|
+
lines.extend(doc.example.split("\n"))
|
|
184
|
+
lines.append("```")
|
|
185
|
+
|
|
186
|
+
return Hover(contents=MarkupContent(
|
|
187
|
+
kind=MarkupKind.Markdown,
|
|
188
|
+
value="\n".join(lines)
|
|
189
|
+
))
|
|
190
|
+
|
|
191
|
+
def _format_intrinsic_hover(self, doc: IntrinsicDoc) -> Hover:
|
|
192
|
+
"""Format hover content for a wavefront intrinsic."""
|
|
193
|
+
lines = [
|
|
194
|
+
f"### `{doc.name}`",
|
|
195
|
+
"",
|
|
196
|
+
f"```cpp",
|
|
197
|
+
doc.signature,
|
|
198
|
+
"```",
|
|
199
|
+
"",
|
|
200
|
+
doc.description,
|
|
201
|
+
"",
|
|
202
|
+
"**⚠️ AMD Wavefront Behavior:**",
|
|
203
|
+
f"> {doc.amd_behavior}",
|
|
204
|
+
"",
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
if doc.parameters:
|
|
208
|
+
lines.append("**Parameters:**")
|
|
209
|
+
for param_name, param_desc in doc.parameters:
|
|
210
|
+
lines.append(f"- `{param_name}`: {param_desc}")
|
|
211
|
+
lines.append("")
|
|
212
|
+
|
|
213
|
+
lines.append(f"**Returns:** {doc.return_value}")
|
|
214
|
+
lines.append("")
|
|
215
|
+
|
|
216
|
+
if doc.example:
|
|
217
|
+
lines.append("**Example:**")
|
|
218
|
+
lines.append("```cpp")
|
|
219
|
+
lines.extend(doc.example.replace("\\n", "\n").split("\n"))
|
|
220
|
+
lines.append("```")
|
|
221
|
+
lines.append("")
|
|
222
|
+
|
|
223
|
+
if doc.cuda_equivalent:
|
|
224
|
+
lines.append(f"**CUDA Equivalent:** `{doc.cuda_equivalent}`")
|
|
225
|
+
|
|
226
|
+
return Hover(contents=MarkupContent(
|
|
227
|
+
kind=MarkupKind.Markdown,
|
|
228
|
+
value="\n".join(lines)
|
|
229
|
+
))
|
|
230
|
+
|
|
231
|
+
def _format_thread_index_hover(self, doc: ThreadIndexDoc) -> Hover:
|
|
232
|
+
"""Format hover content for thread indexing variables."""
|
|
233
|
+
lines = [
|
|
234
|
+
f"### `{doc.name}`",
|
|
235
|
+
"",
|
|
236
|
+
doc.description,
|
|
237
|
+
"",
|
|
238
|
+
"**AMD Context:**",
|
|
239
|
+
f"> {doc.amd_context}",
|
|
240
|
+
"",
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
if doc.common_patterns:
|
|
244
|
+
lines.append("**Common Patterns:**")
|
|
245
|
+
lines.append("```cpp")
|
|
246
|
+
for pattern in doc.common_patterns:
|
|
247
|
+
lines.append(pattern)
|
|
248
|
+
lines.append("```")
|
|
249
|
+
|
|
250
|
+
return Hover(contents=MarkupContent(
|
|
251
|
+
kind=MarkupKind.Markdown,
|
|
252
|
+
value="\n".join(lines)
|
|
253
|
+
))
|
|
254
|
+
|
|
255
|
+
def _format_kernel_hover(self, kernel: HIPKernel) -> Hover:
|
|
256
|
+
"""Format hover content for a kernel function."""
|
|
257
|
+
lines = [
|
|
258
|
+
f"### 🚀 HIP Kernel: `{kernel.name}`",
|
|
259
|
+
"",
|
|
260
|
+
]
|
|
261
|
+
|
|
262
|
+
if kernel.docstring:
|
|
263
|
+
lines.append(kernel.docstring)
|
|
264
|
+
lines.append("")
|
|
265
|
+
|
|
266
|
+
# Build signature
|
|
267
|
+
params_str = ", ".join(kernel.parameters) if kernel.parameters else ""
|
|
268
|
+
lines.append("```cpp")
|
|
269
|
+
if kernel.attributes:
|
|
270
|
+
for attr in kernel.attributes:
|
|
271
|
+
lines.append(attr)
|
|
272
|
+
lines.append(f"__global__ void {kernel.name}({params_str})")
|
|
273
|
+
lines.append("```")
|
|
274
|
+
lines.append("")
|
|
275
|
+
|
|
276
|
+
if kernel.parameter_info:
|
|
277
|
+
lines.append("**Parameters:**")
|
|
278
|
+
for param in kernel.parameter_info:
|
|
279
|
+
type_info = f" (`{param.type_str}`)" if param.type_str else ""
|
|
280
|
+
lines.append(f"- `{param.name}`{type_info}")
|
|
281
|
+
lines.append("")
|
|
282
|
+
|
|
283
|
+
lines.append(f"**Location:** Lines {kernel.line + 1} - {kernel.end_line + 1}")
|
|
284
|
+
lines.append("")
|
|
285
|
+
lines.append("**AMD GPU Execution:**")
|
|
286
|
+
lines.append("- Executed on GPU Compute Units")
|
|
287
|
+
lines.append("- Threads grouped into 64-thread wavefronts (CDNA)")
|
|
288
|
+
lines.append("- Use `<<<grid, block>>>` or `hipLaunchKernelGGL` to launch")
|
|
289
|
+
|
|
290
|
+
return Hover(contents=MarkupContent(
|
|
291
|
+
kind=MarkupKind.Markdown,
|
|
292
|
+
value="\n".join(lines)
|
|
293
|
+
))
|
|
294
|
+
|
|
295
|
+
def _format_device_function_hover(self, func: HIPDeviceFunction) -> Hover:
|
|
296
|
+
"""Format hover content for a device function."""
|
|
297
|
+
lines = [
|
|
298
|
+
f"### ⚡ Device Function: `{func.name}`",
|
|
299
|
+
"",
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
# Build signature
|
|
303
|
+
params_str = ", ".join(func.parameters) if func.parameters else ""
|
|
304
|
+
inline_str = "__forceinline__ " if func.is_inline else ""
|
|
305
|
+
lines.append("```cpp")
|
|
306
|
+
lines.append(f"__device__ {inline_str}{func.return_type} {func.name}({params_str})")
|
|
307
|
+
lines.append("```")
|
|
308
|
+
lines.append("")
|
|
309
|
+
|
|
310
|
+
if func.parameter_info:
|
|
311
|
+
lines.append("**Parameters:**")
|
|
312
|
+
for param in func.parameter_info:
|
|
313
|
+
type_info = f" (`{param.type_str}`)" if param.type_str else ""
|
|
314
|
+
lines.append(f"- `{param.name}`{type_info}")
|
|
315
|
+
lines.append("")
|
|
316
|
+
|
|
317
|
+
lines.append(f"**Returns:** `{func.return_type}`")
|
|
318
|
+
lines.append("")
|
|
319
|
+
lines.append(f"**Location:** Lines {func.line + 1} - {func.end_line + 1}")
|
|
320
|
+
lines.append("")
|
|
321
|
+
lines.append("**Note:** Device functions can only be called from kernel or other device functions.")
|
|
322
|
+
|
|
323
|
+
return Hover(contents=MarkupContent(
|
|
324
|
+
kind=MarkupKind.Markdown,
|
|
325
|
+
value="\n".join(lines)
|
|
326
|
+
))
|
|
327
|
+
|
|
328
|
+
def _format_shared_memory_hover(self, shared: SharedMemoryAllocation) -> Hover:
|
|
329
|
+
"""Format hover content for a shared memory allocation."""
|
|
330
|
+
lines = [
|
|
331
|
+
f"### 📦 Shared Memory (LDS): `{shared.name}`",
|
|
332
|
+
"",
|
|
333
|
+
f"**Type:** `{shared.type_str}`",
|
|
334
|
+
]
|
|
335
|
+
|
|
336
|
+
if shared.array_size:
|
|
337
|
+
lines.append(f"**Array Size:** `[{shared.array_size}]`")
|
|
338
|
+
|
|
339
|
+
if shared.size_bytes:
|
|
340
|
+
if shared.size_bytes >= 1024:
|
|
341
|
+
size_str = f"{shared.size_bytes / 1024:.1f} KB"
|
|
342
|
+
else:
|
|
343
|
+
size_str = f"{shared.size_bytes} bytes"
|
|
344
|
+
lines.append(f"**Size:** {size_str}")
|
|
345
|
+
|
|
346
|
+
if shared.is_dynamic:
|
|
347
|
+
lines.append("**Allocation:** Dynamic (extern)")
|
|
348
|
+
else:
|
|
349
|
+
lines.append("**Allocation:** Static")
|
|
350
|
+
|
|
351
|
+
lines.append("")
|
|
352
|
+
lines.append("**AMD LDS Details:**")
|
|
353
|
+
lines.append("- On-chip memory with ~100x lower latency than HBM")
|
|
354
|
+
lines.append("- 64 KB per Compute Unit")
|
|
355
|
+
lines.append("- Shared by all threads in the block")
|
|
356
|
+
lines.append("- 32 banks of 4 bytes each")
|
|
357
|
+
lines.append("")
|
|
358
|
+
lines.append("💡 **Tip:** Avoid bank conflicts by using padding: `[SIZE + 1]`")
|
|
359
|
+
|
|
360
|
+
return Hover(contents=MarkupContent(
|
|
361
|
+
kind=MarkupKind.Markdown,
|
|
362
|
+
value="\n".join(lines)
|
|
363
|
+
))
|
|
364
|
+
|
|
365
|
+
def _format_launch_site_hover(self, launch: KernelLaunchSite) -> Hover:
|
|
366
|
+
"""Format hover content for a kernel launch site."""
|
|
367
|
+
lines = [
|
|
368
|
+
f"### 🎯 Kernel Launch: `{launch.kernel_name}`",
|
|
369
|
+
"",
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
if launch.is_hip_launch_kernel_ggl:
|
|
373
|
+
lines.append("**Launch Method:** `hipLaunchKernelGGL`")
|
|
374
|
+
else:
|
|
375
|
+
lines.append("**Launch Method:** `<<<>>>` syntax")
|
|
376
|
+
|
|
377
|
+
lines.append("")
|
|
378
|
+
|
|
379
|
+
if launch.grid_dim:
|
|
380
|
+
lines.append(f"**Grid Dimensions:** `{launch.grid_dim}`")
|
|
381
|
+
if launch.block_dim:
|
|
382
|
+
lines.append(f"**Block Dimensions:** `{launch.block_dim}`")
|
|
383
|
+
# Try to parse block dimensions for wavefront info
|
|
384
|
+
self._add_wavefront_info(lines, launch.block_dim)
|
|
385
|
+
|
|
386
|
+
if launch.shared_mem_bytes:
|
|
387
|
+
lines.append(f"**Dynamic Shared Memory:** `{launch.shared_mem_bytes}`")
|
|
388
|
+
if launch.stream:
|
|
389
|
+
lines.append(f"**Stream:** `{launch.stream}`")
|
|
390
|
+
|
|
391
|
+
return Hover(contents=MarkupContent(
|
|
392
|
+
kind=MarkupKind.Markdown,
|
|
393
|
+
value="\n".join(lines)
|
|
394
|
+
))
|
|
395
|
+
|
|
396
|
+
def _add_wavefront_info(self, lines: list[str], block_dim: str) -> None:
|
|
397
|
+
"""Add wavefront information based on block dimensions.
|
|
398
|
+
|
|
399
|
+
Only adds info when block_dim is a simple numeric literal we can parse.
|
|
400
|
+
For complex expressions (variables, dim3, etc.), we don't display wavefront count
|
|
401
|
+
because we can't determine the value at parse time.
|
|
402
|
+
"""
|
|
403
|
+
# Only handle simple numeric block size - we don't guess at complex expressions
|
|
404
|
+
if block_dim.isdigit():
|
|
405
|
+
block_size = int(block_dim)
|
|
406
|
+
wavefronts = (block_size + 63) // 64
|
|
407
|
+
lines.append(f"**Wavefronts per Block:** {wavefronts} (64 threads each on CDNA)")
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def create_hip_hover_service(docs_service: HIPDocsService | None = None) -> HIPHoverService:
|
|
411
|
+
"""Create a HIP hover service instance."""
|
|
412
|
+
return HIPHoverService(docs_service)
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import Hover, MarkupContent, MarkupKind, Position
|
|
3
|
+
|
|
4
|
+
from ..languages.types import KernelInfo, LayoutInfo
|
|
5
|
+
from .analysis_service import AnalysisService
|
|
6
|
+
from .docs_service import DocsService
|
|
7
|
+
from .language_registry_service import LanguageRegistryService
|
|
8
|
+
from .position_service import PositionService
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HoverService:
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
language_registry: LanguageRegistryService,
|
|
16
|
+
analysis_service: AnalysisService,
|
|
17
|
+
docs_service: DocsService,
|
|
18
|
+
position_service: PositionService
|
|
19
|
+
):
|
|
20
|
+
self._language_registry = language_registry
|
|
21
|
+
self._analysis_service = analysis_service
|
|
22
|
+
self._docs_service = docs_service
|
|
23
|
+
self._position_service = position_service
|
|
24
|
+
|
|
25
|
+
def handle_hover(self, uri: str, position: Position, content: str) -> Hover | None:
|
|
26
|
+
test_message = "🎉🎉🎉 **HEYOOO!!! LSP IS DEFINITELY WORKING!!!** 🎉🎉🎉\n\n**THIS IS THE WAFER LSP SERVER!**\n\n"
|
|
27
|
+
|
|
28
|
+
decorator_info = self._position_service.get_decorator_at_position(content, position)
|
|
29
|
+
if decorator_info:
|
|
30
|
+
decorator_name, function_line = decorator_info
|
|
31
|
+
|
|
32
|
+
function_name = None
|
|
33
|
+
lines = content.split("\n")
|
|
34
|
+
if function_line < len(lines):
|
|
35
|
+
func_line = lines[function_line].strip()
|
|
36
|
+
if func_line.startswith("def "):
|
|
37
|
+
func_name_start = func_line.find("def ") + 4
|
|
38
|
+
func_name_end = func_line.find("(", func_name_start)
|
|
39
|
+
if func_name_end > func_name_start:
|
|
40
|
+
function_name = func_line[func_name_start:func_name_end].strip()
|
|
41
|
+
elif func_line.startswith("class "):
|
|
42
|
+
class_name_start = func_line.find("class ") + 6
|
|
43
|
+
class_name_end = func_line.find(":", class_name_start)
|
|
44
|
+
if class_name_end > class_name_start:
|
|
45
|
+
function_name = func_line[class_name_start:class_name_end].strip()
|
|
46
|
+
|
|
47
|
+
hover_content = test_message + self._format_decorator_hover(decorator_name, function_name)
|
|
48
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
49
|
+
|
|
50
|
+
word = self._position_service.get_word_at_position(content, position)
|
|
51
|
+
if word == "cute" or word.startswith("cute."):
|
|
52
|
+
hover_lines = [
|
|
53
|
+
test_message,
|
|
54
|
+
"**cutlass.cute**",
|
|
55
|
+
"",
|
|
56
|
+
"CuTeDSL (CUDA Unified Tensor Expression) library for GPU programming.",
|
|
57
|
+
"",
|
|
58
|
+
"**Key Features:**",
|
|
59
|
+
"- `@cute.kernel` - Define GPU kernels",
|
|
60
|
+
"- `@cute.struct` - Define GPU structs",
|
|
61
|
+
"- `cute.make_layout()` - Create tensor layouts",
|
|
62
|
+
"- `cute.Tensor` - Tensor type annotations",
|
|
63
|
+
"",
|
|
64
|
+
"[Documentation](https://github.com/NVIDIA/cutlass)"
|
|
65
|
+
]
|
|
66
|
+
hover_content = "\n".join(hover_lines)
|
|
67
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
68
|
+
|
|
69
|
+
kernel = self._find_kernel_at_position(content, position, uri)
|
|
70
|
+
if kernel:
|
|
71
|
+
analysis = self._analysis_service.get_analysis_for_kernel(uri, kernel.name)
|
|
72
|
+
|
|
73
|
+
if analysis:
|
|
74
|
+
hover_content = test_message + self._format_kernel_hover(kernel, analysis)
|
|
75
|
+
else:
|
|
76
|
+
hover_content = test_message + self._format_kernel_hover_basic(kernel)
|
|
77
|
+
|
|
78
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
79
|
+
|
|
80
|
+
layout = self._find_layout_at_position(content, position, uri)
|
|
81
|
+
if layout:
|
|
82
|
+
doc_link = self._docs_service.get_doc_for_concept("layout")
|
|
83
|
+
|
|
84
|
+
hover_lines = [
|
|
85
|
+
test_message,
|
|
86
|
+
f"**Layout: {layout.name}**",
|
|
87
|
+
""
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
if layout.shape:
|
|
91
|
+
hover_lines.append(f"Shape: `{layout.shape}`")
|
|
92
|
+
if layout.stride:
|
|
93
|
+
hover_lines.append(f"Stride: `{layout.stride}`")
|
|
94
|
+
|
|
95
|
+
if doc_link:
|
|
96
|
+
hover_lines.append("")
|
|
97
|
+
hover_lines.append(f"[Documentation]({doc_link})")
|
|
98
|
+
|
|
99
|
+
hover_content = "\n".join(hover_lines)
|
|
100
|
+
|
|
101
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
102
|
+
|
|
103
|
+
hover_content = test_message + "**HOVER IS WORKING!** 🚀\n\nMove your cursor over any symbol, decorator, or even empty space to see LSP information.\n\n**Try hovering over:**\n- `@cute.kernel` decorators\n- `cute` module name\n- Kernel function names\n- Layout variables"
|
|
104
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
105
|
+
|
|
106
|
+
def _find_kernel_at_position(
|
|
107
|
+
self, content: str, position: Position, uri: str
|
|
108
|
+
) -> KernelInfo | None:
|
|
109
|
+
language_info = self._language_registry.parse_file(uri, content)
|
|
110
|
+
|
|
111
|
+
if not language_info:
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
word = self._position_service.get_word_at_position(content, position)
|
|
115
|
+
|
|
116
|
+
for kernel in language_info.kernels:
|
|
117
|
+
if kernel.name == word:
|
|
118
|
+
if position.line >= kernel.line:
|
|
119
|
+
return kernel
|
|
120
|
+
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
def _find_layout_at_position(
|
|
124
|
+
self, content: str, position: Position, uri: str
|
|
125
|
+
) -> LayoutInfo | None:
|
|
126
|
+
language_info = self._language_registry.parse_file(uri, content)
|
|
127
|
+
|
|
128
|
+
if not language_info:
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
word = self._position_service.get_word_at_position(content, position)
|
|
132
|
+
|
|
133
|
+
for layout in language_info.layouts:
|
|
134
|
+
if layout.name == word:
|
|
135
|
+
return layout
|
|
136
|
+
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
def _format_kernel_hover(self, kernel: KernelInfo, analysis: dict | None) -> str:
|
|
140
|
+
language_name = self._language_registry.get_language_name(kernel.language) or kernel.language
|
|
141
|
+
|
|
142
|
+
if kernel.language == "cuda" or kernel.language == "cpp":
|
|
143
|
+
lines = [f"**CUDA Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
|
|
144
|
+
else:
|
|
145
|
+
lines = [f"**GPU Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
|
|
146
|
+
|
|
147
|
+
if kernel.docstring:
|
|
148
|
+
lines.append(kernel.docstring)
|
|
149
|
+
lines.append("")
|
|
150
|
+
|
|
151
|
+
if kernel.parameters:
|
|
152
|
+
params_str = ", ".join(kernel.parameters)
|
|
153
|
+
lines.append(f"**Parameters:** `{params_str}`")
|
|
154
|
+
lines.append("")
|
|
155
|
+
|
|
156
|
+
if kernel.language == "cuda" or kernel.language == "cpp":
|
|
157
|
+
lines.append("**CUDA Features:**")
|
|
158
|
+
lines.append("- `__global__` function executed on GPU")
|
|
159
|
+
lines.append("- Can be launched with `<<<grid, block>>>` syntax")
|
|
160
|
+
lines.append("")
|
|
161
|
+
|
|
162
|
+
if analysis:
|
|
163
|
+
lines.append("**Analysis:**")
|
|
164
|
+
if "layouts" in analysis:
|
|
165
|
+
lines.append(f"- Layouts: {analysis['layouts']}")
|
|
166
|
+
if "memory_paths" in analysis:
|
|
167
|
+
lines.append(f"- Memory paths: {analysis['memory_paths']}")
|
|
168
|
+
if "pipeline_stages" in analysis:
|
|
169
|
+
lines.append(f"- Pipeline stages: {analysis['pipeline_stages']}")
|
|
170
|
+
|
|
171
|
+
return "\n".join(lines)
|
|
172
|
+
|
|
173
|
+
def _format_kernel_hover_basic(self, kernel: KernelInfo) -> str:
|
|
174
|
+
return self._format_kernel_hover(kernel, None)
|
|
175
|
+
|
|
176
|
+
def _format_decorator_hover(self, decorator_name: str, function_name: str | None = None) -> str:
|
|
177
|
+
lines = []
|
|
178
|
+
|
|
179
|
+
if decorator_name == "cute.kernel" or decorator_name == "kernel":
|
|
180
|
+
lines.append("**@cute.kernel**")
|
|
181
|
+
lines.append("")
|
|
182
|
+
lines.append("CuTeDSL kernel decorator. Marks a function as a GPU kernel.")
|
|
183
|
+
lines.append("")
|
|
184
|
+
lines.append("**Usage:**")
|
|
185
|
+
lines.append("```python")
|
|
186
|
+
lines.append("@cute.kernel")
|
|
187
|
+
lines.append("def my_kernel(a: cute.Tensor, b: cute.Tensor):")
|
|
188
|
+
lines.append(" # Kernel implementation")
|
|
189
|
+
lines.append(" pass")
|
|
190
|
+
lines.append("```")
|
|
191
|
+
lines.append("")
|
|
192
|
+
lines.append("**Features:**")
|
|
193
|
+
lines.append("- Automatic GPU code generation")
|
|
194
|
+
lines.append("- Tensor layout optimization")
|
|
195
|
+
lines.append("- Memory access pattern analysis")
|
|
196
|
+
|
|
197
|
+
if function_name:
|
|
198
|
+
lines.append("")
|
|
199
|
+
lines.append(f"Applied to: `{function_name}()`")
|
|
200
|
+
|
|
201
|
+
elif decorator_name == "cute.struct" or decorator_name == "struct":
|
|
202
|
+
lines.append("**@cute.struct**")
|
|
203
|
+
lines.append("")
|
|
204
|
+
lines.append("CuTeDSL struct decorator. Marks a class as a GPU struct.")
|
|
205
|
+
lines.append("")
|
|
206
|
+
lines.append("**Usage:**")
|
|
207
|
+
lines.append("```python")
|
|
208
|
+
lines.append("@cute.struct")
|
|
209
|
+
lines.append("class MyStruct:")
|
|
210
|
+
lines.append(" field1: int")
|
|
211
|
+
lines.append(" field2: float")
|
|
212
|
+
lines.append("```")
|
|
213
|
+
|
|
214
|
+
if function_name:
|
|
215
|
+
lines.append("")
|
|
216
|
+
lines.append(f"Applied to: `{function_name}`")
|
|
217
|
+
|
|
218
|
+
else:
|
|
219
|
+
lines.append(f"**{decorator_name}**")
|
|
220
|
+
lines.append("")
|
|
221
|
+
lines.append("CuTeDSL decorator")
|
|
222
|
+
|
|
223
|
+
doc_link = self._docs_service.get_doc_for_concept("kernel" if "kernel" in decorator_name else "struct")
|
|
224
|
+
if doc_link:
|
|
225
|
+
lines.append("")
|
|
226
|
+
lines.append(f"[Documentation]({doc_link})")
|
|
227
|
+
|
|
228
|
+
return "\n".join(lines)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def create_hover_service(
|
|
232
|
+
language_registry: LanguageRegistryService,
|
|
233
|
+
analysis_service: AnalysisService,
|
|
234
|
+
docs_service: DocsService,
|
|
235
|
+
position_service: PositionService
|
|
236
|
+
) -> HoverService:
|
|
237
|
+
return HoverService(language_registry, analysis_service, docs_service, position_service)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
|
|
2
|
+
from ..languages.registry import LanguageRegistry, get_language_registry
|
|
3
|
+
from ..languages.types import LanguageInfo
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LanguageRegistryService:
|
|
7
|
+
|
|
8
|
+
def __init__(self, registry: LanguageRegistry):
|
|
9
|
+
self._registry = registry
|
|
10
|
+
|
|
11
|
+
def detect_language(self, uri: str) -> str | None:
|
|
12
|
+
return self._registry.detect_language(uri)
|
|
13
|
+
|
|
14
|
+
def parse_file(self, uri: str, content: str) -> LanguageInfo | None:
|
|
15
|
+
return self._registry.parse_file(uri, content)
|
|
16
|
+
|
|
17
|
+
def get_language_name(self, language_id: str) -> str | None:
|
|
18
|
+
return self._registry.get_language_name(language_id)
|
|
19
|
+
|
|
20
|
+
def get_supported_extensions(self) -> list[str]:
|
|
21
|
+
return self._registry.get_supported_extensions()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_language_registry_service() -> LanguageRegistryService:
|
|
25
|
+
registry = get_language_registry()
|
|
26
|
+
return LanguageRegistryService(registry)
|