wafer-lsp 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. wafer_lsp/__init__.py +1 -0
  2. wafer_lsp/__main__.py +9 -0
  3. wafer_lsp/analyzers/__init__.py +0 -0
  4. wafer_lsp/analyzers/compiler_integration.py +16 -0
  5. wafer_lsp/analyzers/docs_index.py +36 -0
  6. wafer_lsp/handlers/__init__.py +30 -0
  7. wafer_lsp/handlers/code_action.py +48 -0
  8. wafer_lsp/handlers/code_lens.py +48 -0
  9. wafer_lsp/handlers/completion.py +6 -0
  10. wafer_lsp/handlers/diagnostics.py +41 -0
  11. wafer_lsp/handlers/document_symbol.py +176 -0
  12. wafer_lsp/handlers/hip_diagnostics.py +303 -0
  13. wafer_lsp/handlers/hover.py +251 -0
  14. wafer_lsp/handlers/inlay_hint.py +245 -0
  15. wafer_lsp/handlers/semantic_tokens.py +224 -0
  16. wafer_lsp/handlers/workspace_symbol.py +87 -0
  17. wafer_lsp/languages/README.md +195 -0
  18. wafer_lsp/languages/__init__.py +17 -0
  19. wafer_lsp/languages/converter.py +88 -0
  20. wafer_lsp/languages/detector.py +107 -0
  21. wafer_lsp/languages/parser_manager.py +33 -0
  22. wafer_lsp/languages/registry.py +120 -0
  23. wafer_lsp/languages/types.py +37 -0
  24. wafer_lsp/parsers/__init__.py +36 -0
  25. wafer_lsp/parsers/base_parser.py +9 -0
  26. wafer_lsp/parsers/cuda_parser.py +95 -0
  27. wafer_lsp/parsers/cutedsl_parser.py +114 -0
  28. wafer_lsp/parsers/hip_parser.py +688 -0
  29. wafer_lsp/server.py +58 -0
  30. wafer_lsp/services/__init__.py +38 -0
  31. wafer_lsp/services/analysis_service.py +22 -0
  32. wafer_lsp/services/docs_service.py +40 -0
  33. wafer_lsp/services/document_service.py +20 -0
  34. wafer_lsp/services/hip_docs.py +806 -0
  35. wafer_lsp/services/hip_hover_service.py +412 -0
  36. wafer_lsp/services/hover_service.py +237 -0
  37. wafer_lsp/services/language_registry_service.py +26 -0
  38. wafer_lsp/services/position_service.py +77 -0
  39. wafer_lsp/utils/__init__.py +0 -0
  40. wafer_lsp/utils/lsp_helpers.py +79 -0
  41. wafer_lsp-0.1.13.dist-info/METADATA +60 -0
  42. wafer_lsp-0.1.13.dist-info/RECORD +44 -0
  43. wafer_lsp-0.1.13.dist-info/WHEEL +4 -0
  44. wafer_lsp-0.1.13.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,245 @@
1
+
2
+ import re
3
+
4
+ from lsprotocol.types import InlayHint, InlayHintKind, Position, Range
5
+
6
+ from ..languages.registry import get_language_registry
7
+
8
+
9
+ # HIP kernel launch pattern: kernel<<<grid, block, shared, stream>>>
10
+ HIP_LAUNCH_PATTERN = re.compile(
11
+ r'(\w+)\s*<<<\s*'
12
+ r'([^,>]+)\s*,\s*' # grid
13
+ r'([^,>]+)' # block
14
+ r'(?:\s*,\s*([^,>]+))?' # shared (optional)
15
+ r'(?:\s*,\s*([^>]+))?' # stream (optional)
16
+ r'\s*>>>'
17
+ )
18
+
19
+ # Shared memory declaration pattern
20
+ SHARED_MEM_PATTERN = re.compile(
21
+ r'__shared__\s+([\w\s:<>]+?)\s+(\w+)\s*\[([^\]]+)\]'
22
+ )
23
+
24
+
25
+ def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
26
+ registry = get_language_registry()
27
+ language_info = registry.parse_file(uri, content)
28
+
29
+ if not language_info:
30
+ return []
31
+
32
+ hints: list[InlayHint] = []
33
+ lines = content.split("\n")
34
+
35
+ # Layout hints (CuTeDSL)
36
+ for layout in language_info.layouts:
37
+ if layout.line < range.start.line or layout.line > range.end.line:
38
+ continue
39
+
40
+ layout_line = lines[layout.line] if layout.line < len(lines) else ""
41
+
42
+ if "=" in layout_line:
43
+ equals_pos = layout_line.find("=")
44
+ hint_text = ": Layout"
45
+ if layout.shape:
46
+ hint_text = f": Layout[Shape{layout.shape}]"
47
+
48
+ hint_position = Position(
49
+ line=layout.line,
50
+ character=equals_pos + 1
51
+ )
52
+
53
+ hints.append(InlayHint(
54
+ position=hint_position,
55
+ label=hint_text,
56
+ kind=InlayHintKind.Type,
57
+ padding_left=True,
58
+ padding_right=False
59
+ ))
60
+
61
+ # Kernel hints (CuTeDSL)
62
+ for kernel in language_info.kernels:
63
+ if kernel.line < range.start.line or kernel.line > range.end.line:
64
+ continue
65
+
66
+ kernel_line = lines[kernel.line] if kernel.line < len(lines) else ""
67
+
68
+ if "def " in kernel_line and "(" in kernel_line:
69
+ paren_pos = kernel_line.find("(")
70
+ hint_text = " -> Kernel"
71
+
72
+ hint_position = Position(
73
+ line=kernel.line,
74
+ character=paren_pos
75
+ )
76
+
77
+ hints.append(InlayHint(
78
+ position=hint_position,
79
+ label=hint_text,
80
+ kind=InlayHintKind.Type,
81
+ padding_left=True,
82
+ padding_right=True
83
+ ))
84
+
85
+ # HIP/CUDA-specific hints
86
+ if language_info.language in ("hip", "cuda", "cpp"):
87
+ hints.extend(_get_hip_inlay_hints(lines, range))
88
+
89
+ return hints
90
+
91
+
92
+ def _get_hip_inlay_hints(lines: list[str], range: Range) -> list[InlayHint]:
93
+ """Generate HIP-specific inlay hints.
94
+
95
+ - Kernel launch dimension annotations
96
+ - Shared memory size annotations
97
+ """
98
+ hints: list[InlayHint] = []
99
+
100
+ for i in range(range.start.line, min(range.end.line + 1, len(lines))):
101
+ line = lines[i]
102
+
103
+ # Kernel launch hints
104
+ for match in HIP_LAUNCH_PATTERN.finditer(line):
105
+ kernel_name = match.group(1)
106
+ grid_dim = match.group(2).strip()
107
+ block_dim = match.group(3).strip()
108
+ shared_mem = match.group(4)
109
+ stream = match.group(5)
110
+
111
+ # Add hint after >>> showing launch configuration
112
+ hint_parts = []
113
+
114
+ # Try to parse and annotate dimensions
115
+ grid_info = _parse_dim(grid_dim)
116
+ block_info = _parse_dim(block_dim)
117
+
118
+ if grid_info:
119
+ hint_parts.append(f"{grid_info} blocks")
120
+ if block_info:
121
+ hint_parts.append(f"{block_info} threads/block")
122
+ # Calculate wavefronts (AMD uses 64-thread wavefronts)
123
+ try:
124
+ total_threads = _eval_dim(block_dim)
125
+ if total_threads:
126
+ wavefronts = (total_threads + 63) // 64
127
+ hint_parts.append(f"{wavefronts} wavefront{'s' if wavefronts != 1 else ''}")
128
+ except (ValueError, SyntaxError):
129
+ pass
130
+
131
+ if hint_parts:
132
+ hint_text = " // " + ", ".join(hint_parts)
133
+
134
+ # Position after >>>
135
+ hint_pos = match.end()
136
+
137
+ hints.append(InlayHint(
138
+ position=Position(line=i, character=hint_pos),
139
+ label=hint_text,
140
+ kind=InlayHintKind.Parameter,
141
+ padding_left=True,
142
+ padding_right=False
143
+ ))
144
+
145
+ # Shared memory size hints
146
+ for match in SHARED_MEM_PATTERN.finditer(line):
147
+ type_str = match.group(1).strip()
148
+ var_name = match.group(2)
149
+ array_size = match.group(3).strip()
150
+
151
+ size_bytes = _estimate_size(type_str, array_size)
152
+ if size_bytes:
153
+ if size_bytes >= 1024:
154
+ size_str = f" // {size_bytes / 1024:.1f} KB LDS"
155
+ else:
156
+ size_str = f" // {size_bytes} bytes LDS"
157
+
158
+ # Position at end of declaration
159
+ hint_pos = match.end()
160
+
161
+ hints.append(InlayHint(
162
+ position=Position(line=i, character=hint_pos),
163
+ label=size_str,
164
+ kind=InlayHintKind.Type,
165
+ padding_left=True,
166
+ padding_right=False
167
+ ))
168
+
169
+ return hints
170
+
171
+
172
+ def _parse_dim(dim_str: str) -> str | None:
173
+ """Parse a dimension string and return a human-readable description."""
174
+ dim_str = dim_str.strip()
175
+
176
+ # Simple number
177
+ if dim_str.isdigit():
178
+ return dim_str
179
+
180
+ # dim3(x, y, z)
181
+ if dim_str.startswith("dim3("):
182
+ return dim_str
183
+
184
+ # Variable or expression
185
+ if re.match(r'^[\w_]+$', dim_str):
186
+ return dim_str
187
+
188
+ return None
189
+
190
+
191
+ def _eval_dim(dim_str: str) -> int | None:
192
+ """Try to evaluate a dimension to an integer."""
193
+ dim_str = dim_str.strip()
194
+
195
+ # Simple number
196
+ if dim_str.isdigit():
197
+ return int(dim_str)
198
+
199
+ # dim3(x) or dim3(x, y) or dim3(x, y, z) - try to multiply
200
+ if dim_str.startswith("dim3(") and dim_str.endswith(")"):
201
+ inner = dim_str[5:-1]
202
+ parts = [p.strip() for p in inner.split(",")]
203
+ try:
204
+ total = 1
205
+ for p in parts:
206
+ if p.isdigit():
207
+ total *= int(p)
208
+ else:
209
+ return None # Can't evaluate variable
210
+ return total
211
+ except (ValueError, SyntaxError):
212
+ return None
213
+
214
+ return None
215
+
216
+
217
+ def _estimate_size(type_str: str, array_size: str) -> int | None:
218
+ """Estimate size in bytes for a shared memory allocation."""
219
+ type_sizes = {
220
+ 'char': 1, 'int8_t': 1, 'uint8_t': 1,
221
+ 'short': 2, 'int16_t': 2, 'uint16_t': 2, 'half': 2, '__half': 2,
222
+ 'int': 4, 'int32_t': 4, 'uint32_t': 4, 'float': 4, 'unsigned': 4,
223
+ 'long': 8, 'int64_t': 8, 'uint64_t': 8, 'double': 8,
224
+ 'float4': 16, 'float2': 8, 'int4': 16, 'int2': 8,
225
+ 'double2': 16, 'double4': 32,
226
+ }
227
+
228
+ # Find base type
229
+ base_type = type_str.strip()
230
+ type_size = None
231
+ for known_type, size in type_sizes.items():
232
+ if known_type in base_type:
233
+ type_size = size
234
+ break
235
+
236
+ if type_size is None:
237
+ type_size = 4 # Default to 4 bytes
238
+
239
+ # Try to evaluate array size
240
+ try:
241
+ # Handle simple expressions
242
+ arr_size = eval(array_size.replace('*', ' * '))
243
+ return type_size * arr_size
244
+ except (ValueError, SyntaxError, NameError):
245
+ return None
@@ -0,0 +1,224 @@
1
+
2
+ import re
3
+
4
+ from lsprotocol.types import SemanticTokens, SemanticTokensLegend
5
+
6
+ from ..languages.registry import get_language_registry
7
+
8
+ TOKEN_TYPES = [
9
+ "kernel", # 0: GPU kernel functions (__global__)
10
+ "layout", # 1: Layout variables
11
+ "struct", # 2: Structs
12
+ "decorator", # 3: Python decorators (@cute.kernel)
13
+ "keyword_gpu", # 4: GPU keywords (__global__, __device__, __shared__)
14
+ "keyword_memory", # 5: Memory qualifiers (__shared__, __constant__)
15
+ "function_hip_api", # 6: HIP API calls (hipMalloc, etc.)
16
+ "function_intrinsic", # 7: Wavefront intrinsics (__shfl, __ballot)
17
+ "device_function", # 8: __device__ functions
18
+ ]
19
+
20
+ TOKEN_MODIFIERS = [
21
+ "definition",
22
+ "declaration",
23
+ ]
24
+
25
+ SEMANTIC_TOKENS_LEGEND = SemanticTokensLegend(
26
+ token_types=TOKEN_TYPES,
27
+ token_modifiers=TOKEN_MODIFIERS
28
+ )
29
+
30
+ # HIP-specific patterns for semantic highlighting
31
+ HIP_KEYWORD_PATTERN = re.compile(r'\b(__global__|__device__|__host__|__forceinline__)\b')
32
+ HIP_MEMORY_KEYWORD_PATTERN = re.compile(r'\b(__shared__|__constant__|__restrict__)\b')
33
+ HIP_LAUNCH_BOUNDS_PATTERN = re.compile(r'__launch_bounds__\s*\([^)]+\)')
34
+ HIP_API_PATTERN = re.compile(
35
+ r'\b(hipMalloc|hipMallocManaged|hipMallocAsync|hipFree|hipFreeAsync|'
36
+ r'hipMemcpy|hipMemcpyAsync|hipMemset|hipMemsetAsync|'
37
+ r'hipHostMalloc|hipHostFree|hipHostRegister|hipHostUnregister|'
38
+ r'hipDeviceSynchronize|hipStreamSynchronize|'
39
+ r'hipStreamCreate|hipStreamDestroy|hipStreamCreateWithFlags|'
40
+ r'hipEventCreate|hipEventDestroy|hipEventRecord|hipEventSynchronize|hipEventElapsedTime|'
41
+ r'hipSetDevice|hipGetDevice|hipGetDeviceCount|hipGetDeviceProperties|'
42
+ r'hipLaunchKernelGGL|hipLaunchCooperativeKernel|'
43
+ r'hipGetLastError|hipPeekAtLastError|hipGetErrorString|hipGetErrorName)\b'
44
+ )
45
+ HIP_INTRINSIC_PATTERN = re.compile(
46
+ r'\b(__shfl|__shfl_down|__shfl_up|__shfl_xor|__shfl_sync|'
47
+ r'__ballot|__any|__all|__activemask|'
48
+ r'__syncthreads|__syncwarp|__threadfence|__threadfence_block|__threadfence_system|'
49
+ r'atomicAdd|atomicSub|atomicMax|atomicMin|atomicExch|atomicCAS|atomicAnd|atomicOr|atomicXor|'
50
+ r'__popc|__popcll|__clz|__clzll|__ffs|__ffsll|'
51
+ r'__float2half|__half2float|__float2int_rn|__int2float_rn|'
52
+ r'__ldg|__ldcg|__ldca|__ldcs)\b'
53
+ )
54
+
55
+
56
+ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
57
+ registry = get_language_registry()
58
+ language_info = registry.parse_file(uri, content)
59
+
60
+ if not language_info:
61
+ return SemanticTokens(data=[])
62
+
63
+ # Collect all tokens with their positions
64
+ # We'll sort them later to ensure proper delta calculation
65
+ token_entries: list[tuple[int, int, int, int, int]] = [] # (line, char, length, type, modifier)
66
+ lines = content.split("\n")
67
+
68
+ # Add kernel tokens
69
+ for kernel in language_info.kernels:
70
+ if kernel.line >= len(lines):
71
+ continue
72
+
73
+ kernel_line = lines[kernel.line]
74
+ name_start = kernel_line.find(kernel.name)
75
+
76
+ if name_start >= 0:
77
+ token_entries.append((
78
+ kernel.line,
79
+ name_start,
80
+ len(kernel.name),
81
+ TOKEN_TYPES.index("kernel"),
82
+ 0
83
+ ))
84
+
85
+ # Add layout tokens
86
+ for layout in language_info.layouts:
87
+ if layout.line >= len(lines):
88
+ continue
89
+
90
+ layout_line = lines[layout.line]
91
+ name_start = layout_line.find(layout.name)
92
+
93
+ if name_start >= 0:
94
+ token_entries.append((
95
+ layout.line,
96
+ name_start,
97
+ len(layout.name),
98
+ TOKEN_TYPES.index("layout"),
99
+ 0
100
+ ))
101
+
102
+ # Add struct tokens
103
+ for struct in language_info.structs:
104
+ if struct.line >= len(lines):
105
+ continue
106
+
107
+ struct_line = lines[struct.line]
108
+ name_start = struct_line.find(struct.name)
109
+
110
+ if name_start >= 0:
111
+ token_entries.append((
112
+ struct.line,
113
+ name_start,
114
+ len(struct.name),
115
+ TOKEN_TYPES.index("struct"),
116
+ 0
117
+ ))
118
+
119
+ # Add CuTeDSL decorator tokens
120
+ for i, line in enumerate(lines):
121
+ if "@cute.kernel" in line or "@cute.struct" in line:
122
+ decorator_start = line.find("@")
123
+ if decorator_start >= 0:
124
+ decorator_end = line.find(" ", decorator_start)
125
+ if decorator_end == -1:
126
+ decorator_end = len(line)
127
+
128
+ token_entries.append((
129
+ i,
130
+ decorator_start,
131
+ decorator_end - decorator_start,
132
+ TOKEN_TYPES.index("decorator"),
133
+ 0
134
+ ))
135
+
136
+ # Add HIP-specific tokens if this is a HIP or CUDA file
137
+ if language_info.language in ("hip", "cuda", "cpp"):
138
+ token_entries.extend(_get_hip_tokens(lines))
139
+
140
+ # Sort tokens by position (line, then character)
141
+ token_entries.sort(key=lambda x: (x[0], x[1]))
142
+
143
+ # Convert to delta-encoded format
144
+ tokens: list[int] = []
145
+ prev_line = 0
146
+ prev_char = 0
147
+
148
+ for line, char, length, token_type, modifier in token_entries:
149
+ delta_line = line - prev_line
150
+ delta_char = char - (prev_char if delta_line == 0 else 0)
151
+
152
+ tokens.extend([
153
+ delta_line,
154
+ delta_char,
155
+ length,
156
+ token_type,
157
+ modifier
158
+ ])
159
+
160
+ prev_line = line
161
+ prev_char = char
162
+
163
+ return SemanticTokens(data=tokens)
164
+
165
+
166
+ def _get_hip_tokens(lines: list[str]) -> list[tuple[int, int, int, int, int]]:
167
+ """Extract HIP-specific semantic tokens from code.
168
+
169
+ Returns list of (line, char, length, token_type, modifier) tuples.
170
+ """
171
+ token_entries: list[tuple[int, int, int, int, int]] = []
172
+
173
+ for i, line in enumerate(lines):
174
+ # GPU keywords (__global__, __device__, etc.)
175
+ for match in HIP_KEYWORD_PATTERN.finditer(line):
176
+ token_entries.append((
177
+ i,
178
+ match.start(),
179
+ len(match.group()),
180
+ TOKEN_TYPES.index("keyword_gpu"),
181
+ 0
182
+ ))
183
+
184
+ # Memory keywords (__shared__, __constant__)
185
+ for match in HIP_MEMORY_KEYWORD_PATTERN.finditer(line):
186
+ token_entries.append((
187
+ i,
188
+ match.start(),
189
+ len(match.group()),
190
+ TOKEN_TYPES.index("keyword_memory"),
191
+ 0
192
+ ))
193
+
194
+ # __launch_bounds__
195
+ for match in HIP_LAUNCH_BOUNDS_PATTERN.finditer(line):
196
+ token_entries.append((
197
+ i,
198
+ match.start(),
199
+ len(match.group()),
200
+ TOKEN_TYPES.index("keyword_gpu"),
201
+ 0
202
+ ))
203
+
204
+ # HIP API functions
205
+ for match in HIP_API_PATTERN.finditer(line):
206
+ token_entries.append((
207
+ i,
208
+ match.start(),
209
+ len(match.group()),
210
+ TOKEN_TYPES.index("function_hip_api"),
211
+ 0
212
+ ))
213
+
214
+ # Wavefront intrinsics
215
+ for match in HIP_INTRINSIC_PATTERN.finditer(line):
216
+ token_entries.append((
217
+ i,
218
+ match.start(),
219
+ len(match.group()),
220
+ TOKEN_TYPES.index("function_intrinsic"),
221
+ 0
222
+ ))
223
+
224
+ return token_entries
@@ -0,0 +1,87 @@
1
+
2
+ from lsprotocol.types import Location, Position, Range, SymbolKind, WorkspaceSymbol
3
+
4
+ from ..languages.registry import get_language_registry
5
+
6
+
7
+ def _matches_query(name: str, query: str) -> bool:
8
+ if not query:
9
+ return True
10
+
11
+ name_lower = name.lower()
12
+ query_lower = query.lower()
13
+
14
+ query_idx = 0
15
+ for char in name_lower:
16
+ if query_idx < len(query_lower) and char == query_lower[query_idx]:
17
+ query_idx += 1
18
+
19
+ return query_idx == len(query_lower)
20
+
21
+
22
+ def handle_workspace_symbol(query: str) -> list[WorkspaceSymbol]:
23
+ registry = get_language_registry()
24
+ symbols: list[WorkspaceSymbol] = []
25
+
26
+ return symbols
27
+
28
+
29
+ def handle_workspace_symbol_with_documents(
30
+ query: str,
31
+ document_contents: dict[str, str]
32
+ ) -> list[WorkspaceSymbol]:
33
+ registry = get_language_registry()
34
+ symbols: list[WorkspaceSymbol] = []
35
+
36
+ for uri, content in document_contents.items():
37
+ language_info = registry.parse_file(uri, content)
38
+
39
+ if not language_info:
40
+ continue
41
+
42
+ for kernel in language_info.kernels:
43
+ if _matches_query(kernel.name, query):
44
+ symbols.append(WorkspaceSymbol(
45
+ name=kernel.name,
46
+ kind=SymbolKind.Function,
47
+ location=Location(
48
+ uri=uri,
49
+ range=Range(
50
+ start=Position(line=kernel.line, character=0),
51
+ end=Position(line=kernel.line, character=0)
52
+ )
53
+ ),
54
+ container_name=f"GPU Kernel ({registry.get_language_name(kernel.language)})"
55
+ ))
56
+
57
+ for layout in language_info.layouts:
58
+ if _matches_query(layout.name, query):
59
+ symbols.append(WorkspaceSymbol(
60
+ name=layout.name,
61
+ kind=SymbolKind.Variable,
62
+ location=Location(
63
+ uri=uri,
64
+ range=Range(
65
+ start=Position(line=layout.line, character=0),
66
+ end=Position(line=layout.line, character=0)
67
+ )
68
+ ),
69
+ container_name="Layout"
70
+ ))
71
+
72
+ for struct in language_info.structs:
73
+ if _matches_query(struct.name, query):
74
+ symbols.append(WorkspaceSymbol(
75
+ name=struct.name,
76
+ kind=SymbolKind.Struct,
77
+ location=Location(
78
+ uri=uri,
79
+ range=Range(
80
+ start=Position(line=struct.line, character=0),
81
+ end=Position(line=struct.line, character=0)
82
+ )
83
+ ),
84
+ container_name=f"Struct ({registry.get_language_name(struct.language)})"
85
+ ))
86
+
87
+ return symbols