wafer-lsp 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,30 @@
1
+ from .code_action import handle_code_action
2
+ from .code_lens import handle_code_lens
3
+ from .completion import handle_completion
4
+ from .diagnostics import handle_diagnostics
5
+ from .document_symbol import handle_document_symbol
6
+ from .hip_diagnostics import (
7
+ HIPDiagnosticsProvider,
8
+ create_hip_diagnostics_provider,
9
+ get_hip_diagnostics,
10
+ )
11
+ from .hover import handle_hover
12
+ from .inlay_hint import handle_inlay_hint
13
+ from .semantic_tokens import handle_semantic_tokens, SEMANTIC_TOKENS_LEGEND
14
+ from .workspace_symbol import handle_workspace_symbol
15
+
16
+ __all__ = [
17
+ "handle_code_action",
18
+ "handle_code_lens",
19
+ "handle_completion",
20
+ "handle_diagnostics",
21
+ "handle_document_symbol",
22
+ "handle_hover",
23
+ "handle_inlay_hint",
24
+ "handle_semantic_tokens",
25
+ "handle_workspace_symbol",
26
+ "SEMANTIC_TOKENS_LEGEND",
27
+ "HIPDiagnosticsProvider",
28
+ "create_hip_diagnostics_provider",
29
+ "get_hip_diagnostics",
30
+ ]
@@ -2,9 +2,24 @@
2
2
  from lsprotocol.types import Diagnostic
3
3
 
4
4
  from ..languages.registry import get_language_registry
5
+ from .hip_diagnostics import get_hip_diagnostics
5
6
 
6
7
 
7
- def handle_diagnostics(uri: str, content: str) -> list[Diagnostic]:
8
+ def handle_diagnostics(
9
+ uri: str,
10
+ content: str,
11
+ enable_wavefront_diagnostics: bool = True,
12
+ ) -> list[Diagnostic]:
13
+ """Handle diagnostics for a document.
14
+
15
+ Args:
16
+ uri: Document URI
17
+ content: Document content
18
+ enable_wavefront_diagnostics: Whether to enable HIP wavefront warnings
19
+
20
+ Returns:
21
+ List of diagnostics
22
+ """
8
23
  diagnostics: list[Diagnostic] = []
9
24
 
10
25
  registry = get_language_registry()
@@ -13,4 +28,14 @@ def handle_diagnostics(uri: str, content: str) -> list[Diagnostic]:
13
28
  if not language_info:
14
29
  return diagnostics
15
30
 
31
+ # Add HIP-specific diagnostics for HIP, CUDA, and C++ files
32
+ # (C++ files might contain HIP code if they have HIP markers)
33
+ if language_info.language in ("hip", "cuda", "cpp"):
34
+ hip_diagnostics = get_hip_diagnostics(
35
+ content,
36
+ uri,
37
+ enable_wavefront_diagnostics=enable_wavefront_diagnostics,
38
+ )
39
+ diagnostics.extend(hip_diagnostics)
40
+
16
41
  return diagnostics
@@ -12,9 +12,10 @@ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
12
12
  return []
13
13
 
14
14
  symbols: list[DocumentSymbol] = []
15
+ lines = content.split("\n")
15
16
 
17
+ # Kernels
16
18
  for kernel in language_info.kernels:
17
- lines = content.split("\n")
18
19
  kernel_line = lines[kernel.line] if kernel.line < len(lines) else ""
19
20
  name_start = kernel_line.find(kernel.name)
20
21
  name_end = name_start + len(kernel.name) if name_start >= 0 else 0
@@ -28,16 +29,24 @@ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
28
29
  end=Position(line=min(kernel.line + 10, len(lines) - 1), character=0)
29
30
  )
30
31
 
32
+ # Different detail based on language
33
+ if kernel.language == "hip":
34
+ detail = "🚀 HIP Kernel (AMD GPU)"
35
+ elif kernel.language in ("cuda", "cpp"):
36
+ detail = "🚀 CUDA Kernel"
37
+ else:
38
+ detail = f"GPU Kernel ({registry.get_language_name(kernel.language)})"
39
+
31
40
  symbols.append(DocumentSymbol(
32
41
  name=kernel.name,
33
42
  kind=SymbolKind.Function,
34
43
  range=full_range,
35
44
  selection_range=selection_range,
36
- detail=f"GPU Kernel ({registry.get_language_name(kernel.language)})",
45
+ detail=detail,
37
46
  ))
38
47
 
48
+ # Layouts (CuTeDSL)
39
49
  for layout in language_info.layouts:
40
- lines = content.split("\n")
41
50
  layout_line = lines[layout.line] if layout.line < len(lines) else ""
42
51
  name_start = layout_line.find(layout.name)
43
52
  name_end = name_start + len(layout.name) if name_start >= 0 else 0
@@ -61,8 +70,8 @@ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
61
70
  detail=detail,
62
71
  ))
63
72
 
73
+ # Structs
64
74
  for struct in language_info.structs:
65
- lines = content.split("\n")
66
75
  struct_line = lines[struct.line] if struct.line < len(lines) else ""
67
76
  name_start = struct_line.find(struct.name)
68
77
  name_end = name_start + len(struct.name) if name_start >= 0 else 0
@@ -84,4 +93,84 @@ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
84
93
  detail=f"Struct ({registry.get_language_name(struct.language)})",
85
94
  ))
86
95
 
96
+ # HIP-specific: Device functions and shared memory
97
+ if language_info.language in ("hip", "cuda", "cpp"):
98
+ symbols.extend(_get_hip_symbols(language_info.raw_data, lines))
99
+
100
+ return symbols
101
+
102
+
103
+ def _get_hip_symbols(raw_data: dict, lines: list[str]) -> list[DocumentSymbol]:
104
+ """Extract HIP-specific symbols: device functions, shared memory allocations."""
105
+ symbols: list[DocumentSymbol] = []
106
+
107
+ # Device functions (from HIP parser)
108
+ device_functions = raw_data.get("device_functions", [])
109
+ for func in device_functions:
110
+ if not hasattr(func, "line") or not hasattr(func, "name"):
111
+ continue
112
+
113
+ func_line = lines[func.line] if func.line < len(lines) else ""
114
+ name_start = func_line.find(func.name)
115
+ name_end = name_start + len(func.name) if name_start >= 0 else 0
116
+
117
+ end_line = getattr(func, "end_line", func.line + 10)
118
+
119
+ selection_range = Range(
120
+ start=Position(line=func.line, character=max(0, name_start)),
121
+ end=Position(line=func.line, character=name_end)
122
+ )
123
+ full_range = Range(
124
+ start=Position(line=func.line, character=0),
125
+ end=Position(line=min(end_line, len(lines) - 1), character=0)
126
+ )
127
+
128
+ return_type = getattr(func, "return_type", "void")
129
+ detail = f"⚡ Device Function -> {return_type}"
130
+
131
+ symbols.append(DocumentSymbol(
132
+ name=func.name,
133
+ kind=SymbolKind.Method,
134
+ range=full_range,
135
+ selection_range=selection_range,
136
+ detail=detail,
137
+ ))
138
+
139
+ # Shared memory allocations
140
+ shared_memory = raw_data.get("shared_memory", [])
141
+ for shared in shared_memory:
142
+ if not hasattr(shared, "line") or not hasattr(shared, "name"):
143
+ continue
144
+
145
+ shared_line = lines[shared.line] if shared.line < len(lines) else ""
146
+ name_start = shared_line.find(shared.name)
147
+ name_end = name_start + len(shared.name) if name_start >= 0 else 0
148
+
149
+ selection_range = Range(
150
+ start=Position(line=shared.line, character=max(0, name_start)),
151
+ end=Position(line=shared.line, character=name_end)
152
+ )
153
+ full_range = Range(
154
+ start=Position(line=shared.line, character=0),
155
+ end=Position(line=shared.line, character=len(shared_line))
156
+ )
157
+
158
+ type_str = getattr(shared, "type_str", "")
159
+ size_bytes = getattr(shared, "size_bytes", None)
160
+ if size_bytes:
161
+ if size_bytes >= 1024:
162
+ detail = f"📦 __shared__ {type_str} ({size_bytes / 1024:.1f} KB)"
163
+ else:
164
+ detail = f"📦 __shared__ {type_str} ({size_bytes} bytes)"
165
+ else:
166
+ detail = f"📦 __shared__ {type_str}"
167
+
168
+ symbols.append(DocumentSymbol(
169
+ name=shared.name,
170
+ kind=SymbolKind.Variable,
171
+ range=full_range,
172
+ selection_range=selection_range,
173
+ detail=detail,
174
+ ))
175
+
87
176
  return symbols
@@ -0,0 +1,303 @@
1
+ """
2
+ HIP Diagnostics Provider.
3
+
4
+ Generates warnings for common wavefront-related mistakes in HIP code.
5
+ AMD GPUs use 64-thread wavefronts (CDNA) vs CUDA's 32-thread warps,
6
+ and code written for CUDA often has hard-coded assumptions that break on AMD.
7
+
8
+ Diagnostic Codes:
9
+ - HIP001: Incorrect wavefront size assumption (using 32 instead of 64)
10
+ - HIP002: Wavefront intrinsic misuse
11
+ - HIP003: Ballot result handling errors (32-bit vs 64-bit)
12
+ - HIP004: Lane/wavefront index calculation errors
13
+ """
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any
17
+
18
+ from lsprotocol.types import (
19
+ Diagnostic,
20
+ DiagnosticSeverity,
21
+ DiagnosticTag,
22
+ Position,
23
+ Range,
24
+ )
25
+
26
+ from ..parsers.hip_parser import HIPParser, WavefrontPattern
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class HIPDiagnosticInfo:
31
+ """Information about a HIP diagnostic."""
32
+ code: str
33
+ message: str
34
+ severity: DiagnosticSeverity
35
+ suggestion: str
36
+ doc_link: str | None = None
37
+
38
+
39
+ # Diagnostic code to info mapping
40
+ DIAGNOSTIC_INFO: dict[str, HIPDiagnosticInfo] = {
41
+ "HIP001": HIPDiagnosticInfo(
42
+ code="HIP001",
43
+ message="Potential wavefront size mismatch: AMD GPUs use 64-thread wavefronts, not 32.",
44
+ severity=DiagnosticSeverity.Warning,
45
+ suggestion="Use `warpSize` variable or `__AMDGCN_WAVEFRONT_SIZE__` macro instead of hard-coding 32.",
46
+ doc_link="https://rocm.docs.amd.com/en/latest/",
47
+ ),
48
+ "HIP002": HIPDiagnosticInfo(
49
+ code="HIP002",
50
+ message="Wavefront intrinsic may behave differently on AMD GPUs.",
51
+ severity=DiagnosticSeverity.Warning,
52
+ suggestion="AMD wavefronts have 64 threads. Shuffle/ballot operations operate on the full wavefront. Adjust your code accordingly.",
53
+ ),
54
+ "HIP003": HIPDiagnosticInfo(
55
+ code="HIP003",
56
+ message="Ballot result comparison uses 32-bit mask, but AMD's __ballot returns 64-bit.",
57
+ severity=DiagnosticSeverity.Warning,
58
+ suggestion="Use `0xFFFFFFFFFFFFFFFFull` for 64-bit comparison. Use `__popcll()` to count bits.",
59
+ ),
60
+ "HIP004": HIPDiagnosticInfo(
61
+ code="HIP004",
62
+ message="Lane/wavefront index calculation assumes 32-thread warps.",
63
+ severity=DiagnosticSeverity.Warning,
64
+ suggestion="Use `warpSize` (64 on AMD CDNA) instead of hard-coding 32 for lane/warp calculations.",
65
+ ),
66
+ }
67
+
68
+
69
+ # Map pattern types to diagnostic codes
70
+ PATTERN_TO_DIAGNOSTIC: dict[str, str] = {
71
+ "warp_size_32": "HIP001",
72
+ "shuffle_mask": "HIP002",
73
+ "ballot_32bit": "HIP003",
74
+ "lane_calc_32": "HIP004",
75
+ }
76
+
77
+
78
+ class HIPDiagnosticsProvider:
79
+ """Provides diagnostics for HIP code.
80
+
81
+ Detects patterns that indicate potential issues when running on AMD GPUs,
82
+ particularly around wavefront size differences (64 vs 32).
83
+ """
84
+
85
+ def __init__(self):
86
+ self._parser = HIPParser()
87
+
88
+ def get_diagnostics(
89
+ self,
90
+ content: str,
91
+ uri: str,
92
+ enable_wavefront_diagnostics: bool = True,
93
+ ) -> list[Diagnostic]:
94
+ """Generate diagnostics for HIP code.
95
+
96
+ Args:
97
+ content: The document content
98
+ uri: The document URI
99
+ enable_wavefront_diagnostics: Whether to check for wavefront issues
100
+
101
+ Returns:
102
+ List of LSP Diagnostic objects
103
+ """
104
+ diagnostics: list[Diagnostic] = []
105
+
106
+ if not enable_wavefront_diagnostics:
107
+ return diagnostics
108
+
109
+ # Parse the file to find wavefront patterns
110
+ parsed = self._parser.parse_file(content)
111
+ wavefront_patterns = parsed.get("wavefront_patterns", [])
112
+
113
+ lines = content.split("\n")
114
+
115
+ for pattern in wavefront_patterns:
116
+ diagnostic = self._create_diagnostic_from_pattern(pattern, lines)
117
+ if diagnostic:
118
+ diagnostics.append(diagnostic)
119
+
120
+ # Additional checks that require more context
121
+ diagnostics.extend(self._check_warp_reduction_patterns(content, lines))
122
+ diagnostics.extend(self._check_shuffle_width_parameters(content, lines))
123
+
124
+ return diagnostics
125
+
126
+ def _create_diagnostic_from_pattern(
127
+ self,
128
+ pattern: WavefrontPattern,
129
+ lines: list[str],
130
+ ) -> Diagnostic | None:
131
+ """Create a diagnostic from a detected wavefront pattern."""
132
+ diag_code = PATTERN_TO_DIAGNOSTIC.get(pattern.pattern_type)
133
+ if not diag_code:
134
+ return None
135
+
136
+ diag_info = DIAGNOSTIC_INFO.get(diag_code)
137
+ if not diag_info:
138
+ return None
139
+
140
+ # Find the problematic value in the line to highlight it
141
+ line = lines[pattern.line] if pattern.line < len(lines) else ""
142
+ start_char = line.find(pattern.problematic_value)
143
+ if start_char == -1:
144
+ start_char = 0
145
+ end_char = start_char + len(pattern.problematic_value)
146
+
147
+ range_ = Range(
148
+ start=Position(line=pattern.line, character=start_char),
149
+ end=Position(line=pattern.line, character=end_char),
150
+ )
151
+
152
+ message = self._format_diagnostic_message(
153
+ diag_info,
154
+ pattern.problematic_value,
155
+ pattern.code_snippet,
156
+ )
157
+
158
+ return Diagnostic(
159
+ range=range_,
160
+ severity=diag_info.severity,
161
+ code=diag_info.code,
162
+ source="wafer-hip",
163
+ message=message,
164
+ data={"suggestion": diag_info.suggestion},
165
+ )
166
+
167
+ def _format_diagnostic_message(
168
+ self,
169
+ diag_info: HIPDiagnosticInfo,
170
+ problematic_value: str,
171
+ code_snippet: str,
172
+ ) -> str:
173
+ """Format the diagnostic message with context."""
174
+ msg = diag_info.message
175
+
176
+ # Add specific guidance based on the problematic value
177
+ if problematic_value == "32":
178
+ msg += f"\n\nFound: `{problematic_value}` (assumes CUDA warp size)"
179
+ msg += "\nExpected: `warpSize` or `64` for AMD CDNA GPUs"
180
+ elif problematic_value == "31":
181
+ msg += f"\n\nFound: `& {problematic_value}` (assumes 32-thread warp)"
182
+ msg += "\nExpected: `& (warpSize - 1)` or `& 63` for AMD CDNA GPUs"
183
+ elif problematic_value == "0xFFFFFFFF":
184
+ msg += f"\n\nFound: `{problematic_value}` (32-bit mask)"
185
+ msg += "\nExpected: `0xFFFFFFFFFFFFFFFFull` (64-bit mask for AMD)"
186
+
187
+ msg += f"\n\n💡 {diag_info.suggestion}"
188
+
189
+ return msg
190
+
191
+ def _check_warp_reduction_patterns(
192
+ self,
193
+ content: str,
194
+ lines: list[str],
195
+ ) -> list[Diagnostic]:
196
+ """Check for warp reduction patterns that start from 16 instead of 32."""
197
+ diagnostics: list[Diagnostic] = []
198
+
199
+ # Look for reduction loops that start from 16 (half of CUDA warp)
200
+ # This is a common CUDA pattern: for (int i = 16; i > 0; i >>= 1)
201
+ import re
202
+
203
+ pattern = re.compile(
204
+ r'for\s*\(\s*(?:int|unsigned)\s+\w+\s*=\s*16\s*;[^;]*;\s*\w+\s*>>=\s*1\s*\)'
205
+ )
206
+
207
+ for i, line in enumerate(lines):
208
+ match = pattern.search(line)
209
+ if match:
210
+ diag_info = DIAGNOSTIC_INFO["HIP001"]
211
+
212
+ range_ = Range(
213
+ start=Position(line=i, character=match.start()),
214
+ end=Position(line=i, character=match.end()),
215
+ )
216
+
217
+ diagnostics.append(Diagnostic(
218
+ range=range_,
219
+ severity=DiagnosticSeverity.Warning,
220
+ code="HIP001",
221
+ source="wafer-hip",
222
+ message=(
223
+ "Warp reduction loop starts from 16 (half CUDA warp).\n\n"
224
+ "For AMD's 64-thread wavefronts, start from 32:\n"
225
+ "`for (int offset = 32; offset > 0; offset >>= 1)`\n\n"
226
+ "Or better, use `warpSize / 2` for portability."
227
+ ),
228
+ ))
229
+
230
+ return diagnostics
231
+
232
+ def _check_shuffle_width_parameters(
233
+ self,
234
+ content: str,
235
+ lines: list[str],
236
+ ) -> list[Diagnostic]:
237
+ """Check for shuffle operations with explicit width=32."""
238
+ diagnostics: list[Diagnostic] = []
239
+
240
+ import re
241
+
242
+ # Pattern for __shfl variants with explicit width parameter of 32
243
+ # __shfl_down(val, offset, 32)
244
+ patterns = [
245
+ (re.compile(r'__shfl_down\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_down"),
246
+ (re.compile(r'__shfl_up\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_up"),
247
+ (re.compile(r'__shfl_xor\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl_xor"),
248
+ (re.compile(r'__shfl\s*\([^)]*,\s*(\d+)\s*,\s*32\s*\)'), "__shfl"),
249
+ ]
250
+
251
+ for i, line in enumerate(lines):
252
+ for pattern, intrinsic_name in patterns:
253
+ match = pattern.search(line)
254
+ if match:
255
+ # Find position of "32" in the match
256
+ match_text = match.group(0)
257
+ pos_32 = match_text.rfind("32")
258
+ start_char = match.start() + pos_32
259
+
260
+ range_ = Range(
261
+ start=Position(line=i, character=start_char),
262
+ end=Position(line=i, character=start_char + 2),
263
+ )
264
+
265
+ diagnostics.append(Diagnostic(
266
+ range=range_,
267
+ severity=DiagnosticSeverity.Warning,
268
+ code="HIP002",
269
+ source="wafer-hip",
270
+ message=(
271
+ f"`{intrinsic_name}` with explicit width=32 assumes CUDA warp size.\n\n"
272
+ f"On AMD CDNA GPUs, wavefront size is 64. Either:\n"
273
+ f"• Remove the width parameter (defaults to `warpSize`)\n"
274
+ f"• Use `warpSize` as the width parameter\n\n"
275
+ f"Example: `{intrinsic_name}(val, offset)` or `{intrinsic_name}(val, offset, warpSize)`"
276
+ ),
277
+ ))
278
+
279
+ return diagnostics
280
+
281
+
282
+ def create_hip_diagnostics_provider() -> HIPDiagnosticsProvider:
283
+ """Create a HIP diagnostics provider instance."""
284
+ return HIPDiagnosticsProvider()
285
+
286
+
287
+ def get_hip_diagnostics(
288
+ content: str,
289
+ uri: str,
290
+ enable_wavefront_diagnostics: bool = True,
291
+ ) -> list[Diagnostic]:
292
+ """Convenience function to get HIP diagnostics for a file.
293
+
294
+ Args:
295
+ content: The document content
296
+ uri: The document URI
297
+ enable_wavefront_diagnostics: Whether to check for wavefront issues
298
+
299
+ Returns:
300
+ List of LSP Diagnostic objects
301
+ """
302
+ provider = create_hip_diagnostics_provider()
303
+ return provider.get_diagnostics(content, uri, enable_wavefront_diagnostics)
@@ -5,8 +5,20 @@ from ..analyzers.compiler_integration import get_analysis_for_kernel
5
5
  from ..analyzers.docs_index import DocsIndex
6
6
  from ..languages.registry import get_language_registry
7
7
  from ..languages.types import KernelInfo, LayoutInfo
8
+ from ..services.hip_hover_service import create_hip_hover_service
8
9
  from ..utils.lsp_helpers import get_decorator_at_position, get_word_at_position
9
10
 
11
+ # Lazy-initialized HIP hover service
12
+ _hip_hover_service = None
13
+
14
+
15
+ def _get_hip_hover_service():
16
+ """Get or create the HIP hover service singleton."""
17
+ global _hip_hover_service
18
+ if _hip_hover_service is None:
19
+ _hip_hover_service = create_hip_hover_service()
20
+ return _hip_hover_service
21
+
10
22
 
11
23
  def find_kernel_at_position(
12
24
  content: str, position: Position, uri: str
@@ -133,8 +145,31 @@ def format_decorator_hover(decorator_name: str, function_name: str | None = None
133
145
 
134
146
 
135
147
  def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
136
- test_message = "**HEYOOOOOO** 🎉\n\n"
137
-
148
+ """Handle hover requests for all supported languages.
149
+
150
+ For HIP/CUDA/C++ files, provides rich documentation for:
151
+ - HIP API functions (hipMalloc, hipMemcpy, etc.)
152
+ - Memory qualifiers (__device__, __shared__, etc.)
153
+ - Wavefront intrinsics (__shfl, __ballot, etc.)
154
+ - Thread indexing (threadIdx, blockIdx, etc.)
155
+ - Kernel and device functions
156
+
157
+ For CuTeDSL (Python) files:
158
+ - Decorators (@cute.kernel, @cute.struct)
159
+ - Layouts and kernel functions
160
+ """
161
+ # Detect language to route to appropriate hover handler
162
+ registry = get_language_registry()
163
+ language_info = registry.parse_file(uri, content)
164
+
165
+ # For HIP/CUDA/C++ files, try HIP hover service first
166
+ if language_info and language_info.language in ("hip", "cuda", "cpp"):
167
+ hip_service = _get_hip_hover_service()
168
+ hip_hover = hip_service.get_hover(content, position, uri)
169
+ if hip_hover:
170
+ return hip_hover
171
+
172
+ # CuTeDSL decorator handling
138
173
  decorator_info = get_decorator_at_position(content, position)
139
174
  if decorator_info:
140
175
  decorator_name, function_line = decorator_info
@@ -154,13 +189,13 @@ def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
154
189
  if class_name_end > class_name_start:
155
190
  function_name = func_line[class_name_start:class_name_end].strip()
156
191
 
157
- hover_content = test_message + format_decorator_hover(decorator_name, function_name)
192
+ hover_content = format_decorator_hover(decorator_name, function_name)
158
193
  return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
159
194
 
195
+ # CuTeDSL cute module
160
196
  word = get_word_at_position(content, position)
161
197
  if word == "cute" or word.startswith("cute."):
162
198
  hover_lines = [
163
- test_message,
164
199
  "**cutlass.cute**",
165
200
  "",
166
201
  "CuTeDSL (CUDA Unified Tensor Expression) library for GPU programming.",
@@ -176,24 +211,25 @@ def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
176
211
  hover_content = "\n".join(hover_lines)
177
212
  return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
178
213
 
214
+ # Kernel hover (all languages)
179
215
  kernel = find_kernel_at_position(content, position, uri)
180
216
  if kernel:
181
217
  analysis = get_analysis_for_kernel(uri, kernel.name)
182
218
 
183
219
  if analysis:
184
- hover_content = test_message + format_kernel_hover(kernel, analysis)
220
+ hover_content = format_kernel_hover(kernel, analysis)
185
221
  else:
186
- hover_content = test_message + format_kernel_hover_basic(kernel)
222
+ hover_content = format_kernel_hover_basic(kernel)
187
223
 
188
224
  return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
189
225
 
226
+ # Layout hover (CuTeDSL)
190
227
  layout = find_layout_at_position(content, position, uri)
191
228
  if layout:
192
229
  docs = DocsIndex()
193
230
  doc_link = docs.get_doc_for_concept("layout")
194
231
 
195
232
  hover_lines = [
196
- test_message,
197
233
  f"**Layout: {layout.name}**",
198
234
  ""
199
235
  ]
@@ -211,5 +247,5 @@ def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
211
247
 
212
248
  return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
213
249
 
214
- hover_content = test_message + "Hover is working! Move your cursor over symbols to see more info."
215
- return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
250
+ # No hover info found
251
+ return None