wafer-lsp 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. wafer_lsp/__init__.py +1 -0
  2. wafer_lsp/__main__.py +9 -0
  3. wafer_lsp/analyzers/__init__.py +0 -0
  4. wafer_lsp/analyzers/compiler_integration.py +16 -0
  5. wafer_lsp/analyzers/docs_index.py +36 -0
  6. wafer_lsp/handlers/__init__.py +30 -0
  7. wafer_lsp/handlers/code_action.py +48 -0
  8. wafer_lsp/handlers/code_lens.py +48 -0
  9. wafer_lsp/handlers/completion.py +6 -0
  10. wafer_lsp/handlers/diagnostics.py +41 -0
  11. wafer_lsp/handlers/document_symbol.py +176 -0
  12. wafer_lsp/handlers/hip_diagnostics.py +303 -0
  13. wafer_lsp/handlers/hover.py +251 -0
  14. wafer_lsp/handlers/inlay_hint.py +245 -0
  15. wafer_lsp/handlers/semantic_tokens.py +224 -0
  16. wafer_lsp/handlers/workspace_symbol.py +87 -0
  17. wafer_lsp/languages/README.md +195 -0
  18. wafer_lsp/languages/__init__.py +17 -0
  19. wafer_lsp/languages/converter.py +88 -0
  20. wafer_lsp/languages/detector.py +107 -0
  21. wafer_lsp/languages/parser_manager.py +33 -0
  22. wafer_lsp/languages/registry.py +120 -0
  23. wafer_lsp/languages/types.py +37 -0
  24. wafer_lsp/parsers/__init__.py +36 -0
  25. wafer_lsp/parsers/base_parser.py +9 -0
  26. wafer_lsp/parsers/cuda_parser.py +95 -0
  27. wafer_lsp/parsers/cutedsl_parser.py +114 -0
  28. wafer_lsp/parsers/hip_parser.py +688 -0
  29. wafer_lsp/server.py +58 -0
  30. wafer_lsp/services/__init__.py +38 -0
  31. wafer_lsp/services/analysis_service.py +22 -0
  32. wafer_lsp/services/docs_service.py +40 -0
  33. wafer_lsp/services/document_service.py +20 -0
  34. wafer_lsp/services/hip_docs.py +806 -0
  35. wafer_lsp/services/hip_hover_service.py +412 -0
  36. wafer_lsp/services/hover_service.py +237 -0
  37. wafer_lsp/services/language_registry_service.py +26 -0
  38. wafer_lsp/services/position_service.py +77 -0
  39. wafer_lsp/utils/__init__.py +0 -0
  40. wafer_lsp/utils/lsp_helpers.py +79 -0
  41. wafer_lsp-0.1.13.dist-info/METADATA +60 -0
  42. wafer_lsp-0.1.13.dist-info/RECORD +44 -0
  43. wafer_lsp-0.1.13.dist-info/WHEEL +4 -0
  44. wafer_lsp-0.1.13.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,688 @@
1
+ """
2
+ HIP Parser for wafer-lsp.
3
+
4
+ Extracts GPU constructs from HIP (AMD GPU) source files:
5
+ - __global__ kernels
6
+ - __device__ helper functions
7
+ - __shared__ memory allocations (LDS)
8
+ - Kernel launch sites (<<<>>> and hipLaunchKernelGGL)
9
+ - Wavefront-sensitive patterns for diagnostics
10
+ """
11
+
12
+ import re
13
+ from dataclasses import dataclass, field
14
+ from typing import Any
15
+
16
+ from .base_parser import BaseParser
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class HIPParameter:
21
+ """A kernel or device function parameter."""
22
+ name: str
23
+ type_str: str
24
+ is_pointer: bool = False
25
+ is_const: bool = False
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class HIPKernel:
30
+ """A __global__ GPU kernel function."""
31
+ name: str
32
+ line: int
33
+ end_line: int
34
+ parameters: list[str]
35
+ parameter_info: list[HIPParameter]
36
+ attributes: list[str] # e.g., __launch_bounds__(256)
37
+ docstring: str | None = None
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class HIPDeviceFunction:
42
+ """A __device__ helper function."""
43
+ name: str
44
+ line: int
45
+ end_line: int
46
+ parameters: list[str]
47
+ parameter_info: list[HIPParameter]
48
+ return_type: str = "void"
49
+ is_inline: bool = False
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class SharedMemoryAllocation:
54
+ """A __shared__ memory (LDS) allocation."""
55
+ name: str
56
+ type_str: str
57
+ line: int
58
+ size_bytes: int | None # None if dynamic/unknown
59
+ array_size: str | None # Array dimension if static
60
+ is_dynamic: bool = False
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class KernelLaunchSite:
65
+ """Where a kernel is launched."""
66
+ kernel_name: str
67
+ line: int
68
+ grid_dim: str | None # Grid dimensions if determinable
69
+ block_dim: str | None # Block dimensions if determinable
70
+ shared_mem_bytes: str | None # Dynamic shared memory size
71
+ stream: str | None # CUDA stream if specified
72
+ is_hip_launch_kernel_ggl: bool = False # True if using hipLaunchKernelGGL
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class WavefrontPattern:
77
+ """A pattern that might indicate wavefront-size assumptions.
78
+
79
+ These patterns often indicate code written for CUDA's 32-thread warps
80
+ that may behave incorrectly on AMD's 64-thread wavefronts.
81
+ """
82
+ pattern_type: str # "warp_size_32", "ballot_32bit", "shuffle_mask", "lane_calc_32"
83
+ line: int
84
+ code_snippet: str
85
+ problematic_value: str # The specific value causing concern (e.g., "32", "0xFFFFFFFF")
86
+ severity: str = "warning" # "warning", "error", "info"
87
+
88
+
89
+ class HIPParser(BaseParser):
90
+ """Parser for HIP (AMD GPU) source files.
91
+
92
+ Extracts kernels, device functions, shared memory allocations,
93
+ launch sites, and wavefront-sensitive patterns using regex-based parsing.
94
+
95
+ Why regex instead of full AST parsing:
96
+ - HIP/CUDA syntax is C++ which is complex to fully parse
97
+ - Key constructs (__global__, __device__, __shared__) are lexically distinct
98
+ - Fast enough for real-time IDE use (<100ms)
99
+ - Works on incomplete/invalid code during editing
100
+ """
101
+
102
+ # Regex patterns for HIP constructs
103
+
104
+ # Matches __global__ function declarations
105
+ # Captures: return type, function name
106
+ # Note: __launch_bounds__ can be on the same line or line before
107
+ _KERNEL_PATTERN = re.compile(
108
+ r'(?:template\s*<[^>]*>\s*)?' # Optional template
109
+ r'(?:__launch_bounds__\s*\([^)]+\)\s*[\n\r]*)?' # Optional launch bounds before __global__
110
+ r'__global__\s+'
111
+ r'(?:__device__\s+)?' # Optional __device__ (CUDA allows both)
112
+ r'(?:__launch_bounds__\s*\([^)]+\)\s*)?' # Optional launch bounds after __global__
113
+ r'([\w\s\*&:<>,]+?)\s+' # Return type (usually void)
114
+ r'(\w+)\s*' # Function name (captured)
115
+ r'\(', # Start of parameter list
116
+ re.MULTILINE | re.DOTALL
117
+ )
118
+
119
+ # Separate pattern to detect __launch_bounds__ near __global__
120
+ _LAUNCH_BOUNDS_PATTERN = re.compile(
121
+ r'__launch_bounds__\s*\(([^)]+)\)',
122
+ re.MULTILINE
123
+ )
124
+
125
+ # Matches __device__ function declarations (not __global__)
126
+ _DEVICE_FUNC_PATTERN = re.compile(
127
+ r'(?:template\s*<[^>]*>\s*)?' # Optional template
128
+ r'__device__\s+'
129
+ r'(?!__global__)' # Not followed by __global__
130
+ r'(__forceinline__\s+)?' # Optional __forceinline__ (captured)
131
+ r'(?:__host__\s+)?' # Optional __host__
132
+ r'(?:inline\s+)?'
133
+ r'([\w\s\*&:<>,]+?)\s+' # Return type (captured)
134
+ r'(\w+)\s*' # Function name (captured)
135
+ r'\(',
136
+ re.MULTILINE
137
+ )
138
+
139
+ # Matches __shared__ variable declarations (including 2D arrays)
140
+ _SHARED_MEMORY_PATTERN = re.compile(
141
+ r'__shared__\s+'
142
+ r'([\w\s:<>,]+?)\s+' # Type (captured)
143
+ r'(\w+)\s*' # Variable name (captured)
144
+ r'(\[[^\]]*\](?:\s*\[[^\]]*\])?)?' # Optional array size(s) including 2D (captured)
145
+ r'\s*[;=]',
146
+ re.MULTILINE
147
+ )
148
+
149
+ # Matches kernel launches with <<<>>> syntax
150
+ _LAUNCH_SYNTAX_PATTERN = re.compile(
151
+ r'(\w+)\s*' # Kernel name (captured)
152
+ r'<<<\s*'
153
+ r'([^,>]+)\s*,\s*' # Grid dim (captured)
154
+ r'([^,>]+)' # Block dim (captured)
155
+ r'(?:\s*,\s*([^,>]+))?' # Optional shared mem (captured)
156
+ r'(?:\s*,\s*([^>]+))?' # Optional stream (captured)
157
+ r'\s*>>>',
158
+ re.MULTILINE
159
+ )
160
+
161
+ # Matches hipLaunchKernelGGL calls
162
+ _HIP_LAUNCH_PATTERN = re.compile(
163
+ r'hipLaunchKernelGGL\s*\(\s*'
164
+ r'(\w+)\s*,\s*' # Kernel name (captured)
165
+ r'([^,]+)\s*,\s*' # Grid dim (captured)
166
+ r'([^,]+)\s*,\s*' # Block dim (captured)
167
+ r'([^,]+)\s*,\s*' # Shared mem (captured)
168
+ r'([^,)]+)' # Stream (captured)
169
+ r'(?:\s*,)?', # Optional comma before args
170
+ re.MULTILINE
171
+ )
172
+
173
+ # Patterns that suggest incorrect wavefront size assumptions (CUDA's 32 vs AMD's 64)
174
+ _WARP_SIZE_32_PATTERNS = [
175
+ # threadIdx.x < 32, threadIdx.x % 32, threadIdx.x & 31
176
+ (re.compile(r'threadIdx\s*\.\s*[xyz]\s*[<%&]\s*32\b'), "warp_size_32", "32"),
177
+ (re.compile(r'threadIdx\s*\.\s*[xyz]\s*&\s*31\b'), "warp_size_32", "31"),
178
+ (re.compile(r'threadIdx\s*\.\s*[xyz]\s*/\s*32\b'), "warp_size_32", "32"),
179
+
180
+ # Lane/warp index calculations with hard-coded 32
181
+ (re.compile(r'\blaneId\s*=\s*[^;]*%\s*32\b'), "lane_calc_32", "32"),
182
+ (re.compile(r'\bwarpId\s*=\s*[^;]*/\s*32\b'), "lane_calc_32", "32"),
183
+ (re.compile(r'\blane\s*=\s*[^;]*&\s*31\b'), "lane_calc_32", "31"),
184
+
185
+ # Ballot result compared to 32-bit mask
186
+ (re.compile(r'__ballot\s*\([^)]*\)\s*==\s*0x[Ff]{8}\b'), "ballot_32bit", "0xFFFFFFFF"),
187
+ (re.compile(r'__ballot\s*\([^)]*\)\s*!=\s*0x[Ff]{8}\b'), "ballot_32bit", "0xFFFFFFFF"),
188
+ (re.compile(r'__ballot\s*\([^)]*\)\s*&\s*0x[Ff]{8}\b'), "ballot_32bit", "0xFFFFFFFF"),
189
+
190
+ # Shuffle operations with mask suggesting 32-thread warp
191
+ (re.compile(r'__shfl(?:_sync)?\s*\([^)]*,\s*0x[Ff]{8}\s*\)'), "shuffle_mask", "0xFFFFFFFF"),
192
+ # Match __shfl_down(val, offset, 32) or similar with explicit width=32
193
+ (re.compile(r'__shfl_down(?:_sync)?\s*\([^,]+,\s*[^,]+,\s*32\s*\)'), "shuffle_mask", "32"),
194
+ (re.compile(r'__shfl_up(?:_sync)?\s*\([^,]+,\s*[^,]+,\s*32\s*\)'), "shuffle_mask", "32"),
195
+ (re.compile(r'__shfl_xor(?:_sync)?\s*\([^,]+,\s*[^,]+,\s*32\s*\)'), "shuffle_mask", "32"),
196
+ (re.compile(r'__shfl(?:_sync)?\s*\([^,]+,\s*[^,]+,\s*32\s*\)'), "shuffle_mask", "32"),
197
+
198
+ # activemask() compared to 32-bit value
199
+ (re.compile(r'__activemask\s*\(\s*\)\s*==\s*0x[Ff]{8}\b'), "ballot_32bit", "0xFFFFFFFF"),
200
+
201
+ # Hard-coded warp size
202
+ (re.compile(r'#define\s+WARP_SIZE\s+32\b'), "warp_size_32", "32"),
203
+ (re.compile(r'const(?:expr)?\s+\w+\s+(?:warp|WARP)_?(?:size|SIZE)\s*=\s*32\b'), "warp_size_32", "32"),
204
+ ]
205
+
206
+ def parse_file(self, content: str) -> dict[str, Any]:
207
+ """Parse a HIP source file and extract GPU constructs.
208
+
209
+ Args:
210
+ content: The source file content as a string.
211
+
212
+ Returns:
213
+ Dictionary containing:
214
+ - kernels: List of HIPKernel
215
+ - device_functions: List of HIPDeviceFunction
216
+ - shared_memory: List of SharedMemoryAllocation
217
+ - launch_sites: List of KernelLaunchSite
218
+ - wavefront_patterns: List of WavefrontPattern
219
+ """
220
+ kernels = self._extract_kernels(content)
221
+ device_functions = self._extract_device_functions(content)
222
+ shared_memory = self._extract_shared_memory(content)
223
+ launch_sites = self._extract_launch_sites(content)
224
+ wavefront_patterns = self._extract_wavefront_patterns(content)
225
+
226
+ return {
227
+ "kernels": kernels,
228
+ "device_functions": device_functions,
229
+ "shared_memory": shared_memory,
230
+ "launch_sites": launch_sites,
231
+ "wavefront_patterns": wavefront_patterns,
232
+ }
233
+
234
+ def _extract_kernels(self, content: str) -> list[HIPKernel]:
235
+ """Extract all __global__ kernel functions."""
236
+ kernels: list[HIPKernel] = []
237
+
238
+ for match in self._KERNEL_PATTERN.finditer(content):
239
+ line = content[:match.start()].count('\n')
240
+
241
+ attributes = []
242
+
243
+ # Look for __launch_bounds__ in the matched region or nearby
244
+ # Search from a bit before the match to where __global__ ends
245
+ search_start = max(0, match.start() - 50)
246
+ search_end = match.end()
247
+ search_region = content[search_start:search_end]
248
+
249
+ lb_match = self._LAUNCH_BOUNDS_PATTERN.search(search_region)
250
+ if lb_match:
251
+ attributes.append(f"__launch_bounds__({lb_match.group(1)})")
252
+
253
+ kernel_name = match.group(2)
254
+
255
+ # Extract parameters
256
+ params, param_info = self._extract_parameters(content, match.end() - 1)
257
+
258
+ # Find end of function (approximate by finding matching brace)
259
+ end_line = self._find_function_end(content, match.end())
260
+
261
+ # Extract docstring (comment immediately before the kernel)
262
+ docstring = self._extract_docstring(content, match.start())
263
+
264
+ kernels.append(HIPKernel(
265
+ name=kernel_name,
266
+ line=line,
267
+ end_line=end_line,
268
+ parameters=params,
269
+ parameter_info=param_info,
270
+ attributes=attributes,
271
+ docstring=docstring,
272
+ ))
273
+
274
+ return kernels
275
+
276
+ def _extract_device_functions(self, content: str) -> list[HIPDeviceFunction]:
277
+ """Extract all __device__ helper functions."""
278
+ device_funcs: list[HIPDeviceFunction] = []
279
+
280
+ for match in self._DEVICE_FUNC_PATTERN.finditer(content):
281
+ # Skip if this is part of a __global__ __device__ combination
282
+ prefix = content[max(0, match.start() - 50):match.start()]
283
+ if '__global__' in prefix:
284
+ continue
285
+
286
+ line = content[:match.start()].count('\n')
287
+
288
+ # Group 1 is __forceinline__ (optional), group 2 is return type, group 3 is func name
289
+ forceinline_match = match.group(1)
290
+ return_type = match.group(2).strip()
291
+ func_name = match.group(3)
292
+
293
+ # Check if inline (either captured in pattern or in prefix)
294
+ is_inline = bool(forceinline_match) or \
295
+ '__forceinline__' in content[max(0, match.start() - 30):match.start()] or \
296
+ 'inline' in content[max(0, match.start() - 20):match.start()]
297
+
298
+ # Extract parameters
299
+ params, param_info = self._extract_parameters(content, match.end() - 1)
300
+
301
+ end_line = self._find_function_end(content, match.end())
302
+
303
+ device_funcs.append(HIPDeviceFunction(
304
+ name=func_name,
305
+ line=line,
306
+ end_line=end_line,
307
+ parameters=params,
308
+ parameter_info=param_info,
309
+ return_type=return_type,
310
+ is_inline=is_inline,
311
+ ))
312
+
313
+ return device_funcs
314
+
315
+ def _extract_shared_memory(self, content: str) -> list[SharedMemoryAllocation]:
316
+ """Extract all __shared__ memory declarations."""
317
+ shared_mem: list[SharedMemoryAllocation] = []
318
+
319
+ for match in self._SHARED_MEMORY_PATTERN.finditer(content):
320
+ line = content[:match.start()].count('\n')
321
+
322
+ type_str = match.group(1).strip()
323
+ var_name = match.group(2)
324
+ array_dims = match.group(3) # Could be [n] or [n][m] for 2D
325
+
326
+ # Clean up the array dimension string
327
+ array_size = None
328
+ if array_dims:
329
+ # Remove brackets and combine dimensions
330
+ array_size = array_dims.strip()
331
+
332
+ # Try to compute size in bytes
333
+ size_bytes = self._estimate_shared_mem_size(type_str, array_size)
334
+
335
+ # Check if it's dynamic (extern __shared__)
336
+ is_dynamic = 'extern' in content[max(0, match.start() - 20):match.start()]
337
+
338
+ shared_mem.append(SharedMemoryAllocation(
339
+ name=var_name,
340
+ type_str=type_str,
341
+ line=line,
342
+ size_bytes=size_bytes,
343
+ array_size=array_size,
344
+ is_dynamic=is_dynamic,
345
+ ))
346
+
347
+ return shared_mem
348
+
349
+ def _extract_launch_sites(self, content: str) -> list[KernelLaunchSite]:
350
+ """Extract all kernel launch sites."""
351
+ launch_sites: list[KernelLaunchSite] = []
352
+
353
+ # <<<>>> syntax
354
+ for match in self._LAUNCH_SYNTAX_PATTERN.finditer(content):
355
+ line = content[:match.start()].count('\n')
356
+
357
+ launch_sites.append(KernelLaunchSite(
358
+ kernel_name=match.group(1),
359
+ line=line,
360
+ grid_dim=match.group(2).strip() if match.group(2) else None,
361
+ block_dim=match.group(3).strip() if match.group(3) else None,
362
+ shared_mem_bytes=match.group(4).strip() if match.group(4) else None,
363
+ stream=match.group(5).strip() if match.group(5) else None,
364
+ is_hip_launch_kernel_ggl=False,
365
+ ))
366
+
367
+ # hipLaunchKernelGGL
368
+ for match in self._HIP_LAUNCH_PATTERN.finditer(content):
369
+ line = content[:match.start()].count('\n')
370
+
371
+ launch_sites.append(KernelLaunchSite(
372
+ kernel_name=match.group(1),
373
+ line=line,
374
+ grid_dim=match.group(2).strip() if match.group(2) else None,
375
+ block_dim=match.group(3).strip() if match.group(3) else None,
376
+ shared_mem_bytes=match.group(4).strip() if match.group(4) else None,
377
+ stream=match.group(5).strip() if match.group(5) else None,
378
+ is_hip_launch_kernel_ggl=True,
379
+ ))
380
+
381
+ return launch_sites
382
+
383
+ def _extract_wavefront_patterns(self, content: str) -> list[WavefrontPattern]:
384
+ """Extract patterns that might indicate incorrect wavefront size assumptions.
385
+
386
+ AMD GPUs use 64-thread wavefronts (CDNA) or configurable 32/64 (RDNA),
387
+ while CUDA uses 32-thread warps. Code written for CUDA may have
388
+ hard-coded assumptions about warp size that break on AMD.
389
+ """
390
+ patterns: list[WavefrontPattern] = []
391
+ lines = content.split('\n')
392
+
393
+ for i, line in enumerate(lines):
394
+ # Skip comments
395
+ stripped = line.strip()
396
+ if stripped.startswith('//') or stripped.startswith('/*'):
397
+ continue
398
+
399
+ # Skip lines that use warpSize (correct portable code)
400
+ if 'warpSize' in line or '__AMDGCN_WAVEFRONT_SIZE__' in line:
401
+ continue
402
+
403
+ for pattern_re, pattern_type, problematic_value in self._WARP_SIZE_32_PATTERNS:
404
+ match = pattern_re.search(line)
405
+ if match:
406
+ patterns.append(WavefrontPattern(
407
+ pattern_type=pattern_type,
408
+ line=i,
409
+ code_snippet=line.strip(),
410
+ problematic_value=problematic_value,
411
+ ))
412
+
413
+ return patterns
414
+
415
+ def _extract_parameters(self, content: str, paren_start: int) -> tuple[list[str], list[HIPParameter]]:
416
+ """Extract function parameters starting from the opening parenthesis.
417
+
418
+ Returns:
419
+ Tuple of (parameter names list, detailed parameter info list)
420
+ """
421
+ if paren_start >= len(content) or content[paren_start] != '(':
422
+ return [], []
423
+
424
+ # Find matching closing paren
425
+ depth = 0
426
+ param_end = paren_start
427
+
428
+ for i in range(paren_start, len(content)):
429
+ char = content[i]
430
+ if char == '(':
431
+ depth += 1
432
+ elif char == ')':
433
+ depth -= 1
434
+ if depth == 0:
435
+ param_end = i
436
+ break
437
+
438
+ if param_end == paren_start:
439
+ return [], []
440
+
441
+ param_str = content[paren_start + 1:param_end]
442
+
443
+ # Parse parameters handling templates and nested parens
444
+ params: list[str] = []
445
+ param_info: list[HIPParameter] = []
446
+ current_param = ""
447
+ template_depth = 0
448
+ paren_depth = 0
449
+
450
+ for char in param_str:
451
+ if char == '<':
452
+ template_depth += 1
453
+ current_param += char
454
+ elif char == '>':
455
+ template_depth -= 1
456
+ current_param += char
457
+ elif char == '(':
458
+ paren_depth += 1
459
+ current_param += char
460
+ elif char == ')':
461
+ paren_depth -= 1
462
+ current_param += char
463
+ elif char == ',' and template_depth == 0 and paren_depth == 0:
464
+ param = self._parse_single_parameter(current_param.strip())
465
+ if param:
466
+ params.append(param.name)
467
+ param_info.append(param)
468
+ current_param = ""
469
+ else:
470
+ current_param += char
471
+
472
+ # Handle last parameter
473
+ if current_param.strip():
474
+ param = self._parse_single_parameter(current_param.strip())
475
+ if param:
476
+ params.append(param.name)
477
+ param_info.append(param)
478
+
479
+ return params, param_info
480
+
481
+ def _parse_single_parameter(self, param_str: str) -> HIPParameter | None:
482
+ """Parse a single parameter string into HIPParameter.
483
+
484
+ Returns None if param_str cannot be parsed into a valid parameter.
485
+ """
486
+ if not param_str:
487
+ return None
488
+
489
+ # Remove default value if present
490
+ if '=' in param_str:
491
+ param_str = param_str.split('=')[0].strip()
492
+
493
+ is_const = 'const' in param_str
494
+ is_pointer = '*' in param_str or '&' in param_str
495
+
496
+ # Extract name (last token after removing type qualifiers)
497
+ parts = param_str.replace('*', ' * ').replace('&', ' & ').split()
498
+
499
+ if not parts:
500
+ return None
501
+
502
+ # Find the parameter name (last identifier)
503
+ name = parts[-1].strip('*&')
504
+
505
+ # If we only have one part, we can't distinguish type from name
506
+ # This could be a parameter like `void` (no name) or incomplete code
507
+ if len(parts) == 1:
508
+ # Assume it's just a name with unknown type
509
+ type_str = ""
510
+ else:
511
+ # Reconstruct type from all parts except the last (name)
512
+ type_parts = parts[:-1]
513
+ type_str = ' '.join(type_parts).replace(' * ', '*').replace(' & ', '&')
514
+
515
+ return HIPParameter(
516
+ name=name,
517
+ type_str=type_str,
518
+ is_pointer=is_pointer,
519
+ is_const=is_const,
520
+ )
521
+
522
+ def _find_function_end(self, content: str, start_pos: int) -> int:
523
+ """Find the line number where a function ends (closing brace).
524
+
525
+ Returns:
526
+ Line number of closing brace, or start line if function body not found.
527
+
528
+ Note: Returns start line (not an approximation) when closing brace cannot be found.
529
+ This happens for incomplete code during editing - the caller should handle this case.
530
+ """
531
+ start_line = content[:start_pos].count('\n')
532
+
533
+ # Find opening brace
534
+ brace_pos = content.find('{', start_pos)
535
+ if brace_pos == -1:
536
+ # No function body found (e.g., forward declaration or incomplete code)
537
+ return start_line
538
+
539
+ # Find matching closing brace
540
+ depth = 1
541
+ for i in range(brace_pos + 1, len(content)):
542
+ if content[i] == '{':
543
+ depth += 1
544
+ elif content[i] == '}':
545
+ depth -= 1
546
+ if depth == 0:
547
+ return content[:i].count('\n')
548
+
549
+ # Unbalanced braces - incomplete code during editing
550
+ # Return start line to indicate we couldn't determine the end
551
+ return start_line
552
+
553
+ def _extract_docstring(self, content: str, func_start: int) -> str | None:
554
+ """Extract documentation comment before a function."""
555
+ # Look backwards for a comment block
556
+ search_start = max(0, func_start - 500)
557
+ prefix = content[search_start:func_start].rstrip()
558
+
559
+ # Check for /// or /** style comments
560
+ lines = prefix.split('\n')
561
+ doc_lines: list[str] = []
562
+
563
+ for line in reversed(lines):
564
+ stripped = line.strip()
565
+ if stripped.startswith('///'):
566
+ doc_lines.insert(0, stripped[3:].strip())
567
+ elif stripped.startswith('*') and not stripped.startswith('*/'):
568
+ doc_lines.insert(0, stripped[1:].strip())
569
+ elif stripped.startswith('/**'):
570
+ doc_lines.insert(0, stripped[3:].strip())
571
+ break
572
+ elif stripped.endswith('*/'):
573
+ doc_lines.insert(0, stripped[:-2].strip())
574
+ elif stripped == '' or stripped.startswith('//'):
575
+ continue
576
+ else:
577
+ # Not a doc comment, stop
578
+ if doc_lines:
579
+ break
580
+
581
+ if doc_lines:
582
+ return '\n'.join(doc_lines)
583
+ return None
584
+
585
+ def _estimate_shared_mem_size(self, type_str: str, array_size: str | None) -> int | None:
586
+ """Estimate shared memory size in bytes.
587
+
588
+ Handles both 1D (e.g., [256]) and 2D (e.g., [16][16]) arrays.
589
+ Returns None if size cannot be determined (e.g., non-literal dimension).
590
+
591
+ Why no eval(): We only handle literal integer dimensions and simple multiplication.
592
+ Variable-based dimensions (e.g., [BLOCK_SIZE]) return None - that's correct behavior
593
+ since we can't know the value at parse time.
594
+ """
595
+ # Size mapping for common types
596
+ type_sizes = {
597
+ 'char': 1, 'int8_t': 1, 'uint8_t': 1,
598
+ 'short': 2, 'int16_t': 2, 'uint16_t': 2, 'half': 2, '__half': 2,
599
+ 'int': 4, 'int32_t': 4, 'uint32_t': 4, 'float': 4, 'unsigned': 4,
600
+ 'long': 8, 'int64_t': 8, 'uint64_t': 8, 'double': 8, 'long long': 8,
601
+ 'float4': 16, 'float2': 8, 'int4': 16, 'int2': 8,
602
+ 'double2': 16, 'double4': 32,
603
+ }
604
+
605
+ # Find base type
606
+ base_type = type_str.strip()
607
+ type_size: int | None = None
608
+ for known_type, size in type_sizes.items():
609
+ if known_type in base_type:
610
+ type_size = size
611
+ break
612
+
613
+ if type_size is None:
614
+ return None
615
+
616
+ if not array_size:
617
+ return type_size
618
+
619
+ # Parse array dimensions - could be [n] or [n][m]
620
+ # Extract all bracketed dimensions
621
+ dims = re.findall(r'\[([^\]]+)\]', array_size)
622
+ if not dims:
623
+ return type_size
624
+
625
+ total_elements = 1
626
+ for dim in dims:
627
+ dim_value = self._parse_dimension_expression(dim.strip())
628
+ if dim_value is None:
629
+ # Non-literal dimension (e.g., variable) - cannot determine size
630
+ return None
631
+ total_elements *= dim_value
632
+
633
+ return type_size * total_elements
634
+
635
+ def _parse_dimension_expression(self, expr: str) -> int | None:
636
+ """Parse a dimension expression safely without eval().
637
+
638
+ Handles:
639
+ - Simple integers: "256"
640
+ - Simple multiplication: "16 * 16", "BLOCK_SIZE * 4" (only if all literal)
641
+
642
+ Returns None for anything we can't safely parse (variables, complex expressions).
643
+ This is correct behavior - we don't know variable values at parse time.
644
+ """
645
+ expr = expr.strip()
646
+
647
+ # Simple integer
648
+ if expr.isdigit():
649
+ return int(expr)
650
+
651
+ # Handle expressions with * (multiplication only)
652
+ if '*' in expr:
653
+ parts = expr.split('*')
654
+ result = 1
655
+ for part in parts:
656
+ part = part.strip()
657
+ if not part.isdigit():
658
+ # Contains a variable or non-integer - cannot evaluate
659
+ return None
660
+ result *= int(part)
661
+ return result
662
+
663
+ # Not a simple literal - could be a variable like BLOCK_SIZE
664
+ return None
665
+
666
+
667
+ def is_hip_file(content: str) -> bool:
668
+ """Check if content appears to be a HIP file based on content markers."""
669
+ hip_markers = [
670
+ '#include <hip/hip_runtime.h>',
671
+ '#include "hip/hip_runtime.h"',
672
+ '#include <hip/hip_runtime_api.h>',
673
+ 'hipMalloc',
674
+ 'hipMemcpy',
675
+ 'hipFree',
676
+ 'hipLaunchKernelGGL',
677
+ 'hipDeviceSynchronize',
678
+ '__HIP_PLATFORM_AMD__',
679
+ '__HIP_PLATFORM_HCC__',
680
+ 'HIP_KERNEL_NAME',
681
+ ]
682
+
683
+ content_lower = content.lower()
684
+ for marker in hip_markers:
685
+ if marker.lower() in content_lower:
686
+ return True
687
+
688
+ return False