wafer-lsp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. wafer_lsp/__init__.py +1 -0
  2. wafer_lsp/__main__.py +9 -0
  3. wafer_lsp/analyzers/__init__.py +0 -0
  4. wafer_lsp/analyzers/compiler_integration.py +16 -0
  5. wafer_lsp/analyzers/docs_index.py +36 -0
  6. wafer_lsp/handlers/__init__.py +0 -0
  7. wafer_lsp/handlers/code_action.py +48 -0
  8. wafer_lsp/handlers/code_lens.py +48 -0
  9. wafer_lsp/handlers/completion.py +6 -0
  10. wafer_lsp/handlers/diagnostics.py +16 -0
  11. wafer_lsp/handlers/document_symbol.py +87 -0
  12. wafer_lsp/handlers/hover.py +215 -0
  13. wafer_lsp/handlers/inlay_hint.py +65 -0
  14. wafer_lsp/handlers/semantic_tokens.py +124 -0
  15. wafer_lsp/handlers/workspace_symbol.py +87 -0
  16. wafer_lsp/languages/README.md +195 -0
  17. wafer_lsp/languages/__init__.py +17 -0
  18. wafer_lsp/languages/converter.py +88 -0
  19. wafer_lsp/languages/detector.py +34 -0
  20. wafer_lsp/languages/parser_manager.py +33 -0
  21. wafer_lsp/languages/registry.py +99 -0
  22. wafer_lsp/languages/types.py +37 -0
  23. wafer_lsp/parsers/__init__.py +18 -0
  24. wafer_lsp/parsers/base_parser.py +9 -0
  25. wafer_lsp/parsers/cuda_parser.py +95 -0
  26. wafer_lsp/parsers/cutedsl_parser.py +114 -0
  27. wafer_lsp/server.py +58 -0
  28. wafer_lsp/services/__init__.py +21 -0
  29. wafer_lsp/services/analysis_service.py +22 -0
  30. wafer_lsp/services/docs_service.py +40 -0
  31. wafer_lsp/services/document_service.py +20 -0
  32. wafer_lsp/services/hover_service.py +237 -0
  33. wafer_lsp/services/language_registry_service.py +26 -0
  34. wafer_lsp/services/position_service.py +77 -0
  35. wafer_lsp/utils/__init__.py +0 -0
  36. wafer_lsp/utils/lsp_helpers.py +79 -0
  37. wafer_lsp-0.1.0.dist-info/METADATA +57 -0
  38. wafer_lsp-0.1.0.dist-info/RECORD +40 -0
  39. wafer_lsp-0.1.0.dist-info/WHEEL +4 -0
  40. wafer_lsp-0.1.0.dist-info/entry_points.txt +2 -0
wafer_lsp/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
wafer_lsp/__main__.py ADDED
@@ -0,0 +1,9 @@
1
+ from wafer_lsp.server import server
2
+
3
+
4
+ def main():
5
+ server.start_io()
6
+
7
+
8
+ if __name__ == "__main__":
9
+ main()
File without changes
@@ -0,0 +1,16 @@
1
+ from typing import Any
2
+
3
+ _analysis_cache: dict[str, dict[str, Any]] = {}
4
+
5
+
6
+ def get_analysis_for_kernel(uri: str, kernel_name: str) -> dict[str, Any] | None:
7
+ cache_key = f"{uri}:{kernel_name}"
8
+
9
+ if cache_key in _analysis_cache:
10
+ return _analysis_cache[cache_key]
11
+
12
+ return None
13
+
14
+
15
+ def clear_cache():
16
+ _analysis_cache.clear()
@@ -0,0 +1,36 @@
1
+ from pathlib import Path
2
+
3
+
4
+ class DocsIndex:
5
+
6
+ def __init__(self, docs_path: str | None = None):
7
+ if docs_path:
8
+ self.docs_path = Path(docs_path)
9
+ else:
10
+ self.docs_path = Path(__file__).parent.parent.parent.parent.parent / \
11
+ "curriculum" / "cutlass-docs" / "cutedsl-docs"
12
+
13
+ self.index = self._build_index()
14
+
15
+ def _build_index(self) -> dict[str, list[str]]:
16
+ return {
17
+ "layout": [
18
+ "intro-to-cutedsl.md",
19
+ "partitioning-strategies-inner-outer-threadvalue.md"
20
+ ],
21
+ "TMA": ["blackwell-tutorial-fp16-gemm-0.md"],
22
+ "TMEM": ["colfax-blackwell-umma-tensor-memory-part1.md"],
23
+ "kernel": ["blackwell-tutorial-fp16-gemm-0.md"],
24
+ "struct": ["blackwell-tutorial-fp16-gemm-0.md"],
25
+ "pipeline": ["blackwell-tutorial-fp16-gemm-0.md"],
26
+ "MMA": ["mma-atoms-fundamentals-sm70-example.md"],
27
+ }
28
+
29
+ def get_doc_for_concept(self, concept: str) -> str | None:
30
+ concept_lower = concept.lower()
31
+ if self.index.get(concept_lower):
32
+ doc_file = self.index[concept_lower][0]
33
+ doc_path = self.docs_path / doc_file
34
+ if doc_path.exists():
35
+ return str(doc_path)
36
+ return None
File without changes
@@ -0,0 +1,48 @@
1
+
2
+ from lsprotocol.types import CodeAction, CodeActionKind, Command, Range
3
+
4
+ from ..languages.registry import get_language_registry
5
+ from ..languages.types import KernelInfo
6
+
7
+
8
+ def find_kernel_at_range(content: str, range: Range, uri: str) -> KernelInfo | None:
9
+ registry = get_language_registry()
10
+ language_info = registry.parse_file(uri, content)
11
+
12
+ if not language_info:
13
+ return None
14
+
15
+ for kernel in language_info.kernels:
16
+ if kernel.line <= range.start.line <= kernel.line + 50:
17
+ return kernel
18
+
19
+ return None
20
+
21
+
22
+ def handle_code_action(uri: str, range: Range, content: str) -> list[CodeAction]:
23
+ kernel = find_kernel_at_range(content, range, uri)
24
+ if not kernel:
25
+ return []
26
+
27
+ actions: list[CodeAction] = [
28
+ CodeAction(
29
+ title=f"Analyze Kernel: {kernel.name}",
30
+ kind=CodeActionKind.Source,
31
+ command=Command(
32
+ title=f"Analyze Kernel: {kernel.name}",
33
+ command="wafer.analyzeKernel",
34
+ arguments=[uri, kernel.name]
35
+ )
36
+ ),
37
+ CodeAction(
38
+ title=f"Profile Kernel: {kernel.name}",
39
+ kind=CodeActionKind.Source,
40
+ command=Command(
41
+ title=f"Profile Kernel: {kernel.name}",
42
+ command="wafer.profileKernel",
43
+ arguments=[uri, kernel.name]
44
+ )
45
+ ),
46
+ ]
47
+
48
+ return actions
@@ -0,0 +1,48 @@
1
+
2
+ from lsprotocol.types import CodeLens, Command, Position, Range
3
+
4
+ from ..languages.registry import get_language_registry
5
+
6
+
7
+ def handle_code_lens(uri: str, content: str) -> list[CodeLens]:
8
+ registry = get_language_registry()
9
+ language_info = registry.parse_file(uri, content)
10
+
11
+ if not language_info:
12
+ return []
13
+
14
+ lenses: list[CodeLens] = []
15
+
16
+ for kernel in language_info.kernels:
17
+ lens_range = Range(
18
+ start=Position(line=kernel.line, character=0),
19
+ end=Position(line=kernel.line, character=0)
20
+ )
21
+
22
+ analyze_command = Command(
23
+ title=f"Analyze {kernel.name}",
24
+ command="wafer.analyzeKernel",
25
+ arguments=[uri, kernel.name]
26
+ )
27
+
28
+ profile_command = Command(
29
+ title=f"Profile {kernel.name}",
30
+ command="wafer.profileKernel",
31
+ arguments=[uri, kernel.name]
32
+ )
33
+
34
+ lenses.append(CodeLens(
35
+ range=lens_range,
36
+ command=analyze_command
37
+ ))
38
+
39
+ profile_range = Range(
40
+ start=Position(line=kernel.line, character=20),
41
+ end=Position(line=kernel.line, character=20)
42
+ )
43
+ lenses.append(CodeLens(
44
+ range=profile_range,
45
+ command=profile_command
46
+ ))
47
+
48
+ return lenses
@@ -0,0 +1,6 @@
1
+
2
+ from lsprotocol.types import CompletionItem
3
+
4
+
5
+ def handle_completion(uri: str, position, content: str) -> list[CompletionItem]:
6
+ return []
@@ -0,0 +1,16 @@
1
+
2
+ from lsprotocol.types import Diagnostic
3
+
4
+ from ..languages.registry import get_language_registry
5
+
6
+
7
+ def handle_diagnostics(uri: str, content: str) -> list[Diagnostic]:
8
+ diagnostics: list[Diagnostic] = []
9
+
10
+ registry = get_language_registry()
11
+ language_info = registry.parse_file(uri, content)
12
+
13
+ if not language_info:
14
+ return diagnostics
15
+
16
+ return diagnostics
@@ -0,0 +1,87 @@
1
+
2
+ from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind
3
+
4
+ from ..languages.registry import get_language_registry
5
+
6
+
7
+ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
8
+ registry = get_language_registry()
9
+ language_info = registry.parse_file(uri, content)
10
+
11
+ if not language_info:
12
+ return []
13
+
14
+ symbols: list[DocumentSymbol] = []
15
+
16
+ for kernel in language_info.kernels:
17
+ lines = content.split("\n")
18
+ kernel_line = lines[kernel.line] if kernel.line < len(lines) else ""
19
+ name_start = kernel_line.find(kernel.name)
20
+ name_end = name_start + len(kernel.name) if name_start >= 0 else 0
21
+
22
+ selection_range = Range(
23
+ start=Position(line=kernel.line, character=max(0, name_start)),
24
+ end=Position(line=kernel.line, character=name_end)
25
+ )
26
+ full_range = Range(
27
+ start=Position(line=kernel.line, character=0),
28
+ end=Position(line=min(kernel.line + 10, len(lines) - 1), character=0)
29
+ )
30
+
31
+ symbols.append(DocumentSymbol(
32
+ name=kernel.name,
33
+ kind=SymbolKind.Function,
34
+ range=full_range,
35
+ selection_range=selection_range,
36
+ detail=f"GPU Kernel ({registry.get_language_name(kernel.language)})",
37
+ ))
38
+
39
+ for layout in language_info.layouts:
40
+ lines = content.split("\n")
41
+ layout_line = lines[layout.line] if layout.line < len(lines) else ""
42
+ name_start = layout_line.find(layout.name)
43
+ name_end = name_start + len(layout.name) if name_start >= 0 else 0
44
+
45
+ detail = f"Layout: {layout.shape}" if layout.shape else "Layout"
46
+
47
+ selection_range = Range(
48
+ start=Position(line=layout.line, character=max(0, name_start)),
49
+ end=Position(line=layout.line, character=name_end)
50
+ )
51
+ full_range = Range(
52
+ start=Position(line=layout.line, character=0),
53
+ end=Position(line=layout.line, character=len(layout_line))
54
+ )
55
+
56
+ symbols.append(DocumentSymbol(
57
+ name=layout.name,
58
+ kind=SymbolKind.Variable,
59
+ range=full_range,
60
+ selection_range=selection_range,
61
+ detail=detail,
62
+ ))
63
+
64
+ for struct in language_info.structs:
65
+ lines = content.split("\n")
66
+ struct_line = lines[struct.line] if struct.line < len(lines) else ""
67
+ name_start = struct_line.find(struct.name)
68
+ name_end = name_start + len(struct.name) if name_start >= 0 else 0
69
+
70
+ selection_range = Range(
71
+ start=Position(line=struct.line, character=max(0, name_start)),
72
+ end=Position(line=struct.line, character=name_end)
73
+ )
74
+ full_range = Range(
75
+ start=Position(line=struct.line, character=0),
76
+ end=Position(line=min(struct.line + 10, len(lines) - 1), character=0)
77
+ )
78
+
79
+ symbols.append(DocumentSymbol(
80
+ name=struct.name,
81
+ kind=SymbolKind.Struct,
82
+ range=full_range,
83
+ selection_range=selection_range,
84
+ detail=f"Struct ({registry.get_language_name(struct.language)})",
85
+ ))
86
+
87
+ return symbols
@@ -0,0 +1,215 @@
1
+
2
+ from lsprotocol.types import Hover, MarkupContent, MarkupKind, Position
3
+
4
+ from ..analyzers.compiler_integration import get_analysis_for_kernel
5
+ from ..analyzers.docs_index import DocsIndex
6
+ from ..languages.registry import get_language_registry
7
+ from ..languages.types import KernelInfo, LayoutInfo
8
+ from ..utils.lsp_helpers import get_decorator_at_position, get_word_at_position
9
+
10
+
11
+ def find_kernel_at_position(
12
+ content: str, position: Position, uri: str
13
+ ) -> KernelInfo | None:
14
+ registry = get_language_registry()
15
+ language_info = registry.parse_file(uri, content)
16
+
17
+ if not language_info:
18
+ return None
19
+
20
+ word = get_word_at_position(content, position)
21
+
22
+ for kernel in language_info.kernels:
23
+ if kernel.name == word:
24
+ if position.line >= kernel.line:
25
+ return kernel
26
+
27
+ return None
28
+
29
+
30
+ def find_layout_at_position(
31
+ content: str, position: Position, uri: str
32
+ ) -> LayoutInfo | None:
33
+ registry = get_language_registry()
34
+ language_info = registry.parse_file(uri, content)
35
+
36
+ if not language_info:
37
+ return None
38
+
39
+ word = get_word_at_position(content, position)
40
+
41
+ for layout in language_info.layouts:
42
+ if layout.name == word:
43
+ return layout
44
+
45
+ return None
46
+
47
+
48
+ def format_kernel_hover(kernel: KernelInfo, analysis: dict | None) -> str:
49
+ registry = get_language_registry()
50
+ language_name = registry.get_language_name(kernel.language) or kernel.language
51
+
52
+ lines = [f"**GPU Kernel: {kernel.name}**", f"*Language: {language_name}*", ""]
53
+
54
+ if kernel.docstring:
55
+ lines.append(kernel.docstring)
56
+ lines.append("")
57
+
58
+ if kernel.parameters:
59
+ params_str = ", ".join(kernel.parameters)
60
+ lines.append(f"**Parameters:** `{params_str}`")
61
+ lines.append("")
62
+
63
+ if analysis:
64
+ lines.append("**Analysis:**")
65
+ if "layouts" in analysis:
66
+ lines.append(f"- Layouts: {analysis['layouts']}")
67
+ if "memory_paths" in analysis:
68
+ lines.append(f"- Memory paths: {analysis['memory_paths']}")
69
+ if "pipeline_stages" in analysis:
70
+ lines.append(f"- Pipeline stages: {analysis['pipeline_stages']}")
71
+
72
+ return "\n".join(lines)
73
+
74
+
75
+ def format_kernel_hover_basic(kernel: KernelInfo) -> str:
76
+ return format_kernel_hover(kernel, None)
77
+
78
+
79
+ def format_decorator_hover(decorator_name: str, function_name: str | None = None) -> str:
80
+ lines = []
81
+
82
+ if decorator_name == "cute.kernel" or decorator_name == "kernel":
83
+ lines.append("**@cute.kernel**")
84
+ lines.append("")
85
+ lines.append("CuTeDSL kernel decorator. Marks a function as a GPU kernel.")
86
+ lines.append("")
87
+ lines.append("**Usage:**")
88
+ lines.append("```python")
89
+ lines.append("@cute.kernel")
90
+ lines.append("def my_kernel(a: cute.Tensor, b: cute.Tensor):")
91
+ lines.append(" # Kernel implementation")
92
+ lines.append(" pass")
93
+ lines.append("```")
94
+ lines.append("")
95
+ lines.append("**Features:**")
96
+ lines.append("- Automatic GPU code generation")
97
+ lines.append("- Tensor layout optimization")
98
+ lines.append("- Memory access pattern analysis")
99
+
100
+ if function_name:
101
+ lines.append("")
102
+ lines.append(f"Applied to: `{function_name}()`")
103
+
104
+ elif decorator_name == "cute.struct" or decorator_name == "struct":
105
+ lines.append("**@cute.struct**")
106
+ lines.append("")
107
+ lines.append("CuTeDSL struct decorator. Marks a class as a GPU struct.")
108
+ lines.append("")
109
+ lines.append("**Usage:**")
110
+ lines.append("```python")
111
+ lines.append("@cute.struct")
112
+ lines.append("class MyStruct:")
113
+ lines.append(" field1: int")
114
+ lines.append(" field2: float")
115
+ lines.append("```")
116
+
117
+ if function_name:
118
+ lines.append("")
119
+ lines.append(f"Applied to: `{function_name}`")
120
+
121
+ else:
122
+ lines.append(f"**{decorator_name}**")
123
+ lines.append("")
124
+ lines.append("CuTeDSL decorator")
125
+
126
+ docs = DocsIndex()
127
+ doc_link = docs.get_doc_for_concept("kernel" if "kernel" in decorator_name else "struct")
128
+ if doc_link:
129
+ lines.append("")
130
+ lines.append(f"[Documentation]({doc_link})")
131
+
132
+ return "\n".join(lines)
133
+
134
+
135
+ def handle_hover(uri: str, position: Position, content: str) -> Hover | None:
136
+ test_message = "**HEYOOOOOO** 🎉\n\n"
137
+
138
+ decorator_info = get_decorator_at_position(content, position)
139
+ if decorator_info:
140
+ decorator_name, function_line = decorator_info
141
+
142
+ function_name = None
143
+ lines = content.split("\n")
144
+ if function_line < len(lines):
145
+ func_line = lines[function_line].strip()
146
+ if func_line.startswith("def "):
147
+ func_name_start = func_line.find("def ") + 4
148
+ func_name_end = func_line.find("(", func_name_start)
149
+ if func_name_end > func_name_start:
150
+ function_name = func_line[func_name_start:func_name_end].strip()
151
+ elif func_line.startswith("class "):
152
+ class_name_start = func_line.find("class ") + 6
153
+ class_name_end = func_line.find(":", class_name_start)
154
+ if class_name_end > class_name_start:
155
+ function_name = func_line[class_name_start:class_name_end].strip()
156
+
157
+ hover_content = test_message + format_decorator_hover(decorator_name, function_name)
158
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
159
+
160
+ word = get_word_at_position(content, position)
161
+ if word == "cute" or word.startswith("cute."):
162
+ hover_lines = [
163
+ test_message,
164
+ "**cutlass.cute**",
165
+ "",
166
+ "CuTeDSL (CUDA Unified Tensor Expression) library for GPU programming.",
167
+ "",
168
+ "**Key Features:**",
169
+ "- `@cute.kernel` - Define GPU kernels",
170
+ "- `@cute.struct` - Define GPU structs",
171
+ "- `cute.make_layout()` - Create tensor layouts",
172
+ "- `cute.Tensor` - Tensor type annotations",
173
+ "",
174
+ "[Documentation](https://github.com/NVIDIA/cutlass)"
175
+ ]
176
+ hover_content = "\n".join(hover_lines)
177
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
178
+
179
+ kernel = find_kernel_at_position(content, position, uri)
180
+ if kernel:
181
+ analysis = get_analysis_for_kernel(uri, kernel.name)
182
+
183
+ if analysis:
184
+ hover_content = test_message + format_kernel_hover(kernel, analysis)
185
+ else:
186
+ hover_content = test_message + format_kernel_hover_basic(kernel)
187
+
188
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
189
+
190
+ layout = find_layout_at_position(content, position, uri)
191
+ if layout:
192
+ docs = DocsIndex()
193
+ doc_link = docs.get_doc_for_concept("layout")
194
+
195
+ hover_lines = [
196
+ test_message,
197
+ f"**Layout: {layout.name}**",
198
+ ""
199
+ ]
200
+
201
+ if layout.shape:
202
+ hover_lines.append(f"Shape: `{layout.shape}`")
203
+ if layout.stride:
204
+ hover_lines.append(f"Stride: `{layout.stride}`")
205
+
206
+ if doc_link:
207
+ hover_lines.append("")
208
+ hover_lines.append(f"[Documentation]({doc_link})")
209
+
210
+ hover_content = "\n".join(hover_lines)
211
+
212
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
213
+
214
+ hover_content = test_message + "Hover is working! Move your cursor over symbols to see more info."
215
+ return Hover(contents=MarkupContent(kind=MarkupKind.Markdown, value=hover_content))
@@ -0,0 +1,65 @@
1
+
2
+ from lsprotocol.types import InlayHint, InlayHintKind, Position, Range
3
+
4
+ from ..languages.registry import get_language_registry
5
+
6
+
7
+ def handle_inlay_hint(uri: str, content: str, range: Range) -> list[InlayHint]:
8
+ registry = get_language_registry()
9
+ language_info = registry.parse_file(uri, content)
10
+
11
+ if not language_info:
12
+ return []
13
+
14
+ hints: list[InlayHint] = []
15
+ lines = content.split("\n")
16
+
17
+ for layout in language_info.layouts:
18
+ if layout.line < range.start.line or layout.line > range.end.line:
19
+ continue
20
+
21
+ layout_line = lines[layout.line] if layout.line < len(lines) else ""
22
+
23
+ if "=" in layout_line:
24
+ equals_pos = layout_line.find("=")
25
+ hint_text = ": Layout"
26
+ if layout.shape:
27
+ hint_text = f": Layout[Shape{layout.shape}]"
28
+
29
+ hint_position = Position(
30
+ line=layout.line,
31
+ character=equals_pos + 1
32
+ )
33
+
34
+ hints.append(InlayHint(
35
+ position=hint_position,
36
+ label=hint_text,
37
+ kind=InlayHintKind.Type,
38
+ padding_left=True,
39
+ padding_right=False
40
+ ))
41
+
42
+ for kernel in language_info.kernels:
43
+ if kernel.line < range.start.line or kernel.line > range.end.line:
44
+ continue
45
+
46
+ kernel_line = lines[kernel.line] if kernel.line < len(lines) else ""
47
+
48
+ if "def " in kernel_line and "(" in kernel_line:
49
+ paren_pos = kernel_line.find("(")
50
+ hint_text = " -> Kernel"
51
+
52
+ hint_position = Position(
53
+ line=kernel.line,
54
+ character=paren_pos
55
+ )
56
+
57
+ hints.append(InlayHint(
58
+ position=hint_position,
59
+ label=hint_text,
60
+ kind=InlayHintKind.Type,
61
+ padding_left=True,
62
+ padding_right=True
63
+ ))
64
+
65
+ return hints
@@ -0,0 +1,124 @@
1
+
2
+ from lsprotocol.types import SemanticTokens, SemanticTokensLegend
3
+
4
+ from ..languages.registry import get_language_registry
5
+
6
+ TOKEN_TYPES = [
7
+ "kernel",
8
+ "layout",
9
+ "struct",
10
+ "decorator",
11
+ ]
12
+
13
+ TOKEN_MODIFIERS = [
14
+ "definition",
15
+ "declaration",
16
+ ]
17
+
18
+ SEMANTIC_TOKENS_LEGEND = SemanticTokensLegend(
19
+ token_types=TOKEN_TYPES,
20
+ token_modifiers=TOKEN_MODIFIERS
21
+ )
22
+
23
+
24
+ def handle_semantic_tokens(uri: str, content: str) -> SemanticTokens:
25
+ registry = get_language_registry()
26
+ language_info = registry.parse_file(uri, content)
27
+
28
+ if not language_info:
29
+ return SemanticTokens(data=[])
30
+
31
+ tokens: list[int] = []
32
+ lines = content.split("\n")
33
+ prev_line = 0
34
+ prev_char = 0
35
+
36
+ for kernel in language_info.kernels:
37
+ if kernel.line >= len(lines):
38
+ continue
39
+
40
+ kernel_line = lines[kernel.line]
41
+ name_start = kernel_line.find(kernel.name)
42
+
43
+ if name_start >= 0:
44
+ delta_line = kernel.line - prev_line
45
+ delta_char = name_start - (prev_char if delta_line == 0 else 0)
46
+
47
+ tokens.extend([
48
+ delta_line,
49
+ delta_char,
50
+ len(kernel.name),
51
+ TOKEN_TYPES.index("kernel"),
52
+ 0
53
+ ])
54
+
55
+ prev_line = kernel.line
56
+ prev_char = name_start + len(kernel.name)
57
+
58
+ for layout in language_info.layouts:
59
+ if layout.line >= len(lines):
60
+ continue
61
+
62
+ layout_line = lines[layout.line]
63
+ name_start = layout_line.find(layout.name)
64
+
65
+ if name_start >= 0:
66
+ delta_line = layout.line - prev_line
67
+ delta_char = name_start - (prev_char if delta_line == 0 else 0)
68
+
69
+ tokens.extend([
70
+ delta_line,
71
+ delta_char,
72
+ len(layout.name),
73
+ TOKEN_TYPES.index("layout"),
74
+ 0
75
+ ])
76
+
77
+ prev_line = layout.line
78
+ prev_char = name_start + len(layout.name)
79
+
80
+ for struct in language_info.structs:
81
+ if struct.line >= len(lines):
82
+ continue
83
+
84
+ struct_line = lines[struct.line]
85
+ name_start = struct_line.find(struct.name)
86
+
87
+ if name_start >= 0:
88
+ delta_line = struct.line - prev_line
89
+ delta_char = name_start - (prev_char if delta_line == 0 else 0)
90
+
91
+ tokens.extend([
92
+ delta_line,
93
+ delta_char,
94
+ len(struct.name),
95
+ TOKEN_TYPES.index("struct"),
96
+ 0
97
+ ])
98
+
99
+ prev_line = struct.line
100
+ prev_char = name_start + len(struct.name)
101
+
102
+ for i, line in enumerate(lines):
103
+ if "@cute.kernel" in line or "@cute.struct" in line:
104
+ decorator_start = line.find("@")
105
+ if decorator_start >= 0:
106
+ decorator_end = line.find(" ", decorator_start)
107
+ if decorator_end == -1:
108
+ decorator_end = len(line)
109
+
110
+ delta_line = i - prev_line
111
+ delta_char = decorator_start - (prev_char if delta_line == 0 else 0)
112
+
113
+ tokens.extend([
114
+ delta_line,
115
+ delta_char,
116
+ decorator_end - decorator_start,
117
+ TOKEN_TYPES.index("decorator"),
118
+ 0
119
+ ])
120
+
121
+ prev_line = i
122
+ prev_char = decorator_end
123
+
124
+ return SemanticTokens(data=tokens)