wafer-lsp 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/__init__.py +1 -0
- wafer_lsp/__main__.py +9 -0
- wafer_lsp/analyzers/__init__.py +0 -0
- wafer_lsp/analyzers/compiler_integration.py +16 -0
- wafer_lsp/analyzers/docs_index.py +36 -0
- wafer_lsp/handlers/__init__.py +30 -0
- wafer_lsp/handlers/code_action.py +48 -0
- wafer_lsp/handlers/code_lens.py +48 -0
- wafer_lsp/handlers/completion.py +6 -0
- wafer_lsp/handlers/diagnostics.py +41 -0
- wafer_lsp/handlers/document_symbol.py +176 -0
- wafer_lsp/handlers/hip_diagnostics.py +303 -0
- wafer_lsp/handlers/hover.py +251 -0
- wafer_lsp/handlers/inlay_hint.py +245 -0
- wafer_lsp/handlers/semantic_tokens.py +224 -0
- wafer_lsp/handlers/workspace_symbol.py +87 -0
- wafer_lsp/languages/README.md +195 -0
- wafer_lsp/languages/__init__.py +17 -0
- wafer_lsp/languages/converter.py +88 -0
- wafer_lsp/languages/detector.py +107 -0
- wafer_lsp/languages/parser_manager.py +33 -0
- wafer_lsp/languages/registry.py +120 -0
- wafer_lsp/languages/types.py +37 -0
- wafer_lsp/parsers/__init__.py +36 -0
- wafer_lsp/parsers/base_parser.py +9 -0
- wafer_lsp/parsers/cuda_parser.py +95 -0
- wafer_lsp/parsers/cutedsl_parser.py +114 -0
- wafer_lsp/parsers/hip_parser.py +688 -0
- wafer_lsp/server.py +58 -0
- wafer_lsp/services/__init__.py +38 -0
- wafer_lsp/services/analysis_service.py +22 -0
- wafer_lsp/services/docs_service.py +40 -0
- wafer_lsp/services/document_service.py +20 -0
- wafer_lsp/services/hip_docs.py +806 -0
- wafer_lsp/services/hip_hover_service.py +412 -0
- wafer_lsp/services/hover_service.py +237 -0
- wafer_lsp/services/language_registry_service.py +26 -0
- wafer_lsp/services/position_service.py +77 -0
- wafer_lsp/utils/__init__.py +0 -0
- wafer_lsp/utils/lsp_helpers.py +79 -0
- wafer_lsp-0.1.13.dist-info/METADATA +60 -0
- wafer_lsp-0.1.13.dist-info/RECORD +44 -0
- wafer_lsp-0.1.13.dist-info/WHEEL +4 -0
- wafer_lsp-0.1.13.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .base_parser import BaseParser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class CUDAKernel:
|
|
10
|
+
name: str
|
|
11
|
+
line: int
|
|
12
|
+
parameters: list[str]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CUDAParser(BaseParser):
|
|
16
|
+
|
|
17
|
+
def parse_file(self, content: str) -> dict[str, Any]:
|
|
18
|
+
kernels: list[CUDAKernel] = []
|
|
19
|
+
|
|
20
|
+
pattern = r'__global__\s+(?:__device__\s+)?(?:void|.*?)\s+(\w+)\s*\('
|
|
21
|
+
|
|
22
|
+
for match in re.finditer(pattern, content):
|
|
23
|
+
line = content[:match.start()].count('\n')
|
|
24
|
+
kernel_name = match.group(1)
|
|
25
|
+
|
|
26
|
+
params = self._extract_parameters(content, match.end())
|
|
27
|
+
|
|
28
|
+
kernels.append(CUDAKernel(
|
|
29
|
+
name=kernel_name,
|
|
30
|
+
line=line,
|
|
31
|
+
parameters=params
|
|
32
|
+
))
|
|
33
|
+
|
|
34
|
+
return {"kernels": kernels}
|
|
35
|
+
|
|
36
|
+
def _extract_parameters(self, content: str, start: int) -> list[str]:
|
|
37
|
+
if start >= len(content):
|
|
38
|
+
return []
|
|
39
|
+
|
|
40
|
+
depth = 0
|
|
41
|
+
param_start = start
|
|
42
|
+
param_end = start
|
|
43
|
+
|
|
44
|
+
for i in range(start, len(content)):
|
|
45
|
+
char = content[i]
|
|
46
|
+
if char == '(':
|
|
47
|
+
depth += 1
|
|
48
|
+
elif char == ')':
|
|
49
|
+
depth -= 1
|
|
50
|
+
if depth == 0:
|
|
51
|
+
param_end = i
|
|
52
|
+
break
|
|
53
|
+
|
|
54
|
+
if param_end == start:
|
|
55
|
+
return []
|
|
56
|
+
|
|
57
|
+
param_str = content[param_start:param_end + 1]
|
|
58
|
+
|
|
59
|
+
params: list[str] = []
|
|
60
|
+
current_param = ""
|
|
61
|
+
template_depth = 0
|
|
62
|
+
paren_depth = 0
|
|
63
|
+
|
|
64
|
+
for char in param_str[1:-1]:
|
|
65
|
+
if char == '<':
|
|
66
|
+
template_depth += 1
|
|
67
|
+
current_param += char
|
|
68
|
+
elif char == '>':
|
|
69
|
+
template_depth -= 1
|
|
70
|
+
current_param += char
|
|
71
|
+
elif char == '(':
|
|
72
|
+
paren_depth += 1
|
|
73
|
+
current_param += char
|
|
74
|
+
elif char == ')':
|
|
75
|
+
paren_depth -= 1
|
|
76
|
+
current_param += char
|
|
77
|
+
elif char == ',' and template_depth == 0 and paren_depth == 0:
|
|
78
|
+
param_name = current_param.strip()
|
|
79
|
+
if param_name:
|
|
80
|
+
parts = param_name.split()
|
|
81
|
+
if parts:
|
|
82
|
+
name = parts[-1].strip('*&')
|
|
83
|
+
params.append(name)
|
|
84
|
+
current_param = ""
|
|
85
|
+
else:
|
|
86
|
+
current_param += char
|
|
87
|
+
|
|
88
|
+
if current_param.strip():
|
|
89
|
+
param_name = current_param.strip()
|
|
90
|
+
parts = param_name.split()
|
|
91
|
+
if parts:
|
|
92
|
+
name = parts[-1].strip('*&')
|
|
93
|
+
params.append(name)
|
|
94
|
+
|
|
95
|
+
return params
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .base_parser import BaseParser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class CuTeDSLKernel:
|
|
10
|
+
name: str
|
|
11
|
+
line: int
|
|
12
|
+
parameters: list[str]
|
|
13
|
+
docstring: str | None = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class CuTeDSLLayout:
|
|
18
|
+
name: str
|
|
19
|
+
line: int
|
|
20
|
+
shape: str | None = None
|
|
21
|
+
stride: str | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class CuTeDSLStruct:
|
|
26
|
+
name: str
|
|
27
|
+
line: int
|
|
28
|
+
docstring: str | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CuTeDSLParser(BaseParser):
|
|
32
|
+
|
|
33
|
+
def parse_file(self, content: str) -> dict[str, Any]:
|
|
34
|
+
try:
|
|
35
|
+
tree = ast.parse(content)
|
|
36
|
+
except SyntaxError:
|
|
37
|
+
return {"kernels": [], "layouts": [], "structs": []}
|
|
38
|
+
|
|
39
|
+
kernels: list[CuTeDSLKernel] = []
|
|
40
|
+
layouts: list[CuTeDSLLayout] = []
|
|
41
|
+
structs: list[CuTeDSLStruct] = []
|
|
42
|
+
|
|
43
|
+
for node in ast.walk(tree):
|
|
44
|
+
if isinstance(node, ast.FunctionDef):
|
|
45
|
+
if self._has_decorator(node, "cute.kernel"):
|
|
46
|
+
kernels.append(self._extract_kernel(node))
|
|
47
|
+
elif isinstance(node, ast.ClassDef):
|
|
48
|
+
if self._has_decorator(node, "cute.struct"):
|
|
49
|
+
structs.append(self._extract_struct(node))
|
|
50
|
+
elif isinstance(node, ast.Assign):
|
|
51
|
+
layout = self._extract_layout(node, content)
|
|
52
|
+
if layout:
|
|
53
|
+
layouts.append(layout)
|
|
54
|
+
|
|
55
|
+
return {"kernels": kernels, "layouts": layouts, "structs": structs}
|
|
56
|
+
|
|
57
|
+
def _has_decorator(self, node: ast.FunctionDef | ast.ClassDef, decorator: str) -> bool:
|
|
58
|
+
for dec in node.decorator_list:
|
|
59
|
+
if isinstance(dec, ast.Attribute):
|
|
60
|
+
if isinstance(dec.value, ast.Name) and dec.value.id == "cute":
|
|
61
|
+
if dec.attr == decorator.split(".")[-1]:
|
|
62
|
+
return True
|
|
63
|
+
elif isinstance(dec, ast.Name):
|
|
64
|
+
if dec.id == decorator.split(".")[-1]:
|
|
65
|
+
return True
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
def _extract_kernel(self, node: ast.FunctionDef) -> CuTeDSLKernel:
|
|
69
|
+
parameters = [arg.arg for arg in node.args.args]
|
|
70
|
+
docstring = ast.get_docstring(node)
|
|
71
|
+
|
|
72
|
+
return CuTeDSLKernel(
|
|
73
|
+
name=node.name,
|
|
74
|
+
line=node.lineno - 1,
|
|
75
|
+
parameters=parameters,
|
|
76
|
+
docstring=docstring,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _extract_struct(self, node: ast.ClassDef) -> CuTeDSLStruct:
|
|
80
|
+
docstring = ast.get_docstring(node)
|
|
81
|
+
|
|
82
|
+
return CuTeDSLStruct(
|
|
83
|
+
name=node.name,
|
|
84
|
+
line=node.lineno - 1,
|
|
85
|
+
docstring=docstring,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _extract_layout(self, node: ast.Assign, content: str) -> CuTeDSLLayout | None:
|
|
89
|
+
for target in node.targets:
|
|
90
|
+
if not isinstance(target, ast.Name):
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
if isinstance(node.value, ast.Call):
|
|
94
|
+
call = node.value
|
|
95
|
+
if isinstance(call.func, ast.Attribute):
|
|
96
|
+
if isinstance(call.func.value, ast.Name):
|
|
97
|
+
if call.func.value.id == "cute" and call.func.attr == "make_layout":
|
|
98
|
+
shape_str = None
|
|
99
|
+
stride_str = None
|
|
100
|
+
|
|
101
|
+
if call.args:
|
|
102
|
+
try:
|
|
103
|
+
shape_str = ast.unparse(call.args[0])
|
|
104
|
+
except AttributeError:
|
|
105
|
+
shape_str = str(call.args[0])
|
|
106
|
+
|
|
107
|
+
return CuTeDSLLayout(
|
|
108
|
+
name=target.id,
|
|
109
|
+
line=node.lineno - 1,
|
|
110
|
+
shape=shape_str,
|
|
111
|
+
stride=stride_str,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return None
|