wafer-lsp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. wafer_lsp/__init__.py +1 -0
  2. wafer_lsp/__main__.py +9 -0
  3. wafer_lsp/analyzers/__init__.py +0 -0
  4. wafer_lsp/analyzers/compiler_integration.py +16 -0
  5. wafer_lsp/analyzers/docs_index.py +36 -0
  6. wafer_lsp/handlers/__init__.py +0 -0
  7. wafer_lsp/handlers/code_action.py +48 -0
  8. wafer_lsp/handlers/code_lens.py +48 -0
  9. wafer_lsp/handlers/completion.py +6 -0
  10. wafer_lsp/handlers/diagnostics.py +16 -0
  11. wafer_lsp/handlers/document_symbol.py +87 -0
  12. wafer_lsp/handlers/hover.py +215 -0
  13. wafer_lsp/handlers/inlay_hint.py +65 -0
  14. wafer_lsp/handlers/semantic_tokens.py +124 -0
  15. wafer_lsp/handlers/workspace_symbol.py +87 -0
  16. wafer_lsp/languages/README.md +195 -0
  17. wafer_lsp/languages/__init__.py +17 -0
  18. wafer_lsp/languages/converter.py +88 -0
  19. wafer_lsp/languages/detector.py +34 -0
  20. wafer_lsp/languages/parser_manager.py +33 -0
  21. wafer_lsp/languages/registry.py +99 -0
  22. wafer_lsp/languages/types.py +37 -0
  23. wafer_lsp/parsers/__init__.py +18 -0
  24. wafer_lsp/parsers/base_parser.py +9 -0
  25. wafer_lsp/parsers/cuda_parser.py +95 -0
  26. wafer_lsp/parsers/cutedsl_parser.py +114 -0
  27. wafer_lsp/server.py +58 -0
  28. wafer_lsp/services/__init__.py +21 -0
  29. wafer_lsp/services/analysis_service.py +22 -0
  30. wafer_lsp/services/docs_service.py +40 -0
  31. wafer_lsp/services/document_service.py +20 -0
  32. wafer_lsp/services/hover_service.py +237 -0
  33. wafer_lsp/services/language_registry_service.py +26 -0
  34. wafer_lsp/services/position_service.py +77 -0
  35. wafer_lsp/utils/__init__.py +0 -0
  36. wafer_lsp/utils/lsp_helpers.py +79 -0
  37. wafer_lsp-0.1.0.dist-info/METADATA +57 -0
  38. wafer_lsp-0.1.0.dist-info/RECORD +40 -0
  39. wafer_lsp-0.1.0.dist-info/WHEEL +4 -0
  40. wafer_lsp-0.1.0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,95 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from .base_parser import BaseParser
6
+
7
+
8
+ @dataclass
9
+ class CUDAKernel:
10
+ name: str
11
+ line: int
12
+ parameters: list[str]
13
+
14
+
15
+ class CUDAParser(BaseParser):
16
+
17
+ def parse_file(self, content: str) -> dict[str, Any]:
18
+ kernels: list[CUDAKernel] = []
19
+
20
+ pattern = r'__global__\s+(?:__device__\s+)?(?:void|.*?)\s+(\w+)\s*\('
21
+
22
+ for match in re.finditer(pattern, content):
23
+ line = content[:match.start()].count('\n')
24
+ kernel_name = match.group(1)
25
+
26
+ params = self._extract_parameters(content, match.end())
27
+
28
+ kernels.append(CUDAKernel(
29
+ name=kernel_name,
30
+ line=line,
31
+ parameters=params
32
+ ))
33
+
34
+ return {"kernels": kernels}
35
+
36
+ def _extract_parameters(self, content: str, start: int) -> list[str]:
37
+ if start >= len(content):
38
+ return []
39
+
40
+ depth = 0
41
+ param_start = start
42
+ param_end = start
43
+
44
+ for i in range(start, len(content)):
45
+ char = content[i]
46
+ if char == '(':
47
+ depth += 1
48
+ elif char == ')':
49
+ depth -= 1
50
+ if depth == 0:
51
+ param_end = i
52
+ break
53
+
54
+ if param_end == start:
55
+ return []
56
+
57
+ param_str = content[param_start:param_end + 1]
58
+
59
+ params: list[str] = []
60
+ current_param = ""
61
+ template_depth = 0
62
+ paren_depth = 0
63
+
64
+ for char in param_str[1:-1]:
65
+ if char == '<':
66
+ template_depth += 1
67
+ current_param += char
68
+ elif char == '>':
69
+ template_depth -= 1
70
+ current_param += char
71
+ elif char == '(':
72
+ paren_depth += 1
73
+ current_param += char
74
+ elif char == ')':
75
+ paren_depth -= 1
76
+ current_param += char
77
+ elif char == ',' and template_depth == 0 and paren_depth == 0:
78
+ param_name = current_param.strip()
79
+ if param_name:
80
+ parts = param_name.split()
81
+ if parts:
82
+ name = parts[-1].strip('*&')
83
+ params.append(name)
84
+ current_param = ""
85
+ else:
86
+ current_param += char
87
+
88
+ if current_param.strip():
89
+ param_name = current_param.strip()
90
+ parts = param_name.split()
91
+ if parts:
92
+ name = parts[-1].strip('*&')
93
+ params.append(name)
94
+
95
+ return params
@@ -0,0 +1,114 @@
1
+ import ast
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from .base_parser import BaseParser
6
+
7
+
8
+ @dataclass
9
+ class CuTeDSLKernel:
10
+ name: str
11
+ line: int
12
+ parameters: list[str]
13
+ docstring: str | None = None
14
+
15
+
16
+ @dataclass
17
+ class CuTeDSLLayout:
18
+ name: str
19
+ line: int
20
+ shape: str | None = None
21
+ stride: str | None = None
22
+
23
+
24
+ @dataclass
25
+ class CuTeDSLStruct:
26
+ name: str
27
+ line: int
28
+ docstring: str | None = None
29
+
30
+
31
+ class CuTeDSLParser(BaseParser):
32
+
33
+ def parse_file(self, content: str) -> dict[str, Any]:
34
+ try:
35
+ tree = ast.parse(content)
36
+ except SyntaxError:
37
+ return {"kernels": [], "layouts": [], "structs": []}
38
+
39
+ kernels: list[CuTeDSLKernel] = []
40
+ layouts: list[CuTeDSLLayout] = []
41
+ structs: list[CuTeDSLStruct] = []
42
+
43
+ for node in ast.walk(tree):
44
+ if isinstance(node, ast.FunctionDef):
45
+ if self._has_decorator(node, "cute.kernel"):
46
+ kernels.append(self._extract_kernel(node))
47
+ elif isinstance(node, ast.ClassDef):
48
+ if self._has_decorator(node, "cute.struct"):
49
+ structs.append(self._extract_struct(node))
50
+ elif isinstance(node, ast.Assign):
51
+ layout = self._extract_layout(node, content)
52
+ if layout:
53
+ layouts.append(layout)
54
+
55
+ return {"kernels": kernels, "layouts": layouts, "structs": structs}
56
+
57
+ def _has_decorator(self, node: ast.FunctionDef | ast.ClassDef, decorator: str) -> bool:
58
+ for dec in node.decorator_list:
59
+ if isinstance(dec, ast.Attribute):
60
+ if isinstance(dec.value, ast.Name) and dec.value.id == "cute":
61
+ if dec.attr == decorator.split(".")[-1]:
62
+ return True
63
+ elif isinstance(dec, ast.Name):
64
+ if dec.id == decorator.split(".")[-1]:
65
+ return True
66
+ return False
67
+
68
+ def _extract_kernel(self, node: ast.FunctionDef) -> CuTeDSLKernel:
69
+ parameters = [arg.arg for arg in node.args.args]
70
+ docstring = ast.get_docstring(node)
71
+
72
+ return CuTeDSLKernel(
73
+ name=node.name,
74
+ line=node.lineno - 1,
75
+ parameters=parameters,
76
+ docstring=docstring,
77
+ )
78
+
79
+ def _extract_struct(self, node: ast.ClassDef) -> CuTeDSLStruct:
80
+ docstring = ast.get_docstring(node)
81
+
82
+ return CuTeDSLStruct(
83
+ name=node.name,
84
+ line=node.lineno - 1,
85
+ docstring=docstring,
86
+ )
87
+
88
+ def _extract_layout(self, node: ast.Assign, content: str) -> CuTeDSLLayout | None:
89
+ for target in node.targets:
90
+ if not isinstance(target, ast.Name):
91
+ continue
92
+
93
+ if isinstance(node.value, ast.Call):
94
+ call = node.value
95
+ if isinstance(call.func, ast.Attribute):
96
+ if isinstance(call.func.value, ast.Name):
97
+ if call.func.value.id == "cute" and call.func.attr == "make_layout":
98
+ shape_str = None
99
+ stride_str = None
100
+
101
+ if call.args:
102
+ try:
103
+ shape_str = ast.unparse(call.args[0])
104
+ except AttributeError:
105
+ shape_str = str(call.args[0])
106
+
107
+ return CuTeDSLLayout(
108
+ name=target.id,
109
+ line=node.lineno - 1,
110
+ shape=shape_str,
111
+ stride=stride_str,
112
+ )
113
+
114
+ return None
wafer_lsp/server.py ADDED
@@ -0,0 +1,58 @@
1
+ from lsprotocol.types import (
2
+ INITIALIZE,
3
+ TEXT_DOCUMENT_HOVER,
4
+ )
5
+ from pygls.lsp.server import LanguageServer
6
+
7
+ from .services import (
8
+ create_analysis_service,
9
+ create_docs_service,
10
+ create_document_service,
11
+ create_hover_service,
12
+ create_language_registry_service,
13
+ create_position_service,
14
+ )
15
+
16
+ language_registry_service = create_language_registry_service()
17
+ analysis_service = create_analysis_service()
18
+ docs_service = create_docs_service()
19
+ position_service = create_position_service()
20
+ hover_service = create_hover_service(
21
+ language_registry_service,
22
+ analysis_service,
23
+ docs_service,
24
+ position_service
25
+ )
26
+
27
+ server = LanguageServer("wafer-lsp", "1.0.0")
28
+ document_service = create_document_service(server)
29
+
30
+
31
+ @server.feature(INITIALIZE)
32
+ def initialize(params):
33
+ return {
34
+ "capabilities": {
35
+ "hoverProvider": True,
36
+ }
37
+ }
38
+
39
+
40
+ @server.feature(TEXT_DOCUMENT_HOVER)
41
+ def hover(params):
42
+ uri = params.text_document.uri
43
+ position = params.position
44
+ content = document_service.get_document_content(uri)
45
+ if not content:
46
+ test_message = "🎉🎉🎉 **HEYOOO!!! LSP IS DEFINITELY WORKING!!!** 🎉🎉🎉\n\n**THIS IS THE WAFER LSP SERVER!**\n\n**Document content not available, but LSP is running!**"
47
+ from lsprotocol.types import Hover, MarkupContent, MarkupKind
48
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=test_message))
49
+ result = hover_service.handle_hover(uri, position, content)
50
+ if not result:
51
+ test_message = "🎉🎉🎉 **HEYOOO!!! LSP IS DEFINITELY WORKING!!!** 🎉🎉🎉\n\n**THIS IS THE WAFER LSP SERVER!**\n\n**Hover service returned None, but LSP is running!**"
52
+ from lsprotocol.types import Hover, MarkupContent, MarkupKind
53
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=test_message))
54
+ return result
55
+
56
+
57
+ if __name__ == "__main__":
58
+ server.start_io()
@@ -0,0 +1,21 @@
1
+ from .analysis_service import AnalysisService, create_analysis_service
2
+ from .docs_service import DocsService, create_docs_service
3
+ from .document_service import DocumentService, create_document_service
4
+ from .hover_service import HoverService, create_hover_service
5
+ from .language_registry_service import LanguageRegistryService, create_language_registry_service
6
+ from .position_service import PositionService, create_position_service
7
+
8
+ __all__ = [
9
+ "AnalysisService",
10
+ "DocsService",
11
+ "DocumentService",
12
+ "HoverService",
13
+ "LanguageRegistryService",
14
+ "PositionService",
15
+ "create_analysis_service",
16
+ "create_docs_service",
17
+ "create_document_service",
18
+ "create_hover_service",
19
+ "create_language_registry_service",
20
+ "create_position_service",
21
+ ]
@@ -0,0 +1,22 @@
1
+ from typing import Any
2
+
3
+
4
+ class AnalysisService:
5
+
6
+ def __init__(self):
7
+ self._cache: dict[str, dict[str, Any]] = {}
8
+
9
+ def get_analysis_for_kernel(self, uri: str, kernel_name: str) -> dict[str, Any] | None:
10
+ cache_key = f"{uri}:{kernel_name}"
11
+ return self._cache.get(cache_key)
12
+
13
+ def set_analysis_for_kernel(self, uri: str, kernel_name: str, analysis: dict[str, Any]):
14
+ cache_key = f"{uri}:{kernel_name}"
15
+ self._cache[cache_key] = analysis
16
+
17
+ def clear_cache(self):
18
+ self._cache.clear()
19
+
20
+
21
+ def create_analysis_service() -> AnalysisService:
22
+ return AnalysisService()
@@ -0,0 +1,40 @@
1
+ from pathlib import Path
2
+
3
+
4
+ class DocsService:
5
+
6
+ def __init__(self, docs_path: str | None = None):
7
+ if docs_path:
8
+ self.docs_path = Path(docs_path)
9
+ else:
10
+ self.docs_path = Path(__file__).parent.parent.parent.parent.parent / \
11
+ "curriculum" / "cutlass-docs" / "cutedsl-docs"
12
+
13
+ self.index = self._build_index()
14
+
15
+ def _build_index(self):
16
+ return {
17
+ "layout": [
18
+ "intro-to-cutedsl.md",
19
+ "partitioning-strategies-inner-outer-threadvalue.md"
20
+ ],
21
+ "TMA": ["blackwell-tutorial-fp16-gemm-0.md"],
22
+ "TMEM": ["colfax-blackwell-umma-tensor-memory-part1.md"],
23
+ "kernel": ["blackwell-tutorial-fp16-gemm-0.md"],
24
+ "struct": ["blackwell-tutorial-fp16-gemm-0.md"],
25
+ "pipeline": ["blackwell-tutorial-fp16-gemm-0.md"],
26
+ "MMA": ["mma-atoms-fundamentals-sm70-example.md"],
27
+ }
28
+
29
+ def get_doc_for_concept(self, concept: str) -> str | None:
30
+ concept_lower = concept.lower()
31
+ if self.index.get(concept_lower):
32
+ doc_file = self.index[concept_lower][0]
33
+ doc_path = self.docs_path / doc_file
34
+ if doc_path.exists():
35
+ return str(doc_path)
36
+ return None
37
+
38
+
39
+ def create_docs_service(docs_path: str | None = None) -> DocsService:
40
+ return DocsService(docs_path)
@@ -0,0 +1,20 @@
1
+ from pygls.lsp.server import LanguageServer
2
+
3
+
4
+ class DocumentService:
5
+
6
+ def __init__(self, server: LanguageServer):
7
+ self._server = server
8
+
9
+ def get_document(self, uri: str):
10
+ return self._server.workspace.text_documents.get(uri)
11
+
12
+ def get_document_content(self, uri: str) -> str:
13
+ doc = self.get_document(uri)
14
+ if doc is None:
15
+ return ''
16
+ return getattr(doc, 'text', getattr(doc, 'source', ''))
17
+
18
+
19
+ def create_document_service(server: LanguageServer) -> DocumentService:
20
+ return DocumentService(server)
@@ -0,0 +1,237 @@
1
+
2
+ from lsprotocol.types import Hover, MarkupContent, MarkupKind, Position
3
+
4
+ from ..languages.types import KernelInfo, LayoutInfo
5
+ from .analysis_service import AnalysisService
6
+ from .docs_service import DocsService
7
+ from .language_registry_service import LanguageRegistryService
8
+ from .position_service import PositionService
9
+
10
+
11
+ class HoverService:
12
+
13
+ def __init__(
14
+ self,
15
+ language_registry: LanguageRegistryService,
16
+ analysis_service: AnalysisService,
17
+ docs_service: DocsService,
18
+ position_service: PositionService
19
+ ):
20
+ self._language_registry = language_registry
21
+ self._analysis_service = analysis_service
22
+ self._docs_service = docs_service
23
+ self._position_service = position_service
24
+
25
+ def handle_hover(self, uri: str, position: Position, content: str) -> Hover | None:
26
+ test_message = "🎉🎉🎉 **HEYOOO!!! LSP IS DEFINITELY WORKING!!!** 🎉🎉🎉\n\n**THIS IS THE WAFER LSP SERVER!**\n\n"
27
+
28
+ decorator_info = self._position_service.get_decorator_at_position(content, position)
29
+ if decorator_info:
30
+ decorator_name, function_line = decorator_info
31
+
32
+ function_name = None
33
+ lines = content.split("\n")
34
+ if function_line < len(lines):
35
+ func_line = lines[function_line].strip()
36
+ if func_line.startswith("def "):
37
+ func_name_start = func_line.find("def ") + 4
38
+ func_name_end = func_line.find("(", func_name_start)
39
+ if func_name_end > func_name_start:
40
+ function_name = func_line[func_name_start:func_name_end].strip()
41
+ elif func_line.startswith("class "):
42
+ class_name_start = func_line.find("class ") + 6
43
+ class_name_end = func_line.find(":", class_name_start)
44
+ if class_name_end > class_name_start:
45
+ function_name = func_line[class_name_start:class_name_end].strip()
46
+
47
+ hover_content = test_message + self._format_decorator_hover(decorator_name, function_name)
48
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
49
+
50
+ word = self._position_service.get_word_at_position(content, position)
51
+ if word == "cute" or word.startswith("cute."):
52
+ hover_lines = [
53
+ test_message,
54
+ "**cutlass.cute**",
55
+ "",
56
+ "CuTeDSL (CUDA Unified Tensor Expression) library for GPU programming.",
57
+ "",
58
+ "**Key Features:**",
59
+ "- `@cute.kernel` - Define GPU kernels",
60
+ "- `@cute.struct` - Define GPU structs",
61
+ "- `cute.make_layout()` - Create tensor layouts",
62
+ "- `cute.Tensor` - Tensor type annotations",
63
+ "",
64
+ "[Documentation](https://github.com/NVIDIA/cutlass)"
65
+ ]
66
+ hover_content = "\n".join(hover_lines)
67
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
68
+
69
+ kernel = self._find_kernel_at_position(content, position, uri)
70
+ if kernel:
71
+ analysis = self._analysis_service.get_analysis_for_kernel(uri, kernel.name)
72
+
73
+ if analysis:
74
+ hover_content = test_message + self._format_kernel_hover(kernel, analysis)
75
+ else:
76
+ hover_content = test_message + self._format_kernel_hover_basic(kernel)
77
+
78
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
79
+
80
+ layout = self._find_layout_at_position(content, position, uri)
81
+ if layout:
82
+ doc_link = self._docs_service.get_doc_for_concept("layout")
83
+
84
+ hover_lines = [
85
+ test_message,
86
+ f"**Layout: {layout.name}**",
87
+ ""
88
+ ]
89
+
90
+ if layout.shape:
91
+ hover_lines.append(f"Shape: `{layout.shape}`")
92
+ if layout.stride:
93
+ hover_lines.append(f"Stride: `{layout.stride}`")
94
+
95
+ if doc_link:
96
+ hover_lines.append("")
97
+ hover_lines.append(f"[Documentation]({doc_link})")
98
+
99
+ hover_content = "\n".join(hover_lines)
100
+
101
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
102
+
103
+ hover_content = test_message + "**HOVER IS WORKING!** 🚀\n\nMove your cursor over any symbol, decorator, or even empty space to see LSP information.\n\n**Try hovering over:**\n- `@cute.kernel` decorators\n- `cute` module name\n- Kernel function names\n- Layout variables"
104
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
105
+
106
+ def _find_kernel_at_position(
107
+ self, content: str, position: Position, uri: str
108
+ ) -> KernelInfo | None:
109
+ language_info = self._language_registry.parse_file(uri, content)
110
+
111
+ if not language_info:
112
+ return None
113
+
114
+ word = self._position_service.get_word_at_position(content, position)
115
+
116
+ for kernel in language_info.kernels:
117
+ if kernel.name == word:
118
+ if position.line >= kernel.line:
119
+ return kernel
120
+
121
+ return None
122
+
123
+ def _find_layout_at_position(
124
+ self, content: str, position: Position, uri: str
125
+ ) -> LayoutInfo | None:
126
+ language_info = self._language_registry.parse_file(uri, content)
127
+
128
+ if not language_info:
129
+ return None
130
+
131
+ word = self._position_service.get_word_at_position(content, position)
132
+
133
+ for layout in language_info.layouts:
134
+ if layout.name == word:
135
+ return layout
136
+
137
+ return None
138
+
139
+ def _format_kernel_hover(self, kernel: KernelInfo, analysis: dict | None) -> str:
140
+ language_name = self._language_registry.get_language_name(kernel.language) or kernel.language
141
+
142
+ if kernel.language == "cuda" or kernel.language == "cpp":
143
+ lines = [f"**CUDA Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
144
+ else:
145
+ lines = [f"**GPU Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
146
+
147
+ if kernel.docstring:
148
+ lines.append(kernel.docstring)
149
+ lines.append("")
150
+
151
+ if kernel.parameters:
152
+ params_str = ", ".join(kernel.parameters)
153
+ lines.append(f"**Parameters:** `{params_str}`")
154
+ lines.append("")
155
+
156
+ if kernel.language == "cuda" or kernel.language == "cpp":
157
+ lines.append("**CUDA Features:**")
158
+ lines.append("- `__global__` function executed on GPU")
159
+ lines.append("- Can be launched with `<<<grid, block>>>` syntax")
160
+ lines.append("")
161
+
162
+ if analysis:
163
+ lines.append("**Analysis:**")
164
+ if "layouts" in analysis:
165
+ lines.append(f"- Layouts: {analysis['layouts']}")
166
+ if "memory_paths" in analysis:
167
+ lines.append(f"- Memory paths: {analysis['memory_paths']}")
168
+ if "pipeline_stages" in analysis:
169
+ lines.append(f"- Pipeline stages: {analysis['pipeline_stages']}")
170
+
171
+ return "\n".join(lines)
172
+
173
+ def _format_kernel_hover_basic(self, kernel: KernelInfo) -> str:
174
+ return self._format_kernel_hover(kernel, None)
175
+
176
+ def _format_decorator_hover(self, decorator_name: str, function_name: str | None = None) -> str:
177
+ lines = []
178
+
179
+ if decorator_name == "cute.kernel" or decorator_name == "kernel":
180
+ lines.append("**@cute.kernel**")
181
+ lines.append("")
182
+ lines.append("CuTeDSL kernel decorator. Marks a function as a GPU kernel.")
183
+ lines.append("")
184
+ lines.append("**Usage:**")
185
+ lines.append("```python")
186
+ lines.append("@cute.kernel")
187
+ lines.append("def my_kernel(a: cute.Tensor, b: cute.Tensor):")
188
+ lines.append(" # Kernel implementation")
189
+ lines.append(" pass")
190
+ lines.append("```")
191
+ lines.append("")
192
+ lines.append("**Features:**")
193
+ lines.append("- Automatic GPU code generation")
194
+ lines.append("- Tensor layout optimization")
195
+ lines.append("- Memory access pattern analysis")
196
+
197
+ if function_name:
198
+ lines.append("")
199
+ lines.append(f"Applied to: `{function_name}()`")
200
+
201
+ elif decorator_name == "cute.struct" or decorator_name == "struct":
202
+ lines.append("**@cute.struct**")
203
+ lines.append("")
204
+ lines.append("CuTeDSL struct decorator. Marks a class as a GPU struct.")
205
+ lines.append("")
206
+ lines.append("**Usage:**")
207
+ lines.append("```python")
208
+ lines.append("@cute.struct")
209
+ lines.append("class MyStruct:")
210
+ lines.append(" field1: int")
211
+ lines.append(" field2: float")
212
+ lines.append("```")
213
+
214
+ if function_name:
215
+ lines.append("")
216
+ lines.append(f"Applied to: `{function_name}`")
217
+
218
+ else:
219
+ lines.append(f"**{decorator_name}**")
220
+ lines.append("")
221
+ lines.append("CuTeDSL decorator")
222
+
223
+ doc_link = self._docs_service.get_doc_for_concept("kernel" if "kernel" in decorator_name else "struct")
224
+ if doc_link:
225
+ lines.append("")
226
+ lines.append(f"[Documentation]({doc_link})")
227
+
228
+ return "\n".join(lines)
229
+
230
+
231
+ def create_hover_service(
232
+ language_registry: LanguageRegistryService,
233
+ analysis_service: AnalysisService,
234
+ docs_service: DocsService,
235
+ position_service: PositionService
236
+ ) -> HoverService:
237
+ return HoverService(language_registry, analysis_service, docs_service, position_service)
@@ -0,0 +1,26 @@
1
+
2
+ from ..languages.registry import LanguageRegistry, get_language_registry
3
+ from ..languages.types import LanguageInfo
4
+
5
+
6
+ class LanguageRegistryService:
7
+
8
+ def __init__(self, registry: LanguageRegistry):
9
+ self._registry = registry
10
+
11
+ def detect_language(self, uri: str) -> str | None:
12
+ return self._registry.detect_language(uri)
13
+
14
+ def parse_file(self, uri: str, content: str) -> LanguageInfo | None:
15
+ return self._registry.parse_file(uri, content)
16
+
17
+ def get_language_name(self, language_id: str) -> str | None:
18
+ return self._registry.get_language_name(language_id)
19
+
20
+ def get_supported_extensions(self) -> list[str]:
21
+ return self._registry.get_supported_extensions()
22
+
23
+
24
+ def create_language_registry_service() -> LanguageRegistryService:
25
+ registry = get_language_registry()
26
+ return LanguageRegistryService(registry)