wafer-lsp 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,27 @@
1
1
 
2
+ import re
3
+
2
4
  from lsprotocol.types import InlayHint, InlayHintKind, Position, Range
3
5
 
4
6
  from ..languages.registry import get_language_registry
5
7
 
6
8
 
9
+ # HIP kernel launch pattern: kernel<<<grid, block, shared, stream>>>
10
+ HIP_LAUNCH_PATTERN = re.compile(
11
+ r'(\w+)\s*<<<\s*'
12
+ r'([^,>]+)\s*,\s*' # grid
13
+ r'([^,>]+)' # block
14
+ r'(?:\s*,\s*([^,>]+))?' # shared (optional)
15
+ r'(?:\s*,\s*([^>]+))?' # stream (optional)
16
+ r'\s*>>>'
17
+ )
18
+
19
+ # Shared memory declaration pattern
20
+ SHARED_MEM_PATTERN = re.compile(
21
+ r'__shared__\s+([\w\s:<>]+?)\s+(\w+)\s*\[([^\]]+)\]'
22
+ )
23
+
24
+
7
25
  def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
8
26
  registry = get_language_registry()
9
27
  language_info = registry.parse_file(uri, content)
@@ -14,6 +32,7 @@ def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
14
32
  hints: list[InlayHint] = []
15
33
  lines = content.split("\n")
16
34
 
35
+ # Layout hints (CuTeDSL)
17
36
  for layout in language_info.layouts:
18
37
  if layout.line < range.start.line or layout.line > range.end.line:
19
38
  continue
@@ -39,6 +58,7 @@ def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
39
58
  padding_right=False
40
59
  ))
41
60
 
61
+ # Kernel hints (CuTeDSL)
42
62
  for kernel in language_info.kernels:
43
63
  if kernel.line < range.start.line or kernel.line > range.end.line:
44
64
  continue
@@ -62,4 +82,164 @@ def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
62
82
  padding_right=True
63
83
  ))
64
84
 
85
+ # HIP/CUDA-specific hints
86
+ if language_info.language in ("hip", "cuda", "cpp"):
87
+ hints.extend(_get_hip_inlay_hints(lines, range))
88
+
65
89
  return hints
90
+
91
+
92
+ def _get_hip_inlay_hints(lines: list[str], range: Range) -> list[InlayHint]:
93
+ """Generate HIP-specific inlay hints.
94
+
95
+ - Kernel launch dimension annotations
96
+ - Shared memory size annotations
97
+ """
98
+ hints: list[InlayHint] = []
99
+
100
+ for i in range(range.start.line, min(range.end.line + 1, len(lines))):
101
+ line = lines[i]
102
+
103
+ # Kernel launch hints
104
+ for match in HIP_LAUNCH_PATTERN.finditer(line):
105
+ kernel_name = match.group(1)
106
+ grid_dim = match.group(2).strip()
107
+ block_dim = match.group(3).strip()
108
+ shared_mem = match.group(4)
109
+ stream = match.group(5)
110
+
111
+ # Add hint after >>> showing launch configuration
112
+ hint_parts = []
113
+
114
+ # Try to parse and annotate dimensions
115
+ grid_info = _parse_dim(grid_dim)
116
+ block_info = _parse_dim(block_dim)
117
+
118
+ if grid_info:
119
+ hint_parts.append(f"{grid_info} blocks")
120
+ if block_info:
121
+ hint_parts.append(f"{block_info} threads/block")
122
+ # Calculate wavefronts (AMD uses 64-thread wavefronts)
123
+ try:
124
+ total_threads = _eval_dim(block_dim)
125
+ if total_threads:
126
+ wavefronts = (total_threads + 63) // 64
127
+ hint_parts.append(f"{wavefronts} wavefront{'s' if wavefronts != 1 else ''}")
128
+ except (ValueError, SyntaxError):
129
+ pass
130
+
131
+ if hint_parts:
132
+ hint_text = " // " + ", ".join(hint_parts)
133
+
134
+ # Position after >>>
135
+ hint_pos = match.end()
136
+
137
+ hints.append(InlayHint(
138
+ position=Position(line=i, character=hint_pos),
139
+ label=hint_text,
140
+ kind=InlayHintKind.Parameter,
141
+ padding_left=True,
142
+ padding_right=False
143
+ ))
144
+
145
+ # Shared memory size hints
146
+ for match in SHARED_MEM_PATTERN.finditer(line):
147
+ type_str = match.group(1).strip()
148
+ var_name = match.group(2)
149
+ array_size = match.group(3).strip()
150
+
151
+ size_bytes = _estimate_size(type_str, array_size)
152
+ if size_bytes:
153
+ if size_bytes >= 1024:
154
+ size_str = f" // {size_bytes / 1024:.1f} KB LDS"
155
+ else:
156
+ size_str = f" // {size_bytes} bytes LDS"
157
+
158
+ # Position at end of declaration
159
+ hint_pos = match.end()
160
+
161
+ hints.append(InlayHint(
162
+ position=Position(line=i, character=hint_pos),
163
+ label=size_str,
164
+ kind=InlayHintKind.Type,
165
+ padding_left=True,
166
+ padding_right=False
167
+ ))
168
+
169
+ return hints
170
+
171
+
172
+ def _parse_dim(dim_str: str) -> str | None:
173
+ """Parse a dimension string and return a human-readable description."""
174
+ dim_str = dim_str.strip()
175
+
176
+ # Simple number
177
+ if dim_str.isdigit():
178
+ return dim_str
179
+
180
+ # dim3(x, y, z)
181
+ if dim_str.startswith("dim3("):
182
+ return dim_str
183
+
184
+ # Variable or expression
185
+ if re.match(r'^[\w_]+$', dim_str):
186
+ return dim_str
187
+
188
+ return None
189
+
190
+
191
+ def _eval_dim(dim_str: str) -> int | None:
192
+ """Try to evaluate a dimension to an integer."""
193
+ dim_str = dim_str.strip()
194
+
195
+ # Simple number
196
+ if dim_str.isdigit():
197
+ return int(dim_str)
198
+
199
+ # dim3(x) or dim3(x, y) or dim3(x, y, z) - try to multiply
200
+ if dim_str.startswith("dim3(") and dim_str.endswith(")"):
201
+ inner = dim_str[5:-1]
202
+ parts = [p.strip() for p in inner.split(",")]
203
+ try:
204
+ total = 1
205
+ for p in parts:
206
+ if p.isdigit():
207
+ total *= int(p)
208
+ else:
209
+ return None # Can't evaluate variable
210
+ return total
211
+ except (ValueError, SyntaxError):
212
+ return None
213
+
214
+ return None
215
+
216
+
217
+ def _estimate_size(type_str: str, array_size: str) -> int | None:
218
+ """Estimate size in bytes for a shared memory allocation."""
219
+ type_sizes = {
220
+ 'char': 1, 'int8_t': 1, 'uint8_t': 1,
221
+ 'short': 2, 'int16_t': 2, 'uint16_t': 2, 'half': 2, '__half': 2,
222
+ 'int': 4, 'int32_t': 4, 'uint32_t': 4, 'float': 4, 'unsigned': 4,
223
+ 'long': 8, 'int64_t': 8, 'uint64_t': 8, 'double': 8,
224
+ 'float4': 16, 'float2': 8, 'int4': 16, 'int2': 8,
225
+ 'double2': 16, 'double4': 32,
226
+ }
227
+
228
+ # Find base type
229
+ base_type = type_str.strip()
230
+ type_size = None
231
+ for known_type, size in type_sizes.items():
232
+ if known_type in base_type:
233
+ type_size = size
234
+ break
235
+
236
+ if type_size is None:
237
+ type_size = 4 # Default to 4 bytes
238
+
239
+ # Try to evaluate array size
240
+ try:
241
+ # Handle simple expressions
242
+ arr_size = eval(array_size.replace('*', ' * '))
243
+ return type_size * arr_size
244
+ except (ValueError, SyntaxError, NameError):
245
+ return None
@@ -1,13 +1,20 @@
1
1
 
2
+ import re
3
+
2
4
  from lsprotocol.types import SemanticTokens, SemanticTokensLegend
3
5
 
4
6
  from ..languages.registry import get_language_registry
5
7
 
6
8
  TOKEN_TYPES = [
7
- "kernel",
8
- "layout",
9
- "struct",
10
- "decorator",
9
+ "kernel", # 0: GPU kernel functions (__global__)
10
+ "layout", # 1: Layout variables
11
+ "struct", # 2: Structs
12
+ "decorator", # 3: Python decorators (@cute.kernel)
13
+ "keyword_gpu", # 4: GPU keywords (__global__, __device__, __shared__)
14
+ "keyword_memory", # 5: Memory qualifiers (__shared__, __constant__)
15
+ "function_hip_api", # 6: HIP API calls (hipMalloc, etc.)
16
+ "function_intrinsic", # 7: Wavefront intrinsics (__shfl, __ballot)
17
+ "device_function", # 8: __device__ functions
11
18
  ]
12
19
 
13
20
  TOKEN_MODIFIERS = [
@@ -20,6 +27,31 @@ SEMANTIC_TOKENS_LEGEND = SemanticTokensLegend(
20
27
  token_modifiers=TOKEN_MODIFIERS
21
28
  )
22
29
 
30
+ # HIP-specific patterns for semantic highlighting
31
+ HIP_KEYWORD_PATTERN = re.compile(r'\b(__global__|__device__|__host__|__forceinline__)\b')
32
+ HIP_MEMORY_KEYWORD_PATTERN = re.compile(r'\b(__shared__|__constant__|__restrict__)\b')
33
+ HIP_LAUNCH_BOUNDS_PATTERN = re.compile(r'__launch_bounds__\s*\([^)]+\)')
34
+ HIP_API_PATTERN = re.compile(
35
+ r'\b(hipMalloc|hipMallocManaged|hipMallocAsync|hipFree|hipFreeAsync|'
36
+ r'hipMemcpy|hipMemcpyAsync|hipMemset|hipMemsetAsync|'
37
+ r'hipHostMalloc|hipHostFree|hipHostRegister|hipHostUnregister|'
38
+ r'hipDeviceSynchronize|hipStreamSynchronize|'
39
+ r'hipStreamCreate|hipStreamDestroy|hipStreamCreateWithFlags|'
40
+ r'hipEventCreate|hipEventDestroy|hipEventRecord|hipEventSynchronize|hipEventElapsedTime|'
41
+ r'hipSetDevice|hipGetDevice|hipGetDeviceCount|hipGetDeviceProperties|'
42
+ r'hipLaunchKernelGGL|hipLaunchCooperativeKernel|'
43
+ r'hipGetLastError|hipPeekAtLastError|hipGetErrorString|hipGetErrorName)\b'
44
+ )
45
+ HIP_INTRINSIC_PATTERN = re.compile(
46
+ r'\b(__shfl|__shfl_down|__shfl_up|__shfl_xor|__shfl_sync|'
47
+ r'__ballot|__any|__all|__activemask|'
48
+ r'__syncthreads|__syncwarp|__threadfence|__threadfence_block|__threadfence_system|'
49
+ r'atomicAdd|atomicSub|atomicMax|atomicMin|atomicExch|atomicCAS|atomicAnd|atomicOr|atomicXor|'
50
+ r'__popc|__popcll|__clz|__clzll|__ffs|__ffsll|'
51
+ r'__float2half|__half2float|__float2int_rn|__int2float_rn|'
52
+ r'__ldg|__ldcg|__ldca|__ldcs)\b'
53
+ )
54
+
23
55
 
24
56
  def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
25
57
  registry = get_language_registry()
@@ -28,11 +60,12 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
28
60
  if not language_info:
29
61
  return SemanticTokens(data=[])
30
62
 
31
- tokens: list[int] = []
63
+ # Collect all tokens with their positions
64
+ # We'll sort them later to ensure proper delta calculation
65
+ token_entries: list[tuple[int, int, int, int, int]] = [] # (line, char, length, type, modifier)
32
66
  lines = content.split("\n")
33
- prev_line = 0
34
- prev_char = 0
35
67
 
68
+ # Add kernel tokens
36
69
  for kernel in language_info.kernels:
37
70
  if kernel.line >= len(lines):
38
71
  continue
@@ -41,20 +74,15 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
41
74
  name_start = kernel_line.find(kernel.name)
42
75
 
43
76
  if name_start >= 0:
44
- delta_line = kernel.line - prev_line
45
- delta_char = name_start - (prev_char if delta_line == 0 else 0)
46
-
47
- tokens.extend([
48
- delta_line,
49
- delta_char,
77
+ token_entries.append((
78
+ kernel.line,
79
+ name_start,
50
80
  len(kernel.name),
51
81
  TOKEN_TYPES.index("kernel"),
52
82
  0
53
- ])
54
-
55
- prev_line = kernel.line
56
- prev_char = name_start + len(kernel.name)
83
+ ))
57
84
 
85
+ # Add layout tokens
58
86
  for layout in language_info.layouts:
59
87
  if layout.line >= len(lines):
60
88
  continue
@@ -63,20 +91,15 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
63
91
  name_start = layout_line.find(layout.name)
64
92
 
65
93
  if name_start >= 0:
66
- delta_line = layout.line - prev_line
67
- delta_char = name_start - (prev_char if delta_line == 0 else 0)
68
-
69
- tokens.extend([
70
- delta_line,
71
- delta_char,
94
+ token_entries.append((
95
+ layout.line,
96
+ name_start,
72
97
  len(layout.name),
73
98
  TOKEN_TYPES.index("layout"),
74
99
  0
75
- ])
76
-
77
- prev_line = layout.line
78
- prev_char = name_start + len(layout.name)
100
+ ))
79
101
 
102
+ # Add struct tokens
80
103
  for struct in language_info.structs:
81
104
  if struct.line >= len(lines):
82
105
  continue
@@ -85,20 +108,15 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
85
108
  name_start = struct_line.find(struct.name)
86
109
 
87
110
  if name_start >= 0:
88
- delta_line = struct.line - prev_line
89
- delta_char = name_start - (prev_char if delta_line == 0 else 0)
90
-
91
- tokens.extend([
92
- delta_line,
93
- delta_char,
111
+ token_entries.append((
112
+ struct.line,
113
+ name_start,
94
114
  len(struct.name),
95
115
  TOKEN_TYPES.index("struct"),
96
116
  0
97
- ])
98
-
99
- prev_line = struct.line
100
- prev_char = name_start + len(struct.name)
117
+ ))
101
118
 
119
+ # Add CuTeDSL decorator tokens
102
120
  for i, line in enumerate(lines):
103
121
  if "@cute.kernel" in line or "@cute.struct" in line:
104
122
  decorator_start = line.find("@")
@@ -107,18 +125,100 @@ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
107
125
  if decorator_end == -1:
108
126
  decorator_end = len(line)
109
127
 
110
- delta_line = i - prev_line
111
- delta_char = decorator_start - (prev_char if delta_line == 0 else 0)
112
-
113
- tokens.extend([
114
- delta_line,
115
- delta_char,
128
+ token_entries.append((
129
+ i,
130
+ decorator_start,
116
131
  decorator_end - decorator_start,
117
132
  TOKEN_TYPES.index("decorator"),
118
133
  0
119
- ])
134
+ ))
135
+
136
+ # Add HIP-specific tokens if this is a HIP or CUDA file
137
+ if language_info.language in ("hip", "cuda", "cpp"):
138
+ token_entries.extend(_get_hip_tokens(lines))
120
139
 
121
- prev_line = i
122
- prev_char = decorator_end
140
+ # Sort tokens by position (line, then character)
141
+ token_entries.sort(key=lambda x: (x[0], x[1]))
142
+
143
+ # Convert to delta-encoded format
144
+ tokens: list[int] = []
145
+ prev_line = 0
146
+ prev_char = 0
147
+
148
+ for line, char, length, token_type, modifier in token_entries:
149
+ delta_line = line - prev_line
150
+ delta_char = char - (prev_char if delta_line == 0 else 0)
151
+
152
+ tokens.extend([
153
+ delta_line,
154
+ delta_char,
155
+ length,
156
+ token_type,
157
+ modifier
158
+ ])
159
+
160
+ prev_line = line
161
+ prev_char = char
123
162
 
124
163
  return SemanticTokens(data=tokens)
164
+
165
+
166
+ def _get_hip_tokens(lines: list[str]) -> list[tuple[int, int, int, int, int]]:
167
+ """Extract HIP-specific semantic tokens from code.
168
+
169
+ Returns list of (line, char, length, token_type, modifier) tuples.
170
+ """
171
+ token_entries: list[tuple[int, int, int, int, int]] = []
172
+
173
+ for i, line in enumerate(lines):
174
+ # GPU keywords (__global__, __device__, etc.)
175
+ for match in HIP_KEYWORD_PATTERN.finditer(line):
176
+ token_entries.append((
177
+ i,
178
+ match.start(),
179
+ len(match.group()),
180
+ TOKEN_TYPES.index("keyword_gpu"),
181
+ 0
182
+ ))
183
+
184
+ # Memory keywords (__shared__, __constant__)
185
+ for match in HIP_MEMORY_KEYWORD_PATTERN.finditer(line):
186
+ token_entries.append((
187
+ i,
188
+ match.start(),
189
+ len(match.group()),
190
+ TOKEN_TYPES.index("keyword_memory"),
191
+ 0
192
+ ))
193
+
194
+ # __launch_bounds__
195
+ for match in HIP_LAUNCH_BOUNDS_PATTERN.finditer(line):
196
+ token_entries.append((
197
+ i,
198
+ match.start(),
199
+ len(match.group()),
200
+ TOKEN_TYPES.index("keyword_gpu"),
201
+ 0
202
+ ))
203
+
204
+ # HIP API functions
205
+ for match in HIP_API_PATTERN.finditer(line):
206
+ token_entries.append((
207
+ i,
208
+ match.start(),
209
+ len(match.group()),
210
+ TOKEN_TYPES.index("function_hip_api"),
211
+ 0
212
+ ))
213
+
214
+ # Wavefront intrinsics
215
+ for match in HIP_INTRINSIC_PATTERN.finditer(line):
216
+ token_entries.append((
217
+ i,
218
+ match.start(),
219
+ len(match.group()),
220
+ TOKEN_TYPES.index("function_intrinsic"),
221
+ 0
222
+ ))
223
+
224
+ return token_entries
@@ -2,33 +2,106 @@ from pathlib import Path
2
2
 
3
3
 
4
4
  class LanguageDetector:
5
+ """Detects language based on file extension and content markers.
6
+
7
+ Supports both extension-based detection (fast) and content-based detection
8
+ (for files that share extensions, e.g., .cpp files that could be HIP or CUDA).
9
+ """
5
10
 
6
11
  def __init__(self):
7
12
  self._extensions: dict[str, str] = {}
13
+ self._content_markers: dict[str, list[str]] = {} # language_id -> markers
14
+ # Compound extensions like .hip.cpp need special handling
15
+ self._compound_extensions: dict[str, str] = {}
8
16
 
9
17
  def register_extension(self, extension: str, language_id: str):
10
- normalized_ext = extension if extension.startswith(".") else f".{ext}"
11
- self._extensions[normalized_ext] = language_id
18
+ normalized_ext = extension if extension.startswith(".") else f".{extension}"
19
+
20
+ # Check if this is a compound extension (e.g., .hip.cpp)
21
+ if normalized_ext.count(".") > 1:
22
+ self._compound_extensions[normalized_ext] = language_id
23
+ else:
24
+ self._extensions[normalized_ext] = language_id
25
+
26
+ def register_content_markers(self, language_id: str, markers: list[str]):
27
+ """Register content markers for content-based language detection."""
28
+ self._content_markers[language_id] = markers
12
29
 
13
- def detect_from_uri(self, uri: str) -> str | None:
30
+ def detect_from_uri(self, uri: str, content: str | None = None) -> str | None:
31
+ """Detect language from URI and optionally content.
32
+
33
+ Args:
34
+ uri: File URI or path
35
+ content: Optional file content for content-based detection
36
+
37
+ Returns:
38
+ Language ID or None
39
+ """
14
40
  if uri.startswith("file://"):
15
41
  file_path = uri[7:]
16
42
  else:
17
43
  file_path = uri
18
44
 
19
- return self.detect_from_path(file_path)
45
+ return self.detect_from_path(file_path, content)
20
46
 
21
- def detect_from_path(self, file_path: str) -> str | None:
47
+ def detect_from_path(self, file_path: str, content: str | None = None) -> str | None:
48
+ """Detect language from file path and optionally content.
49
+
50
+ Order of detection:
51
+ 1. Compound extensions (e.g., .hip.cpp) - most specific
52
+ 2. Content markers (for shared extensions like .cpp)
53
+ 3. Simple extension
54
+ """
22
55
  path = Path(file_path)
56
+
57
+ # 1. Check compound extensions first
58
+ # Get the last two suffixes for compound extension detection
59
+ suffixes = path.suffixes
60
+ if len(suffixes) >= 2:
61
+ compound_ext = "".join(suffixes[-2:]).lower()
62
+ if compound_ext in self._compound_extensions:
63
+ return self._compound_extensions[compound_ext]
64
+
65
+ # 2. If content is provided, check content markers
66
+ if content:
67
+ content_lang = self._detect_from_content(content)
68
+ if content_lang:
69
+ return content_lang
70
+
71
+ # 3. Fall back to simple extension
23
72
  ext = path.suffix.lower()
24
73
  return self._extensions.get(ext)
25
74
 
75
+ def _detect_from_content(self, content: str) -> str | None:
76
+ """Detect language based on content markers.
77
+
78
+ Returns the language with the most matching markers.
79
+ """
80
+ best_match: str | None = None
81
+ best_count = 0
82
+
83
+ for language_id, markers in self._content_markers.items():
84
+ match_count = sum(1 for marker in markers if marker in content)
85
+ if match_count > best_count:
86
+ best_count = match_count
87
+ best_match = language_id
88
+
89
+ # Require at least one marker match
90
+ return best_match if best_count > 0 else None
91
+
26
92
  def detect_from_extension(self, extension: str) -> str | None:
27
- normalized_ext = extension if extension.startswith(".") else f".{ext}"
93
+ normalized_ext = extension if extension.startswith(".") else f".{extension}"
94
+ normalized_ext = normalized_ext.lower() # Case insensitive
28
95
  return self._extensions.get(normalized_ext)
29
96
 
30
97
  def get_supported_extensions(self) -> list[str]:
31
- return list(self._extensions.keys())
98
+ all_extensions = list(self._extensions.keys())
99
+ all_extensions.extend(self._compound_extensions.keys())
100
+ return all_extensions
101
+
102
+ def is_supported(self, uri: str, content: str | None = None) -> bool:
103
+ return self.detect_from_uri(uri, content) is not None
32
104
 
33
- def is_supported(self, uri: str) -> bool:
34
- return self.detect_from_uri(uri) is not None
105
+ def get_content_markers(self, language_id: str) -> list[str]:
106
+ """Get content markers for a language."""
107
+ return self._content_markers.get(language_id, [])
@@ -1,6 +1,7 @@
1
1
 
2
2
  from ..parsers.cuda_parser import CUDAParser
3
3
  from ..parsers.cutedsl_parser import CuTeDSLParser
4
+ from ..parsers.hip_parser import HIPParser
4
5
  from .converter import ParserResultConverter
5
6
  from .detector import LanguageDetector
6
7
  from .parser_manager import ParserManager
@@ -33,6 +34,22 @@ class LanguageRegistry:
33
34
  file_patterns=["*.cu", "*.cuh"]
34
35
  )
35
36
 
37
+ # HIP (AMD GPU) - Register before cpp so .hip.cpp files get detected as HIP
38
+ self.register_language(
39
+ language_id="hip",
40
+ display_name="HIP (AMD GPU)",
41
+ parser=HIPParser(),
42
+ extensions=[".hip", ".hip.cpp", ".hip.hpp", ".hipcc"],
43
+ file_patterns=["*.hip", "*.hip.cpp", "*.hip.hpp", "*.hipcc"],
44
+ content_markers=[
45
+ "#include <hip/hip_runtime.h>",
46
+ "#include \"hip/hip_runtime.h\"",
47
+ "hipMalloc",
48
+ "hipLaunchKernelGGL",
49
+ "__HIP_PLATFORM_AMD__",
50
+ ]
51
+ )
52
+
36
53
  self.register_language(
37
54
  language_id="cpp",
38
55
  display_name="C++",
@@ -47,12 +64,16 @@ class LanguageRegistry:
47
64
  display_name: str,
48
65
  parser,
49
66
  extensions: list[str],
50
- file_patterns: list[str] | None = None
67
+ file_patterns: list[str] | None = None,
68
+ content_markers: list[str] | None = None
51
69
  ):
52
70
  self._parser_manager.register_parser(language_id, display_name, parser)
53
71
 
54
72
  for ext in extensions:
55
73
  self._detector.register_extension(ext, language_id)
74
+
75
+ if content_markers:
76
+ self._detector.register_content_markers(language_id, content_markers)
56
77
 
57
78
  def detect_language(self, uri: str) -> str | None:
58
79
  return self._detector.detect_from_uri(uri)
@@ -6,6 +6,16 @@ from .cutedsl_parser import (
6
6
  CuTeDSLParser,
7
7
  CuTeDSLStruct,
8
8
  )
9
+ from .hip_parser import (
10
+ HIPKernel,
11
+ HIPDeviceFunction,
12
+ HIPParameter,
13
+ HIPParser,
14
+ KernelLaunchSite,
15
+ SharedMemoryAllocation,
16
+ WavefrontPattern,
17
+ is_hip_file,
18
+ )
9
19
 
10
20
  __all__ = [
11
21
  "BaseParser",
@@ -15,4 +25,12 @@ __all__ = [
15
25
  "CuTeDSLLayout",
16
26
  "CuTeDSLParser",
17
27
  "CuTeDSLStruct",
28
+ "HIPKernel",
29
+ "HIPDeviceFunction",
30
+ "HIPParameter",
31
+ "HIPParser",
32
+ "KernelLaunchSite",
33
+ "SharedMemoryAllocation",
34
+ "WavefrontPattern",
35
+ "is_hip_file",
18
36
  ]