wafer-lsp 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/handlers/__init__.py +30 -0
- wafer_lsp/handlers/diagnostics.py +26 -1
- wafer_lsp/handlers/document_symbol.py +93 -4
- wafer_lsp/handlers/hip_diagnostics.py +303 -0
- wafer_lsp/handlers/hover.py +45 -9
- wafer_lsp/handlers/inlay_hint.py +180 -0
- wafer_lsp/handlers/semantic_tokens.py +146 -46
- wafer_lsp/languages/detector.py +82 -9
- wafer_lsp/languages/registry.py +22 -1
- wafer_lsp/parsers/__init__.py +18 -0
- wafer_lsp/parsers/hip_parser.py +688 -0
- wafer_lsp/services/__init__.py +17 -0
- wafer_lsp/services/hip_docs.py +806 -0
- wafer_lsp/services/hip_hover_service.py +412 -0
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.1.dist-info}/METADATA +4 -1
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.1.dist-info}/RECORD +18 -14
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.1.dist-info}/WHEEL +0 -0
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.1.dist-info}/entry_points.txt +0 -0
wafer_lsp/handlers/__init__.py
CHANGED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from .code_action import handle_code_action
|
|
2
|
+
from .code_lens import handle_code_lens
|
|
3
|
+
from .completion import handle_completion
|
|
4
|
+
from .diagnostics import handle_diagnostics
|
|
5
|
+
from .document_symbol import handle_document_symbol
|
|
6
|
+
from .hip_diagnostics import (
|
|
7
|
+
HIPDiagnosticsProvider,
|
|
8
|
+
create_hip_diagnostics_provider,
|
|
9
|
+
get_hip_diagnostics,
|
|
10
|
+
)
|
|
11
|
+
from .hover import handle_hover
|
|
12
|
+
from .inlay_hint import handle_inlay_hint
|
|
13
|
+
from .semantic_tokens import handle_semantic_tokens, SEMANTIC_TOKENS_LEGEND
|
|
14
|
+
from .workspace_symbol import handle_workspace_symbol
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"handle_code_action",
|
|
18
|
+
"handle_code_lens",
|
|
19
|
+
"handle_completion",
|
|
20
|
+
"handle_diagnostics",
|
|
21
|
+
"handle_document_symbol",
|
|
22
|
+
"handle_hover",
|
|
23
|
+
"handle_inlay_hint",
|
|
24
|
+
"handle_semantic_tokens",
|
|
25
|
+
"handle_workspace_symbol",
|
|
26
|
+
"SEMANTIC_TOKENS_LEGEND",
|
|
27
|
+
"HIPDiagnosticsProvider",
|
|
28
|
+
"create_hip_diagnostics_provider",
|
|
29
|
+
"get_hip_diagnostics",
|
|
30
|
+
]
|
|
@@ -2,9 +2,24 @@
|
|
|
2
2
|
from lsprotocol.types import Diagnostic
|
|
3
3
|
|
|
4
4
|
from ..languages.registry import get_language_registry
|
|
5
|
+
from .hip_diagnostics import get_hip_diagnostics
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
def handle_diagnostics(
|
|
8
|
+
def handle_diagnostics(
|
|
9
|
+
uri: str,
|
|
10
|
+
content: str,
|
|
11
|
+
enable_wavefront_diagnostics: bool = True,
|
|
12
|
+
) -> list[Diagnostic]:
|
|
13
|
+
"""Handle diagnostics for a document.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
uri: Document URI
|
|
17
|
+
content: Document content
|
|
18
|
+
enable_wavefront_diagnostics: Whether to enable HIP wavefront warnings
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
List of diagnostics
|
|
22
|
+
"""
|
|
8
23
|
diagnostics: list[Diagnostic] = []
|
|
9
24
|
|
|
10
25
|
registry = get_language_registry()
|
|
@@ -13,4 +28,14 @@ def handle_diagnostics(uri: str, content: str) -> list[Diagnostic]:
|
|
|
13
28
|
if not language_info:
|
|
14
29
|
return diagnostics
|
|
15
30
|
|
|
31
|
+
# Add HIP-specific diagnostics for HIP, CUDA, and C++ files
|
|
32
|
+
# (C++ files might contain HIP code if they have HIP markers)
|
|
33
|
+
if language_info.language in ("hip", "cuda", "cpp"):
|
|
34
|
+
hip_diagnostics = get_hip_diagnostics(
|
|
35
|
+
content,
|
|
36
|
+
uri,
|
|
37
|
+
enable_wavefront_diagnostics=enable_wavefront_diagnostics,
|
|
38
|
+
)
|
|
39
|
+
diagnostics.extend(hip_diagnostics)
|
|
40
|
+
|
|
16
41
|
return diagnostics
|
|
@@ -12,9 +12,10 @@ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
|
|
|
12
12
|
return []
|
|
13
13
|
|
|
14
14
|
symbols: list[DocumentSymbol] = []
|
|
15
|
+
lines = content.split("\n")
|
|
15
16
|
|
|
17
|
+
# Kernels
|
|
16
18
|
for kernel in language_info.kernels:
|
|
17
|
-
lines = content.split("\n")
|
|
18
19
|
kernel_line = lines[kernel.line] if kernel.line < len(lines) else ""
|
|
19
20
|
name_start = kernel_line.find(kernel.name)
|
|
20
21
|
name_end = name_start + len(kernel.name) if name_start >= 0 else 0
|
|
@@ -28,16 +29,24 @@ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
|
|
|
28
29
|
end=Position(line=min(kernel.line + 10, len(lines) - 1), character=0)
|
|
29
30
|
)
|
|
30
31
|
|
|
32
|
+
# Different detail based on language
|
|
33
|
+
if kernel.language == "hip":
|
|
34
|
+
detail = "🚀 HIP Kernel (AMD GPU)"
|
|
35
|
+
elif kernel.language in ("cuda", "cpp"):
|
|
36
|
+
detail = "🚀 CUDA Kernel"
|
|
37
|
+
else:
|
|
38
|
+
detail = f"GPU Kernel ({registry.get_language_name(kernel.language)})"
|
|
39
|
+
|
|
31
40
|
symbols.append(DocumentSymbol(
|
|
32
41
|
name=kernel.name,
|
|
33
42
|
kind=SymbolKind.Function,
|
|
34
43
|
range=full_range,
|
|
35
44
|
selection_range=selection_range,
|
|
36
|
-
detail=
|
|
45
|
+
detail=detail,
|
|
37
46
|
))
|
|
38
47
|
|
|
48
|
+
# Layouts (CuTeDSL)
|
|
39
49
|
for layout in language_info.layouts:
|
|
40
|
-
lines = content.split("\n")
|
|
41
50
|
layout_line = lines[layout.line] if layout.line < len(lines) else ""
|
|
42
51
|
name_start = layout_line.find(layout.name)
|
|
43
52
|
name_end = name_start + len(layout.name) if name_start >= 0 else 0
|
|
@@ -61,8 +70,8 @@ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
|
|
|
61
70
|
detail=detail,
|
|
62
71
|
))
|
|
63
72
|
|
|
73
|
+
# Structs
|
|
64
74
|
for struct in language_info.structs:
|
|
65
|
-
lines = content.split("\n")
|
|
66
75
|
struct_line = lines[struct.line] if struct.line < len(lines) else ""
|
|
67
76
|
name_start = struct_line.find(struct.name)
|
|
68
77
|
name_end = name_start + len(struct.name) if name_start >= 0 else 0
|
|
@@ -84,4 +93,84 @@ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
|
|
|
84
93
|
detail=f"Struct ({registry.get_language_name(struct.language)})",
|
|
85
94
|
))
|
|
86
95
|
|
|
96
|
+
# HIP-specific: Device functions and shared memory
|
|
97
|
+
if language_info.language in ("hip", "cuda", "cpp"):
|
|
98
|
+
symbols.extend(_get_hip_symbols(language_info.raw_data, lines))
|
|
99
|
+
|
|
100
|
+
return symbols
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _get_hip_symbols(raw_data: dict, lines: list[str]) -> list[DocumentSymbol]:
|
|
104
|
+
"""Extract HIP-specific symbols: device functions, shared memory allocations."""
|
|
105
|
+
symbols: list[DocumentSymbol] = []
|
|
106
|
+
|
|
107
|
+
# Device functions (from HIP parser)
|
|
108
|
+
device_functions = raw_data.get("device_functions", [])
|
|
109
|
+
for func in device_functions:
|
|
110
|
+
if not hasattr(func, "line") or not hasattr(func, "name"):
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
func_line = lines[func.line] if func.line < len(lines) else ""
|
|
114
|
+
name_start = func_line.find(func.name)
|
|
115
|
+
name_end = name_start + len(func.name) if name_start >= 0 else 0
|
|
116
|
+
|
|
117
|
+
end_line = getattr(func, "end_line", func.line + 10)
|
|
118
|
+
|
|
119
|
+
selection_range = Range(
|
|
120
|
+
start=Position(line=func.line, character=max(0, name_start)),
|
|
121
|
+
end=Position(line=func.line, character=name_end)
|
|
122
|
+
)
|
|
123
|
+
full_range = Range(
|
|
124
|
+
start=Position(line=func.line, character=0),
|
|
125
|
+
end=Position(line=min(end_line, len(lines) - 1), character=0)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return_type = getattr(func, "return_type", "void")
|
|
129
|
+
detail = f"⚡ Device Function -> {return_type}"
|
|
130
|
+
|
|
131
|
+
symbols.append(DocumentSymbol(
|
|
132
|
+
name=func.name,
|
|
133
|
+
kind=SymbolKind.Method,
|
|
134
|
+
range=full_range,
|
|
135
|
+
selection_range=selection_range,
|
|
136
|
+
detail=detail,
|
|
137
|
+
))
|
|
138
|
+
|
|
139
|
+
# Shared memory allocations
|
|
140
|
+
shared_memory = raw_data.get("shared_memory", [])
|
|
141
|
+
for shared in shared_memory:
|
|
142
|
+
if not hasattr(shared, "line") or not hasattr(shared, "name"):
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
shared_line = lines[shared.line] if shared.line < len(lines) else ""
|
|
146
|
+
name_start = shared_line.find(shared.name)
|
|
147
|
+
name_end = name_start + len(shared.name) if name_start >= 0 else 0
|
|
148
|
+
|
|
149
|
+
selection_range = Range(
|
|
150
|
+
start=Position(line=shared.line, character=max(0, name_start)),
|
|
151
|
+
end=Position(line=shared.line, character=name_end)
|
|
152
|
+
)
|
|
153
|
+
full_range = Range(
|
|
154
|
+
start=Position(line=shared.line, character=0),
|
|
155
|
+
end=Position(line=shared.line, character=len(shared_line))
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
type_str = getattr(shared, "type_str", "")
|
|
159
|
+
size_bytes = getattr(shared, "size_bytes", None)
|
|
160
|
+
if size_bytes:
|
|
161
|
+
if size_bytes >= 1024:
|
|
162
|
+
detail = f"📦 __shared__ {type_str} ({size_bytes / 1024:.1f} KB)"
|
|
163
|
+
else:
|
|
164
|
+
detail = f"📦 __shared__ {type_str} ({size_bytes} bytes)"
|
|
165
|
+
else:
|
|
166
|
+
detail = f"📦 __shared__ {type_str}"
|
|
167
|
+
|
|
168
|
+
symbols.append(DocumentSymbol(
|
|
169
|
+
name=shared.name,
|
|
170
|
+
kind=SymbolKind.Variable,
|
|
171
|
+
range=full_range,
|
|
172
|
+
selection_range=selection_range,
|
|
173
|
+
detail=detail,
|
|
174
|
+
))
|
|
175
|
+
|
|
87
176
|
return symbols
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HIP Diagnostics Provider.
|
|
3
|
+
|
|
4
|
+
Generates warnings for common wavefront-related mistakes in HIP code.
|
|
5
|
+
AMD GPUs use 64-thread wavefronts (CDNA) vs CUDA's 32-thread warps,
|
|
6
|
+
and code written for CUDA often has hard-coded assumptions that break on AMD.
|
|
7
|
+
|
|
8
|
+
Diagnostic Codes:
|
|
9
|
+
- HIP001: Incorrect wavefront size assumption (using 32 instead of 64)
|
|
10
|
+
- HIP002: Wavefront intrinsic misuse
|
|
11
|
+
- HIP003: Ballot result handling errors (32-bit vs 64-bit)
|
|
12
|
+
- HIP004: Lane/wavefront index calculation errors
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from lsprotocol.types import (
|
|
19
|
+
Diagnostic,
|
|
20
|
+
DiagnosticSeverity,
|
|
21
|
+
DiagnosticTag,
|
|
22
|
+
Position,
|
|
23
|
+
Range,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from ..parsers.hip_parser import HIPParser, WavefrontPattern
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class HIPDiagnosticInfo:
|
|
31
|
+
"""Information about a HIP diagnostic."""
|
|
32
|
+
code: str
|
|
33
|
+
message: str
|
|
34
|
+
severity: DiagnosticSeverity
|
|
35
|
+
suggestion: str
|
|
36
|
+
doc_link: str | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Diagnostic code to info mapping
|
|
40
|
+
DIAGNOSTIC_INFO: dict[str, HIPDiagnosticInfo] = {
|
|
41
|
+
"HIP001": HIPDiagnosticInfo(
|
|
42
|
+
code="HIP001",
|
|
43
|
+
message="Potential wavefront size mismatch: AMD GPUs use 64-thread wavefronts, not 32.",
|
|
44
|
+
severity=DiagnosticSeverity.Warning,
|
|
45
|
+
suggestion="Use `warpSize` variable or `__AMDGCN_WAVEFRONT_SIZE__` macro instead of hard-coding 32.",
|
|
46
|
+
doc_link="https://rocm.docs.amd.com/en/latest/",
|
|
47
|
+
),
|
|
48
|
+
"HIP002": HIPDiagnosticInfo(
|
|
49
|
+
code="HIP002",
|
|
50
|
+
message="Wavefront intrinsic may behave differently on AMD GPUs.",
|
|
51
|
+
severity=DiagnosticSeverity.Warning,
|
|
52
|
+
suggestion="AMD wavefronts have 64 threads. Shuffle/ballot operations operate on the full wavefront. Adjust your code accordingly.",
|
|
53
|
+
),
|
|
54
|
+
"HIP003": HIPDiagnosticInfo(
|
|
55
|
+
code="HIP003",
|
|
56
|
+
message="Ballot result comparison uses 32-bit mask, but AMD's __ballot returns 64-bit.",
|
|
57
|
+
severity=DiagnosticSeverity.Warning,
|
|
58
|
+
suggestion="Use `0xFFFFFFFFFFFFFFFFull` for 64-bit comparison. Use `__popcll()` to count bits.",
|
|
59
|
+
),
|
|
60
|
+
"HIP004": HIPDiagnosticInfo(
|
|
61
|
+
code="HIP004",
|
|
62
|
+
message="Lane/wavefront index calculation assumes 32-thread warps.",
|
|
63
|
+
severity=DiagnosticSeverity.Warning,
|
|
64
|
+
suggestion="Use `warpSize` (64 on AMD CDNA) instead of hard-coding 32 for lane/warp calculations.",
|
|
65
|
+
),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Map pattern types to diagnostic codes
|
|
70
|
+
PATTERN_TO_DIAGNOSTIC: dict[str, str] = {
|
|
71
|
+
"warp_size_32": "HIP001",
|
|
72
|
+
"shuffle_mask": "HIP002",
|
|
73
|
+
"ballot_32bit": "HIP003",
|
|
74
|
+
"lane_calc_32": "HIP004",
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class HIPDiagnosticsProvider:
|
|
79
|
+
"""Provides diagnostics for HIP code.
|
|
80
|
+
|
|
81
|
+
Detects patterns that indicate potential issues when running on AMD GPUs,
|
|
82
|
+
particularly around wavefront size differences (64 vs 32).
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self):
|
|
86
|
+
self._parser = HIPParser()
|
|
87
|
+
|
|
88
|
+
def get_diagnostics(
|
|
89
|
+
self,
|
|
90
|
+
content: str,
|
|
91
|
+
uri: str,
|
|
92
|
+
enable_wavefront_diagnostics: bool = True,
|
|
93
|
+
) -> list[Diagnostic]:
|
|
94
|
+
"""Generate diagnostics for HIP code.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
content: The document content
|
|
98
|
+
uri: The document URI
|
|
99
|
+
enable_wavefront_diagnostics: Whether to check for wavefront issues
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of LSP Diagnostic objects
|
|
103
|
+
"""
|
|
104
|
+
diagnostics: list[Diagnostic] = []
|
|
105
|
+
|
|
106
|
+
if not enable_wavefront_diagnostics:
|
|
107
|
+
return diagnostics
|
|
108
|
+
|
|
109
|
+
# Parse the file to find wavefront patterns
|
|
110
|
+
parsed = self._parser.parse_file(content)
|
|
111
|
+
wavefront_patterns = parsed.get("wavefront_patterns", [])
|
|
112
|
+
|
|
113
|
+
lines = content.split("\n")
|
|
114
|
+
|
|
115
|
+
for pattern in wavefront_patterns:
|
|
116
|
+
diagnostic = self._create_diagnostic_from_pattern(pattern, lines)
|
|
117
|
+
if diagnostic:
|
|
118
|
+
diagnostics.append(diagnostic)
|
|
119
|
+
|
|
120
|
+
# Additional checks that require more context
|
|
121
|
+
diagnostics.extend(self._check_warp_reduction_patterns(content, lines))
|
|
122
|
+
diagnostics.extend(self._check_shuffle_width_parameters(content, lines))
|
|
123
|
+
|
|
124
|
+
return diagnostics
|
|
125
|
+
|
|
126
|
+
def _create_diagnostic_from_pattern(
|
|
127
|
+
self,
|
|
128
|
+
pattern: WavefrontPattern,
|
|
129
|
+
lines: list[str],
|
|
130
|
+
) -> Diagnostic | None:
|
|
131
|
+
"""Create a diagnostic from a detected wavefront pattern."""
|
|
132
|
+
diag_code = PATTERN_TO_DIAGNOSTIC.get(pattern.pattern_type)
|
|
133
|
+
if not diag_code:
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
diag_info = DIAGNOSTIC_INFO.get(diag_code)
|
|
137
|
+
if not diag_info:
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
# Find the problematic value in the line to highlight it
|
|
141
|
+
line = lines[pattern.line] if pattern.line < len(lines) else ""
|
|
142
|
+
start_char = line.find(pattern.problematic_value)
|
|
143
|
+
if start_char == -1:
|
|
144
|
+
start_char = 0
|
|
145
|
+
end_char = start_char + len(pattern.problematic_value)
|
|
146
|
+
|
|
147
|
+
range_ = Range(
|
|
148
|
+
start=Position(line=pattern.line, character=start_char),
|
|
149
|
+
end=Position(line=pattern.line, character=end_char),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
message = self._format_diagnostic_message(
|
|
153
|
+
diag_info,
|
|
154
|
+
pattern.problematic_value,
|
|
155
|
+
pattern.code_snippet,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return Diagnostic(
|
|
159
|
+
range=range_,
|
|
160
|
+
severity=diag_info.severity,
|
|
161
|
+
code=diag_info.code,
|
|
162
|
+
source="wafer-hip",
|
|
163
|
+
message=message,
|
|
164
|
+
data={"suggestion": diag_info.suggestion},
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _format_diagnostic_message(
|
|
168
|
+
self,
|
|
169
|
+
diag_info: HIPDiagnosticInfo,
|
|
170
|
+
problematic_value: str,
|
|
171
|
+
code_snippet: str,
|
|
172
|
+
) -> str:
|
|
173
|
+
"""Format the diagnostic message with context."""
|
|
174
|
+
msg = diag_info.message
|
|
175
|
+
|
|
176
|
+
# Add specific guidance based on the problematic value
|
|
177
|
+
if problematic_value == "32":
|
|
178
|
+
msg += f"\n\nFound: `{problematic_value}` (assumes CUDA warp size)"
|
|
179
|
+
msg += "\nExpected: `warpSize` or `64` for AMD CDNA GPUs"
|
|
180
|
+
elif problematic_value == "31":
|
|
181
|
+
msg += f"\n\nFound: `& {problematic_value}` (assumes 32-thread warp)"
|
|
182
|
+
msg += "\nExpected: `& (warpSize - 1)` or `& 63` for AMD CDNA GPUs"
|
|
183
|
+
elif problematic_value == "0xFFFFFFFF":
|
|
184
|
+
msg += f"\n\nFound: `{problematic_value}` (32-bit mask)"
|
|
185
|
+
msg += "\nExpected: `0xFFFFFFFFFFFFFFFFull` (64-bit mask for AMD)"
|
|
186
|
+
|
|
187
|
+
msg += f"\n\n💡 {diag_info.suggestion}"
|
|
188
|
+
|
|
189
|
+
return msg
|
|
190
|
+
|
|
191
|
+
def _check_warp_reduction_patterns(
|
|
192
|
+
self,
|
|
193
|
+
content: str,
|
|
194
|
+
lines: list[str],
|
|
195
|
+
) -> list[Diagnostic]:
|
|
196
|
+
"""Check for warp reduction patterns that start from 16 instead of 32."""
|
|
197
|
+
diagnostics: list[Diagnostic] = []
|
|
198
|
+
|
|
199
|
+
# Look for reduction loops that start from 16 (half of CUDA warp)
|
|
200
|
+
# This is a common CUDA pattern: for (int i = 16; i > 0; i >>= 1)
|
|
201
|
+
import re
|
|
202
|
+
|
|
203
|
+
pattern = re.compile(
|
|
204
|
+
r'for\s*\(\s*(?:int|unsigned)\s+\w+\s*=\s*16\s*;[^;]*;\s*\w+\s*>>=\s*1\s*\)'
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
for i, line in enumerate(lines):
|
|
208
|
+
match = pattern.search(line)
|
|
209
|
+
if match:
|
|
210
|
+
diag_info = DIAGNOSTIC_INFO["HIP001"]
|
|
211
|
+
|
|
212
|
+
range_ = Range(
|
|
213
|
+
start=Position(line=i, character=match.start()),
|
|
214
|
+
end=Position(line=i, character=match.end()),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
diagnostics.append(Diagnostic(
|
|
218
|
+
range=range_,
|
|
219
|
+
severity=DiagnosticSeverity.Warning,
|
|
220
|
+
code="HIP001",
|
|
221
|
+
source="wafer-hip",
|
|
222
|
+
message=(
|
|
223
|
+
"Warp reduction loop starts from 16 (half CUDA warp).\n\n"
|
|
224
|
+
"For AMD's 64-thread wavefronts, start from 32:\n"
|
|
225
|
+
"`for (int offset = 32; offset > 0; offset >>= 1)`\n\n"
|
|
226
|
+
"Or better, use `warpSize / 2` for portability."
|
|
227
|
+
),
|
|
228
|
+
))
|
|
229
|
+
|
|
230
|
+
return diagnostics
|
|
231
|
+
|
|
232
|
+
def _check_shuffle_width_parameters(
|
|
233
|
+
self,
|
|
234
|
+
content: str,
|
|
235
|
+
lines: list[str],
|
|
236
|
+
) -> list[Diagnostic]:
|
|
237
|
+
"""Check for shuffle operations with explicit width=32."""
|
|
238
|
+
diagnostics: list[Diagnostic] = []
|
|
239
|
+
|
|
240
|
+
import re
|
|
241
|
+
|
|
242
|
+
# Pattern for __shfl variants with explicit width parameter of 32
|
|
243
|
+
# __shfl_down(val, offset, 32)
|
|
244
|
+
patterns = [
|
|
245
|
+
(re.compile(r'__shfl_down\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_down"),
|
|
246
|
+
(re.compile(r'__shfl_up\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_up"),
|
|
247
|
+
(re.compile(r'__shfl_xor\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_xor"),
|
|
248
|
+
(re.compile(r'__shfl\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl"),
|
|
249
|
+
]
|
|
250
|
+
|
|
251
|
+
for i, line in enumerate(lines):
|
|
252
|
+
for pattern, intrinsic_name in patterns:
|
|
253
|
+
match = pattern.search(line)
|
|
254
|
+
if match:
|
|
255
|
+
# Find position of "32" in the match
|
|
256
|
+
match_text = match.group(0)
|
|
257
|
+
pos_32 = match_text.rfind("32")
|
|
258
|
+
start_char = match.start() + pos_32
|
|
259
|
+
|
|
260
|
+
range_ = Range(
|
|
261
|
+
start=Position(line=i, character=start_char),
|
|
262
|
+
end=Position(line=i, character=start_char + 2),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
diagnostics.append(Diagnostic(
|
|
266
|
+
range=range_,
|
|
267
|
+
severity=DiagnosticSeverity.Warning,
|
|
268
|
+
code="HIP002",
|
|
269
|
+
source="wafer-hip",
|
|
270
|
+
message=(
|
|
271
|
+
f"`{intrinsic_name}` with explicit width=32 assumes CUDA warp size.\n\n"
|
|
272
|
+
f"On AMD CDNA GPUs, wavefront size is 64. Either:\n"
|
|
273
|
+
f"• Remove the width parameter (defaults to `warpSize`)\n"
|
|
274
|
+
f"• Use `warpSize` as the width parameter\n\n"
|
|
275
|
+
f"Example: `{intrinsic_name}(val, offset)` or `{intrinsic_name}(val, offset, warpSize)`"
|
|
276
|
+
),
|
|
277
|
+
))
|
|
278
|
+
|
|
279
|
+
return diagnostics
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def create_hip_diagnostics_provider() -> HIPDiagnosticsProvider:
|
|
283
|
+
"""Create a HIP diagnostics provider instance."""
|
|
284
|
+
return HIPDiagnosticsProvider()
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_hip_diagnostics(
|
|
288
|
+
content: str,
|
|
289
|
+
uri: str,
|
|
290
|
+
enable_wavefront_diagnostics: bool = True,
|
|
291
|
+
) -> list[Diagnostic]:
|
|
292
|
+
"""Convenience function to get HIP diagnostics for a file.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
content: The document content
|
|
296
|
+
uri: The document URI
|
|
297
|
+
enable_wavefront_diagnostics: Whether to check for wavefront issues
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
List of LSP Diagnostic objects
|
|
301
|
+
"""
|
|
302
|
+
provider = create_hip_diagnostics_provider()
|
|
303
|
+
return provider.get_diagnostics(content, uri, enable_wavefront_diagnostics)
|
wafer_lsp/handlers/hover.py
CHANGED
|
@@ -5,8 +5,20 @@ from ..analyzers.compiler_integration import get_analysis_for_kernel
|
|
|
5
5
|
from ..analyzers.docs_index import DocsIndex
|
|
6
6
|
from ..languages.registry import get_language_registry
|
|
7
7
|
from ..languages.types import KernelInfo, LayoutInfo
|
|
8
|
+
from ..services.hip_hover_service import create_hip_hover_service
|
|
8
9
|
from ..utils.lsp_helpers import get_decorator_at_position, get_word_at_position
|
|
9
10
|
|
|
11
|
+
# Lazy-initialized HIP hover service
|
|
12
|
+
_hip_hover_service = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_hip_hover_service():
|
|
16
|
+
"""Get or create the HIP hover service singleton."""
|
|
17
|
+
global _hip_hover_service
|
|
18
|
+
if _hip_hover_service is None:
|
|
19
|
+
_hip_hover_service = create_hip_hover_service()
|
|
20
|
+
return _hip_hover_service
|
|
21
|
+
|
|
10
22
|
|
|
11
23
|
def find_kernel_at_position(
|
|
12
24
|
content: str, position: Position, uri: str
|
|
@@ -133,8 +145,31 @@ def format_decorator_hover(decorator_name: str, function_name: str | None = None
|
|
|
133
145
|
|
|
134
146
|
|
|
135
147
|
def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
|
|
136
|
-
|
|
137
|
-
|
|
148
|
+
"""Handle hover requests for all supported languages.
|
|
149
|
+
|
|
150
|
+
For HIP/CUDA/C++ files, provides rich documentation for:
|
|
151
|
+
- HIP API functions (hipMalloc, hipMemcpy, etc.)
|
|
152
|
+
- Memory qualifiers (__device__, __shared__, etc.)
|
|
153
|
+
- Wavefront intrinsics (__shfl, __ballot, etc.)
|
|
154
|
+
- Thread indexing (threadIdx, blockIdx, etc.)
|
|
155
|
+
- Kernel and device functions
|
|
156
|
+
|
|
157
|
+
For CuTeDSL (Python) files:
|
|
158
|
+
- Decorators (@cute.kernel, @cute.struct)
|
|
159
|
+
- Layouts and kernel functions
|
|
160
|
+
"""
|
|
161
|
+
# Detect language to route to appropriate hover handler
|
|
162
|
+
registry = get_language_registry()
|
|
163
|
+
language_info = registry.parse_file(uri, content)
|
|
164
|
+
|
|
165
|
+
# For HIP/CUDA/C++ files, try HIP hover service first
|
|
166
|
+
if language_info and language_info.language in ("hip", "cuda", "cpp"):
|
|
167
|
+
hip_service = _get_hip_hover_service()
|
|
168
|
+
hip_hover = hip_service.get_hover(content, position, uri)
|
|
169
|
+
if hip_hover:
|
|
170
|
+
return hip_hover
|
|
171
|
+
|
|
172
|
+
# CuTeDSL decorator handling
|
|
138
173
|
decorator_info = get_decorator_at_position(content, position)
|
|
139
174
|
if decorator_info:
|
|
140
175
|
decorator_name, function_line = decorator_info
|
|
@@ -154,13 +189,13 @@ def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
|
|
|
154
189
|
if class_name_end > class_name_start:
|
|
155
190
|
function_name = func_line[class_name_start:class_name_end].strip()
|
|
156
191
|
|
|
157
|
-
hover_content =
|
|
192
|
+
hover_content = format_decorator_hover(decorator_name, function_name)
|
|
158
193
|
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
159
194
|
|
|
195
|
+
# CuTeDSL cute module
|
|
160
196
|
word = get_word_at_position(content, position)
|
|
161
197
|
if word == "cute" or word.startswith("cute."):
|
|
162
198
|
hover_lines = [
|
|
163
|
-
test_message,
|
|
164
199
|
"**cutlass.cute**",
|
|
165
200
|
"",
|
|
166
201
|
"CuTeDSL (CUDA Unified Tensor Expression) library for GPU programming.",
|
|
@@ -176,24 +211,25 @@ def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
|
|
|
176
211
|
hover_content = "\n".join(hover_lines)
|
|
177
212
|
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
178
213
|
|
|
214
|
+
# Kernel hover (all languages)
|
|
179
215
|
kernel = find_kernel_at_position(content, position, uri)
|
|
180
216
|
if kernel:
|
|
181
217
|
analysis = get_analysis_for_kernel(uri, kernel.name)
|
|
182
218
|
|
|
183
219
|
if analysis:
|
|
184
|
-
hover_content =
|
|
220
|
+
hover_content = format_kernel_hover(kernel, analysis)
|
|
185
221
|
else:
|
|
186
|
-
hover_content =
|
|
222
|
+
hover_content = format_kernel_hover_basic(kernel)
|
|
187
223
|
|
|
188
224
|
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
189
225
|
|
|
226
|
+
# Layout hover (CuTeDSL)
|
|
190
227
|
layout = find_layout_at_position(content, position, uri)
|
|
191
228
|
if layout:
|
|
192
229
|
docs = DocsIndex()
|
|
193
230
|
doc_link = docs.get_doc_for_concept("layout")
|
|
194
231
|
|
|
195
232
|
hover_lines = [
|
|
196
|
-
test_message,
|
|
197
233
|
f"**Layout: {layout.name}**",
|
|
198
234
|
""
|
|
199
235
|
]
|
|
@@ -211,5 +247,5 @@ def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
|
|
|
211
247
|
|
|
212
248
|
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
213
249
|
|
|
214
|
-
|
|
215
|
-
return
|
|
250
|
+
# No hover info found
|
|
251
|
+
return None
|