wafer-lsp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/__init__.py +1 -0
- wafer_lsp/__main__.py +9 -0
- wafer_lsp/analyzers/__init__.py +0 -0
- wafer_lsp/analyzers/compiler_integration.py +16 -0
- wafer_lsp/analyzers/docs_index.py +36 -0
- wafer_lsp/handlers/__init__.py +0 -0
- wafer_lsp/handlers/code_action.py +48 -0
- wafer_lsp/handlers/code_lens.py +48 -0
- wafer_lsp/handlers/completion.py +6 -0
- wafer_lsp/handlers/diagnostics.py +16 -0
- wafer_lsp/handlers/document_symbol.py +87 -0
- wafer_lsp/handlers/hover.py +215 -0
- wafer_lsp/handlers/inlay_hint.py +65 -0
- wafer_lsp/handlers/semantic_tokens.py +124 -0
- wafer_lsp/handlers/workspace_symbol.py +87 -0
- wafer_lsp/languages/README.md +195 -0
- wafer_lsp/languages/__init__.py +17 -0
- wafer_lsp/languages/converter.py +88 -0
- wafer_lsp/languages/detector.py +34 -0
- wafer_lsp/languages/parser_manager.py +33 -0
- wafer_lsp/languages/registry.py +99 -0
- wafer_lsp/languages/types.py +37 -0
- wafer_lsp/parsers/__init__.py +18 -0
- wafer_lsp/parsers/base_parser.py +9 -0
- wafer_lsp/parsers/cuda_parser.py +95 -0
- wafer_lsp/parsers/cutedsl_parser.py +114 -0
- wafer_lsp/server.py +58 -0
- wafer_lsp/services/__init__.py +21 -0
- wafer_lsp/services/analysis_service.py +22 -0
- wafer_lsp/services/docs_service.py +40 -0
- wafer_lsp/services/document_service.py +20 -0
- wafer_lsp/services/hover_service.py +237 -0
- wafer_lsp/services/language_registry_service.py +26 -0
- wafer_lsp/services/position_service.py +77 -0
- wafer_lsp/utils/__init__.py +0 -0
- wafer_lsp/utils/lsp_helpers.py +79 -0
- wafer_lsp-0.1.0.dist-info/METADATA +57 -0
- wafer_lsp-0.1.0.dist-info/RECORD +40 -0
- wafer_lsp-0.1.0.dist-info/WHEEL +4 -0
- wafer_lsp-0.1.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .base_parser import BaseParser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class CUDAKernel:
|
|
10
|
+
name: str
|
|
11
|
+
line: int
|
|
12
|
+
parameters: list[str]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CUDAParser(BaseParser):
|
|
16
|
+
|
|
17
|
+
def parse_file(self, content: str) -> dict[str, Any]:
|
|
18
|
+
kernels: list[CUDAKernel] = []
|
|
19
|
+
|
|
20
|
+
pattern = r'__global__\s+(?:__device__\s+)?(?:void|.*?)\s+(\w+)\s*\('
|
|
21
|
+
|
|
22
|
+
for match in re.finditer(pattern, content):
|
|
23
|
+
line = content[:match.start()].count('\n')
|
|
24
|
+
kernel_name = match.group(1)
|
|
25
|
+
|
|
26
|
+
params = self._extract_parameters(content, match.end())
|
|
27
|
+
|
|
28
|
+
kernels.append(CUDAKernel(
|
|
29
|
+
name=kernel_name,
|
|
30
|
+
line=line,
|
|
31
|
+
parameters=params
|
|
32
|
+
))
|
|
33
|
+
|
|
34
|
+
return {"kernels": kernels}
|
|
35
|
+
|
|
36
|
+
def _extract_parameters(self, content: str, start: int) -> list[str]:
|
|
37
|
+
if start >= len(content):
|
|
38
|
+
return []
|
|
39
|
+
|
|
40
|
+
depth = 0
|
|
41
|
+
param_start = start
|
|
42
|
+
param_end = start
|
|
43
|
+
|
|
44
|
+
for i in range(start, len(content)):
|
|
45
|
+
char = content[i]
|
|
46
|
+
if char == '(':
|
|
47
|
+
depth += 1
|
|
48
|
+
elif char == ')':
|
|
49
|
+
depth -= 1
|
|
50
|
+
if depth == 0:
|
|
51
|
+
param_end = i
|
|
52
|
+
break
|
|
53
|
+
|
|
54
|
+
if param_end == start:
|
|
55
|
+
return []
|
|
56
|
+
|
|
57
|
+
param_str = content[param_start:param_end + 1]
|
|
58
|
+
|
|
59
|
+
params: list[str] = []
|
|
60
|
+
current_param = ""
|
|
61
|
+
template_depth = 0
|
|
62
|
+
paren_depth = 0
|
|
63
|
+
|
|
64
|
+
for char in param_str[1:-1]:
|
|
65
|
+
if char == '<':
|
|
66
|
+
template_depth += 1
|
|
67
|
+
current_param += char
|
|
68
|
+
elif char == '>':
|
|
69
|
+
template_depth -= 1
|
|
70
|
+
current_param += char
|
|
71
|
+
elif char == '(':
|
|
72
|
+
paren_depth += 1
|
|
73
|
+
current_param += char
|
|
74
|
+
elif char == ')':
|
|
75
|
+
paren_depth -= 1
|
|
76
|
+
current_param += char
|
|
77
|
+
elif char == ',' and template_depth == 0 and paren_depth == 0:
|
|
78
|
+
param_name = current_param.strip()
|
|
79
|
+
if param_name:
|
|
80
|
+
parts = param_name.split()
|
|
81
|
+
if parts:
|
|
82
|
+
name = parts[-1].strip('*&')
|
|
83
|
+
params.append(name)
|
|
84
|
+
current_param = ""
|
|
85
|
+
else:
|
|
86
|
+
current_param += char
|
|
87
|
+
|
|
88
|
+
if current_param.strip():
|
|
89
|
+
param_name = current_param.strip()
|
|
90
|
+
parts = param_name.split()
|
|
91
|
+
if parts:
|
|
92
|
+
name = parts[-1].strip('*&')
|
|
93
|
+
params.append(name)
|
|
94
|
+
|
|
95
|
+
return params
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .base_parser import BaseParser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class CuTeDSLKernel:
|
|
10
|
+
name: str
|
|
11
|
+
line: int
|
|
12
|
+
parameters: list[str]
|
|
13
|
+
docstring: str | None = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class CuTeDSLLayout:
|
|
18
|
+
name: str
|
|
19
|
+
line: int
|
|
20
|
+
shape: str | None = None
|
|
21
|
+
stride: str | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class CuTeDSLStruct:
|
|
26
|
+
name: str
|
|
27
|
+
line: int
|
|
28
|
+
docstring: str | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CuTeDSLParser(BaseParser):
|
|
32
|
+
|
|
33
|
+
def parse_file(self, content: str) -> dict[str, Any]:
|
|
34
|
+
try:
|
|
35
|
+
tree = ast.parse(content)
|
|
36
|
+
except SyntaxError:
|
|
37
|
+
return {"kernels": [], "layouts": [], "structs": []}
|
|
38
|
+
|
|
39
|
+
kernels: list[CuTeDSLKernel] = []
|
|
40
|
+
layouts: list[CuTeDSLLayout] = []
|
|
41
|
+
structs: list[CuTeDSLStruct] = []
|
|
42
|
+
|
|
43
|
+
for node in ast.walk(tree):
|
|
44
|
+
if isinstance(node, ast.FunctionDef):
|
|
45
|
+
if self._has_decorator(node, "cute.kernel"):
|
|
46
|
+
kernels.append(self._extract_kernel(node))
|
|
47
|
+
elif isinstance(node, ast.ClassDef):
|
|
48
|
+
if self._has_decorator(node, "cute.struct"):
|
|
49
|
+
structs.append(self._extract_struct(node))
|
|
50
|
+
elif isinstance(node, ast.Assign):
|
|
51
|
+
layout = self._extract_layout(node, content)
|
|
52
|
+
if layout:
|
|
53
|
+
layouts.append(layout)
|
|
54
|
+
|
|
55
|
+
return {"kernels": kernels, "layouts": layouts, "structs": structs}
|
|
56
|
+
|
|
57
|
+
def _has_decorator(self, node: ast.FunctionDef | ast.ClassDef, decorator: str) -> bool:
|
|
58
|
+
for dec in node.decorator_list:
|
|
59
|
+
if isinstance(dec, ast.Attribute):
|
|
60
|
+
if isinstance(dec.value, ast.Name) and dec.value.id == "cute":
|
|
61
|
+
if dec.attr == decorator.split(".")[-1]:
|
|
62
|
+
return True
|
|
63
|
+
elif isinstance(dec, ast.Name):
|
|
64
|
+
if dec.id == decorator.split(".")[-1]:
|
|
65
|
+
return True
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
def _extract_kernel(self, node: ast.FunctionDef) -> CuTeDSLKernel:
|
|
69
|
+
parameters = [arg.arg for arg in node.args.args]
|
|
70
|
+
docstring = ast.get_docstring(node)
|
|
71
|
+
|
|
72
|
+
return CuTeDSLKernel(
|
|
73
|
+
name=node.name,
|
|
74
|
+
line=node.lineno - 1,
|
|
75
|
+
parameters=parameters,
|
|
76
|
+
docstring=docstring,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _extract_struct(self, node: ast.ClassDef) -> CuTeDSLStruct:
|
|
80
|
+
docstring = ast.get_docstring(node)
|
|
81
|
+
|
|
82
|
+
return CuTeDSLStruct(
|
|
83
|
+
name=node.name,
|
|
84
|
+
line=node.lineno - 1,
|
|
85
|
+
docstring=docstring,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _extract_layout(self, node: ast.Assign, content: str) -> CuTeDSLLayout | None:
|
|
89
|
+
for target in node.targets:
|
|
90
|
+
if not isinstance(target, ast.Name):
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
if isinstance(node.value, ast.Call):
|
|
94
|
+
call = node.value
|
|
95
|
+
if isinstance(call.func, ast.Attribute):
|
|
96
|
+
if isinstance(call.func.value, ast.Name):
|
|
97
|
+
if call.func.value.id == "cute" and call.func.attr == "make_layout":
|
|
98
|
+
shape_str = None
|
|
99
|
+
stride_str = None
|
|
100
|
+
|
|
101
|
+
if call.args:
|
|
102
|
+
try:
|
|
103
|
+
shape_str = ast.unparse(call.args[0])
|
|
104
|
+
except AttributeError:
|
|
105
|
+
shape_str = str(call.args[0])
|
|
106
|
+
|
|
107
|
+
return CuTeDSLLayout(
|
|
108
|
+
name=target.id,
|
|
109
|
+
line=node.lineno - 1,
|
|
110
|
+
shape=shape_str,
|
|
111
|
+
stride=stride_str,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return None
|
wafer_lsp/server.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from lsprotocol.types import (
|
|
2
|
+
INITIALIZE,
|
|
3
|
+
TEXT_DOCUMENT_HOVER,
|
|
4
|
+
)
|
|
5
|
+
from pygls.lsp.server import LanguageServer
|
|
6
|
+
|
|
7
|
+
from .services import (
|
|
8
|
+
create_analysis_service,
|
|
9
|
+
create_docs_service,
|
|
10
|
+
create_document_service,
|
|
11
|
+
create_hover_service,
|
|
12
|
+
create_language_registry_service,
|
|
13
|
+
create_position_service,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
language_registry_service = create_language_registry_service()
|
|
17
|
+
analysis_service = create_analysis_service()
|
|
18
|
+
docs_service = create_docs_service()
|
|
19
|
+
position_service = create_position_service()
|
|
20
|
+
hover_service = create_hover_service(
|
|
21
|
+
language_registry_service,
|
|
22
|
+
analysis_service,
|
|
23
|
+
docs_service,
|
|
24
|
+
position_service
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
server = LanguageServer("wafer-lsp", "1.0.0")
|
|
28
|
+
document_service = create_document_service(server)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@server.feature(INITIALIZE)
|
|
32
|
+
def initialize(params):
|
|
33
|
+
return {
|
|
34
|
+
"capabilities": {
|
|
35
|
+
"hoverProvider": True,
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@server.feature(TEXT_DOCUMENT_HOVER)
|
|
41
|
+
def hover(params):
|
|
42
|
+
uri = params.text_document.uri
|
|
43
|
+
position = params.position
|
|
44
|
+
content = document_service.get_document_content(uri)
|
|
45
|
+
if not content:
|
|
46
|
+
test_message = "🎉🎉🎉 **HEYOOO!!! LSP IS DEFINITELY WORKING!!!** 🎉🎉🎉\n\n**THIS IS THE WAFER LSP SERVER!**\n\n**Document content not available, but LSP is running!**"
|
|
47
|
+
from lsprotocol.types import Hover, MarkupContent, MarkupKind
|
|
48
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=test_message))
|
|
49
|
+
result = hover_service.handle_hover(uri, position, content)
|
|
50
|
+
if not result:
|
|
51
|
+
test_message = "🎉🎉🎉 **HEYOOO!!! LSP IS DEFINITELY WORKING!!!** 🎉🎉🎉\n\n**THIS IS THE WAFER LSP SERVER!**\n\n**Hover service returned None, but LSP is running!**"
|
|
52
|
+
from lsprotocol.types import Hover, MarkupContent, MarkupKind
|
|
53
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=test_message))
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
if __name__ == "__main__":
|
|
58
|
+
server.start_io()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .analysis_service import AnalysisService, create_analysis_service
|
|
2
|
+
from .docs_service import DocsService, create_docs_service
|
|
3
|
+
from .document_service import DocumentService, create_document_service
|
|
4
|
+
from .hover_service import HoverService, create_hover_service
|
|
5
|
+
from .language_registry_service import LanguageRegistryService, create_language_registry_service
|
|
6
|
+
from .position_service import PositionService, create_position_service
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"AnalysisService",
|
|
10
|
+
"DocsService",
|
|
11
|
+
"DocumentService",
|
|
12
|
+
"HoverService",
|
|
13
|
+
"LanguageRegistryService",
|
|
14
|
+
"PositionService",
|
|
15
|
+
"create_analysis_service",
|
|
16
|
+
"create_docs_service",
|
|
17
|
+
"create_document_service",
|
|
18
|
+
"create_hover_service",
|
|
19
|
+
"create_language_registry_service",
|
|
20
|
+
"create_position_service",
|
|
21
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AnalysisService:
|
|
5
|
+
|
|
6
|
+
def __init__(self):
|
|
7
|
+
self._cache: dict[str, dict[str, Any]] = {}
|
|
8
|
+
|
|
9
|
+
def get_analysis_for_kernel(self, uri: str, kernel_name: str) -> dict[str, Any] | None:
|
|
10
|
+
cache_key = f"{uri}:{kernel_name}"
|
|
11
|
+
return self._cache.get(cache_key)
|
|
12
|
+
|
|
13
|
+
def set_analysis_for_kernel(self, uri: str, kernel_name: str, analysis: dict[str, Any]):
|
|
14
|
+
cache_key = f"{uri}:{kernel_name}"
|
|
15
|
+
self._cache[cache_key] = analysis
|
|
16
|
+
|
|
17
|
+
def clear_cache(self):
|
|
18
|
+
self._cache.clear()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_analysis_service() -> AnalysisService:
|
|
22
|
+
return AnalysisService()
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DocsService:
|
|
5
|
+
|
|
6
|
+
def __init__(self, docs_path: str | None = None):
|
|
7
|
+
if docs_path:
|
|
8
|
+
self.docs_path = Path(docs_path)
|
|
9
|
+
else:
|
|
10
|
+
self.docs_path = Path(__file__).parent.parent.parent.parent.parent / \
|
|
11
|
+
"curriculum" / "cutlass-docs" / "cutedsl-docs"
|
|
12
|
+
|
|
13
|
+
self.index = self._build_index()
|
|
14
|
+
|
|
15
|
+
def _build_index(self):
|
|
16
|
+
return {
|
|
17
|
+
"layout": [
|
|
18
|
+
"intro-to-cutedsl.md",
|
|
19
|
+
"partitioning-strategies-inner-outer-threadvalue.md"
|
|
20
|
+
],
|
|
21
|
+
"TMA": ["blackwell-tutorial-fp16-gemm-0.md"],
|
|
22
|
+
"TMEM": ["colfax-blackwell-umma-tensor-memory-part1.md"],
|
|
23
|
+
"kernel": ["blackwell-tutorial-fp16-gemm-0.md"],
|
|
24
|
+
"struct": ["blackwell-tutorial-fp16-gemm-0.md"],
|
|
25
|
+
"pipeline": ["blackwell-tutorial-fp16-gemm-0.md"],
|
|
26
|
+
"MMA": ["mma-atoms-fundamentals-sm70-example.md"],
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def get_doc_for_concept(self, concept: str) -> str | None:
|
|
30
|
+
concept_lower = concept.lower()
|
|
31
|
+
if self.index.get(concept_lower):
|
|
32
|
+
doc_file = self.index[concept_lower][0]
|
|
33
|
+
doc_path = self.docs_path / doc_file
|
|
34
|
+
if doc_path.exists():
|
|
35
|
+
return str(doc_path)
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def create_docs_service(docs_path: str | None = None) -> DocsService:
|
|
40
|
+
return DocsService(docs_path)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from pygls.lsp.server import LanguageServer
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DocumentService:
|
|
5
|
+
|
|
6
|
+
def __init__(self, server: LanguageServer):
|
|
7
|
+
self._server = server
|
|
8
|
+
|
|
9
|
+
def get_document(self, uri: str):
|
|
10
|
+
return self._server.workspace.text_documents.get(uri)
|
|
11
|
+
|
|
12
|
+
def get_document_content(self, uri: str) -> str:
|
|
13
|
+
doc = self.get_document(uri)
|
|
14
|
+
if doc is None:
|
|
15
|
+
return ''
|
|
16
|
+
return getattr(doc, 'text', getattr(doc, 'source', ''))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_document_service(server: LanguageServer) -> DocumentService:
|
|
20
|
+
return DocumentService(server)
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import Hover, MarkupContent, MarkupKind, Position
|
|
3
|
+
|
|
4
|
+
from ..languages.types import KernelInfo, LayoutInfo
|
|
5
|
+
from .analysis_service import AnalysisService
|
|
6
|
+
from .docs_service import DocsService
|
|
7
|
+
from .language_registry_service import LanguageRegistryService
|
|
8
|
+
from .position_service import PositionService
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HoverService:
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
language_registry: LanguageRegistryService,
|
|
16
|
+
analysis_service: AnalysisService,
|
|
17
|
+
docs_service: DocsService,
|
|
18
|
+
position_service: PositionService
|
|
19
|
+
):
|
|
20
|
+
self._language_registry = language_registry
|
|
21
|
+
self._analysis_service = analysis_service
|
|
22
|
+
self._docs_service = docs_service
|
|
23
|
+
self._position_service = position_service
|
|
24
|
+
|
|
25
|
+
def handle_hover(self, uri: str, position: Position, content: str) -> Hover | None:
|
|
26
|
+
test_message = "🎉🎉🎉 **HEYOOO!!! LSP IS DEFINITELY WORKING!!!** 🎉🎉🎉\n\n**THIS IS THE WAFER LSP SERVER!**\n\n"
|
|
27
|
+
|
|
28
|
+
decorator_info = self._position_service.get_decorator_at_position(content, position)
|
|
29
|
+
if decorator_info:
|
|
30
|
+
decorator_name, function_line = decorator_info
|
|
31
|
+
|
|
32
|
+
function_name = None
|
|
33
|
+
lines = content.split("\n")
|
|
34
|
+
if function_line < len(lines):
|
|
35
|
+
func_line = lines[function_line].strip()
|
|
36
|
+
if func_line.startswith("def "):
|
|
37
|
+
func_name_start = func_line.find("def ") + 4
|
|
38
|
+
func_name_end = func_line.find("(", func_name_start)
|
|
39
|
+
if func_name_end > func_name_start:
|
|
40
|
+
function_name = func_line[func_name_start:func_name_end].strip()
|
|
41
|
+
elif func_line.startswith("class "):
|
|
42
|
+
class_name_start = func_line.find("class ") + 6
|
|
43
|
+
class_name_end = func_line.find(":", class_name_start)
|
|
44
|
+
if class_name_end > class_name_start:
|
|
45
|
+
function_name = func_line[class_name_start:class_name_end].strip()
|
|
46
|
+
|
|
47
|
+
hover_content = test_message + self._format_decorator_hover(decorator_name, function_name)
|
|
48
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
49
|
+
|
|
50
|
+
word = self._position_service.get_word_at_position(content, position)
|
|
51
|
+
if word == "cute" or word.startswith("cute."):
|
|
52
|
+
hover_lines = [
|
|
53
|
+
test_message,
|
|
54
|
+
"**cutlass.cute**",
|
|
55
|
+
"",
|
|
56
|
+
"CuTeDSL (CUDA Unified Tensor Expression) library for GPU programming.",
|
|
57
|
+
"",
|
|
58
|
+
"**Key Features:**",
|
|
59
|
+
"- `@cute.kernel` - Define GPU kernels",
|
|
60
|
+
"- `@cute.struct` - Define GPU structs",
|
|
61
|
+
"- `cute.make_layout()` - Create tensor layouts",
|
|
62
|
+
"- `cute.Tensor` - Tensor type annotations",
|
|
63
|
+
"",
|
|
64
|
+
"[Documentation](https://github.com/NVIDIA/cutlass)"
|
|
65
|
+
]
|
|
66
|
+
hover_content = "\n".join(hover_lines)
|
|
67
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
68
|
+
|
|
69
|
+
kernel = self._find_kernel_at_position(content, position, uri)
|
|
70
|
+
if kernel:
|
|
71
|
+
analysis = self._analysis_service.get_analysis_for_kernel(uri, kernel.name)
|
|
72
|
+
|
|
73
|
+
if analysis:
|
|
74
|
+
hover_content = test_message + self._format_kernel_hover(kernel, analysis)
|
|
75
|
+
else:
|
|
76
|
+
hover_content = test_message + self._format_kernel_hover_basic(kernel)
|
|
77
|
+
|
|
78
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
79
|
+
|
|
80
|
+
layout = self._find_layout_at_position(content, position, uri)
|
|
81
|
+
if layout:
|
|
82
|
+
doc_link = self._docs_service.get_doc_for_concept("layout")
|
|
83
|
+
|
|
84
|
+
hover_lines = [
|
|
85
|
+
test_message,
|
|
86
|
+
f"**Layout: {layout.name}**",
|
|
87
|
+
""
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
if layout.shape:
|
|
91
|
+
hover_lines.append(f"Shape: `{layout.shape}`")
|
|
92
|
+
if layout.stride:
|
|
93
|
+
hover_lines.append(f"Stride: `{layout.stride}`")
|
|
94
|
+
|
|
95
|
+
if doc_link:
|
|
96
|
+
hover_lines.append("")
|
|
97
|
+
hover_lines.append(f"[Documentation]({doc_link})")
|
|
98
|
+
|
|
99
|
+
hover_content = "\n".join(hover_lines)
|
|
100
|
+
|
|
101
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
102
|
+
|
|
103
|
+
hover_content = test_message + "**HOVER IS WORKING!** 🚀\n\nMove your cursor over any symbol, decorator, or even empty space to see LSP information.\n\n**Try hovering over:**\n- `@cute.kernel` decorators\n- `cute` module name\n- Kernel function names\n- Layout variables"
|
|
104
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
105
|
+
|
|
106
|
+
def _find_kernel_at_position(
|
|
107
|
+
self, content: str, position: Position, uri: str
|
|
108
|
+
) -> KernelInfo | None:
|
|
109
|
+
language_info = self._language_registry.parse_file(uri, content)
|
|
110
|
+
|
|
111
|
+
if not language_info:
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
word = self._position_service.get_word_at_position(content, position)
|
|
115
|
+
|
|
116
|
+
for kernel in language_info.kernels:
|
|
117
|
+
if kernel.name == word:
|
|
118
|
+
if position.line >= kernel.line:
|
|
119
|
+
return kernel
|
|
120
|
+
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
def _find_layout_at_position(
|
|
124
|
+
self, content: str, position: Position, uri: str
|
|
125
|
+
) -> LayoutInfo | None:
|
|
126
|
+
language_info = self._language_registry.parse_file(uri, content)
|
|
127
|
+
|
|
128
|
+
if not language_info:
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
word = self._position_service.get_word_at_position(content, position)
|
|
132
|
+
|
|
133
|
+
for layout in language_info.layouts:
|
|
134
|
+
if layout.name == word:
|
|
135
|
+
return layout
|
|
136
|
+
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
def _format_kernel_hover(self, kernel: KernelInfo, analysis: dict | None) -> str:
|
|
140
|
+
language_name = self._language_registry.get_language_name(kernel.language) or kernel.language
|
|
141
|
+
|
|
142
|
+
if kernel.language == "cuda" or kernel.language == "cpp":
|
|
143
|
+
lines = [f"**CUDA Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
|
|
144
|
+
else:
|
|
145
|
+
lines = [f"**GPU Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
|
|
146
|
+
|
|
147
|
+
if kernel.docstring:
|
|
148
|
+
lines.append(kernel.docstring)
|
|
149
|
+
lines.append("")
|
|
150
|
+
|
|
151
|
+
if kernel.parameters:
|
|
152
|
+
params_str = ", ".join(kernel.parameters)
|
|
153
|
+
lines.append(f"**Parameters:** `{params_str}`")
|
|
154
|
+
lines.append("")
|
|
155
|
+
|
|
156
|
+
if kernel.language == "cuda" or kernel.language == "cpp":
|
|
157
|
+
lines.append("**CUDA Features:**")
|
|
158
|
+
lines.append("- `__global__` function executed on GPU")
|
|
159
|
+
lines.append("- Can be launched with `<<<grid, block>>>` syntax")
|
|
160
|
+
lines.append("")
|
|
161
|
+
|
|
162
|
+
if analysis:
|
|
163
|
+
lines.append("**Analysis:**")
|
|
164
|
+
if "layouts" in analysis:
|
|
165
|
+
lines.append(f"- Layouts: {analysis['layouts']}")
|
|
166
|
+
if "memory_paths" in analysis:
|
|
167
|
+
lines.append(f"- Memory paths: {analysis['memory_paths']}")
|
|
168
|
+
if "pipeline_stages" in analysis:
|
|
169
|
+
lines.append(f"- Pipeline stages: {analysis['pipeline_stages']}")
|
|
170
|
+
|
|
171
|
+
return "\n".join(lines)
|
|
172
|
+
|
|
173
|
+
def _format_kernel_hover_basic(self, kernel: KernelInfo) -> str:
|
|
174
|
+
return self._format_kernel_hover(kernel, None)
|
|
175
|
+
|
|
176
|
+
def _format_decorator_hover(self, decorator_name: str, function_name: str | None = None) -> str:
|
|
177
|
+
lines = []
|
|
178
|
+
|
|
179
|
+
if decorator_name == "cute.kernel" or decorator_name == "kernel":
|
|
180
|
+
lines.append("**@cute.kernel**")
|
|
181
|
+
lines.append("")
|
|
182
|
+
lines.append("CuTeDSL kernel decorator. Marks a function as a GPU kernel.")
|
|
183
|
+
lines.append("")
|
|
184
|
+
lines.append("**Usage:**")
|
|
185
|
+
lines.append("```python")
|
|
186
|
+
lines.append("@cute.kernel")
|
|
187
|
+
lines.append("def my_kernel(a: cute.Tensor, b: cute.Tensor):")
|
|
188
|
+
lines.append(" # Kernel implementation")
|
|
189
|
+
lines.append(" pass")
|
|
190
|
+
lines.append("```")
|
|
191
|
+
lines.append("")
|
|
192
|
+
lines.append("**Features:**")
|
|
193
|
+
lines.append("- Automatic GPU code generation")
|
|
194
|
+
lines.append("- Tensor layout optimization")
|
|
195
|
+
lines.append("- Memory access pattern analysis")
|
|
196
|
+
|
|
197
|
+
if function_name:
|
|
198
|
+
lines.append("")
|
|
199
|
+
lines.append(f"Applied to: `{function_name}()`")
|
|
200
|
+
|
|
201
|
+
elif decorator_name == "cute.struct" or decorator_name == "struct":
|
|
202
|
+
lines.append("**@cute.struct**")
|
|
203
|
+
lines.append("")
|
|
204
|
+
lines.append("CuTeDSL struct decorator. Marks a class as a GPU struct.")
|
|
205
|
+
lines.append("")
|
|
206
|
+
lines.append("**Usage:**")
|
|
207
|
+
lines.append("```python")
|
|
208
|
+
lines.append("@cute.struct")
|
|
209
|
+
lines.append("class MyStruct:")
|
|
210
|
+
lines.append(" field1: int")
|
|
211
|
+
lines.append(" field2: float")
|
|
212
|
+
lines.append("```")
|
|
213
|
+
|
|
214
|
+
if function_name:
|
|
215
|
+
lines.append("")
|
|
216
|
+
lines.append(f"Applied to: `{function_name}`")
|
|
217
|
+
|
|
218
|
+
else:
|
|
219
|
+
lines.append(f"**{decorator_name}**")
|
|
220
|
+
lines.append("")
|
|
221
|
+
lines.append("CuTeDSL decorator")
|
|
222
|
+
|
|
223
|
+
doc_link = self._docs_service.get_doc_for_concept("kernel" if "kernel" in decorator_name else "struct")
|
|
224
|
+
if doc_link:
|
|
225
|
+
lines.append("")
|
|
226
|
+
lines.append(f"[Documentation]({doc_link})")
|
|
227
|
+
|
|
228
|
+
return "\n".join(lines)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def create_hover_service(
|
|
232
|
+
language_registry: LanguageRegistryService,
|
|
233
|
+
analysis_service: AnalysisService,
|
|
234
|
+
docs_service: DocsService,
|
|
235
|
+
position_service: PositionService
|
|
236
|
+
) -> HoverService:
|
|
237
|
+
return HoverService(language_registry, analysis_service, docs_service, position_service)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
|
|
2
|
+
from ..languages.registry import LanguageRegistry, get_language_registry
|
|
3
|
+
from ..languages.types import LanguageInfo
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LanguageRegistryService:
|
|
7
|
+
|
|
8
|
+
def __init__(self, registry: LanguageRegistry):
|
|
9
|
+
self._registry = registry
|
|
10
|
+
|
|
11
|
+
def detect_language(self, uri: str) -> str | None:
|
|
12
|
+
return self._registry.detect_language(uri)
|
|
13
|
+
|
|
14
|
+
def parse_file(self, uri: str, content: str) -> LanguageInfo | None:
|
|
15
|
+
return self._registry.parse_file(uri, content)
|
|
16
|
+
|
|
17
|
+
def get_language_name(self, language_id: str) -> str | None:
|
|
18
|
+
return self._registry.get_language_name(language_id)
|
|
19
|
+
|
|
20
|
+
def get_supported_extensions(self) -> list[str]:
|
|
21
|
+
return self._registry.get_supported_extensions()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_language_registry_service() -> LanguageRegistryService:
|
|
25
|
+
registry = get_language_registry()
|
|
26
|
+
return LanguageRegistryService(registry)
|