wafer-lsp 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/handlers/__init__.py +30 -0
- wafer_lsp/handlers/diagnostics.py +26 -1
- wafer_lsp/handlers/document_symbol.py +93 -4
- wafer_lsp/handlers/hip_diagnostics.py +303 -0
- wafer_lsp/handlers/hover.py +45 -9
- wafer_lsp/handlers/inlay_hint.py +180 -0
- wafer_lsp/handlers/semantic_tokens.py +146 -46
- wafer_lsp/languages/detector.py +82 -9
- wafer_lsp/languages/registry.py +22 -1
- wafer_lsp/parsers/__init__.py +18 -0
- wafer_lsp/parsers/hip_parser.py +688 -0
- wafer_lsp/services/__init__.py +17 -0
- wafer_lsp/services/hip_docs.py +806 -0
- wafer_lsp/services/hip_hover_service.py +412 -0
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.2.dist-info}/METADATA +4 -1
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.2.dist-info}/RECORD +18 -14
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.2.dist-info}/WHEEL +0 -0
- {wafer_lsp-0.1.0.dist-info → wafer_lsp-0.1.2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,688 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HIP Parser for wafer-lsp.
|
|
3
|
+
|
|
4
|
+
Extracts GPU constructs from HIP (AMD GPU) source files:
|
|
5
|
+
- __global__ kernels
|
|
6
|
+
- __device__ helper functions
|
|
7
|
+
- __shared__ memory allocations (LDS)
|
|
8
|
+
- Kernel launch sites (<<<>>> and hipLaunchKernelGGL)
|
|
9
|
+
- Wavefront-sensitive patterns for diagnostics
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import re
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from .base_parser import BaseParser
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class HIPParameter:
|
|
21
|
+
"""A kernel or device function parameter."""
|
|
22
|
+
name: str
|
|
23
|
+
type_str: str
|
|
24
|
+
is_pointer: bool = False
|
|
25
|
+
is_const: bool = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class HIPKernel:
|
|
30
|
+
"""A __global__ GPU kernel function."""
|
|
31
|
+
name: str
|
|
32
|
+
line: int
|
|
33
|
+
end_line: int
|
|
34
|
+
parameters: list[str]
|
|
35
|
+
parameter_info: list[HIPParameter]
|
|
36
|
+
attributes: list[str] # e.g., __launch_bounds__(256)
|
|
37
|
+
docstring: str | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class HIPDeviceFunction:
|
|
42
|
+
"""A __device__ helper function."""
|
|
43
|
+
name: str
|
|
44
|
+
line: int
|
|
45
|
+
end_line: int
|
|
46
|
+
parameters: list[str]
|
|
47
|
+
parameter_info: list[HIPParameter]
|
|
48
|
+
return_type: str = "void"
|
|
49
|
+
is_inline: bool = False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class SharedMemoryAllocation:
|
|
54
|
+
"""A __shared__ memory (LDS) allocation."""
|
|
55
|
+
name: str
|
|
56
|
+
type_str: str
|
|
57
|
+
line: int
|
|
58
|
+
size_bytes: int | None # None if dynamic/unknown
|
|
59
|
+
array_size: str | None # Array dimension if static
|
|
60
|
+
is_dynamic: bool = False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass(frozen=True)
|
|
64
|
+
class KernelLaunchSite:
|
|
65
|
+
"""Where a kernel is launched."""
|
|
66
|
+
kernel_name: str
|
|
67
|
+
line: int
|
|
68
|
+
grid_dim: str | None # Grid dimensions if determinable
|
|
69
|
+
block_dim: str | None # Block dimensions if determinable
|
|
70
|
+
shared_mem_bytes: str | None # Dynamic shared memory size
|
|
71
|
+
stream: str | None # CUDA stream if specified
|
|
72
|
+
is_hip_launch_kernel_ggl: bool = False # True if using hipLaunchKernelGGL
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass(frozen=True)
|
|
76
|
+
class WavefrontPattern:
|
|
77
|
+
"""A pattern that might indicate wavefront-size assumptions.
|
|
78
|
+
|
|
79
|
+
These patterns often indicate code written for CUDA's 32-thread warps
|
|
80
|
+
that may behave incorrectly on AMD's 64-thread wavefronts.
|
|
81
|
+
"""
|
|
82
|
+
pattern_type: str # "warp_size_32", "ballot_32bit", "shuffle_mask", "lane_calc_32"
|
|
83
|
+
line: int
|
|
84
|
+
code_snippet: str
|
|
85
|
+
problematic_value: str # The specific value causing concern (e.g., "32", "0xFFFFFFFF")
|
|
86
|
+
severity: str = "warning" # "warning", "error", "info"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class HIPParser(BaseParser):
|
|
90
|
+
"""Parser for HIP (AMD GPU) source files.
|
|
91
|
+
|
|
92
|
+
Extracts kernels, device functions, shared memory allocations,
|
|
93
|
+
launch sites, and wavefront-sensitive patterns using regex-based parsing.
|
|
94
|
+
|
|
95
|
+
Why regex instead of full AST parsing:
|
|
96
|
+
- HIP/CUDA syntax is C++ which is complex to fully parse
|
|
97
|
+
- Key constructs (__global__, __device__, __shared__) are lexically distinct
|
|
98
|
+
- Fast enough for real-time IDE use (<100ms)
|
|
99
|
+
- Works on incomplete/invalid code during editing
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
# Regex patterns for HIP constructs
|
|
103
|
+
|
|
104
|
+
# Matches __global__ function declarations
|
|
105
|
+
# Captures: return type, function name
|
|
106
|
+
# Note: __launch_bounds__ can be on the same line or line before
|
|
107
|
+
_KERNEL_PATTERN = re.compile(
|
|
108
|
+
r'(?:template\s*<[^>]*>\s*)?' # Optional template
|
|
109
|
+
r'(?:__launch_bounds__\s*\([^)]+\)\s*[\n\r]*)?' # Optional launch bounds before __global__
|
|
110
|
+
r'__global__\s+'
|
|
111
|
+
r'(?:__device__\s+)?' # Optional __device__ (CUDA allows both)
|
|
112
|
+
r'(?:__launch_bounds__\s*\([^)]+\)\s*)?' # Optional launch bounds after __global__
|
|
113
|
+
r'([\w\s\*&:<>,]+?)\s+' # Return type (usually void)
|
|
114
|
+
r'(\w+)\s*' # Function name (captured)
|
|
115
|
+
r'\(', # Start of parameter list
|
|
116
|
+
re.MULTILINE | re.DOTALL
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Separate pattern to detect __launch_bounds__ near __global__
|
|
120
|
+
_LAUNCH_BOUNDS_PATTERN = re.compile(
|
|
121
|
+
r'__launch_bounds__\s*\(([^)]+)\)',
|
|
122
|
+
re.MULTILINE
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Matches __device__ function declarations (not __global__)
|
|
126
|
+
_DEVICE_FUNC_PATTERN = re.compile(
|
|
127
|
+
r'(?:template\s*<[^>]*>\s*)?' # Optional template
|
|
128
|
+
r'__device__\s+'
|
|
129
|
+
r'(?!__global__)' # Not followed by __global__
|
|
130
|
+
r'(__forceinline__\s+)?' # Optional __forceinline__ (captured)
|
|
131
|
+
r'(?:__host__\s+)?' # Optional __host__
|
|
132
|
+
r'(?:inline\s+)?'
|
|
133
|
+
r'([\w\s\*&:<>,]+?)\s+' # Return type (captured)
|
|
134
|
+
r'(\w+)\s*' # Function name (captured)
|
|
135
|
+
r'\(',
|
|
136
|
+
re.MULTILINE
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Matches __shared__ variable declarations (including 2D arrays)
|
|
140
|
+
_SHARED_MEMORY_PATTERN = re.compile(
|
|
141
|
+
r'__shared__\s+'
|
|
142
|
+
r'([\w\s:<>,]+?)\s+' # Type (captured)
|
|
143
|
+
r'(\w+)\s*' # Variable name (captured)
|
|
144
|
+
r'(\[[^\]]*\](?:\s*\[[^\]]*\])?)?' # Optional array size(s) including 2D (captured)
|
|
145
|
+
r'\s*[;=]',
|
|
146
|
+
re.MULTILINE
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Matches kernel launches with <<<>>> syntax
|
|
150
|
+
_LAUNCH_SYNTAX_PATTERN = re.compile(
|
|
151
|
+
r'(\w+)\s*' # Kernel name (captured)
|
|
152
|
+
r'<<<\s*'
|
|
153
|
+
r'([^,>]+)\s*,\s*' # Grid dim (captured)
|
|
154
|
+
r'([^,>]+)' # Block dim (captured)
|
|
155
|
+
r'(?:\s*,\s*([^,>]+))?' # Optional shared mem (captured)
|
|
156
|
+
r'(?:\s*,\s*([^>]+))?' # Optional stream (captured)
|
|
157
|
+
r'\s*>>>',
|
|
158
|
+
re.MULTILINE
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Matches hipLaunchKernelGGL calls
|
|
162
|
+
_HIP_LAUNCH_PATTERN = re.compile(
|
|
163
|
+
r'hipLaunchKernelGGL\s*\(\s*'
|
|
164
|
+
r'(\w+)\s*,\s*' # Kernel name (captured)
|
|
165
|
+
r'([^,]+)\s*,\s*' # Grid dim (captured)
|
|
166
|
+
r'([^,]+)\s*,\s*' # Block dim (captured)
|
|
167
|
+
r'([^,]+)\s*,\s*' # Shared mem (captured)
|
|
168
|
+
r'([^,)]+)' # Stream (captured)
|
|
169
|
+
r'(?:\s*,)?', # Optional comma before args
|
|
170
|
+
re.MULTILINE
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Patterns that suggest incorrect wavefront size assumptions (CUDA's 32 vs AMD's 64)
|
|
174
|
+
_WARP_SIZE_32_PATTERNS = [
|
|
175
|
+
# threadIdx.x < 32, threadIdx.x % 32, threadIdx.x & 31
|
|
176
|
+
(re.compile(r'threadIdx\s*\.\s*[xyz]\s*[<%&]\s*32\b'), "warp_size_32", "32"),
|
|
177
|
+
(re.compile(r'threadIdx\s*\.\s*[xyz]\s*&\s*31\b'), "warp_size_32", "31"),
|
|
178
|
+
(re.compile(r'threadIdx\s*\.\s*[xyz]\s*/\s*32\b'), "warp_size_32", "32"),
|
|
179
|
+
|
|
180
|
+
# Lane/warp index calculations with hard-coded 32
|
|
181
|
+
(re.compile(r'\blaneId\s*=\s*[^;]*%\s*32\b'), "lane_calc_32", "32"),
|
|
182
|
+
(re.compile(r'\bwarpId\s*=\s*[^;]*/\s*32\b'), "lane_calc_32", "32"),
|
|
183
|
+
(re.compile(r'\blane\s*=\s*[^;]*&\s*31\b'), "lane_calc_32", "31"),
|
|
184
|
+
|
|
185
|
+
# Ballot result compared to 32-bit mask
|
|
186
|
+
(re.compile(r'__ballot\s*\([^)]*\)\s*==\s*0x[Ff]{8}\b'), "ballot_32bit", "0xFFFFFFFF"),
|
|
187
|
+
(re.compile(r'__ballot\s*\([^)]*\)\s*!=\s*0x[Ff]{8}\b'), "ballot_32bit", "0xFFFFFFFF"),
|
|
188
|
+
(re.compile(r'__ballot\s*\([^)]*\)\s*&\s*0x[Ff]{8}\b'), "ballot_32bit", "0xFFFFFFFF"),
|
|
189
|
+
|
|
190
|
+
# Shuffle operations with mask suggesting 32-thread warp
|
|
191
|
+
(re.compile(r'__shfl(?:_sync)?\s*\([^)]*,\s*0x[Ff]{8}\s*\)'), "shuffle_mask", "0xFFFFFFFF"),
|
|
192
|
+
# Match __shfl_down(val, offset, 32) or similar with explicit width=32
|
|
193
|
+
(re.compile(r'__shfl_down(?:_sync)?\s*\([^,]+,\s*[^,]+,\s*32\s*\)'), "shuffle_mask", "32"),
|
|
194
|
+
(re.compile(r'__shfl_up(?:_sync)?\s*\([^,]+,\s*[^,]+,\s*32\s*\)'), "shuffle_mask", "32"),
|
|
195
|
+
(re.compile(r'__shfl_xor(?:_sync)?\s*\([^,]+,\s*[^,]+,\s*32\s*\)'), "shuffle_mask", "32"),
|
|
196
|
+
(re.compile(r'__shfl(?:_sync)?\s*\([^,]+,\s*[^,]+,\s*32\s*\)'), "shuffle_mask", "32"),
|
|
197
|
+
|
|
198
|
+
# activemask() compared to 32-bit value
|
|
199
|
+
(re.compile(r'__activemask\s*\(\s*\)\s*==\s*0x[Ff]{8}\b'), "ballot_32bit", "0xFFFFFFFF"),
|
|
200
|
+
|
|
201
|
+
# Hard-coded warp size
|
|
202
|
+
(re.compile(r'#define\s+WARP_SIZE\s+32\b'), "warp_size_32", "32"),
|
|
203
|
+
(re.compile(r'const(?:expr)?\s+\w+\s+(?:warp|WARP)_?(?:size|SIZE)\s*=\s*32\b'), "warp_size_32", "32"),
|
|
204
|
+
]
|
|
205
|
+
|
|
206
|
+
def parse_file(self, content: str) -> dict[str, Any]:
|
|
207
|
+
"""Parse a HIP source file and extract GPU constructs.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
content: The source file content as a string.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Dictionary containing:
|
|
214
|
+
- kernels: List of HIPKernel
|
|
215
|
+
- device_functions: List of HIPDeviceFunction
|
|
216
|
+
- shared_memory: List of SharedMemoryAllocation
|
|
217
|
+
- launch_sites: List of KernelLaunchSite
|
|
218
|
+
- wavefront_patterns: List of WavefrontPattern
|
|
219
|
+
"""
|
|
220
|
+
kernels = self._extract_kernels(content)
|
|
221
|
+
device_functions = self._extract_device_functions(content)
|
|
222
|
+
shared_memory = self._extract_shared_memory(content)
|
|
223
|
+
launch_sites = self._extract_launch_sites(content)
|
|
224
|
+
wavefront_patterns = self._extract_wavefront_patterns(content)
|
|
225
|
+
|
|
226
|
+
return {
|
|
227
|
+
"kernels": kernels,
|
|
228
|
+
"device_functions": device_functions,
|
|
229
|
+
"shared_memory": shared_memory,
|
|
230
|
+
"launch_sites": launch_sites,
|
|
231
|
+
"wavefront_patterns": wavefront_patterns,
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
def _extract_kernels(self, content: str) -> list[HIPKernel]:
|
|
235
|
+
"""Extract all __global__ kernel functions."""
|
|
236
|
+
kernels: list[HIPKernel] = []
|
|
237
|
+
|
|
238
|
+
for match in self._KERNEL_PATTERN.finditer(content):
|
|
239
|
+
line = content[:match.start()].count('\n')
|
|
240
|
+
|
|
241
|
+
attributes = []
|
|
242
|
+
|
|
243
|
+
# Look for __launch_bounds__ in the matched region or nearby
|
|
244
|
+
# Search from a bit before the match to where __global__ ends
|
|
245
|
+
search_start = max(0, match.start() - 50)
|
|
246
|
+
search_end = match.end()
|
|
247
|
+
search_region = content[search_start:search_end]
|
|
248
|
+
|
|
249
|
+
lb_match = self._LAUNCH_BOUNDS_PATTERN.search(search_region)
|
|
250
|
+
if lb_match:
|
|
251
|
+
attributes.append(f"__launch_bounds__({lb_match.group(1)})")
|
|
252
|
+
|
|
253
|
+
kernel_name = match.group(2)
|
|
254
|
+
|
|
255
|
+
# Extract parameters
|
|
256
|
+
params, param_info = self._extract_parameters(content, match.end() - 1)
|
|
257
|
+
|
|
258
|
+
# Find end of function (approximate by finding matching brace)
|
|
259
|
+
end_line = self._find_function_end(content, match.end())
|
|
260
|
+
|
|
261
|
+
# Extract docstring (comment immediately before the kernel)
|
|
262
|
+
docstring = self._extract_docstring(content, match.start())
|
|
263
|
+
|
|
264
|
+
kernels.append(HIPKernel(
|
|
265
|
+
name=kernel_name,
|
|
266
|
+
line=line,
|
|
267
|
+
end_line=end_line,
|
|
268
|
+
parameters=params,
|
|
269
|
+
parameter_info=param_info,
|
|
270
|
+
attributes=attributes,
|
|
271
|
+
docstring=docstring,
|
|
272
|
+
))
|
|
273
|
+
|
|
274
|
+
return kernels
|
|
275
|
+
|
|
276
|
+
def _extract_device_functions(self, content: str) -> list[HIPDeviceFunction]:
|
|
277
|
+
"""Extract all __device__ helper functions."""
|
|
278
|
+
device_funcs: list[HIPDeviceFunction] = []
|
|
279
|
+
|
|
280
|
+
for match in self._DEVICE_FUNC_PATTERN.finditer(content):
|
|
281
|
+
# Skip if this is part of a __global__ __device__ combination
|
|
282
|
+
prefix = content[max(0, match.start() - 50):match.start()]
|
|
283
|
+
if '__global__' in prefix:
|
|
284
|
+
continue
|
|
285
|
+
|
|
286
|
+
line = content[:match.start()].count('\n')
|
|
287
|
+
|
|
288
|
+
# Group 1 is __forceinline__ (optional), group 2 is return type, group 3 is func name
|
|
289
|
+
forceinline_match = match.group(1)
|
|
290
|
+
return_type = match.group(2).strip()
|
|
291
|
+
func_name = match.group(3)
|
|
292
|
+
|
|
293
|
+
# Check if inline (either captured in pattern or in prefix)
|
|
294
|
+
is_inline = bool(forceinline_match) or \
|
|
295
|
+
'__forceinline__' in content[max(0, match.start() - 30):match.start()] or \
|
|
296
|
+
'inline' in content[max(0, match.start() - 20):match.start()]
|
|
297
|
+
|
|
298
|
+
# Extract parameters
|
|
299
|
+
params, param_info = self._extract_parameters(content, match.end() - 1)
|
|
300
|
+
|
|
301
|
+
end_line = self._find_function_end(content, match.end())
|
|
302
|
+
|
|
303
|
+
device_funcs.append(HIPDeviceFunction(
|
|
304
|
+
name=func_name,
|
|
305
|
+
line=line,
|
|
306
|
+
end_line=end_line,
|
|
307
|
+
parameters=params,
|
|
308
|
+
parameter_info=param_info,
|
|
309
|
+
return_type=return_type,
|
|
310
|
+
is_inline=is_inline,
|
|
311
|
+
))
|
|
312
|
+
|
|
313
|
+
return device_funcs
|
|
314
|
+
|
|
315
|
+
def _extract_shared_memory(self, content: str) -> list[SharedMemoryAllocation]:
|
|
316
|
+
"""Extract all __shared__ memory declarations."""
|
|
317
|
+
shared_mem: list[SharedMemoryAllocation] = []
|
|
318
|
+
|
|
319
|
+
for match in self._SHARED_MEMORY_PATTERN.finditer(content):
|
|
320
|
+
line = content[:match.start()].count('\n')
|
|
321
|
+
|
|
322
|
+
type_str = match.group(1).strip()
|
|
323
|
+
var_name = match.group(2)
|
|
324
|
+
array_dims = match.group(3) # Could be [n] or [n][m] for 2D
|
|
325
|
+
|
|
326
|
+
# Clean up the array dimension string
|
|
327
|
+
array_size = None
|
|
328
|
+
if array_dims:
|
|
329
|
+
# Remove brackets and combine dimensions
|
|
330
|
+
array_size = array_dims.strip()
|
|
331
|
+
|
|
332
|
+
# Try to compute size in bytes
|
|
333
|
+
size_bytes = self._estimate_shared_mem_size(type_str, array_size)
|
|
334
|
+
|
|
335
|
+
# Check if it's dynamic (extern __shared__)
|
|
336
|
+
is_dynamic = 'extern' in content[max(0, match.start() - 20):match.start()]
|
|
337
|
+
|
|
338
|
+
shared_mem.append(SharedMemoryAllocation(
|
|
339
|
+
name=var_name,
|
|
340
|
+
type_str=type_str,
|
|
341
|
+
line=line,
|
|
342
|
+
size_bytes=size_bytes,
|
|
343
|
+
array_size=array_size,
|
|
344
|
+
is_dynamic=is_dynamic,
|
|
345
|
+
))
|
|
346
|
+
|
|
347
|
+
return shared_mem
|
|
348
|
+
|
|
349
|
+
def _extract_launch_sites(self, content: str) -> list[KernelLaunchSite]:
|
|
350
|
+
"""Extract all kernel launch sites."""
|
|
351
|
+
launch_sites: list[KernelLaunchSite] = []
|
|
352
|
+
|
|
353
|
+
# <<<>>> syntax
|
|
354
|
+
for match in self._LAUNCH_SYNTAX_PATTERN.finditer(content):
|
|
355
|
+
line = content[:match.start()].count('\n')
|
|
356
|
+
|
|
357
|
+
launch_sites.append(KernelLaunchSite(
|
|
358
|
+
kernel_name=match.group(1),
|
|
359
|
+
line=line,
|
|
360
|
+
grid_dim=match.group(2).strip() if match.group(2) else None,
|
|
361
|
+
block_dim=match.group(3).strip() if match.group(3) else None,
|
|
362
|
+
shared_mem_bytes=match.group(4).strip() if match.group(4) else None,
|
|
363
|
+
stream=match.group(5).strip() if match.group(5) else None,
|
|
364
|
+
is_hip_launch_kernel_ggl=False,
|
|
365
|
+
))
|
|
366
|
+
|
|
367
|
+
# hipLaunchKernelGGL
|
|
368
|
+
for match in self._HIP_LAUNCH_PATTERN.finditer(content):
|
|
369
|
+
line = content[:match.start()].count('\n')
|
|
370
|
+
|
|
371
|
+
launch_sites.append(KernelLaunchSite(
|
|
372
|
+
kernel_name=match.group(1),
|
|
373
|
+
line=line,
|
|
374
|
+
grid_dim=match.group(2).strip() if match.group(2) else None,
|
|
375
|
+
block_dim=match.group(3).strip() if match.group(3) else None,
|
|
376
|
+
shared_mem_bytes=match.group(4).strip() if match.group(4) else None,
|
|
377
|
+
stream=match.group(5).strip() if match.group(5) else None,
|
|
378
|
+
is_hip_launch_kernel_ggl=True,
|
|
379
|
+
))
|
|
380
|
+
|
|
381
|
+
return launch_sites
|
|
382
|
+
|
|
383
|
+
def _extract_wavefront_patterns(self, content: str) -> list[WavefrontPattern]:
|
|
384
|
+
"""Extract patterns that might indicate incorrect wavefront size assumptions.
|
|
385
|
+
|
|
386
|
+
AMD GPUs use 64-thread wavefronts (CDNA) or configurable 32/64 (RDNA),
|
|
387
|
+
while CUDA uses 32-thread warps. Code written for CUDA may have
|
|
388
|
+
hard-coded assumptions about warp size that break on AMD.
|
|
389
|
+
"""
|
|
390
|
+
patterns: list[WavefrontPattern] = []
|
|
391
|
+
lines = content.split('\n')
|
|
392
|
+
|
|
393
|
+
for i, line in enumerate(lines):
|
|
394
|
+
# Skip comments
|
|
395
|
+
stripped = line.strip()
|
|
396
|
+
if stripped.startswith('//') or stripped.startswith('/*'):
|
|
397
|
+
continue
|
|
398
|
+
|
|
399
|
+
# Skip lines that use warpSize (correct portable code)
|
|
400
|
+
if 'warpSize' in line or '__AMDGCN_WAVEFRONT_SIZE__' in line:
|
|
401
|
+
continue
|
|
402
|
+
|
|
403
|
+
for pattern_re, pattern_type, problematic_value in self._WARP_SIZE_32_PATTERNS:
|
|
404
|
+
match = pattern_re.search(line)
|
|
405
|
+
if match:
|
|
406
|
+
patterns.append(WavefrontPattern(
|
|
407
|
+
pattern_type=pattern_type,
|
|
408
|
+
line=i,
|
|
409
|
+
code_snippet=line.strip(),
|
|
410
|
+
problematic_value=problematic_value,
|
|
411
|
+
))
|
|
412
|
+
|
|
413
|
+
return patterns
|
|
414
|
+
|
|
415
|
+
def _extract_parameters(self, content: str, paren_start: int) -> tuple[list[str], list[HIPParameter]]:
|
|
416
|
+
"""Extract function parameters starting from the opening parenthesis.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
Tuple of (parameter names list, detailed parameter info list)
|
|
420
|
+
"""
|
|
421
|
+
if paren_start >= len(content) or content[paren_start] != '(':
|
|
422
|
+
return [], []
|
|
423
|
+
|
|
424
|
+
# Find matching closing paren
|
|
425
|
+
depth = 0
|
|
426
|
+
param_end = paren_start
|
|
427
|
+
|
|
428
|
+
for i in range(paren_start, len(content)):
|
|
429
|
+
char = content[i]
|
|
430
|
+
if char == '(':
|
|
431
|
+
depth += 1
|
|
432
|
+
elif char == ')':
|
|
433
|
+
depth -= 1
|
|
434
|
+
if depth == 0:
|
|
435
|
+
param_end = i
|
|
436
|
+
break
|
|
437
|
+
|
|
438
|
+
if param_end == paren_start:
|
|
439
|
+
return [], []
|
|
440
|
+
|
|
441
|
+
param_str = content[paren_start + 1:param_end]
|
|
442
|
+
|
|
443
|
+
# Parse parameters handling templates and nested parens
|
|
444
|
+
params: list[str] = []
|
|
445
|
+
param_info: list[HIPParameter] = []
|
|
446
|
+
current_param = ""
|
|
447
|
+
template_depth = 0
|
|
448
|
+
paren_depth = 0
|
|
449
|
+
|
|
450
|
+
for char in param_str:
|
|
451
|
+
if char == '<':
|
|
452
|
+
template_depth += 1
|
|
453
|
+
current_param += char
|
|
454
|
+
elif char == '>':
|
|
455
|
+
template_depth -= 1
|
|
456
|
+
current_param += char
|
|
457
|
+
elif char == '(':
|
|
458
|
+
paren_depth += 1
|
|
459
|
+
current_param += char
|
|
460
|
+
elif char == ')':
|
|
461
|
+
paren_depth -= 1
|
|
462
|
+
current_param += char
|
|
463
|
+
elif char == ',' and template_depth == 0 and paren_depth == 0:
|
|
464
|
+
param = self._parse_single_parameter(current_param.strip())
|
|
465
|
+
if param:
|
|
466
|
+
params.append(param.name)
|
|
467
|
+
param_info.append(param)
|
|
468
|
+
current_param = ""
|
|
469
|
+
else:
|
|
470
|
+
current_param += char
|
|
471
|
+
|
|
472
|
+
# Handle last parameter
|
|
473
|
+
if current_param.strip():
|
|
474
|
+
param = self._parse_single_parameter(current_param.strip())
|
|
475
|
+
if param:
|
|
476
|
+
params.append(param.name)
|
|
477
|
+
param_info.append(param)
|
|
478
|
+
|
|
479
|
+
return params, param_info
|
|
480
|
+
|
|
481
|
+
def _parse_single_parameter(self, param_str: str) -> HIPParameter | None:
|
|
482
|
+
"""Parse a single parameter string into HIPParameter.
|
|
483
|
+
|
|
484
|
+
Returns None if param_str cannot be parsed into a valid parameter.
|
|
485
|
+
"""
|
|
486
|
+
if not param_str:
|
|
487
|
+
return None
|
|
488
|
+
|
|
489
|
+
# Remove default value if present
|
|
490
|
+
if '=' in param_str:
|
|
491
|
+
param_str = param_str.split('=')[0].strip()
|
|
492
|
+
|
|
493
|
+
is_const = 'const' in param_str
|
|
494
|
+
is_pointer = '*' in param_str or '&' in param_str
|
|
495
|
+
|
|
496
|
+
# Extract name (last token after removing type qualifiers)
|
|
497
|
+
parts = param_str.replace('*', ' * ').replace('&', ' & ').split()
|
|
498
|
+
|
|
499
|
+
if not parts:
|
|
500
|
+
return None
|
|
501
|
+
|
|
502
|
+
# Find the parameter name (last identifier)
|
|
503
|
+
name = parts[-1].strip('*&')
|
|
504
|
+
|
|
505
|
+
# If we only have one part, we can't distinguish type from name
|
|
506
|
+
# This could be a parameter like `void` (no name) or incomplete code
|
|
507
|
+
if len(parts) == 1:
|
|
508
|
+
# Assume it's just a name with unknown type
|
|
509
|
+
type_str = ""
|
|
510
|
+
else:
|
|
511
|
+
# Reconstruct type from all parts except the last (name)
|
|
512
|
+
type_parts = parts[:-1]
|
|
513
|
+
type_str = ' '.join(type_parts).replace(' * ', '*').replace(' & ', '&')
|
|
514
|
+
|
|
515
|
+
return HIPParameter(
|
|
516
|
+
name=name,
|
|
517
|
+
type_str=type_str,
|
|
518
|
+
is_pointer=is_pointer,
|
|
519
|
+
is_const=is_const,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
def _find_function_end(self, content: str, start_pos: int) -> int:
|
|
523
|
+
"""Find the line number where a function ends (closing brace).
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
Line number of closing brace, or start line if function body not found.
|
|
527
|
+
|
|
528
|
+
Note: Returns start line (not an approximation) when closing brace cannot be found.
|
|
529
|
+
This happens for incomplete code during editing - the caller should handle this case.
|
|
530
|
+
"""
|
|
531
|
+
start_line = content[:start_pos].count('\n')
|
|
532
|
+
|
|
533
|
+
# Find opening brace
|
|
534
|
+
brace_pos = content.find('{', start_pos)
|
|
535
|
+
if brace_pos == -1:
|
|
536
|
+
# No function body found (e.g., forward declaration or incomplete code)
|
|
537
|
+
return start_line
|
|
538
|
+
|
|
539
|
+
# Find matching closing brace
|
|
540
|
+
depth = 1
|
|
541
|
+
for i in range(brace_pos + 1, len(content)):
|
|
542
|
+
if content[i] == '{':
|
|
543
|
+
depth += 1
|
|
544
|
+
elif content[i] == '}':
|
|
545
|
+
depth -= 1
|
|
546
|
+
if depth == 0:
|
|
547
|
+
return content[:i].count('\n')
|
|
548
|
+
|
|
549
|
+
# Unbalanced braces - incomplete code during editing
|
|
550
|
+
# Return start line to indicate we couldn't determine the end
|
|
551
|
+
return start_line
|
|
552
|
+
|
|
553
|
+
def _extract_docstring(self, content: str, func_start: int) -> str | None:
|
|
554
|
+
"""Extract documentation comment before a function."""
|
|
555
|
+
# Look backwards for a comment block
|
|
556
|
+
search_start = max(0, func_start - 500)
|
|
557
|
+
prefix = content[search_start:func_start].rstrip()
|
|
558
|
+
|
|
559
|
+
# Check for /// or /** style comments
|
|
560
|
+
lines = prefix.split('\n')
|
|
561
|
+
doc_lines: list[str] = []
|
|
562
|
+
|
|
563
|
+
for line in reversed(lines):
|
|
564
|
+
stripped = line.strip()
|
|
565
|
+
if stripped.startswith('///'):
|
|
566
|
+
doc_lines.insert(0, stripped[3:].strip())
|
|
567
|
+
elif stripped.startswith('*') and not stripped.startswith('*/'):
|
|
568
|
+
doc_lines.insert(0, stripped[1:].strip())
|
|
569
|
+
elif stripped.startswith('/**'):
|
|
570
|
+
doc_lines.insert(0, stripped[3:].strip())
|
|
571
|
+
break
|
|
572
|
+
elif stripped.endswith('*/'):
|
|
573
|
+
doc_lines.insert(0, stripped[:-2].strip())
|
|
574
|
+
elif stripped == '' or stripped.startswith('//'):
|
|
575
|
+
continue
|
|
576
|
+
else:
|
|
577
|
+
# Not a doc comment, stop
|
|
578
|
+
if doc_lines:
|
|
579
|
+
break
|
|
580
|
+
|
|
581
|
+
if doc_lines:
|
|
582
|
+
return '\n'.join(doc_lines)
|
|
583
|
+
return None
|
|
584
|
+
|
|
585
|
+
def _estimate_shared_mem_size(self, type_str: str, array_size: str | None) -> int | None:
|
|
586
|
+
"""Estimate shared memory size in bytes.
|
|
587
|
+
|
|
588
|
+
Handles both 1D (e.g., [256]) and 2D (e.g., [16][16]) arrays.
|
|
589
|
+
Returns None if size cannot be determined (e.g., non-literal dimension).
|
|
590
|
+
|
|
591
|
+
Why no eval(): We only handle literal integer dimensions and simple multiplication.
|
|
592
|
+
Variable-based dimensions (e.g., [BLOCK_SIZE]) return None - that's correct behavior
|
|
593
|
+
since we can't know the value at parse time.
|
|
594
|
+
"""
|
|
595
|
+
# Size mapping for common types
|
|
596
|
+
type_sizes = {
|
|
597
|
+
'char': 1, 'int8_t': 1, 'uint8_t': 1,
|
|
598
|
+
'short': 2, 'int16_t': 2, 'uint16_t': 2, 'half': 2, '__half': 2,
|
|
599
|
+
'int': 4, 'int32_t': 4, 'uint32_t': 4, 'float': 4, 'unsigned': 4,
|
|
600
|
+
'long': 8, 'int64_t': 8, 'uint64_t': 8, 'double': 8, 'long long': 8,
|
|
601
|
+
'float4': 16, 'float2': 8, 'int4': 16, 'int2': 8,
|
|
602
|
+
'double2': 16, 'double4': 32,
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
# Find base type
|
|
606
|
+
base_type = type_str.strip()
|
|
607
|
+
type_size: int | None = None
|
|
608
|
+
for known_type, size in type_sizes.items():
|
|
609
|
+
if known_type in base_type:
|
|
610
|
+
type_size = size
|
|
611
|
+
break
|
|
612
|
+
|
|
613
|
+
if type_size is None:
|
|
614
|
+
return None
|
|
615
|
+
|
|
616
|
+
if not array_size:
|
|
617
|
+
return type_size
|
|
618
|
+
|
|
619
|
+
# Parse array dimensions - could be [n] or [n][m]
|
|
620
|
+
# Extract all bracketed dimensions
|
|
621
|
+
dims = re.findall(r'\[([^\]]+)\]', array_size)
|
|
622
|
+
if not dims:
|
|
623
|
+
return type_size
|
|
624
|
+
|
|
625
|
+
total_elements = 1
|
|
626
|
+
for dim in dims:
|
|
627
|
+
dim_value = self._parse_dimension_expression(dim.strip())
|
|
628
|
+
if dim_value is None:
|
|
629
|
+
# Non-literal dimension (e.g., variable) - cannot determine size
|
|
630
|
+
return None
|
|
631
|
+
total_elements *= dim_value
|
|
632
|
+
|
|
633
|
+
return type_size * total_elements
|
|
634
|
+
|
|
635
|
+
def _parse_dimension_expression(self, expr: str) -> int | None:
|
|
636
|
+
"""Parse a dimension expression safely without eval().
|
|
637
|
+
|
|
638
|
+
Handles:
|
|
639
|
+
- Simple integers: "256"
|
|
640
|
+
- Simple multiplication: "16 * 16", "BLOCK_SIZE * 4" (only if all literal)
|
|
641
|
+
|
|
642
|
+
Returns None for anything we can't safely parse (variables, complex expressions).
|
|
643
|
+
This is correct behavior - we don't know variable values at parse time.
|
|
644
|
+
"""
|
|
645
|
+
expr = expr.strip()
|
|
646
|
+
|
|
647
|
+
# Simple integer
|
|
648
|
+
if expr.isdigit():
|
|
649
|
+
return int(expr)
|
|
650
|
+
|
|
651
|
+
# Handle expressions with * (multiplication only)
|
|
652
|
+
if '*' in expr:
|
|
653
|
+
parts = expr.split('*')
|
|
654
|
+
result = 1
|
|
655
|
+
for part in parts:
|
|
656
|
+
part = part.strip()
|
|
657
|
+
if not part.isdigit():
|
|
658
|
+
# Contains a variable or non-integer - cannot evaluate
|
|
659
|
+
return None
|
|
660
|
+
result *= int(part)
|
|
661
|
+
return result
|
|
662
|
+
|
|
663
|
+
# Not a simple literal - could be a variable like BLOCK_SIZE
|
|
664
|
+
return None
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def is_hip_file(content: str) -> bool:
|
|
668
|
+
"""Check if content appears to be a HIP file based on content markers."""
|
|
669
|
+
hip_markers = [
|
|
670
|
+
'#include <hip/hip_runtime.h>',
|
|
671
|
+
'#include "hip/hip_runtime.h"',
|
|
672
|
+
'#include <hip/hip_runtime_api.h>',
|
|
673
|
+
'hipMalloc',
|
|
674
|
+
'hipMemcpy',
|
|
675
|
+
'hipFree',
|
|
676
|
+
'hipLaunchKernelGGL',
|
|
677
|
+
'hipDeviceSynchronize',
|
|
678
|
+
'__HIP_PLATFORM_AMD__',
|
|
679
|
+
'__HIP_PLATFORM_HCC__',
|
|
680
|
+
'HIP_KERNEL_NAME',
|
|
681
|
+
]
|
|
682
|
+
|
|
683
|
+
content_lower = content.lower()
|
|
684
|
+
for marker in hip_markers:
|
|
685
|
+
if marker.lower() in content_lower:
|
|
686
|
+
return True
|
|
687
|
+
|
|
688
|
+
return False
|