wafer-lsp 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. wafer_lsp/__init__.py +1 -0
  2. wafer_lsp/__main__.py +9 -0
  3. wafer_lsp/analyzers/__init__.py +0 -0
  4. wafer_lsp/analyzers/compiler_integration.py +16 -0
  5. wafer_lsp/analyzers/docs_index.py +36 -0
  6. wafer_lsp/handlers/__init__.py +30 -0
  7. wafer_lsp/handlers/code_action.py +48 -0
  8. wafer_lsp/handlers/code_lens.py +48 -0
  9. wafer_lsp/handlers/completion.py +6 -0
  10. wafer_lsp/handlers/diagnostics.py +41 -0
  11. wafer_lsp/handlers/document_symbol.py +176 -0
  12. wafer_lsp/handlers/hip_diagnostics.py +303 -0
  13. wafer_lsp/handlers/hover.py +251 -0
  14. wafer_lsp/handlers/inlay_hint.py +245 -0
  15. wafer_lsp/handlers/semantic_tokens.py +224 -0
  16. wafer_lsp/handlers/workspace_symbol.py +87 -0
  17. wafer_lsp/languages/README.md +195 -0
  18. wafer_lsp/languages/__init__.py +17 -0
  19. wafer_lsp/languages/converter.py +88 -0
  20. wafer_lsp/languages/detector.py +107 -0
  21. wafer_lsp/languages/parser_manager.py +33 -0
  22. wafer_lsp/languages/registry.py +120 -0
  23. wafer_lsp/languages/types.py +37 -0
  24. wafer_lsp/parsers/__init__.py +36 -0
  25. wafer_lsp/parsers/base_parser.py +9 -0
  26. wafer_lsp/parsers/cuda_parser.py +95 -0
  27. wafer_lsp/parsers/cutedsl_parser.py +114 -0
  28. wafer_lsp/parsers/hip_parser.py +688 -0
  29. wafer_lsp/server.py +58 -0
  30. wafer_lsp/services/__init__.py +38 -0
  31. wafer_lsp/services/analysis_service.py +22 -0
  32. wafer_lsp/services/docs_service.py +40 -0
  33. wafer_lsp/services/document_service.py +20 -0
  34. wafer_lsp/services/hip_docs.py +806 -0
  35. wafer_lsp/services/hip_hover_service.py +412 -0
  36. wafer_lsp/services/hover_service.py +237 -0
  37. wafer_lsp/services/language_registry_service.py +26 -0
  38. wafer_lsp/services/position_service.py +77 -0
  39. wafer_lsp/utils/__init__.py +0 -0
  40. wafer_lsp/utils/lsp_helpers.py +79 -0
  41. wafer_lsp-0.1.13.dist-info/METADATA +60 -0
  42. wafer_lsp-0.1.13.dist-info/RECORD +44 -0
  43. wafer_lsp-0.1.13.dist-info/WHEEL +4 -0
  44. wafer_lsp-0.1.13.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,95 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from .base_parser import BaseParser
6
+
7
+
8
+ @dataclass
9
+ class CUDAKernel:
10
+ name: str
11
+ line: int
12
+ parameters: list[str]
13
+
14
+
15
+ class CUDAParser(BaseParser):
16
+
17
+ def parse_file(self, content: str) -> dict[str, Any]:
18
+ kernels: list[CUDAKernel] = []
19
+
20
+ pattern = r'__global__\s+(?:__device__\s+)?(?:void|.*?)\s+(\w+)\s*\('
21
+
22
+ for match in re.finditer(pattern, content):
23
+ line = content[:match.start()].count('\n')
24
+ kernel_name = match.group(1)
25
+
26
+ params = self._extract_parameters(content, match.end())
27
+
28
+ kernels.append(CUDAKernel(
29
+ name=kernel_name,
30
+ line=line,
31
+ parameters=params
32
+ ))
33
+
34
+ return {"kernels": kernels}
35
+
36
+ def _extract_parameters(self, content: str, start: int) -> list[str]:
37
+ if start >= len(content):
38
+ return []
39
+
40
+ depth = 0
41
+ param_start = start
42
+ param_end = start
43
+
44
+ for i in range(start, len(content)):
45
+ char = content[i]
46
+ if char == '(':
47
+ depth += 1
48
+ elif char == ')':
49
+ depth -= 1
50
+ if depth == 0:
51
+ param_end = i
52
+ break
53
+
54
+ if param_end == start:
55
+ return []
56
+
57
+ param_str = content[param_start:param_end + 1]
58
+
59
+ params: list[str] = []
60
+ current_param = ""
61
+ template_depth = 0
62
+ paren_depth = 0
63
+
64
+ for char in param_str[1:-1]:
65
+ if char == '<':
66
+ template_depth += 1
67
+ current_param += char
68
+ elif char == '>':
69
+ template_depth -= 1
70
+ current_param += char
71
+ elif char == '(':
72
+ paren_depth += 1
73
+ current_param += char
74
+ elif char == ')':
75
+ paren_depth -= 1
76
+ current_param += char
77
+ elif char == ',' and template_depth == 0 and paren_depth == 0:
78
+ param_name = current_param.strip()
79
+ if param_name:
80
+ parts = param_name.split()
81
+ if parts:
82
+ name = parts[-1].strip('*&')
83
+ params.append(name)
84
+ current_param = ""
85
+ else:
86
+ current_param += char
87
+
88
+ if current_param.strip():
89
+ param_name = current_param.strip()
90
+ parts = param_name.split()
91
+ if parts:
92
+ name = parts[-1].strip('*&')
93
+ params.append(name)
94
+
95
+ return params
@@ -0,0 +1,114 @@
1
+ import ast
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from .base_parser import BaseParser
6
+
7
+
8
+ @dataclass
9
+ class CuTeDSLKernel:
10
+ name: str
11
+ line: int
12
+ parameters: list[str]
13
+ docstring: str | None = None
14
+
15
+
16
+ @dataclass
17
+ class CuTeDSLLayout:
18
+ name: str
19
+ line: int
20
+ shape: str | None = None
21
+ stride: str | None = None
22
+
23
+
24
+ @dataclass
25
+ class CuTeDSLStruct:
26
+ name: str
27
+ line: int
28
+ docstring: str | None = None
29
+
30
+
31
+ class CuTeDSLParser(BaseParser):
32
+
33
+ def parse_file(self, content: str) -> dict[str, Any]:
34
+ try:
35
+ tree = ast.parse(content)
36
+ except SyntaxError:
37
+ return {"kernels": [], "layouts": [], "structs": []}
38
+
39
+ kernels: list[CuTeDSLKernel] = []
40
+ layouts: list[CuTeDSLLayout] = []
41
+ structs: list[CuTeDSLStruct] = []
42
+
43
+ for node in ast.walk(tree):
44
+ if isinstance(node, ast.FunctionDef):
45
+ if self._has_decorator(node, "cute.kernel"):
46
+ kernels.append(self._extract_kernel(node))
47
+ elif isinstance(node, ast.ClassDef):
48
+ if self._has_decorator(node, "cute.struct"):
49
+ structs.append(self._extract_struct(node))
50
+ elif isinstance(node, ast.Assign):
51
+ layout = self._extract_layout(node, content)
52
+ if layout:
53
+ layouts.append(layout)
54
+
55
+ return {"kernels": kernels, "layouts": layouts, "structs": structs}
56
+
57
+ def _has_decorator(self, node: ast.FunctionDef | ast.ClassDef, decorator: str) -> bool:
58
+ for dec in node.decorator_list:
59
+ if isinstance(dec, ast.Attribute):
60
+ if isinstance(dec.value, ast.Name) and dec.value.id == "cute":
61
+ if dec.attr == decorator.split(".")[-1]:
62
+ return True
63
+ elif isinstance(dec, ast.Name):
64
+ if dec.id == decorator.split(".")[-1]:
65
+ return True
66
+ return False
67
+
68
+ def _extract_kernel(self, node: ast.FunctionDef) -> CuTeDSLKernel:
69
+ parameters = [arg.arg for arg in node.args.args]
70
+ docstring = ast.get_docstring(node)
71
+
72
+ return CuTeDSLKernel(
73
+ name=node.name,
74
+ line=node.lineno - 1,
75
+ parameters=parameters,
76
+ docstring=docstring,
77
+ )
78
+
79
+ def _extract_struct(self, node: ast.ClassDef) -> CuTeDSLStruct:
80
+ docstring = ast.get_docstring(node)
81
+
82
+ return CuTeDSLStruct(
83
+ name=node.name,
84
+ line=node.lineno - 1,
85
+ docstring=docstring,
86
+ )
87
+
88
+ def _extract_layout(self, node: ast.Assign, content: str) -> CuTeDSLLayout | None:
89
+ for target in node.targets:
90
+ if not isinstance(target, ast.Name):
91
+ continue
92
+
93
+ if isinstance(node.value, ast.Call):
94
+ call = node.value
95
+ if isinstance(call.func, ast.Attribute):
96
+ if isinstance(call.func.value, ast.Name):
97
+ if call.func.value.id == "cute" and call.func.attr == "make_layout":
98
+ shape_str = None
99
+ stride_str = None
100
+
101
+ if call.args:
102
+ try:
103
+ shape_str = ast.unparse(call.args[0])
104
+ except AttributeError:
105
+ shape_str = str(call.args[0])
106
+
107
+ return CuTeDSLLayout(
108
+ name=target.id,
109
+ line=node.lineno - 1,
110
+ shape=shape_str,
111
+ stride=stride_str,
112
+ )
113
+
114
+ return None