wafer-lsp 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/__init__.py +1 -0
- wafer_lsp/__main__.py +9 -0
- wafer_lsp/analyzers/__init__.py +0 -0
- wafer_lsp/analyzers/compiler_integration.py +16 -0
- wafer_lsp/analyzers/docs_index.py +36 -0
- wafer_lsp/handlers/__init__.py +30 -0
- wafer_lsp/handlers/code_action.py +48 -0
- wafer_lsp/handlers/code_lens.py +48 -0
- wafer_lsp/handlers/completion.py +6 -0
- wafer_lsp/handlers/diagnostics.py +41 -0
- wafer_lsp/handlers/document_symbol.py +176 -0
- wafer_lsp/handlers/hip_diagnostics.py +303 -0
- wafer_lsp/handlers/hover.py +251 -0
- wafer_lsp/handlers/inlay_hint.py +245 -0
- wafer_lsp/handlers/semantic_tokens.py +224 -0
- wafer_lsp/handlers/workspace_symbol.py +87 -0
- wafer_lsp/languages/README.md +195 -0
- wafer_lsp/languages/__init__.py +17 -0
- wafer_lsp/languages/converter.py +88 -0
- wafer_lsp/languages/detector.py +107 -0
- wafer_lsp/languages/parser_manager.py +33 -0
- wafer_lsp/languages/registry.py +120 -0
- wafer_lsp/languages/types.py +37 -0
- wafer_lsp/parsers/__init__.py +36 -0
- wafer_lsp/parsers/base_parser.py +9 -0
- wafer_lsp/parsers/cuda_parser.py +95 -0
- wafer_lsp/parsers/cutedsl_parser.py +114 -0
- wafer_lsp/parsers/hip_parser.py +688 -0
- wafer_lsp/server.py +58 -0
- wafer_lsp/services/__init__.py +38 -0
- wafer_lsp/services/analysis_service.py +22 -0
- wafer_lsp/services/docs_service.py +40 -0
- wafer_lsp/services/document_service.py +20 -0
- wafer_lsp/services/hip_docs.py +806 -0
- wafer_lsp/services/hip_hover_service.py +412 -0
- wafer_lsp/services/hover_service.py +237 -0
- wafer_lsp/services/language_registry_service.py +26 -0
- wafer_lsp/services/position_service.py +77 -0
- wafer_lsp/utils/__init__.py +0 -0
- wafer_lsp/utils/lsp_helpers.py +79 -0
- wafer_lsp-0.1.13.dist-info/METADATA +60 -0
- wafer_lsp-0.1.13.dist-info/RECORD +44 -0
- wafer_lsp-0.1.13.dist-info/WHEEL +4 -0
- wafer_lsp-0.1.13.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HIP Diagnostics Provider.
|
|
3
|
+
|
|
4
|
+
Generates warnings for common wavefront-related mistakes in HIP code.
|
|
5
|
+
AMD GPUs use 64-thread wavefronts (CDNA) vs CUDA's 32-thread warps,
|
|
6
|
+
and code written for CUDA often has hard-coded assumptions that break on AMD.
|
|
7
|
+
|
|
8
|
+
Diagnostic Codes:
|
|
9
|
+
- HIP001: Incorrect wavefront size assumption (using 32 instead of 64)
|
|
10
|
+
- HIP002: Wavefront intrinsic misuse
|
|
11
|
+
- HIP003: Ballot result handling errors (32-bit vs 64-bit)
|
|
12
|
+
- HIP004: Lane/wavefront index calculation errors
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from lsprotocol.types import (
|
|
19
|
+
Diagnostic,
|
|
20
|
+
DiagnosticSeverity,
|
|
21
|
+
DiagnosticTag,
|
|
22
|
+
Position,
|
|
23
|
+
Range,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from ..parsers.hip_parser import HIPParser, WavefrontPattern
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class HIPDiagnosticInfo:
|
|
31
|
+
"""Information about a HIP diagnostic."""
|
|
32
|
+
code: str
|
|
33
|
+
message: str
|
|
34
|
+
severity: DiagnosticSeverity
|
|
35
|
+
suggestion: str
|
|
36
|
+
doc_link: str | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Diagnostic code to info mapping
|
|
40
|
+
DIAGNOSTIC_INFO: dict[str, HIPDiagnosticInfo] = {
|
|
41
|
+
"HIP001": HIPDiagnosticInfo(
|
|
42
|
+
code="HIP001",
|
|
43
|
+
message="Potential wavefront size mismatch: AMD GPUs use 64-thread wavefronts, not 32.",
|
|
44
|
+
severity=DiagnosticSeverity.Warning,
|
|
45
|
+
suggestion="Use `warpSize` variable or `__AMDGCN_WAVEFRONT_SIZE__` macro instead of hard-coding 32.",
|
|
46
|
+
doc_link="https://rocm.docs.amd.com/en/latest/",
|
|
47
|
+
),
|
|
48
|
+
"HIP002": HIPDiagnosticInfo(
|
|
49
|
+
code="HIP002",
|
|
50
|
+
message="Wavefront intrinsic may behave differently on AMD GPUs.",
|
|
51
|
+
severity=DiagnosticSeverity.Warning,
|
|
52
|
+
suggestion="AMD wavefronts have 64 threads. Shuffle/ballot operations operate on the full wavefront. Adjust your code accordingly.",
|
|
53
|
+
),
|
|
54
|
+
"HIP003": HIPDiagnosticInfo(
|
|
55
|
+
code="HIP003",
|
|
56
|
+
message="Ballot result comparison uses 32-bit mask, but AMD's __ballot returns 64-bit.",
|
|
57
|
+
severity=DiagnosticSeverity.Warning,
|
|
58
|
+
suggestion="Use `0xFFFFFFFFFFFFFFFFull` for 64-bit comparison. Use `__popcll()` to count bits.",
|
|
59
|
+
),
|
|
60
|
+
"HIP004": HIPDiagnosticInfo(
|
|
61
|
+
code="HIP004",
|
|
62
|
+
message="Lane/wavefront index calculation assumes 32-thread warps.",
|
|
63
|
+
severity=DiagnosticSeverity.Warning,
|
|
64
|
+
suggestion="Use `warpSize` (64 on AMD CDNA) instead of hard-coding 32 for lane/warp calculations.",
|
|
65
|
+
),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Map pattern types to diagnostic codes
|
|
70
|
+
PATTERN_TO_DIAGNOSTIC: dict[str, str] = {
|
|
71
|
+
"warp_size_32": "HIP001",
|
|
72
|
+
"shuffle_mask": "HIP002",
|
|
73
|
+
"ballot_32bit": "HIP003",
|
|
74
|
+
"lane_calc_32": "HIP004",
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class HIPDiagnosticsProvider:
|
|
79
|
+
"""Provides diagnostics for HIP code.
|
|
80
|
+
|
|
81
|
+
Detects patterns that indicate potential issues when running on AMD GPUs,
|
|
82
|
+
particularly around wavefront size differences (64 vs 32).
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self):
|
|
86
|
+
self._parser = HIPParser()
|
|
87
|
+
|
|
88
|
+
def get_diagnostics(
|
|
89
|
+
self,
|
|
90
|
+
content: str,
|
|
91
|
+
uri: str,
|
|
92
|
+
enable_wavefront_diagnostics: bool = True,
|
|
93
|
+
) -> list[Diagnostic]:
|
|
94
|
+
"""Generate diagnostics for HIP code.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
content: The document content
|
|
98
|
+
uri: The document URI
|
|
99
|
+
enable_wavefront_diagnostics: Whether to check for wavefront issues
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of LSP Diagnostic objects
|
|
103
|
+
"""
|
|
104
|
+
diagnostics: list[Diagnostic] = []
|
|
105
|
+
|
|
106
|
+
if not enable_wavefront_diagnostics:
|
|
107
|
+
return diagnostics
|
|
108
|
+
|
|
109
|
+
# Parse the file to find wavefront patterns
|
|
110
|
+
parsed = self._parser.parse_file(content)
|
|
111
|
+
wavefront_patterns = parsed.get("wavefront_patterns", [])
|
|
112
|
+
|
|
113
|
+
lines = content.split("\n")
|
|
114
|
+
|
|
115
|
+
for pattern in wavefront_patterns:
|
|
116
|
+
diagnostic = self._create_diagnostic_from_pattern(pattern, lines)
|
|
117
|
+
if diagnostic:
|
|
118
|
+
diagnostics.append(diagnostic)
|
|
119
|
+
|
|
120
|
+
# Additional checks that require more context
|
|
121
|
+
diagnostics.extend(self._check_warp_reduction_patterns(content, lines))
|
|
122
|
+
diagnostics.extend(self._check_shuffle_width_parameters(content, lines))
|
|
123
|
+
|
|
124
|
+
return diagnostics
|
|
125
|
+
|
|
126
|
+
def _create_diagnostic_from_pattern(
|
|
127
|
+
self,
|
|
128
|
+
pattern: WavefrontPattern,
|
|
129
|
+
lines: list[str],
|
|
130
|
+
) -> Diagnostic | None:
|
|
131
|
+
"""Create a diagnostic from a detected wavefront pattern."""
|
|
132
|
+
diag_code = PATTERN_TO_DIAGNOSTIC.get(pattern.pattern_type)
|
|
133
|
+
if not diag_code:
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
diag_info = DIAGNOSTIC_INFO.get(diag_code)
|
|
137
|
+
if not diag_info:
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
# Find the problematic value in the line to highlight it
|
|
141
|
+
line = lines[pattern.line] if pattern.line < len(lines) else ""
|
|
142
|
+
start_char = line.find(pattern.problematic_value)
|
|
143
|
+
if start_char == -1:
|
|
144
|
+
start_char = 0
|
|
145
|
+
end_char = start_char + len(pattern.problematic_value)
|
|
146
|
+
|
|
147
|
+
range_ = Range(
|
|
148
|
+
start=Position(line=pattern.line, character=start_char),
|
|
149
|
+
end=Position(line=pattern.line, character=end_char),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
message = self._format_diagnostic_message(
|
|
153
|
+
diag_info,
|
|
154
|
+
pattern.problematic_value,
|
|
155
|
+
pattern.code_snippet,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return Diagnostic(
|
|
159
|
+
range=range_,
|
|
160
|
+
severity=diag_info.severity,
|
|
161
|
+
code=diag_info.code,
|
|
162
|
+
source="wafer-hip",
|
|
163
|
+
message=message,
|
|
164
|
+
data={"suggestion": diag_info.suggestion},
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _format_diagnostic_message(
|
|
168
|
+
self,
|
|
169
|
+
diag_info: HIPDiagnosticInfo,
|
|
170
|
+
problematic_value: str,
|
|
171
|
+
code_snippet: str,
|
|
172
|
+
) -> str:
|
|
173
|
+
"""Format the diagnostic message with context."""
|
|
174
|
+
msg = diag_info.message
|
|
175
|
+
|
|
176
|
+
# Add specific guidance based on the problematic value
|
|
177
|
+
if problematic_value == "32":
|
|
178
|
+
msg += f"\n\nFound: `{problematic_value}` (assumes CUDA warp size)"
|
|
179
|
+
msg += "\nExpected: `warpSize` or `64` for AMD CDNA GPUs"
|
|
180
|
+
elif problematic_value == "31":
|
|
181
|
+
msg += f"\n\nFound: `& {problematic_value}` (assumes 32-thread warp)"
|
|
182
|
+
msg += "\nExpected: `& (warpSize - 1)` or `& 63` for AMD CDNA GPUs"
|
|
183
|
+
elif problematic_value == "0xFFFFFFFF":
|
|
184
|
+
msg += f"\n\nFound: `{problematic_value}` (32-bit mask)"
|
|
185
|
+
msg += "\nExpected: `0xFFFFFFFFFFFFFFFFull` (64-bit mask for AMD)"
|
|
186
|
+
|
|
187
|
+
msg += f"\n\n💡 {diag_info.suggestion}"
|
|
188
|
+
|
|
189
|
+
return msg
|
|
190
|
+
|
|
191
|
+
def _check_warp_reduction_patterns(
|
|
192
|
+
self,
|
|
193
|
+
content: str,
|
|
194
|
+
lines: list[str],
|
|
195
|
+
) -> list[Diagnostic]:
|
|
196
|
+
"""Check for warp reduction patterns that start from 16 instead of 32."""
|
|
197
|
+
diagnostics: list[Diagnostic] = []
|
|
198
|
+
|
|
199
|
+
# Look for reduction loops that start from 16 (half of CUDA warp)
|
|
200
|
+
# This is a common CUDA pattern: for (int i = 16; i > 0; i >>= 1)
|
|
201
|
+
import re
|
|
202
|
+
|
|
203
|
+
pattern = re.compile(
|
|
204
|
+
r'for\s*\(\s*(?:int|unsigned)\s+\w+\s*=\s*16\s*;[^;]*;\s*\w+\s*>>=\s*1\s*\)'
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
for i, line in enumerate(lines):
|
|
208
|
+
match = pattern.search(line)
|
|
209
|
+
if match:
|
|
210
|
+
diag_info = DIAGNOSTIC_INFO["HIP001"]
|
|
211
|
+
|
|
212
|
+
range_ = Range(
|
|
213
|
+
start=Position(line=i, character=match.start()),
|
|
214
|
+
end=Position(line=i, character=match.end()),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
diagnostics.append(Diagnostic(
|
|
218
|
+
range=range_,
|
|
219
|
+
severity=DiagnosticSeverity.Warning,
|
|
220
|
+
code="HIP001",
|
|
221
|
+
source="wafer-hip",
|
|
222
|
+
message=(
|
|
223
|
+
"Warp reduction loop starts from 16 (half CUDA warp).\n\n"
|
|
224
|
+
"For AMD's 64-thread wavefronts, start from 32:\n"
|
|
225
|
+
"`for (int offset = 32; offset > 0; offset >>= 1)`\n\n"
|
|
226
|
+
"Or better, use `warpSize / 2` for portability."
|
|
227
|
+
),
|
|
228
|
+
))
|
|
229
|
+
|
|
230
|
+
return diagnostics
|
|
231
|
+
|
|
232
|
+
def _check_shuffle_width_parameters(
|
|
233
|
+
self,
|
|
234
|
+
content: str,
|
|
235
|
+
lines: list[str],
|
|
236
|
+
) -> list[Diagnostic]:
|
|
237
|
+
"""Check for shuffle operations with explicit width=32."""
|
|
238
|
+
diagnostics: list[Diagnostic] = []
|
|
239
|
+
|
|
240
|
+
import re
|
|
241
|
+
|
|
242
|
+
# Pattern for __shfl variants with explicit width parameter of 32
|
|
243
|
+
# __shfl_down(val, offset, 32)
|
|
244
|
+
patterns = [
|
|
245
|
+
(re.compile(r'__shfl_down\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_down"),
|
|
246
|
+
(re.compile(r'__shfl_up\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_up"),
|
|
247
|
+
(re.compile(r'__shfl_xor\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_xor"),
|
|
248
|
+
(re.compile(r'__shfl\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl"),
|
|
249
|
+
]
|
|
250
|
+
|
|
251
|
+
for i, line in enumerate(lines):
|
|
252
|
+
for pattern, intrinsic_name in patterns:
|
|
253
|
+
match = pattern.search(line)
|
|
254
|
+
if match:
|
|
255
|
+
# Find position of "32" in the match
|
|
256
|
+
match_text = match.group(0)
|
|
257
|
+
pos_32 = match_text.rfind("32")
|
|
258
|
+
start_char = match.start() + pos_32
|
|
259
|
+
|
|
260
|
+
range_ = Range(
|
|
261
|
+
start=Position(line=i, character=start_char),
|
|
262
|
+
end=Position(line=i, character=start_char + 2),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
diagnostics.append(Diagnostic(
|
|
266
|
+
range=range_,
|
|
267
|
+
severity=DiagnosticSeverity.Warning,
|
|
268
|
+
code="HIP002",
|
|
269
|
+
source="wafer-hip",
|
|
270
|
+
message=(
|
|
271
|
+
f"`{intrinsic_name}` with explicit width=32 assumes CUDA warp size.\n\n"
|
|
272
|
+
f"On AMD CDNA GPUs, wavefront size is 64. Either:\n"
|
|
273
|
+
f"• Remove the width parameter (defaults to `warpSize`)\n"
|
|
274
|
+
f"• Use `warpSize` as the width parameter\n\n"
|
|
275
|
+
f"Example: `{intrinsic_name}(val, offset)` or `{intrinsic_name}(val, offset, warpSize)`"
|
|
276
|
+
),
|
|
277
|
+
))
|
|
278
|
+
|
|
279
|
+
return diagnostics
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def create_hip_diagnostics_provider() -> HIPDiagnosticsProvider:
|
|
283
|
+
"""Create a HIP diagnostics provider instance."""
|
|
284
|
+
return HIPDiagnosticsProvider()
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_hip_diagnostics(
|
|
288
|
+
content: str,
|
|
289
|
+
uri: str,
|
|
290
|
+
enable_wavefront_diagnostics: bool = True,
|
|
291
|
+
) -> list[Diagnostic]:
|
|
292
|
+
"""Convenience function to get HIP diagnostics for a file.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
content: The document content
|
|
296
|
+
uri: The document URI
|
|
297
|
+
enable_wavefront_diagnostics: Whether to check for wavefront issues
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
List of LSP Diagnostic objects
|
|
301
|
+
"""
|
|
302
|
+
provider = create_hip_diagnostics_provider()
|
|
303
|
+
return provider.get_diagnostics(content, uri, enable_wavefront_diagnostics)
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import Hover, MarkupContent, MarkupKind, Position
|
|
3
|
+
|
|
4
|
+
from ..analyzers.compiler_integration import get_analysis_for_kernel
|
|
5
|
+
from ..analyzers.docs_index import DocsIndex
|
|
6
|
+
from ..languages.registry import get_language_registry
|
|
7
|
+
from ..languages.types import KernelInfo, LayoutInfo
|
|
8
|
+
from ..services.hip_hover_service import create_hip_hover_service
|
|
9
|
+
from ..utils.lsp_helpers import get_decorator_at_position, get_word_at_position
|
|
10
|
+
|
|
11
|
+
# Lazy-initialized HIP hover service
|
|
12
|
+
_hip_hover_service = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_hip_hover_service():
|
|
16
|
+
"""Get or create the HIP hover service singleton."""
|
|
17
|
+
global _hip_hover_service
|
|
18
|
+
if _hip_hover_service is None:
|
|
19
|
+
_hip_hover_service = create_hip_hover_service()
|
|
20
|
+
return _hip_hover_service
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def find_kernel_at_position(
|
|
24
|
+
content: str, position: Position, uri: str
|
|
25
|
+
) -> KernelInfo | None:
|
|
26
|
+
registry = get_language_registry()
|
|
27
|
+
language_info = registry.parse_file(uri, content)
|
|
28
|
+
|
|
29
|
+
if not language_info:
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
word = get_word_at_position(content, position)
|
|
33
|
+
|
|
34
|
+
for kernel in language_info.kernels:
|
|
35
|
+
if kernel.name == word:
|
|
36
|
+
if position.line >= kernel.line:
|
|
37
|
+
return kernel
|
|
38
|
+
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def find_layout_at_position(
|
|
43
|
+
content: str, position: Position, uri: str
|
|
44
|
+
) -> LayoutInfo | None:
|
|
45
|
+
registry = get_language_registry()
|
|
46
|
+
language_info = registry.parse_file(uri, content)
|
|
47
|
+
|
|
48
|
+
if not language_info:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
word = get_word_at_position(content, position)
|
|
52
|
+
|
|
53
|
+
for layout in language_info.layouts:
|
|
54
|
+
if layout.name == word:
|
|
55
|
+
return layout
|
|
56
|
+
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def format_kernel_hover(kernel: KernelInfo, analysis: dict | None) -> str:
|
|
61
|
+
registry = get_language_registry()
|
|
62
|
+
language_name = registry.get_language_name(kernel.language) or kernel.language
|
|
63
|
+
|
|
64
|
+
lines = [f"**GPU Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
|
|
65
|
+
|
|
66
|
+
if kernel.docstring:
|
|
67
|
+
lines.append(kernel.docstring)
|
|
68
|
+
lines.append("")
|
|
69
|
+
|
|
70
|
+
if kernel.parameters:
|
|
71
|
+
params_str = ", ".join(kernel.parameters)
|
|
72
|
+
lines.append(f"**Parameters:** `{params_str}`")
|
|
73
|
+
lines.append("")
|
|
74
|
+
|
|
75
|
+
if analysis:
|
|
76
|
+
lines.append("**Analysis:**")
|
|
77
|
+
if "layouts" in analysis:
|
|
78
|
+
lines.append(f"- Layouts: {analysis['layouts']}")
|
|
79
|
+
if "memory_paths" in analysis:
|
|
80
|
+
lines.append(f"- Memory paths: {analysis['memory_paths']}")
|
|
81
|
+
if "pipeline_stages" in analysis:
|
|
82
|
+
lines.append(f"- Pipeline stages: {analysis['pipeline_stages']}")
|
|
83
|
+
|
|
84
|
+
return "\n".join(lines)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def format_kernel_hover_basic(kernel: KernelInfo) -> str:
|
|
88
|
+
return format_kernel_hover(kernel, None)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def format_decorator_hover(decorator_name: str, function_name: str | None = None) -> str:
|
|
92
|
+
lines = []
|
|
93
|
+
|
|
94
|
+
if decorator_name == "cute.kernel" or decorator_name == "kernel":
|
|
95
|
+
lines.append("**@cute.kernel**")
|
|
96
|
+
lines.append("")
|
|
97
|
+
lines.append("CuTeDSL kernel decorator. Marks a function as a GPU kernel.")
|
|
98
|
+
lines.append("")
|
|
99
|
+
lines.append("**Usage:**")
|
|
100
|
+
lines.append("```python")
|
|
101
|
+
lines.append("@cute.kernel")
|
|
102
|
+
lines.append("def my_kernel(a: cute.Tensor, b: cute.Tensor):")
|
|
103
|
+
lines.append(" # Kernel implementation")
|
|
104
|
+
lines.append(" pass")
|
|
105
|
+
lines.append("```")
|
|
106
|
+
lines.append("")
|
|
107
|
+
lines.append("**Features:**")
|
|
108
|
+
lines.append("- Automatic GPU code generation")
|
|
109
|
+
lines.append("- Tensor layout optimization")
|
|
110
|
+
lines.append("- Memory access pattern analysis")
|
|
111
|
+
|
|
112
|
+
if function_name:
|
|
113
|
+
lines.append("")
|
|
114
|
+
lines.append(f"Applied to: `{function_name}()`")
|
|
115
|
+
|
|
116
|
+
elif decorator_name == "cute.struct" or decorator_name == "struct":
|
|
117
|
+
lines.append("**@cute.struct**")
|
|
118
|
+
lines.append("")
|
|
119
|
+
lines.append("CuTeDSL struct decorator. Marks a class as a GPU struct.")
|
|
120
|
+
lines.append("")
|
|
121
|
+
lines.append("**Usage:**")
|
|
122
|
+
lines.append("```python")
|
|
123
|
+
lines.append("@cute.struct")
|
|
124
|
+
lines.append("class MyStruct:")
|
|
125
|
+
lines.append(" field1: int")
|
|
126
|
+
lines.append(" field2: float")
|
|
127
|
+
lines.append("```")
|
|
128
|
+
|
|
129
|
+
if function_name:
|
|
130
|
+
lines.append("")
|
|
131
|
+
lines.append(f"Applied to: `{function_name}`")
|
|
132
|
+
|
|
133
|
+
else:
|
|
134
|
+
lines.append(f"**{decorator_name}**")
|
|
135
|
+
lines.append("")
|
|
136
|
+
lines.append("CuTeDSL decorator")
|
|
137
|
+
|
|
138
|
+
docs = DocsIndex()
|
|
139
|
+
doc_link = docs.get_doc_for_concept("kernel" if "kernel" in decorator_name else "struct")
|
|
140
|
+
if doc_link:
|
|
141
|
+
lines.append("")
|
|
142
|
+
lines.append(f"[Documentation]({doc_link})")
|
|
143
|
+
|
|
144
|
+
return "\n".join(lines)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
|
|
148
|
+
"""Handle hover requests for all supported languages.
|
|
149
|
+
|
|
150
|
+
For HIP/CUDA/C++ files, provides rich documentation for:
|
|
151
|
+
- HIP API functions (hipMalloc, hipMemcpy, etc.)
|
|
152
|
+
- Memory qualifiers (__device__, __shared__, etc.)
|
|
153
|
+
- Wavefront intrinsics (__shfl, __ballot, etc.)
|
|
154
|
+
- Thread indexing (threadIdx, blockIdx, etc.)
|
|
155
|
+
- Kernel and device functions
|
|
156
|
+
|
|
157
|
+
For CuTeDSL (Python) files:
|
|
158
|
+
- Decorators (@cute.kernel, @cute.struct)
|
|
159
|
+
- Layouts and kernel functions
|
|
160
|
+
"""
|
|
161
|
+
# Detect language to route to appropriate hover handler
|
|
162
|
+
registry = get_language_registry()
|
|
163
|
+
language_info = registry.parse_file(uri, content)
|
|
164
|
+
|
|
165
|
+
# For HIP/CUDA/C++ files, try HIP hover service first
|
|
166
|
+
if language_info and language_info.language in ("hip", "cuda", "cpp"):
|
|
167
|
+
hip_service = _get_hip_hover_service()
|
|
168
|
+
hip_hover = hip_service.get_hover(content, position, uri)
|
|
169
|
+
if hip_hover:
|
|
170
|
+
return hip_hover
|
|
171
|
+
|
|
172
|
+
# CuTeDSL decorator handling
|
|
173
|
+
decorator_info = get_decorator_at_position(content, position)
|
|
174
|
+
if decorator_info:
|
|
175
|
+
decorator_name, function_line = decorator_info
|
|
176
|
+
|
|
177
|
+
function_name = None
|
|
178
|
+
lines = content.split("\n")
|
|
179
|
+
if function_line < len(lines):
|
|
180
|
+
func_line = lines[function_line].strip()
|
|
181
|
+
if func_line.startswith("def "):
|
|
182
|
+
func_name_start = func_line.find("def ") + 4
|
|
183
|
+
func_name_end = func_line.find("(", func_name_start)
|
|
184
|
+
if func_name_end > func_name_start:
|
|
185
|
+
function_name = func_line[func_name_start:func_name_end].strip()
|
|
186
|
+
elif func_line.startswith("class "):
|
|
187
|
+
class_name_start = func_line.find("class ") + 6
|
|
188
|
+
class_name_end = func_line.find(":", class_name_start)
|
|
189
|
+
if class_name_end > class_name_start:
|
|
190
|
+
function_name = func_line[class_name_start:class_name_end].strip()
|
|
191
|
+
|
|
192
|
+
hover_content = format_decorator_hover(decorator_name, function_name)
|
|
193
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
194
|
+
|
|
195
|
+
# CuTeDSL cute module
|
|
196
|
+
word = get_word_at_position(content, position)
|
|
197
|
+
if word == "cute" or word.startswith("cute."):
|
|
198
|
+
hover_lines = [
|
|
199
|
+
"**cutlass.cute**",
|
|
200
|
+
"",
|
|
201
|
+
"CuTeDSL (CUDA Unified Tensor Expression) library for GPU programming.",
|
|
202
|
+
"",
|
|
203
|
+
"**Key Features:**",
|
|
204
|
+
"- `@cute.kernel` - Define GPU kernels",
|
|
205
|
+
"- `@cute.struct` - Define GPU structs",
|
|
206
|
+
"- `cute.make_layout()` - Create tensor layouts",
|
|
207
|
+
"- `cute.Tensor` - Tensor type annotations",
|
|
208
|
+
"",
|
|
209
|
+
"[Documentation](https://github.com/NVIDIA/cutlass)"
|
|
210
|
+
]
|
|
211
|
+
hover_content = "\n".join(hover_lines)
|
|
212
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
213
|
+
|
|
214
|
+
# Kernel hover (all languages)
|
|
215
|
+
kernel = find_kernel_at_position(content, position, uri)
|
|
216
|
+
if kernel:
|
|
217
|
+
analysis = get_analysis_for_kernel(uri, kernel.name)
|
|
218
|
+
|
|
219
|
+
if analysis:
|
|
220
|
+
hover_content = format_kernel_hover(kernel, analysis)
|
|
221
|
+
else:
|
|
222
|
+
hover_content = format_kernel_hover_basic(kernel)
|
|
223
|
+
|
|
224
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
225
|
+
|
|
226
|
+
# Layout hover (CuTeDSL)
|
|
227
|
+
layout = find_layout_at_position(content, position, uri)
|
|
228
|
+
if layout:
|
|
229
|
+
docs = DocsIndex()
|
|
230
|
+
doc_link = docs.get_doc_for_concept("layout")
|
|
231
|
+
|
|
232
|
+
hover_lines = [
|
|
233
|
+
f"**Layout: {layout.name}**",
|
|
234
|
+
""
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
if layout.shape:
|
|
238
|
+
hover_lines.append(f"Shape: `{layout.shape}`")
|
|
239
|
+
if layout.stride:
|
|
240
|
+
hover_lines.append(f"Stride: `{layout.stride}`")
|
|
241
|
+
|
|
242
|
+
if doc_link:
|
|
243
|
+
hover_lines.append("")
|
|
244
|
+
hover_lines.append(f"[Documentation]({doc_link})")
|
|
245
|
+
|
|
246
|
+
hover_content = "\n".join(hover_lines)
|
|
247
|
+
|
|
248
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
249
|
+
|
|
250
|
+
# No hover info found
|
|
251
|
+
return None
|