wafer-lsp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/__init__.py +1 -0
- wafer_lsp/__main__.py +9 -0
- wafer_lsp/analyzers/__init__.py +0 -0
- wafer_lsp/analyzers/compiler_integration.py +16 -0
- wafer_lsp/analyzers/docs_index.py +36 -0
- wafer_lsp/handlers/__init__.py +0 -0
- wafer_lsp/handlers/code_action.py +48 -0
- wafer_lsp/handlers/code_lens.py +48 -0
- wafer_lsp/handlers/completion.py +6 -0
- wafer_lsp/handlers/diagnostics.py +16 -0
- wafer_lsp/handlers/document_symbol.py +87 -0
- wafer_lsp/handlers/hover.py +215 -0
- wafer_lsp/handlers/inlay_hint.py +65 -0
- wafer_lsp/handlers/semantic_tokens.py +124 -0
- wafer_lsp/handlers/workspace_symbol.py +87 -0
- wafer_lsp/languages/README.md +195 -0
- wafer_lsp/languages/__init__.py +17 -0
- wafer_lsp/languages/converter.py +88 -0
- wafer_lsp/languages/detector.py +34 -0
- wafer_lsp/languages/parser_manager.py +33 -0
- wafer_lsp/languages/registry.py +99 -0
- wafer_lsp/languages/types.py +37 -0
- wafer_lsp/parsers/__init__.py +18 -0
- wafer_lsp/parsers/base_parser.py +9 -0
- wafer_lsp/parsers/cuda_parser.py +95 -0
- wafer_lsp/parsers/cutedsl_parser.py +114 -0
- wafer_lsp/server.py +58 -0
- wafer_lsp/services/__init__.py +21 -0
- wafer_lsp/services/analysis_service.py +22 -0
- wafer_lsp/services/docs_service.py +40 -0
- wafer_lsp/services/document_service.py +20 -0
- wafer_lsp/services/hover_service.py +237 -0
- wafer_lsp/services/language_registry_service.py +26 -0
- wafer_lsp/services/position_service.py +77 -0
- wafer_lsp/utils/__init__.py +0 -0
- wafer_lsp/utils/lsp_helpers.py +79 -0
- wafer_lsp-0.1.0.dist-info/METADATA +57 -0
- wafer_lsp-0.1.0.dist-info/RECORD +40 -0
- wafer_lsp-0.1.0.dist-info/WHEEL +4 -0
- wafer_lsp-0.1.0.dist-info/entry_points.txt +2 -0
wafer_lsp/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
wafer_lsp/__main__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
_analysis_cache: dict[str, dict[str, Any]] = {}
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_analysis_for_kernel(uri: str, kernel_name: str) -> dict[str, Any] | None:
|
|
7
|
+
cache_key = f"{uri}:{kernel_name}"
|
|
8
|
+
|
|
9
|
+
if cache_key in _analysis_cache:
|
|
10
|
+
return _analysis_cache[cache_key]
|
|
11
|
+
|
|
12
|
+
return None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def clear_cache():
|
|
16
|
+
_analysis_cache.clear()
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DocsIndex:
|
|
5
|
+
|
|
6
|
+
def __init__(self, docs_path: str | None = None):
|
|
7
|
+
if docs_path:
|
|
8
|
+
self.docs_path = Path(docs_path)
|
|
9
|
+
else:
|
|
10
|
+
self.docs_path = Path(__file__).parent.parent.parent.parent.parent / \
|
|
11
|
+
"curriculum" / "cutlass-docs" / "cutedsl-docs"
|
|
12
|
+
|
|
13
|
+
self.index = self._build_index()
|
|
14
|
+
|
|
15
|
+
def _build_index(self) -> dict[str, list[str]]:
|
|
16
|
+
return {
|
|
17
|
+
"layout": [
|
|
18
|
+
"intro-to-cutedsl.md",
|
|
19
|
+
"partitioning-strategies-inner-outer-threadvalue.md"
|
|
20
|
+
],
|
|
21
|
+
"TMA": ["blackwell-tutorial-fp16-gemm-0.md"],
|
|
22
|
+
"TMEM": ["colfax-blackwell-umma-tensor-memory-part1.md"],
|
|
23
|
+
"kernel": ["blackwell-tutorial-fp16-gemm-0.md"],
|
|
24
|
+
"struct": ["blackwell-tutorial-fp16-gemm-0.md"],
|
|
25
|
+
"pipeline": ["blackwell-tutorial-fp16-gemm-0.md"],
|
|
26
|
+
"MMA": ["mma-atoms-fundamentals-sm70-example.md"],
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def get_doc_for_concept(self, concept: str) -> str | None:
|
|
30
|
+
concept_lower = concept.lower()
|
|
31
|
+
if self.index.get(concept_lower):
|
|
32
|
+
doc_file = self.index[concept_lower][0]
|
|
33
|
+
doc_path = self.docs_path / doc_file
|
|
34
|
+
if doc_path.exists():
|
|
35
|
+
return str(doc_path)
|
|
36
|
+
return None
|
|
File without changes
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import CodeAction, CodeActionKind, Command, Range
|
|
3
|
+
|
|
4
|
+
from ..languages.registry import get_language_registry
|
|
5
|
+
from ..languages.types import KernelInfo
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def find_kernel_at_range(content: str, range: Range, uri: str) -> KernelInfo | None:
|
|
9
|
+
registry = get_language_registry()
|
|
10
|
+
language_info = registry.parse_file(uri, content)
|
|
11
|
+
|
|
12
|
+
if not language_info:
|
|
13
|
+
return None
|
|
14
|
+
|
|
15
|
+
for kernel in language_info.kernels:
|
|
16
|
+
if kernel.line <= range.start.line <= kernel.line + 50:
|
|
17
|
+
return kernel
|
|
18
|
+
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def handle_code_action(uri: str, range: Range, content: str) -> list[CodeAction]:
|
|
23
|
+
kernel = find_kernel_at_range(content, range, uri)
|
|
24
|
+
if not kernel:
|
|
25
|
+
return []
|
|
26
|
+
|
|
27
|
+
actions: list[CodeAction] = [
|
|
28
|
+
CodeAction(
|
|
29
|
+
title=f"Analyze Kernel: {kernel.name}",
|
|
30
|
+
kind=CodeActionKind.Source,
|
|
31
|
+
command=Command(
|
|
32
|
+
title=f"Analyze Kernel: {kernel.name}",
|
|
33
|
+
command="wafer.analyzeKernel",
|
|
34
|
+
arguments=[uri, kernel.name]
|
|
35
|
+
)
|
|
36
|
+
),
|
|
37
|
+
CodeAction(
|
|
38
|
+
title=f"Profile Kernel: {kernel.name}",
|
|
39
|
+
kind=CodeActionKind.Source,
|
|
40
|
+
command=Command(
|
|
41
|
+
title=f"Profile Kernel: {kernel.name}",
|
|
42
|
+
command="wafer.profileKernel",
|
|
43
|
+
arguments=[uri, kernel.name]
|
|
44
|
+
)
|
|
45
|
+
),
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
return actions
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import CodeLens, Command, Position, Range
|
|
3
|
+
|
|
4
|
+
from ..languages.registry import get_language_registry
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def handle_code_lens(uri: str, content: str) -> list[CodeLens]:
|
|
8
|
+
registry = get_language_registry()
|
|
9
|
+
language_info = registry.parse_file(uri, content)
|
|
10
|
+
|
|
11
|
+
if not language_info:
|
|
12
|
+
return []
|
|
13
|
+
|
|
14
|
+
lenses: list[CodeLens] = []
|
|
15
|
+
|
|
16
|
+
for kernel in language_info.kernels:
|
|
17
|
+
lens_range = Range(
|
|
18
|
+
start=Position(line=kernel.line, character=0),
|
|
19
|
+
end=Position(line=kernel.line, character=0)
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
analyze_command = Command(
|
|
23
|
+
title=f"Analyze {kernel.name}",
|
|
24
|
+
command="wafer.analyzeKernel",
|
|
25
|
+
arguments=[uri, kernel.name]
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
profile_command = Command(
|
|
29
|
+
title=f"Profile {kernel.name}",
|
|
30
|
+
command="wafer.profileKernel",
|
|
31
|
+
arguments=[uri, kernel.name]
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
lenses.append(CodeLens(
|
|
35
|
+
range=lens_range,
|
|
36
|
+
command=analyze_command
|
|
37
|
+
))
|
|
38
|
+
|
|
39
|
+
profile_range = Range(
|
|
40
|
+
start=Position(line=kernel.line, character=20),
|
|
41
|
+
end=Position(line=kernel.line, character=20)
|
|
42
|
+
)
|
|
43
|
+
lenses.append(CodeLens(
|
|
44
|
+
range=profile_range,
|
|
45
|
+
command=profile_command
|
|
46
|
+
))
|
|
47
|
+
|
|
48
|
+
return lenses
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import Diagnostic
|
|
3
|
+
|
|
4
|
+
from ..languages.registry import get_language_registry
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def handle_diagnostics(uri: str, content: str) -> list[Diagnostic]:
|
|
8
|
+
diagnostics: list[Diagnostic] = []
|
|
9
|
+
|
|
10
|
+
registry = get_language_registry()
|
|
11
|
+
language_info = registry.parse_file(uri, content)
|
|
12
|
+
|
|
13
|
+
if not language_info:
|
|
14
|
+
return diagnostics
|
|
15
|
+
|
|
16
|
+
return diagnostics
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind
|
|
3
|
+
|
|
4
|
+
from ..languages.registry import get_language_registry
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
|
|
8
|
+
registry = get_language_registry()
|
|
9
|
+
language_info = registry.parse_file(uri, content)
|
|
10
|
+
|
|
11
|
+
if not language_info:
|
|
12
|
+
return []
|
|
13
|
+
|
|
14
|
+
symbols: list[DocumentSymbol] = []
|
|
15
|
+
|
|
16
|
+
for kernel in language_info.kernels:
|
|
17
|
+
lines = content.split("\n")
|
|
18
|
+
kernel_line = lines[kernel.line] if kernel.line < len(lines) else ""
|
|
19
|
+
name_start = kernel_line.find(kernel.name)
|
|
20
|
+
name_end = name_start + len(kernel.name) if name_start >= 0 else 0
|
|
21
|
+
|
|
22
|
+
selection_range = Range(
|
|
23
|
+
start=Position(line=kernel.line, character=max(0, name_start)),
|
|
24
|
+
end=Position(line=kernel.line, character=name_end)
|
|
25
|
+
)
|
|
26
|
+
full_range = Range(
|
|
27
|
+
start=Position(line=kernel.line, character=0),
|
|
28
|
+
end=Position(line=min(kernel.line + 10, len(lines) - 1), character=0)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
symbols.append(DocumentSymbol(
|
|
32
|
+
name=kernel.name,
|
|
33
|
+
kind=SymbolKind.Function,
|
|
34
|
+
range=full_range,
|
|
35
|
+
selection_range=selection_range,
|
|
36
|
+
detail=f"GPU Kernel ({registry.get_language_name(kernel.language)})",
|
|
37
|
+
))
|
|
38
|
+
|
|
39
|
+
for layout in language_info.layouts:
|
|
40
|
+
lines = content.split("\n")
|
|
41
|
+
layout_line = lines[layout.line] if layout.line < len(lines) else ""
|
|
42
|
+
name_start = layout_line.find(layout.name)
|
|
43
|
+
name_end = name_start + len(layout.name) if name_start >= 0 else 0
|
|
44
|
+
|
|
45
|
+
detail = f"Layout: {layout.shape}" if layout.shape else "Layout"
|
|
46
|
+
|
|
47
|
+
selection_range = Range(
|
|
48
|
+
start=Position(line=layout.line, character=max(0, name_start)),
|
|
49
|
+
end=Position(line=layout.line, character=name_end)
|
|
50
|
+
)
|
|
51
|
+
full_range = Range(
|
|
52
|
+
start=Position(line=layout.line, character=0),
|
|
53
|
+
end=Position(line=layout.line, character=len(layout_line))
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
symbols.append(DocumentSymbol(
|
|
57
|
+
name=layout.name,
|
|
58
|
+
kind=SymbolKind.Variable,
|
|
59
|
+
range=full_range,
|
|
60
|
+
selection_range=selection_range,
|
|
61
|
+
detail=detail,
|
|
62
|
+
))
|
|
63
|
+
|
|
64
|
+
for struct in language_info.structs:
|
|
65
|
+
lines = content.split("\n")
|
|
66
|
+
struct_line = lines[struct.line] if struct.line < len(lines) else ""
|
|
67
|
+
name_start = struct_line.find(struct.name)
|
|
68
|
+
name_end = name_start + len(struct.name) if name_start >= 0 else 0
|
|
69
|
+
|
|
70
|
+
selection_range = Range(
|
|
71
|
+
start=Position(line=struct.line, character=max(0, name_start)),
|
|
72
|
+
end=Position(line=struct.line, character=name_end)
|
|
73
|
+
)
|
|
74
|
+
full_range = Range(
|
|
75
|
+
start=Position(line=struct.line, character=0),
|
|
76
|
+
end=Position(line=min(struct.line + 10, len(lines) - 1), character=0)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
symbols.append(DocumentSymbol(
|
|
80
|
+
name=struct.name,
|
|
81
|
+
kind=SymbolKind.Struct,
|
|
82
|
+
range=full_range,
|
|
83
|
+
selection_range=selection_range,
|
|
84
|
+
detail=f"Struct ({registry.get_language_name(struct.language)})",
|
|
85
|
+
))
|
|
86
|
+
|
|
87
|
+
return symbols
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import Hover, MarkupContent, MarkupKind, Position
|
|
3
|
+
|
|
4
|
+
from ..analyzers.compiler_integration import get_analysis_for_kernel
|
|
5
|
+
from ..analyzers.docs_index import DocsIndex
|
|
6
|
+
from ..languages.registry import get_language_registry
|
|
7
|
+
from ..languages.types import KernelInfo, LayoutInfo
|
|
8
|
+
from ..utils.lsp_helpers import get_decorator_at_position, get_word_at_position
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def find_kernel_at_position(
|
|
12
|
+
content: str, position: Position, uri: str
|
|
13
|
+
) -> KernelInfo | None:
|
|
14
|
+
registry = get_language_registry()
|
|
15
|
+
language_info = registry.parse_file(uri, content)
|
|
16
|
+
|
|
17
|
+
if not language_info:
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
word = get_word_at_position(content, position)
|
|
21
|
+
|
|
22
|
+
for kernel in language_info.kernels:
|
|
23
|
+
if kernel.name == word:
|
|
24
|
+
if position.line >= kernel.line:
|
|
25
|
+
return kernel
|
|
26
|
+
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def find_layout_at_position(
|
|
31
|
+
content: str, position: Position, uri: str
|
|
32
|
+
) -> LayoutInfo | None:
|
|
33
|
+
registry = get_language_registry()
|
|
34
|
+
language_info = registry.parse_file(uri, content)
|
|
35
|
+
|
|
36
|
+
if not language_info:
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
word = get_word_at_position(content, position)
|
|
40
|
+
|
|
41
|
+
for layout in language_info.layouts:
|
|
42
|
+
if layout.name == word:
|
|
43
|
+
return layout
|
|
44
|
+
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def format_kernel_hover(kernel: KernelInfo, analysis: dict | None) -> str:
|
|
49
|
+
registry = get_language_registry()
|
|
50
|
+
language_name = registry.get_language_name(kernel.language) or kernel.language
|
|
51
|
+
|
|
52
|
+
lines = [f"**GPU Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
|
|
53
|
+
|
|
54
|
+
if kernel.docstring:
|
|
55
|
+
lines.append(kernel.docstring)
|
|
56
|
+
lines.append("")
|
|
57
|
+
|
|
58
|
+
if kernel.parameters:
|
|
59
|
+
params_str = ", ".join(kernel.parameters)
|
|
60
|
+
lines.append(f"**Parameters:** `{params_str}`")
|
|
61
|
+
lines.append("")
|
|
62
|
+
|
|
63
|
+
if analysis:
|
|
64
|
+
lines.append("**Analysis:**")
|
|
65
|
+
if "layouts" in analysis:
|
|
66
|
+
lines.append(f"- Layouts: {analysis['layouts']}")
|
|
67
|
+
if "memory_paths" in analysis:
|
|
68
|
+
lines.append(f"- Memory paths: {analysis['memory_paths']}")
|
|
69
|
+
if "pipeline_stages" in analysis:
|
|
70
|
+
lines.append(f"- Pipeline stages: {analysis['pipeline_stages']}")
|
|
71
|
+
|
|
72
|
+
return "\n".join(lines)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def format_kernel_hover_basic(kernel: KernelInfo) -> str:
|
|
76
|
+
return format_kernel_hover(kernel, None)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def format_decorator_hover(decorator_name: str, function_name: str | None = None) -> str:
|
|
80
|
+
lines = []
|
|
81
|
+
|
|
82
|
+
if decorator_name == "cute.kernel" or decorator_name == "kernel":
|
|
83
|
+
lines.append("**@cute.kernel**")
|
|
84
|
+
lines.append("")
|
|
85
|
+
lines.append("CuTeDSL kernel decorator. Marks a function as a GPU kernel.")
|
|
86
|
+
lines.append("")
|
|
87
|
+
lines.append("**Usage:**")
|
|
88
|
+
lines.append("```python")
|
|
89
|
+
lines.append("@cute.kernel")
|
|
90
|
+
lines.append("def my_kernel(a: cute.Tensor, b: cute.Tensor):")
|
|
91
|
+
lines.append(" # Kernel implementation")
|
|
92
|
+
lines.append(" pass")
|
|
93
|
+
lines.append("```")
|
|
94
|
+
lines.append("")
|
|
95
|
+
lines.append("**Features:**")
|
|
96
|
+
lines.append("- Automatic GPU code generation")
|
|
97
|
+
lines.append("- Tensor layout optimization")
|
|
98
|
+
lines.append("- Memory access pattern analysis")
|
|
99
|
+
|
|
100
|
+
if function_name:
|
|
101
|
+
lines.append("")
|
|
102
|
+
lines.append(f"Applied to: `{function_name}()`")
|
|
103
|
+
|
|
104
|
+
elif decorator_name == "cute.struct" or decorator_name == "struct":
|
|
105
|
+
lines.append("**@cute.struct**")
|
|
106
|
+
lines.append("")
|
|
107
|
+
lines.append("CuTeDSL struct decorator. Marks a class as a GPU struct.")
|
|
108
|
+
lines.append("")
|
|
109
|
+
lines.append("**Usage:**")
|
|
110
|
+
lines.append("```python")
|
|
111
|
+
lines.append("@cute.struct")
|
|
112
|
+
lines.append("class MyStruct:")
|
|
113
|
+
lines.append(" field1: int")
|
|
114
|
+
lines.append(" field2: float")
|
|
115
|
+
lines.append("```")
|
|
116
|
+
|
|
117
|
+
if function_name:
|
|
118
|
+
lines.append("")
|
|
119
|
+
lines.append(f"Applied to: `{function_name}`")
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
lines.append(f"**{decorator_name}**")
|
|
123
|
+
lines.append("")
|
|
124
|
+
lines.append("CuTeDSL decorator")
|
|
125
|
+
|
|
126
|
+
docs = DocsIndex()
|
|
127
|
+
doc_link = docs.get_doc_for_concept("kernel" if "kernel" in decorator_name else "struct")
|
|
128
|
+
if doc_link:
|
|
129
|
+
lines.append("")
|
|
130
|
+
lines.append(f"[Documentation]({doc_link})")
|
|
131
|
+
|
|
132
|
+
return "\n".join(lines)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
|
|
136
|
+
test_message = "**HEYOOOOOO** 🎉\n\n"
|
|
137
|
+
|
|
138
|
+
decorator_info = get_decorator_at_position(content, position)
|
|
139
|
+
if decorator_info:
|
|
140
|
+
decorator_name, function_line = decorator_info
|
|
141
|
+
|
|
142
|
+
function_name = None
|
|
143
|
+
lines = content.split("\n")
|
|
144
|
+
if function_line < len(lines):
|
|
145
|
+
func_line = lines[function_line].strip()
|
|
146
|
+
if func_line.startswith("def "):
|
|
147
|
+
func_name_start = func_line.find("def ") + 4
|
|
148
|
+
func_name_end = func_line.find("(", func_name_start)
|
|
149
|
+
if func_name_end > func_name_start:
|
|
150
|
+
function_name = func_line[func_name_start:func_name_end].strip()
|
|
151
|
+
elif func_line.startswith("class "):
|
|
152
|
+
class_name_start = func_line.find("class ") + 6
|
|
153
|
+
class_name_end = func_line.find(":", class_name_start)
|
|
154
|
+
if class_name_end > class_name_start:
|
|
155
|
+
function_name = func_line[class_name_start:class_name_end].strip()
|
|
156
|
+
|
|
157
|
+
hover_content = test_message + format_decorator_hover(decorator_name, function_name)
|
|
158
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
159
|
+
|
|
160
|
+
word = get_word_at_position(content, position)
|
|
161
|
+
if word == "cute" or word.startswith("cute."):
|
|
162
|
+
hover_lines = [
|
|
163
|
+
test_message,
|
|
164
|
+
"**cutlass.cute**",
|
|
165
|
+
"",
|
|
166
|
+
"CuTeDSL (CUDA Unified Tensor Expression) library for GPU programming.",
|
|
167
|
+
"",
|
|
168
|
+
"**Key Features:**",
|
|
169
|
+
"- `@cute.kernel` - Define GPU kernels",
|
|
170
|
+
"- `@cute.struct` - Define GPU structs",
|
|
171
|
+
"- `cute.make_layout()` - Create tensor layouts",
|
|
172
|
+
"- `cute.Tensor` - Tensor type annotations",
|
|
173
|
+
"",
|
|
174
|
+
"[Documentation](https://github.com/NVIDIA/cutlass)"
|
|
175
|
+
]
|
|
176
|
+
hover_content = "\n".join(hover_lines)
|
|
177
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
178
|
+
|
|
179
|
+
kernel = find_kernel_at_position(content, position, uri)
|
|
180
|
+
if kernel:
|
|
181
|
+
analysis = get_analysis_for_kernel(uri, kernel.name)
|
|
182
|
+
|
|
183
|
+
if analysis:
|
|
184
|
+
hover_content = test_message + format_kernel_hover(kernel, analysis)
|
|
185
|
+
else:
|
|
186
|
+
hover_content = test_message + format_kernel_hover_basic(kernel)
|
|
187
|
+
|
|
188
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
189
|
+
|
|
190
|
+
layout = find_layout_at_position(content, position, uri)
|
|
191
|
+
if layout:
|
|
192
|
+
docs = DocsIndex()
|
|
193
|
+
doc_link = docs.get_doc_for_concept("layout")
|
|
194
|
+
|
|
195
|
+
hover_lines = [
|
|
196
|
+
test_message,
|
|
197
|
+
f"**Layout: {layout.name}**",
|
|
198
|
+
""
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
if layout.shape:
|
|
202
|
+
hover_lines.append(f"Shape: `{layout.shape}`")
|
|
203
|
+
if layout.stride:
|
|
204
|
+
hover_lines.append(f"Stride: `{layout.stride}`")
|
|
205
|
+
|
|
206
|
+
if doc_link:
|
|
207
|
+
hover_lines.append("")
|
|
208
|
+
hover_lines.append(f"[Documentation]({doc_link})")
|
|
209
|
+
|
|
210
|
+
hover_content = "\n".join(hover_lines)
|
|
211
|
+
|
|
212
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
213
|
+
|
|
214
|
+
hover_content = test_message + "Hover is working! Move your cursor over symbols to see more info."
|
|
215
|
+
return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import InlayHint, InlayHintKind, Position, Range
|
|
3
|
+
|
|
4
|
+
from ..languages.registry import get_language_registry
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
|
|
8
|
+
registry = get_language_registry()
|
|
9
|
+
language_info = registry.parse_file(uri, content)
|
|
10
|
+
|
|
11
|
+
if not language_info:
|
|
12
|
+
return []
|
|
13
|
+
|
|
14
|
+
hints: list[InlayHint] = []
|
|
15
|
+
lines = content.split("\n")
|
|
16
|
+
|
|
17
|
+
for layout in language_info.layouts:
|
|
18
|
+
if layout.line < range.start.line or layout.line > range.end.line:
|
|
19
|
+
continue
|
|
20
|
+
|
|
21
|
+
layout_line = lines[layout.line] if layout.line < len(lines) else ""
|
|
22
|
+
|
|
23
|
+
if "=" in layout_line:
|
|
24
|
+
equals_pos = layout_line.find("=")
|
|
25
|
+
hint_text = ": Layout"
|
|
26
|
+
if layout.shape:
|
|
27
|
+
hint_text = f": Layout[Shape{layout.shape}]"
|
|
28
|
+
|
|
29
|
+
hint_position = Position(
|
|
30
|
+
line=layout.line,
|
|
31
|
+
character=equals_pos + 1
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
hints.append(InlayHint(
|
|
35
|
+
position=hint_position,
|
|
36
|
+
label=hint_text,
|
|
37
|
+
kind=InlayHintKind.Type,
|
|
38
|
+
padding_left=True,
|
|
39
|
+
padding_right=False
|
|
40
|
+
))
|
|
41
|
+
|
|
42
|
+
for kernel in language_info.kernels:
|
|
43
|
+
if kernel.line < range.start.line or kernel.line > range.end.line:
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
kernel_line = lines[kernel.line] if kernel.line < len(lines) else ""
|
|
47
|
+
|
|
48
|
+
if "def " in kernel_line and "(" in kernel_line:
|
|
49
|
+
paren_pos = kernel_line.find("(")
|
|
50
|
+
hint_text = " -> Kernel"
|
|
51
|
+
|
|
52
|
+
hint_position = Position(
|
|
53
|
+
line=kernel.line,
|
|
54
|
+
character=paren_pos
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
hints.append(InlayHint(
|
|
58
|
+
position=hint_position,
|
|
59
|
+
label=hint_text,
|
|
60
|
+
kind=InlayHintKind.Type,
|
|
61
|
+
padding_left=True,
|
|
62
|
+
padding_right=True
|
|
63
|
+
))
|
|
64
|
+
|
|
65
|
+
return hints
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import SemanticTokens, SemanticTokensLegend
|
|
3
|
+
|
|
4
|
+
from ..languages.registry import get_language_registry
|
|
5
|
+
|
|
6
|
+
TOKEN_TYPES = [
|
|
7
|
+
"kernel",
|
|
8
|
+
"layout",
|
|
9
|
+
"struct",
|
|
10
|
+
"decorator",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TOKEN_MODIFIERS = [
|
|
14
|
+
"definition",
|
|
15
|
+
"declaration",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
SEMANTIC_TOKENS_LEGEND = SemanticTokensLegend(
|
|
19
|
+
token_types=TOKEN_TYPES,
|
|
20
|
+
token_modifiers=TOKEN_MODIFIERS
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
|
|
25
|
+
registry = get_language_registry()
|
|
26
|
+
language_info = registry.parse_file(uri, content)
|
|
27
|
+
|
|
28
|
+
if not language_info:
|
|
29
|
+
return SemanticTokens(data=[])
|
|
30
|
+
|
|
31
|
+
tokens: list[int] = []
|
|
32
|
+
lines = content.split("\n")
|
|
33
|
+
prev_line = 0
|
|
34
|
+
prev_char = 0
|
|
35
|
+
|
|
36
|
+
for kernel in language_info.kernels:
|
|
37
|
+
if kernel.line >= len(lines):
|
|
38
|
+
continue
|
|
39
|
+
|
|
40
|
+
kernel_line = lines[kernel.line]
|
|
41
|
+
name_start = kernel_line.find(kernel.name)
|
|
42
|
+
|
|
43
|
+
if name_start >= 0:
|
|
44
|
+
delta_line = kernel.line - prev_line
|
|
45
|
+
delta_char = name_start - (prev_char if delta_line == 0 else 0)
|
|
46
|
+
|
|
47
|
+
tokens.extend([
|
|
48
|
+
delta_line,
|
|
49
|
+
delta_char,
|
|
50
|
+
len(kernel.name),
|
|
51
|
+
TOKEN_TYPES.index("kernel"),
|
|
52
|
+
0
|
|
53
|
+
])
|
|
54
|
+
|
|
55
|
+
prev_line = kernel.line
|
|
56
|
+
prev_char = name_start + len(kernel.name)
|
|
57
|
+
|
|
58
|
+
for layout in language_info.layouts:
|
|
59
|
+
if layout.line >= len(lines):
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
layout_line = lines[layout.line]
|
|
63
|
+
name_start = layout_line.find(layout.name)
|
|
64
|
+
|
|
65
|
+
if name_start >= 0:
|
|
66
|
+
delta_line = layout.line - prev_line
|
|
67
|
+
delta_char = name_start - (prev_char if delta_line == 0 else 0)
|
|
68
|
+
|
|
69
|
+
tokens.extend([
|
|
70
|
+
delta_line,
|
|
71
|
+
delta_char,
|
|
72
|
+
len(layout.name),
|
|
73
|
+
TOKEN_TYPES.index("layout"),
|
|
74
|
+
0
|
|
75
|
+
])
|
|
76
|
+
|
|
77
|
+
prev_line = layout.line
|
|
78
|
+
prev_char = name_start + len(layout.name)
|
|
79
|
+
|
|
80
|
+
for struct in language_info.structs:
|
|
81
|
+
if struct.line >= len(lines):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
struct_line = lines[struct.line]
|
|
85
|
+
name_start = struct_line.find(struct.name)
|
|
86
|
+
|
|
87
|
+
if name_start >= 0:
|
|
88
|
+
delta_line = struct.line - prev_line
|
|
89
|
+
delta_char = name_start - (prev_char if delta_line == 0 else 0)
|
|
90
|
+
|
|
91
|
+
tokens.extend([
|
|
92
|
+
delta_line,
|
|
93
|
+
delta_char,
|
|
94
|
+
len(struct.name),
|
|
95
|
+
TOKEN_TYPES.index("struct"),
|
|
96
|
+
0
|
|
97
|
+
])
|
|
98
|
+
|
|
99
|
+
prev_line = struct.line
|
|
100
|
+
prev_char = name_start + len(struct.name)
|
|
101
|
+
|
|
102
|
+
for i, line in enumerate(lines):
|
|
103
|
+
if "@cute.kernel" in line or "@cute.struct" in line:
|
|
104
|
+
decorator_start = line.find("@")
|
|
105
|
+
if decorator_start >= 0:
|
|
106
|
+
decorator_end = line.find(" ", decorator_start)
|
|
107
|
+
if decorator_end == -1:
|
|
108
|
+
decorator_end = len(line)
|
|
109
|
+
|
|
110
|
+
delta_line = i - prev_line
|
|
111
|
+
delta_char = decorator_start - (prev_char if delta_line == 0 else 0)
|
|
112
|
+
|
|
113
|
+
tokens.extend([
|
|
114
|
+
delta_line,
|
|
115
|
+
delta_char,
|
|
116
|
+
decorator_end - decorator_start,
|
|
117
|
+
TOKEN_TYPES.index("decorator"),
|
|
118
|
+
0
|
|
119
|
+
])
|
|
120
|
+
|
|
121
|
+
prev_line = i
|
|
122
|
+
prev_char = decorator_end
|
|
123
|
+
|
|
124
|
+
return SemanticTokens(data=tokens)
|