wafer-lsp 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. wafer_lsp/__init__.py +1 -0
  2. wafer_lsp/__main__.py +9 -0
  3. wafer_lsp/analyzers/__init__.py +0 -0
  4. wafer_lsp/analyzers/compiler_integration.py +16 -0
  5. wafer_lsp/analyzers/docs_index.py +36 -0
  6. wafer_lsp/handlers/__init__.py +30 -0
  7. wafer_lsp/handlers/code_action.py +48 -0
  8. wafer_lsp/handlers/code_lens.py +48 -0
  9. wafer_lsp/handlers/completion.py +6 -0
  10. wafer_lsp/handlers/diagnostics.py +41 -0
  11. wafer_lsp/handlers/document_symbol.py +176 -0
  12. wafer_lsp/handlers/hip_diagnostics.py +303 -0
  13. wafer_lsp/handlers/hover.py +251 -0
  14. wafer_lsp/handlers/inlay_hint.py +245 -0
  15. wafer_lsp/handlers/semantic_tokens.py +224 -0
  16. wafer_lsp/handlers/workspace_symbol.py +87 -0
  17. wafer_lsp/languages/README.md +195 -0
  18. wafer_lsp/languages/__init__.py +17 -0
  19. wafer_lsp/languages/converter.py +88 -0
  20. wafer_lsp/languages/detector.py +107 -0
  21. wafer_lsp/languages/parser_manager.py +33 -0
  22. wafer_lsp/languages/registry.py +120 -0
  23. wafer_lsp/languages/types.py +37 -0
  24. wafer_lsp/parsers/__init__.py +36 -0
  25. wafer_lsp/parsers/base_parser.py +9 -0
  26. wafer_lsp/parsers/cuda_parser.py +95 -0
  27. wafer_lsp/parsers/cutedsl_parser.py +114 -0
  28. wafer_lsp/parsers/hip_parser.py +688 -0
  29. wafer_lsp/server.py +58 -0
  30. wafer_lsp/services/__init__.py +38 -0
  31. wafer_lsp/services/analysis_service.py +22 -0
  32. wafer_lsp/services/docs_service.py +40 -0
  33. wafer_lsp/services/document_service.py +20 -0
  34. wafer_lsp/services/hip_docs.py +806 -0
  35. wafer_lsp/services/hip_hover_service.py +412 -0
  36. wafer_lsp/services/hover_service.py +237 -0
  37. wafer_lsp/services/language_registry_service.py +26 -0
  38. wafer_lsp/services/position_service.py +77 -0
  39. wafer_lsp/utils/__init__.py +0 -0
  40. wafer_lsp/utils/lsp_helpers.py +79 -0
  41. wafer_lsp-0.1.13.dist-info/METADATA +60 -0
  42. wafer_lsp-0.1.13.dist-info/RECORD +44 -0
  43. wafer_lsp-0.1.13.dist-info/WHEEL +4 -0
  44. wafer_lsp-0.1.13.dist-info/entry_points.txt +2 -0
wafer_lsp/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
wafer_lsp/__main__.py ADDED
@@ -0,0 +1,9 @@
1
+ from wafer_lsp.server import server
2
+
3
+
4
+ def main():
5
+ server.start_io()
6
+
7
+
8
+ if __name__ == "__main__":
9
+ main()
File without changes
@@ -0,0 +1,16 @@
1
+ from typing import Any
2
+
3
+ _analysis_cache: dict[str, dict[str, Any]] = {}
4
+
5
+
6
+ def get_analysis_for_kernel(uri: str, kernel_name: str) -> dict[str, Any] | None:
7
+ cache_key = f"{uri}:{kernel_name}"
8
+
9
+ if cache_key in _analysis_cache:
10
+ return _analysis_cache[cache_key]
11
+
12
+ return None
13
+
14
+
15
+ def clear_cache():
16
+ _analysis_cache.clear()
@@ -0,0 +1,36 @@
1
+ from pathlib import Path
2
+
3
+
4
+ class DocsIndex:
5
+
6
+ def __init__(self, docs_path: str | None = None):
7
+ if docs_path:
8
+ self.docs_path = Path(docs_path)
9
+ else:
10
+ self.docs_path = Path(__file__).parent.parent.parent.parent.parent / \
11
+ "curriculum" / "cutlass-docs" / "cutedsl-docs"
12
+
13
+ self.index = self._build_index()
14
+
15
+ def _build_index(self) -> dict[str, list[str]]:
16
+ return {
17
+ "layout": [
18
+ "intro-to-cutedsl.md",
19
+ "partitioning-strategies-inner-outer-threadvalue.md"
20
+ ],
21
+ "TMA": ["blackwell-tutorial-fp16-gemm-0.md"],
22
+ "TMEM": ["colfax-blackwell-umma-tensor-memory-part1.md"],
23
+ "kernel": ["blackwell-tutorial-fp16-gemm-0.md"],
24
+ "struct": ["blackwell-tutorial-fp16-gemm-0.md"],
25
+ "pipeline": ["blackwell-tutorial-fp16-gemm-0.md"],
26
+ "MMA": ["mma-atoms-fundamentals-sm70-example.md"],
27
+ }
28
+
29
+ def get_doc_for_concept(self, concept: str) -> str | None:
30
+ concept_lower = concept.lower()
31
+ if self.index.get(concept_lower):
32
+ doc_file = self.index[concept_lower][0]
33
+ doc_path = self.docs_path / doc_file
34
+ if doc_path.exists():
35
+ return str(doc_path)
36
+ return None
@@ -0,0 +1,30 @@
1
+ from .code_action import handle_code_action
2
+ from .code_lens import handle_code_lens
3
+ from .completion import handle_completion
4
+ from .diagnostics import handle_diagnostics
5
+ from .document_symbol import handle_document_symbol
6
+ from .hip_diagnostics import (
7
+ HIPDiagnosticsProvider,
8
+ create_hip_diagnostics_provider,
9
+ get_hip_diagnostics,
10
+ )
11
+ from .hover import handle_hover
12
+ from .inlay_hint import handle_inlay_hint
13
+ from .semantic_tokens import handle_semantic_tokens, SEMANTIC_TOKENS_LEGEND
14
+ from .workspace_symbol import handle_workspace_symbol
15
+
16
+ __all__ = [
17
+ "handle_code_action",
18
+ "handle_code_lens",
19
+ "handle_completion",
20
+ "handle_diagnostics",
21
+ "handle_document_symbol",
22
+ "handle_hover",
23
+ "handle_inlay_hint",
24
+ "handle_semantic_tokens",
25
+ "handle_workspace_symbol",
26
+ "SEMANTIC_TOKENS_LEGEND",
27
+ "HIPDiagnosticsProvider",
28
+ "create_hip_diagnostics_provider",
29
+ "get_hip_diagnostics",
30
+ ]
@@ -0,0 +1,48 @@
1
+
2
+ from lsprotocol.types import CodeAction, CodeActionKind, Command, Range
3
+
4
+ from ..languages.registry import get_language_registry
5
+ from ..languages.types import KernelInfo
6
+
7
+
8
+ def find_kernel_at_range(content: str, range: Range, uri: str) -> KernelInfo | None:
9
+ registry = get_language_registry()
10
+ language_info = registry.parse_file(uri, content)
11
+
12
+ if not language_info:
13
+ return None
14
+
15
+ for kernel in language_info.kernels:
16
+ if kernel.line <= range.start.line <= kernel.line + 50:
17
+ return kernel
18
+
19
+ return None
20
+
21
+
22
+ def handle_code_action(uri: str, range: Range, content: str) -> list[CodeAction]:
23
+ kernel = find_kernel_at_range(content, range, uri)
24
+ if not kernel:
25
+ return []
26
+
27
+ actions: list[CodeAction] = [
28
+ CodeAction(
29
+ title=f"Analyze Kernel: {kernel.name}",
30
+ kind=CodeActionKind.Source,
31
+ command=Command(
32
+ title=f"Analyze Kernel: {kernel.name}",
33
+ command="wafer.analyzeKernel",
34
+ arguments=[uri, kernel.name]
35
+ )
36
+ ),
37
+ CodeAction(
38
+ title=f"Profile Kernel: {kernel.name}",
39
+ kind=CodeActionKind.Source,
40
+ command=Command(
41
+ title=f"Profile Kernel: {kernel.name}",
42
+ command="wafer.profileKernel",
43
+ arguments=[uri, kernel.name]
44
+ )
45
+ ),
46
+ ]
47
+
48
+ return actions
@@ -0,0 +1,48 @@
1
+
2
+ from lsprotocol.types import CodeLens, Command, Position, Range
3
+
4
+ from ..languages.registry import get_language_registry
5
+
6
+
7
+ def handle_code_lens(uri: str, content: str) -> list[CodeLens]:
8
+ registry = get_language_registry()
9
+ language_info = registry.parse_file(uri, content)
10
+
11
+ if not language_info:
12
+ return []
13
+
14
+ lenses: list[CodeLens] = []
15
+
16
+ for kernel in language_info.kernels:
17
+ lens_range = Range(
18
+ start=Position(line=kernel.line, character=0),
19
+ end=Position(line=kernel.line, character=0)
20
+ )
21
+
22
+ analyze_command = Command(
23
+ title=f"Analyze {kernel.name}",
24
+ command="wafer.analyzeKernel",
25
+ arguments=[uri, kernel.name]
26
+ )
27
+
28
+ profile_command = Command(
29
+ title=f"Profile {kernel.name}",
30
+ command="wafer.profileKernel",
31
+ arguments=[uri, kernel.name]
32
+ )
33
+
34
+ lenses.append(CodeLens(
35
+ range=lens_range,
36
+ command=analyze_command
37
+ ))
38
+
39
+ profile_range = Range(
40
+ start=Position(line=kernel.line, character=20),
41
+ end=Position(line=kernel.line, character=20)
42
+ )
43
+ lenses.append(CodeLens(
44
+ range=profile_range,
45
+ command=profile_command
46
+ ))
47
+
48
+ return lenses
@@ -0,0 +1,6 @@
1
+
2
+ from lsprotocol.types import CompletionItem
3
+
4
+
5
+ def handle_completion(uri: str, position, content: str) -> list[CompletionItem]:
6
+ return []
@@ -0,0 +1,41 @@
1
+
2
+ from lsprotocol.types import Diagnostic
3
+
4
+ from ..languages.registry import get_language_registry
5
+ from .hip_diagnostics import get_hip_diagnostics
6
+
7
+
8
+ def handle_diagnostics(
9
+ uri: str,
10
+ content: str,
11
+ enable_wavefront_diagnostics: bool = True,
12
+ ) -> list[Diagnostic]:
13
+ """Handle diagnostics for a document.
14
+
15
+ Args:
16
+ uri: Document URI
17
+ content: Document content
18
+ enable_wavefront_diagnostics: Whether to enable HIP wavefront warnings
19
+
20
+ Returns:
21
+ List of diagnostics
22
+ """
23
+ diagnostics: list[Diagnostic] = []
24
+
25
+ registry = get_language_registry()
26
+ language_info = registry.parse_file(uri, content)
27
+
28
+ if not language_info:
29
+ return diagnostics
30
+
31
+ # Add HIP-specific diagnostics for HIP, CUDA, and C++ files
32
+ # (C++ files might contain HIP code if they have HIP markers)
33
+ if language_info.language in ("hip", "cuda", "cpp"):
34
+ hip_diagnostics = get_hip_diagnostics(
35
+ content,
36
+ uri,
37
+ enable_wavefront_diagnostics=enable_wavefront_diagnostics,
38
+ )
39
+ diagnostics.extend(hip_diagnostics)
40
+
41
+ return diagnostics
@@ -0,0 +1,176 @@
1
+
2
+ from lsprotocol.types import DocumentSymbol, Position, Range, SymbolKind
3
+
4
+ from ..languages.registry import get_language_registry
5
+
6
+
7
+ def handle_document_symbol(uri: str, content: str) -> list[DocumentSymbol]:
8
+ registry = get_language_registry()
9
+ language_info = registry.parse_file(uri, content)
10
+
11
+ if not language_info:
12
+ return []
13
+
14
+ symbols: list[DocumentSymbol] = []
15
+ lines = content.split("\n")
16
+
17
+ # Kernels
18
+ for kernel in language_info.kernels:
19
+ kernel_line = lines[kernel.line] if kernel.line < len(lines) else ""
20
+ name_start = kernel_line.find(kernel.name)
21
+ name_end = name_start + len(kernel.name) if name_start >= 0 else 0
22
+
23
+ selection_range = Range(
24
+ start=Position(line=kernel.line, character=max(0, name_start)),
25
+ end=Position(line=kernel.line, character=name_end)
26
+ )
27
+ full_range = Range(
28
+ start=Position(line=kernel.line, character=0),
29
+ end=Position(line=min(kernel.line + 10, len(lines) - 1), character=0)
30
+ )
31
+
32
+ # Different detail based on language
33
+ if kernel.language == "hip":
34
+ detail = "🚀 HIP Kernel (AMD GPU)"
35
+ elif kernel.language in ("cuda", "cpp"):
36
+ detail = "🚀 CUDA Kernel"
37
+ else:
38
+ detail = f"GPU Kernel ({registry.get_language_name(kernel.language)})"
39
+
40
+ symbols.append(DocumentSymbol(
41
+ name=kernel.name,
42
+ kind=SymbolKind.Function,
43
+ range=full_range,
44
+ selection_range=selection_range,
45
+ detail=detail,
46
+ ))
47
+
48
+ # Layouts (CuTeDSL)
49
+ for layout in language_info.layouts:
50
+ layout_line = lines[layout.line] if layout.line < len(lines) else ""
51
+ name_start = layout_line.find(layout.name)
52
+ name_end = name_start + len(layout.name) if name_start >= 0 else 0
53
+
54
+ detail = f"Layout: {layout.shape}" if layout.shape else "Layout"
55
+
56
+ selection_range = Range(
57
+ start=Position(line=layout.line, character=max(0, name_start)),
58
+ end=Position(line=layout.line, character=name_end)
59
+ )
60
+ full_range = Range(
61
+ start=Position(line=layout.line, character=0),
62
+ end=Position(line=layout.line, character=len(layout_line))
63
+ )
64
+
65
+ symbols.append(DocumentSymbol(
66
+ name=layout.name,
67
+ kind=SymbolKind.Variable,
68
+ range=full_range,
69
+ selection_range=selection_range,
70
+ detail=detail,
71
+ ))
72
+
73
+ # Structs
74
+ for struct in language_info.structs:
75
+ struct_line = lines[struct.line] if struct.line < len(lines) else ""
76
+ name_start = struct_line.find(struct.name)
77
+ name_end = name_start + len(struct.name) if name_start >= 0 else 0
78
+
79
+ selection_range = Range(
80
+ start=Position(line=struct.line, character=max(0, name_start)),
81
+ end=Position(line=struct.line, character=name_end)
82
+ )
83
+ full_range = Range(
84
+ start=Position(line=struct.line, character=0),
85
+ end=Position(line=min(struct.line + 10, len(lines) - 1), character=0)
86
+ )
87
+
88
+ symbols.append(DocumentSymbol(
89
+ name=struct.name,
90
+ kind=SymbolKind.Struct,
91
+ range=full_range,
92
+ selection_range=selection_range,
93
+ detail=f"Struct ({registry.get_language_name(struct.language)})",
94
+ ))
95
+
96
+ # HIP-specific: Device functions and shared memory
97
+ if language_info.language in ("hip", "cuda", "cpp"):
98
+ symbols.extend(_get_hip_symbols(language_info.raw_data, lines))
99
+
100
+ return symbols
101
+
102
+
103
+ def _get_hip_symbols(raw_data: dict, lines: list[str]) -> list[DocumentSymbol]:
104
+ """Extract HIP-specific symbols: device functions, shared memory allocations."""
105
+ symbols: list[DocumentSymbol] = []
106
+
107
+ # Device functions (from HIP parser)
108
+ device_functions = raw_data.get("device_functions", [])
109
+ for func in device_functions:
110
+ if not hasattr(func, "line") or not hasattr(func, "name"):
111
+ continue
112
+
113
+ func_line = lines[func.line] if func.line < len(lines) else ""
114
+ name_start = func_line.find(func.name)
115
+ name_end = name_start + len(func.name) if name_start >= 0 else 0
116
+
117
+ end_line = getattr(func, "end_line", func.line + 10)
118
+
119
+ selection_range = Range(
120
+ start=Position(line=func.line, character=max(0, name_start)),
121
+ end=Position(line=func.line, character=name_end)
122
+ )
123
+ full_range = Range(
124
+ start=Position(line=func.line, character=0),
125
+ end=Position(line=min(end_line, len(lines) - 1), character=0)
126
+ )
127
+
128
+ return_type = getattr(func, "return_type", "void")
129
+ detail = f"âš¡ Device Function -> {return_type}"
130
+
131
+ symbols.append(DocumentSymbol(
132
+ name=func.name,
133
+ kind=SymbolKind.Method,
134
+ range=full_range,
135
+ selection_range=selection_range,
136
+ detail=detail,
137
+ ))
138
+
139
+ # Shared memory allocations
140
+ shared_memory = raw_data.get("shared_memory", [])
141
+ for shared in shared_memory:
142
+ if not hasattr(shared, "line") or not hasattr(shared, "name"):
143
+ continue
144
+
145
+ shared_line = lines[shared.line] if shared.line < len(lines) else ""
146
+ name_start = shared_line.find(shared.name)
147
+ name_end = name_start + len(shared.name) if name_start >= 0 else 0
148
+
149
+ selection_range = Range(
150
+ start=Position(line=shared.line, character=max(0, name_start)),
151
+ end=Position(line=shared.line, character=name_end)
152
+ )
153
+ full_range = Range(
154
+ start=Position(line=shared.line, character=0),
155
+ end=Position(line=shared.line, character=len(shared_line))
156
+ )
157
+
158
+ type_str = getattr(shared, "type_str", "")
159
+ size_bytes = getattr(shared, "size_bytes", None)
160
+ if size_bytes:
161
+ if size_bytes >= 1024:
162
+ detail = f"📦 __shared__ {type_str} ({size_bytes / 1024:.1f} KB)"
163
+ else:
164
+ detail = f"📦 __shared__ {type_str} ({size_bytes} bytes)"
165
+ else:
166
+ detail = f"📦 __shared__ {type_str}"
167
+
168
+ symbols.append(DocumentSymbol(
169
+ name=shared.name,
170
+ kind=SymbolKind.Variable,
171
+ range=full_range,
172
+ selection_range=selection_range,
173
+ detail=detail,
174
+ ))
175
+
176
+ return symbols