wafer-lsp 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/handlers/__init__.py +30 -0
- wafer_lsp/handlers/diagnostics.py +26 -1
- wafer_lsp/handlers/document_symbol.py +93 -4
- wafer_lsp/handlers/hip_diagnostics.py +303 -0
- wafer_lsp/handlers/hover.py +45 -9
- wafer_lsp/handlers/inlay_hint.py +180 -0
- wafer_lsp/handlers/semantic_tokens.py +146 -46
- wafer_lsp/languages/detector.py +82 -9
- wafer_lsp/languages/registry.py +22 -1
- wafer_lsp/parsers/__init__.py +18 -0
- wafer_lsp/parsers/hip_parser.py +688 -0
- wafer_lsp/services/__init__.py +17 -0
- wafer_lsp/services/hip_docs.py +806 -0
- wafer_lsp/services/hip_hover_service.py +412 -0
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.1.dist-info}/METADATA +4 -1
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.1.dist-info}/RECORD +18 -14
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.1.dist-info}/WHEEL +0 -0
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.1.dist-info}/entry_points.txt +0 -0
wafer_lsp/handlers/inlay_hint.py
CHANGED
|
@@ -1,9 +1,27 @@
|
|
|
1
1
|
|
|
2
|
+
import re
|
|
3
|
+
|
|
2
4
|
from lsprotocol.types import InlayHint, InlayHintKind, Position, Range
|
|
3
5
|
|
|
4
6
|
from ..languages.registry import get_language_registry
|
|
5
7
|
|
|
6
8
|
|
|
9
|
+
# HIP kernel launch pattern: kernel<<<grid, block, shared, stream>>>
|
|
10
|
+
HIP_LAUNCH_PATTERN = re.compile(
|
|
11
|
+
r'(\w+)\s*<<<\s*'
|
|
12
|
+
r'([^,>]+)\s*,\s*' # grid
|
|
13
|
+
r'([^,>]+)' # block
|
|
14
|
+
r'(?:\s*,\s*([^,>]+))?' # shared (optional)
|
|
15
|
+
r'(?:\s*,\s*([^>]+))?' # stream (optional)
|
|
16
|
+
r'\s*>>>'
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Shared memory declaration pattern
|
|
20
|
+
SHARED_MEM_PATTERN = re.compile(
|
|
21
|
+
r'__shared__\s+([\w\s:<>]+?)\s+(\w+)\s*\[([^\]]+)\]'
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
7
25
|
def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
|
|
8
26
|
registry = get_language_registry()
|
|
9
27
|
language_info = registry.parse_file(uri, content)
|
|
@@ -14,6 +32,7 @@ def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
|
|
|
14
32
|
hints: list[InlayHint] = []
|
|
15
33
|
lines = content.split("\n")
|
|
16
34
|
|
|
35
|
+
# Layout hints (CuTeDSL)
|
|
17
36
|
for layout in language_info.layouts:
|
|
18
37
|
if layout.line < range.start.line or layout.line > range.end.line:
|
|
19
38
|
continue
|
|
@@ -39,6 +58,7 @@ def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
|
|
|
39
58
|
padding_right=False
|
|
40
59
|
))
|
|
41
60
|
|
|
61
|
+
# Kernel hints (CuTeDSL)
|
|
42
62
|
for kernel in language_info.kernels:
|
|
43
63
|
if kernel.line < range.start.line or kernel.line > range.end.line:
|
|
44
64
|
continue
|
|
@@ -62,4 +82,164 @@ def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
|
|
|
62
82
|
padding_right=True
|
|
63
83
|
))
|
|
64
84
|
|
|
85
|
+
# HIP/CUDA-specific hints
|
|
86
|
+
if language_info.language in ("hip", "cuda", "cpp"):
|
|
87
|
+
hints.extend(_get_hip_inlay_hints(lines, range))
|
|
88
|
+
|
|
65
89
|
return hints
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _get_hip_inlay_hints(lines: list[str], range: Range) -> list[InlayHint]:
|
|
93
|
+
"""Generate HIP-specific inlay hints.
|
|
94
|
+
|
|
95
|
+
- Kernel launch dimension annotations
|
|
96
|
+
- Shared memory size annotations
|
|
97
|
+
"""
|
|
98
|
+
hints: list[InlayHint] = []
|
|
99
|
+
|
|
100
|
+
for i in range(range.start.line, min(range.end.line + 1, len(lines))):
|
|
101
|
+
line = lines[i]
|
|
102
|
+
|
|
103
|
+
# Kernel launch hints
|
|
104
|
+
for match in HIP_LAUNCH_PATTERN.finditer(line):
|
|
105
|
+
kernel_name = match.group(1)
|
|
106
|
+
grid_dim = match.group(2).strip()
|
|
107
|
+
block_dim = match.group(3).strip()
|
|
108
|
+
shared_mem = match.group(4)
|
|
109
|
+
stream = match.group(5)
|
|
110
|
+
|
|
111
|
+
# Add hint after >>> showing launch configuration
|
|
112
|
+
hint_parts = []
|
|
113
|
+
|
|
114
|
+
# Try to parse and annotate dimensions
|
|
115
|
+
grid_info = _parse_dim(grid_dim)
|
|
116
|
+
block_info = _parse_dim(block_dim)
|
|
117
|
+
|
|
118
|
+
if grid_info:
|
|
119
|
+
hint_parts.append(f"{grid_info} blocks")
|
|
120
|
+
if block_info:
|
|
121
|
+
hint_parts.append(f"{block_info} threads/block")
|
|
122
|
+
# Calculate wavefronts (AMD uses 64-thread wavefronts)
|
|
123
|
+
try:
|
|
124
|
+
total_threads = _eval_dim(block_dim)
|
|
125
|
+
if total_threads:
|
|
126
|
+
wavefronts = (total_threads + 63) // 64
|
|
127
|
+
hint_parts.append(f"{wavefronts} wavefront{'s' if wavefronts != 1 else ''}")
|
|
128
|
+
except (ValueError, SyntaxError):
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
if hint_parts:
|
|
132
|
+
hint_text = " // " + ", ".join(hint_parts)
|
|
133
|
+
|
|
134
|
+
# Position after >>>
|
|
135
|
+
hint_pos = match.end()
|
|
136
|
+
|
|
137
|
+
hints.append(InlayHint(
|
|
138
|
+
position=Position(line=i, character=hint_pos),
|
|
139
|
+
label=hint_text,
|
|
140
|
+
kind=InlayHintKind.Parameter,
|
|
141
|
+
padding_left=True,
|
|
142
|
+
padding_right=False
|
|
143
|
+
))
|
|
144
|
+
|
|
145
|
+
# Shared memory size hints
|
|
146
|
+
for match in SHARED_MEM_PATTERN.finditer(line):
|
|
147
|
+
type_str = match.group(1).strip()
|
|
148
|
+
var_name = match.group(2)
|
|
149
|
+
array_size = match.group(3).strip()
|
|
150
|
+
|
|
151
|
+
size_bytes = _estimate_size(type_str, array_size)
|
|
152
|
+
if size_bytes:
|
|
153
|
+
if size_bytes >= 1024:
|
|
154
|
+
size_str = f" // {size_bytes / 1024:.1f} KB LDS"
|
|
155
|
+
else:
|
|
156
|
+
size_str = f" // {size_bytes} bytes LDS"
|
|
157
|
+
|
|
158
|
+
# Position at end of declaration
|
|
159
|
+
hint_pos = match.end()
|
|
160
|
+
|
|
161
|
+
hints.append(InlayHint(
|
|
162
|
+
position=Position(line=i, character=hint_pos),
|
|
163
|
+
label=size_str,
|
|
164
|
+
kind=InlayHintKind.Type,
|
|
165
|
+
padding_left=True,
|
|
166
|
+
padding_right=False
|
|
167
|
+
))
|
|
168
|
+
|
|
169
|
+
return hints
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _parse_dim(dim_str: str) -> str | None:
|
|
173
|
+
"""Parse a dimension string and return a human-readable description."""
|
|
174
|
+
dim_str = dim_str.strip()
|
|
175
|
+
|
|
176
|
+
# Simple number
|
|
177
|
+
if dim_str.isdigit():
|
|
178
|
+
return dim_str
|
|
179
|
+
|
|
180
|
+
# dim3(x, y, z)
|
|
181
|
+
if dim_str.startswith("dim3("):
|
|
182
|
+
return dim_str
|
|
183
|
+
|
|
184
|
+
# Variable or expression
|
|
185
|
+
if re.match(r'^[\w_]+$', dim_str):
|
|
186
|
+
return dim_str
|
|
187
|
+
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _eval_dim(dim_str: str) -> int | None:
|
|
192
|
+
"""Try to evaluate a dimension to an integer."""
|
|
193
|
+
dim_str = dim_str.strip()
|
|
194
|
+
|
|
195
|
+
# Simple number
|
|
196
|
+
if dim_str.isdigit():
|
|
197
|
+
return int(dim_str)
|
|
198
|
+
|
|
199
|
+
# dim3(x) or dim3(x, y) or dim3(x, y, z) - try to multiply
|
|
200
|
+
if dim_str.startswith("dim3(") and dim_str.endswith(")"):
|
|
201
|
+
inner = dim_str[5:-1]
|
|
202
|
+
parts = [p.strip() for p in inner.split(",")]
|
|
203
|
+
try:
|
|
204
|
+
total = 1
|
|
205
|
+
for p in parts:
|
|
206
|
+
if p.isdigit():
|
|
207
|
+
total *= int(p)
|
|
208
|
+
else:
|
|
209
|
+
return None # Can't evaluate variable
|
|
210
|
+
return total
|
|
211
|
+
except (ValueError, SyntaxError):
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _estimate_size(type_str: str, array_size: str) -> int | None:
|
|
218
|
+
"""Estimate size in bytes for a shared memory allocation."""
|
|
219
|
+
type_sizes = {
|
|
220
|
+
'char': 1, 'int8_t': 1, 'uint8_t': 1,
|
|
221
|
+
'short': 2, 'int16_t': 2, 'uint16_t': 2, 'half': 2, '__half': 2,
|
|
222
|
+
'int': 4, 'int32_t': 4, 'uint32_t': 4, 'float': 4, 'unsigned': 4,
|
|
223
|
+
'long': 8, 'int64_t': 8, 'uint64_t': 8, 'double': 8,
|
|
224
|
+
'float4': 16, 'float2': 8, 'int4': 16, 'int2': 8,
|
|
225
|
+
'double2': 16, 'double4': 32,
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
# Find base type
|
|
229
|
+
base_type = type_str.strip()
|
|
230
|
+
type_size = None
|
|
231
|
+
for known_type, size in type_sizes.items():
|
|
232
|
+
if known_type in base_type:
|
|
233
|
+
type_size = size
|
|
234
|
+
break
|
|
235
|
+
|
|
236
|
+
if type_size is None:
|
|
237
|
+
type_size = 4 # Default to 4 bytes
|
|
238
|
+
|
|
239
|
+
# Try to evaluate array size
|
|
240
|
+
try:
|
|
241
|
+
# Handle simple expressions
|
|
242
|
+
arr_size = eval(array_size.replace('*', ' * '))
|
|
243
|
+
return type_size * arr_size
|
|
244
|
+
except (ValueError, SyntaxError, NameError):
|
|
245
|
+
return None
|
|
@@ -1,13 +1,20 @@
|
|
|
1
1
|
|
|
2
|
+
import re
|
|
3
|
+
|
|
2
4
|
from lsprotocol.types import SemanticTokens, SemanticTokensLegend
|
|
3
5
|
|
|
4
6
|
from ..languages.registry import get_language_registry
|
|
5
7
|
|
|
6
8
|
TOKEN_TYPES = [
|
|
7
|
-
"kernel",
|
|
8
|
-
"layout",
|
|
9
|
-
"struct",
|
|
10
|
-
"decorator",
|
|
9
|
+
"kernel", # 0: GPU kernel functions (__global__)
|
|
10
|
+
"layout", # 1: Layout variables
|
|
11
|
+
"struct", # 2: Structs
|
|
12
|
+
"decorator", # 3: Python decorators (@cute.kernel)
|
|
13
|
+
"keyword_gpu", # 4: GPU keywords (__global__, __device__, __shared__)
|
|
14
|
+
"keyword_memory", # 5: Memory qualifiers (__shared__, __constant__)
|
|
15
|
+
"function_hip_api", # 6: HIP API calls (hipMalloc, etc.)
|
|
16
|
+
"function_intrinsic", # 7: Wavefront intrinsics (__shfl, __ballot)
|
|
17
|
+
"device_function", # 8: __device__ functions
|
|
11
18
|
]
|
|
12
19
|
|
|
13
20
|
TOKEN_MODIFIERS = [
|
|
@@ -20,6 +27,31 @@ SEMANTIC_TOKENS_LEGEND = SemanticTokensLegend(
|
|
|
20
27
|
token_modifiers=TOKEN_MODIFIERS
|
|
21
28
|
)
|
|
22
29
|
|
|
30
|
+
# HIP-specific patterns for semantic highlighting
|
|
31
|
+
HIP_KEYWORD_PATTERN = re.compile(r'\b(__global__|__device__|__host__|__forceinline__)\b')
|
|
32
|
+
HIP_MEMORY_KEYWORD_PATTERN = re.compile(r'\b(__shared__|__constant__|__restrict__)\b')
|
|
33
|
+
HIP_LAUNCH_BOUNDS_PATTERN = re.compile(r'__launch_bounds__\s*\([^)]+\)')
|
|
34
|
+
HIP_API_PATTERN = re.compile(
|
|
35
|
+
r'\b(hipMalloc|hipMallocManaged|hipMallocAsync|hipFree|hipFreeAsync|'
|
|
36
|
+
r'hipMemcpy|hipMemcpyAsync|hipMemset|hipMemsetAsync|'
|
|
37
|
+
r'hipHostMalloc|hipHostFree|hipHostRegister|hipHostUnregister|'
|
|
38
|
+
r'hipDeviceSynchronize|hipStreamSynchronize|'
|
|
39
|
+
r'hipStreamCreate|hipStreamDestroy|hipStreamCreateWithFlags|'
|
|
40
|
+
r'hipEventCreate|hipEventDestroy|hipEventRecord|hipEventSynchronize|hipEventElapsedTime|'
|
|
41
|
+
r'hipSetDevice|hipGetDevice|hipGetDeviceCount|hipGetDeviceProperties|'
|
|
42
|
+
r'hipLaunchKernelGGL|hipLaunchCooperativeKernel|'
|
|
43
|
+
r'hipGetLastError|hipPeekAtLastError|hipGetErrorString|hipGetErrorName)\b'
|
|
44
|
+
)
|
|
45
|
+
HIP_INTRINSIC_PATTERN = re.compile(
|
|
46
|
+
r'\b(__shfl|__shfl_down|__shfl_up|__shfl_xor|__shfl_sync|'
|
|
47
|
+
r'__ballot|__any|__all|__activemask|'
|
|
48
|
+
r'__syncthreads|__syncwarp|__threadfence|__threadfence_block|__threadfence_system|'
|
|
49
|
+
r'atomicAdd|atomicSub|atomicMax|atomicMin|atomicExch|atomicCAS|atomicAnd|atomicOr|atomicXor|'
|
|
50
|
+
r'__popc|__popcll|__clz|__clzll|__ffs|__ffsll|'
|
|
51
|
+
r'__float2half|__half2float|__float2int_rn|__int2float_rn|'
|
|
52
|
+
r'__ldg|__ldcg|__ldca|__ldcs)\b'
|
|
53
|
+
)
|
|
54
|
+
|
|
23
55
|
|
|
24
56
|
def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
|
|
25
57
|
registry = get_language_registry()
|
|
@@ -28,11 +60,12 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
|
|
|
28
60
|
if not language_info:
|
|
29
61
|
return SemanticTokens(data=[])
|
|
30
62
|
|
|
31
|
-
tokens
|
|
63
|
+
# Collect all tokens with their positions
|
|
64
|
+
# We'll sort them later to ensure proper delta calculation
|
|
65
|
+
token_entries: list[tuple[int, int, int, int, int]] = [] # (line, char, length, type, modifier)
|
|
32
66
|
lines = content.split("\n")
|
|
33
|
-
prev_line = 0
|
|
34
|
-
prev_char = 0
|
|
35
67
|
|
|
68
|
+
# Add kernel tokens
|
|
36
69
|
for kernel in language_info.kernels:
|
|
37
70
|
if kernel.line >= len(lines):
|
|
38
71
|
continue
|
|
@@ -41,20 +74,15 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
|
|
|
41
74
|
name_start = kernel_line.find(kernel.name)
|
|
42
75
|
|
|
43
76
|
if name_start >= 0:
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
tokens.extend([
|
|
48
|
-
delta_line,
|
|
49
|
-
delta_char,
|
|
77
|
+
token_entries.append((
|
|
78
|
+
kernel.line,
|
|
79
|
+
name_start,
|
|
50
80
|
len(kernel.name),
|
|
51
81
|
TOKEN_TYPES.index("kernel"),
|
|
52
82
|
0
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
prev_line = kernel.line
|
|
56
|
-
prev_char = name_start + len(kernel.name)
|
|
83
|
+
))
|
|
57
84
|
|
|
85
|
+
# Add layout tokens
|
|
58
86
|
for layout in language_info.layouts:
|
|
59
87
|
if layout.line >= len(lines):
|
|
60
88
|
continue
|
|
@@ -63,20 +91,15 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
|
|
|
63
91
|
name_start = layout_line.find(layout.name)
|
|
64
92
|
|
|
65
93
|
if name_start >= 0:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
tokens.extend([
|
|
70
|
-
delta_line,
|
|
71
|
-
delta_char,
|
|
94
|
+
token_entries.append((
|
|
95
|
+
layout.line,
|
|
96
|
+
name_start,
|
|
72
97
|
len(layout.name),
|
|
73
98
|
TOKEN_TYPES.index("layout"),
|
|
74
99
|
0
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
prev_line = layout.line
|
|
78
|
-
prev_char = name_start + len(layout.name)
|
|
100
|
+
))
|
|
79
101
|
|
|
102
|
+
# Add struct tokens
|
|
80
103
|
for struct in language_info.structs:
|
|
81
104
|
if struct.line >= len(lines):
|
|
82
105
|
continue
|
|
@@ -85,20 +108,15 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
|
|
|
85
108
|
name_start = struct_line.find(struct.name)
|
|
86
109
|
|
|
87
110
|
if name_start >= 0:
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
tokens.extend([
|
|
92
|
-
delta_line,
|
|
93
|
-
delta_char,
|
|
111
|
+
token_entries.append((
|
|
112
|
+
struct.line,
|
|
113
|
+
name_start,
|
|
94
114
|
len(struct.name),
|
|
95
115
|
TOKEN_TYPES.index("struct"),
|
|
96
116
|
0
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
prev_line = struct.line
|
|
100
|
-
prev_char = name_start + len(struct.name)
|
|
117
|
+
))
|
|
101
118
|
|
|
119
|
+
# Add CuTeDSL decorator tokens
|
|
102
120
|
for i, line in enumerate(lines):
|
|
103
121
|
if "@cute.kernel" in line or "@cute.struct" in line:
|
|
104
122
|
decorator_start = line.find("@")
|
|
@@ -107,18 +125,100 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
|
|
|
107
125
|
if decorator_end == -1:
|
|
108
126
|
decorator_end = len(line)
|
|
109
127
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
tokens.extend([
|
|
114
|
-
delta_line,
|
|
115
|
-
delta_char,
|
|
128
|
+
token_entries.append((
|
|
129
|
+
i,
|
|
130
|
+
decorator_start,
|
|
116
131
|
decorator_end - decorator_start,
|
|
117
132
|
TOKEN_TYPES.index("decorator"),
|
|
118
133
|
0
|
|
119
|
-
|
|
134
|
+
))
|
|
135
|
+
|
|
136
|
+
# Add HIP-specific tokens if this is a HIP or CUDA file
|
|
137
|
+
if language_info.language in ("hip", "cuda", "cpp"):
|
|
138
|
+
token_entries.extend(_get_hip_tokens(lines))
|
|
120
139
|
|
|
121
|
-
|
|
122
|
-
|
|
140
|
+
# Sort tokens by position (line, then character)
|
|
141
|
+
token_entries.sort(key=lambda x: (x[0], x[1]))
|
|
142
|
+
|
|
143
|
+
# Convert to delta-encoded format
|
|
144
|
+
tokens: list[int] = []
|
|
145
|
+
prev_line = 0
|
|
146
|
+
prev_char = 0
|
|
147
|
+
|
|
148
|
+
for line, char, length, token_type, modifier in token_entries:
|
|
149
|
+
delta_line = line - prev_line
|
|
150
|
+
delta_char = char - (prev_char if delta_line == 0 else 0)
|
|
151
|
+
|
|
152
|
+
tokens.extend([
|
|
153
|
+
delta_line,
|
|
154
|
+
delta_char,
|
|
155
|
+
length,
|
|
156
|
+
token_type,
|
|
157
|
+
modifier
|
|
158
|
+
])
|
|
159
|
+
|
|
160
|
+
prev_line = line
|
|
161
|
+
prev_char = char
|
|
123
162
|
|
|
124
163
|
return SemanticTokens(data=tokens)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _get_hip_tokens(lines: list[str]) -> list[tuple[int, int, int, int, int]]:
|
|
167
|
+
"""Extract HIP-specific semantic tokens from code.
|
|
168
|
+
|
|
169
|
+
Returns list of (line, char, length, token_type, modifier) tuples.
|
|
170
|
+
"""
|
|
171
|
+
token_entries: list[tuple[int, int, int, int, int]] = []
|
|
172
|
+
|
|
173
|
+
for i, line in enumerate(lines):
|
|
174
|
+
# GPU keywords (__global__, __device__, etc.)
|
|
175
|
+
for match in HIP_KEYWORD_PATTERN.finditer(line):
|
|
176
|
+
token_entries.append((
|
|
177
|
+
i,
|
|
178
|
+
match.start(),
|
|
179
|
+
len(match.group()),
|
|
180
|
+
TOKEN_TYPES.index("keyword_gpu"),
|
|
181
|
+
0
|
|
182
|
+
))
|
|
183
|
+
|
|
184
|
+
# Memory keywords (__shared__, __constant__)
|
|
185
|
+
for match in HIP_MEMORY_KEYWORD_PATTERN.finditer(line):
|
|
186
|
+
token_entries.append((
|
|
187
|
+
i,
|
|
188
|
+
match.start(),
|
|
189
|
+
len(match.group()),
|
|
190
|
+
TOKEN_TYPES.index("keyword_memory"),
|
|
191
|
+
0
|
|
192
|
+
))
|
|
193
|
+
|
|
194
|
+
# __launch_bounds__
|
|
195
|
+
for match in HIP_LAUNCH_BOUNDS_PATTERN.finditer(line):
|
|
196
|
+
token_entries.append((
|
|
197
|
+
i,
|
|
198
|
+
match.start(),
|
|
199
|
+
len(match.group()),
|
|
200
|
+
TOKEN_TYPES.index("keyword_gpu"),
|
|
201
|
+
0
|
|
202
|
+
))
|
|
203
|
+
|
|
204
|
+
# HIP API functions
|
|
205
|
+
for match in HIP_API_PATTERN.finditer(line):
|
|
206
|
+
token_entries.append((
|
|
207
|
+
i,
|
|
208
|
+
match.start(),
|
|
209
|
+
len(match.group()),
|
|
210
|
+
TOKEN_TYPES.index("function_hip_api"),
|
|
211
|
+
0
|
|
212
|
+
))
|
|
213
|
+
|
|
214
|
+
# Wavefront intrinsics
|
|
215
|
+
for match in HIP_INTRINSIC_PATTERN.finditer(line):
|
|
216
|
+
token_entries.append((
|
|
217
|
+
i,
|
|
218
|
+
match.start(),
|
|
219
|
+
len(match.group()),
|
|
220
|
+
TOKEN_TYPES.index("function_intrinsic"),
|
|
221
|
+
0
|
|
222
|
+
))
|
|
223
|
+
|
|
224
|
+
return token_entries
|
wafer_lsp/languages/detector.py
CHANGED
|
@@ -2,33 +2,106 @@ from pathlib import Path
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class LanguageDetector:
|
|
5
|
+
"""Detects language based on file extension and content markers.
|
|
6
|
+
|
|
7
|
+
Supports both extension-based detection (fast) and content-based detection
|
|
8
|
+
(for files that share extensions, e.g., .cpp files that could be HIP or CUDA).
|
|
9
|
+
"""
|
|
5
10
|
|
|
6
11
|
def __init__(self):
|
|
7
12
|
self._extensions: dict[str, str] = {}
|
|
13
|
+
self._content_markers: dict[str, list[str]] = {} # language_id -> markers
|
|
14
|
+
# Compound extensions like .hip.cpp need special handling
|
|
15
|
+
self._compound_extensions: dict[str, str] = {}
|
|
8
16
|
|
|
9
17
|
def register_extension(self, extension: str, language_id: str):
|
|
10
|
-
normalized_ext = extension if extension.startswith(".") else f".{
|
|
11
|
-
|
|
18
|
+
normalized_ext = extension if extension.startswith(".") else f".{extension}"
|
|
19
|
+
|
|
20
|
+
# Check if this is a compound extension (e.g., .hip.cpp)
|
|
21
|
+
if normalized_ext.count(".") > 1:
|
|
22
|
+
self._compound_extensions[normalized_ext] = language_id
|
|
23
|
+
else:
|
|
24
|
+
self._extensions[normalized_ext] = language_id
|
|
25
|
+
|
|
26
|
+
def register_content_markers(self, language_id: str, markers: list[str]):
|
|
27
|
+
"""Register content markers for content-based language detection."""
|
|
28
|
+
self._content_markers[language_id] = markers
|
|
12
29
|
|
|
13
|
-
def detect_from_uri(self, uri: str) -> str | None:
|
|
30
|
+
def detect_from_uri(self, uri: str, content: str | None = None) -> str | None:
|
|
31
|
+
"""Detect language from URI and optionally content.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
uri: File URI or path
|
|
35
|
+
content: Optional file content for content-based detection
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Language ID or None
|
|
39
|
+
"""
|
|
14
40
|
if uri.startswith("file://"):
|
|
15
41
|
file_path = uri[7:]
|
|
16
42
|
else:
|
|
17
43
|
file_path = uri
|
|
18
44
|
|
|
19
|
-
return self.detect_from_path(file_path)
|
|
45
|
+
return self.detect_from_path(file_path, content)
|
|
20
46
|
|
|
21
|
-
def detect_from_path(self, file_path: str) -> str | None:
|
|
47
|
+
def detect_from_path(self, file_path: str, content: str | None = None) -> str | None:
|
|
48
|
+
"""Detect language from file path and optionally content.
|
|
49
|
+
|
|
50
|
+
Order of detection:
|
|
51
|
+
1. Compound extensions (e.g., .hip.cpp) - most specific
|
|
52
|
+
2. Content markers (for shared extensions like .cpp)
|
|
53
|
+
3. Simple extension
|
|
54
|
+
"""
|
|
22
55
|
path = Path(file_path)
|
|
56
|
+
|
|
57
|
+
# 1. Check compound extensions first
|
|
58
|
+
# Get the last two suffixes for compound extension detection
|
|
59
|
+
suffixes = path.suffixes
|
|
60
|
+
if len(suffixes) >= 2:
|
|
61
|
+
compound_ext = "".join(suffixes[-2:]).lower()
|
|
62
|
+
if compound_ext in self._compound_extensions:
|
|
63
|
+
return self._compound_extensions[compound_ext]
|
|
64
|
+
|
|
65
|
+
# 2. If content is provided, check content markers
|
|
66
|
+
if content:
|
|
67
|
+
content_lang = self._detect_from_content(content)
|
|
68
|
+
if content_lang:
|
|
69
|
+
return content_lang
|
|
70
|
+
|
|
71
|
+
# 3. Fall back to simple extension
|
|
23
72
|
ext = path.suffix.lower()
|
|
24
73
|
return self._extensions.get(ext)
|
|
25
74
|
|
|
75
|
+
def _detect_from_content(self, content: str) -> str | None:
|
|
76
|
+
"""Detect language based on content markers.
|
|
77
|
+
|
|
78
|
+
Returns the language with the most matching markers.
|
|
79
|
+
"""
|
|
80
|
+
best_match: str | None = None
|
|
81
|
+
best_count = 0
|
|
82
|
+
|
|
83
|
+
for language_id, markers in self._content_markers.items():
|
|
84
|
+
match_count = sum(1 for marker in markers if marker in content)
|
|
85
|
+
if match_count > best_count:
|
|
86
|
+
best_count = match_count
|
|
87
|
+
best_match = language_id
|
|
88
|
+
|
|
89
|
+
# Require at least one marker match
|
|
90
|
+
return best_match if best_count > 0 else None
|
|
91
|
+
|
|
26
92
|
def detect_from_extension(self, extension: str) -> str | None:
|
|
27
|
-
normalized_ext = extension if extension.startswith(".") else f".{
|
|
93
|
+
normalized_ext = extension if extension.startswith(".") else f".{extension}"
|
|
94
|
+
normalized_ext = normalized_ext.lower() # Case insensitive
|
|
28
95
|
return self._extensions.get(normalized_ext)
|
|
29
96
|
|
|
30
97
|
def get_supported_extensions(self) -> list[str]:
|
|
31
|
-
|
|
98
|
+
all_extensions = list(self._extensions.keys())
|
|
99
|
+
all_extensions.extend(self._compound_extensions.keys())
|
|
100
|
+
return all_extensions
|
|
101
|
+
|
|
102
|
+
def is_supported(self, uri: str, content: str | None = None) -> bool:
|
|
103
|
+
return self.detect_from_uri(uri, content) is not None
|
|
32
104
|
|
|
33
|
-
def
|
|
34
|
-
|
|
105
|
+
def get_content_markers(self, language_id: str) -> list[str]:
|
|
106
|
+
"""Get content markers for a language."""
|
|
107
|
+
return self._content_markers.get(language_id, [])
|
wafer_lsp/languages/registry.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
from ..parsers.cuda_parser import CUDAParser
|
|
3
3
|
from ..parsers.cutedsl_parser import CuTeDSLParser
|
|
4
|
+
from ..parsers.hip_parser import HIPParser
|
|
4
5
|
from .converter import ParserResultConverter
|
|
5
6
|
from .detector import LanguageDetector
|
|
6
7
|
from .parser_manager import ParserManager
|
|
@@ -33,6 +34,22 @@ class LanguageRegistry:
|
|
|
33
34
|
file_patterns=["*.cu", "*.cuh"]
|
|
34
35
|
)
|
|
35
36
|
|
|
37
|
+
# HIP (AMD GPU) - Register before cpp so .hip.cpp files get detected as HIP
|
|
38
|
+
self.register_language(
|
|
39
|
+
language_id="hip",
|
|
40
|
+
display_name="HIP (AMD GPU)",
|
|
41
|
+
parser=HIPParser(),
|
|
42
|
+
extensions=[".hip", ".hip.cpp", ".hip.hpp", ".hipcc"],
|
|
43
|
+
file_patterns=["*.hip", "*.hip.cpp", "*.hip.hpp", "*.hipcc"],
|
|
44
|
+
content_markers=[
|
|
45
|
+
"#include <hip/hip_runtime.h>",
|
|
46
|
+
"#include \"hip/hip_runtime.h\"",
|
|
47
|
+
"hipMalloc",
|
|
48
|
+
"hipLaunchKernelGGL",
|
|
49
|
+
"__HIP_PLATFORM_AMD__",
|
|
50
|
+
]
|
|
51
|
+
)
|
|
52
|
+
|
|
36
53
|
self.register_language(
|
|
37
54
|
language_id="cpp",
|
|
38
55
|
display_name="C++",
|
|
@@ -47,12 +64,16 @@ class LanguageRegistry:
|
|
|
47
64
|
display_name: str,
|
|
48
65
|
parser,
|
|
49
66
|
extensions: list[str],
|
|
50
|
-
file_patterns: list[str] | None = None
|
|
67
|
+
file_patterns: list[str] | None = None,
|
|
68
|
+
content_markers: list[str] | None = None
|
|
51
69
|
):
|
|
52
70
|
self._parser_manager.register_parser(language_id, display_name, parser)
|
|
53
71
|
|
|
54
72
|
for ext in extensions:
|
|
55
73
|
self._detector.register_extension(ext, language_id)
|
|
74
|
+
|
|
75
|
+
if content_markers:
|
|
76
|
+
self._detector.register_content_markers(language_id, content_markers)
|
|
56
77
|
|
|
57
78
|
def detect_language(self, uri: str) -> str | None:
|
|
58
79
|
return self._detector.detect_from_uri(uri)
|
wafer_lsp/parsers/__init__.py
CHANGED
|
@@ -6,6 +6,16 @@ from .cutedsl_parser import (
|
|
|
6
6
|
CuTeDSLParser,
|
|
7
7
|
CuTeDSLStruct,
|
|
8
8
|
)
|
|
9
|
+
from .hip_parser import (
|
|
10
|
+
HIPKernel,
|
|
11
|
+
HIPDeviceFunction,
|
|
12
|
+
HIPParameter,
|
|
13
|
+
HIPParser,
|
|
14
|
+
KernelLaunchSite,
|
|
15
|
+
SharedMemoryAllocation,
|
|
16
|
+
WavefrontPattern,
|
|
17
|
+
is_hip_file,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
__all__ = [
|
|
11
21
|
"BaseParser",
|
|
@@ -15,4 +25,12 @@ __all__ = [
|
|
|
15
25
|
"CuTeDSLLayout",
|
|
16
26
|
"CuTeDSLParser",
|
|
17
27
|
"CuTeDSLStruct",
|
|
28
|
+
"HIPKernel",
|
|
29
|
+
"HIPDeviceFunction",
|
|
30
|
+
"HIPParameter",
|
|
31
|
+
"HIPParser",
|
|
32
|
+
"KernelLaunchSite",
|
|
33
|
+
"SharedMemoryAllocation",
|
|
34
|
+
"WavefrontPattern",
|
|
35
|
+
"is_hip_file",
|
|
18
36
|
]
|