wafer-lsp 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/__init__.py +1 -0
- wafer_lsp/__main__.py +9 -0
- wafer_lsp/analyzers/__init__.py +0 -0
- wafer_lsp/analyzers/compiler_integration.py +16 -0
- wafer_lsp/analyzers/docs_index.py +36 -0
- wafer_lsp/handlers/__init__.py +30 -0
- wafer_lsp/handlers/code_action.py +48 -0
- wafer_lsp/handlers/code_lens.py +48 -0
- wafer_lsp/handlers/completion.py +6 -0
- wafer_lsp/handlers/diagnostics.py +41 -0
- wafer_lsp/handlers/document_symbol.py +176 -0
- wafer_lsp/handlers/hip_diagnostics.py +303 -0
- wafer_lsp/handlers/hover.py +251 -0
- wafer_lsp/handlers/inlay_hint.py +245 -0
- wafer_lsp/handlers/semantic_tokens.py +224 -0
- wafer_lsp/handlers/workspace_symbol.py +87 -0
- wafer_lsp/languages/README.md +195 -0
- wafer_lsp/languages/__init__.py +17 -0
- wafer_lsp/languages/converter.py +88 -0
- wafer_lsp/languages/detector.py +107 -0
- wafer_lsp/languages/parser_manager.py +33 -0
- wafer_lsp/languages/registry.py +120 -0
- wafer_lsp/languages/types.py +37 -0
- wafer_lsp/parsers/__init__.py +36 -0
- wafer_lsp/parsers/base_parser.py +9 -0
- wafer_lsp/parsers/cuda_parser.py +95 -0
- wafer_lsp/parsers/cutedsl_parser.py +114 -0
- wafer_lsp/parsers/hip_parser.py +688 -0
- wafer_lsp/server.py +58 -0
- wafer_lsp/services/__init__.py +38 -0
- wafer_lsp/services/analysis_service.py +22 -0
- wafer_lsp/services/docs_service.py +40 -0
- wafer_lsp/services/document_service.py +20 -0
- wafer_lsp/services/hip_docs.py +806 -0
- wafer_lsp/services/hip_hover_service.py +412 -0
- wafer_lsp/services/hover_service.py +237 -0
- wafer_lsp/services/language_registry_service.py +26 -0
- wafer_lsp/services/position_service.py +77 -0
- wafer_lsp/utils/__init__.py +0 -0
- wafer_lsp/utils/lsp_helpers.py +79 -0
- wafer_lsp-0.1.13.dist-info/METADATA +60 -0
- wafer_lsp-0.1.13.dist-info/RECORD +44 -0
- wafer_lsp-0.1.13.dist-info/WHEEL +4 -0
- wafer_lsp-0.1.13.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Adding a New Language to Wafer LSP
|
|
2
|
+
|
|
3
|
+
This guide explains how to add support for a new GPU programming language to the LSP server.
|
|
4
|
+
|
|
5
|
+
## Architecture Overview
|
|
6
|
+
|
|
7
|
+
The LSP server uses a modular, component-based architecture:
|
|
8
|
+
|
|
9
|
+
- **`LanguageDetector`**: Detects language from file extensions/URIs
|
|
10
|
+
- **`ParserManager`**: Manages parser instances for each language
|
|
11
|
+
- **`ParserResultConverter`**: Converts parser-specific types to common types
|
|
12
|
+
- **`LanguageRegistry`**: Coordinates all components (facade pattern)
|
|
13
|
+
|
|
14
|
+
Each component is small, focused, and independently testable.
|
|
15
|
+
|
|
16
|
+
## Overview
|
|
17
|
+
|
|
18
|
+
To add a new language, you need to:
|
|
19
|
+
|
|
20
|
+
1. Create a parser that implements `BaseParser`
|
|
21
|
+
2. Register the parser with the language registry
|
|
22
|
+
3. (Optional) Add language-specific handlers
|
|
23
|
+
|
|
24
|
+
## Step 1: Create a Parser
|
|
25
|
+
|
|
26
|
+
Create a new parser file in `wafer_lsp/parsers/`:
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
# wafer_lsp/parsers/my_language_parser.py
|
|
30
|
+
|
|
31
|
+
from typing import List, Dict, Any, Optional
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
from .base_parser import BaseParser
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class MyLanguageKernel:
|
|
37
|
+
name: str
|
|
38
|
+
line: int
|
|
39
|
+
parameters: List[str]
|
|
40
|
+
docstring: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
class MyLanguageParser(BaseParser):
|
|
43
|
+
"""Parser for MyLanguage GPU code."""
|
|
44
|
+
|
|
45
|
+
def parse_file(self, content: str) -> Dict[str, Any]:
|
|
46
|
+
"""Parse file and extract kernels, layouts, etc."""
|
|
47
|
+
kernels: List[MyLanguageKernel] = []
|
|
48
|
+
|
|
49
|
+
# Your parsing logic here
|
|
50
|
+
# Example: find kernel definitions
|
|
51
|
+
for match in re.finditer(r'kernel\s+(\w+)', content):
|
|
52
|
+
kernels.append(MyLanguageKernel(
|
|
53
|
+
name=match.group(1),
|
|
54
|
+
line=content[:match.start()].count('\n'),
|
|
55
|
+
parameters=[],
|
|
56
|
+
))
|
|
57
|
+
|
|
58
|
+
return {
|
|
59
|
+
"kernels": kernels,
|
|
60
|
+
"layouts": [], # If your language has layouts
|
|
61
|
+
"structs": [], # If your language has structs
|
|
62
|
+
}
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
## Step 2: Register the Parser
|
|
66
|
+
|
|
67
|
+
Update `wafer_lsp/languages/registry.py` to register your parser:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
from ..parsers.my_language_parser import MyLanguageParser
|
|
71
|
+
|
|
72
|
+
class LanguageRegistry:
|
|
73
|
+
def _register_defaults(self):
|
|
74
|
+
# ... existing registrations ...
|
|
75
|
+
|
|
76
|
+
# Register your new language
|
|
77
|
+
self.register_language(
|
|
78
|
+
language_id="mylang",
|
|
79
|
+
display_name="My Language",
|
|
80
|
+
parser=MyLanguageParser(),
|
|
81
|
+
extensions=[".mylang", ".ml"],
|
|
82
|
+
file_patterns=["*.mylang", "*.ml"]
|
|
83
|
+
)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Step 3: Update VS Code Extension (if needed)
|
|
87
|
+
|
|
88
|
+
If your language needs special handling in the VS Code extension, update `lspClient.ts`:
|
|
89
|
+
|
|
90
|
+
```typescript
|
|
91
|
+
const clientOptions: LanguageClientOptions = {
|
|
92
|
+
documentSelector: [
|
|
93
|
+
// ... existing selectors ...
|
|
94
|
+
{ scheme: 'file', language: 'mylang' }, // Your language
|
|
95
|
+
],
|
|
96
|
+
// ...
|
|
97
|
+
};
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
## Parser Interface
|
|
101
|
+
|
|
102
|
+
Your parser must implement `BaseParser`:
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
class BaseParser(ABC):
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def parse_file(self, content: str) -> Dict[str, Any]:
|
|
108
|
+
"""Parse file content and extract language-specific constructs.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Dictionary with keys:
|
|
112
|
+
- "kernels": List of kernel objects (must have .name, .line, .parameters)
|
|
113
|
+
- "layouts": List of layout objects (must have .name, .line)
|
|
114
|
+
- "structs": List of struct objects (must have .name, .line)
|
|
115
|
+
"""
|
|
116
|
+
pass
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
## Example: Adding OpenCL Support
|
|
120
|
+
|
|
121
|
+
Here's a complete example for adding OpenCL support:
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
# wafer_lsp/parsers/opencl_parser.py
|
|
125
|
+
|
|
126
|
+
import re
|
|
127
|
+
from typing import List, Dict, Any
|
|
128
|
+
from dataclasses import dataclass
|
|
129
|
+
from .base_parser import BaseParser
|
|
130
|
+
|
|
131
|
+
@dataclass
|
|
132
|
+
class OpenCLKernel:
|
|
133
|
+
name: str
|
|
134
|
+
line: int
|
|
135
|
+
parameters: List[str]
|
|
136
|
+
|
|
137
|
+
class OpenCLParser(BaseParser):
|
|
138
|
+
def parse_file(self, content: str) -> Dict[str, Any]:
|
|
139
|
+
kernels: List[OpenCLKernel] = []
|
|
140
|
+
|
|
141
|
+
# Pattern: __kernel void kernel_name(...)
|
|
142
|
+
pattern = r'__kernel\s+(?:__global\s+)?(?:void|.*?)\s+(\w+)\s*\('
|
|
143
|
+
|
|
144
|
+
for match in re.finditer(pattern, content):
|
|
145
|
+
line = content[:match.start()].count('\n')
|
|
146
|
+
kernel_name = match.group(1)
|
|
147
|
+
params = self._extract_parameters(content, match.end())
|
|
148
|
+
|
|
149
|
+
kernels.append(OpenCLKernel(
|
|
150
|
+
name=kernel_name,
|
|
151
|
+
line=line,
|
|
152
|
+
parameters=params
|
|
153
|
+
))
|
|
154
|
+
|
|
155
|
+
return {"kernels": kernels, "layouts": [], "structs": []}
|
|
156
|
+
|
|
157
|
+
def _extract_parameters(self, content: str, start: int) -> List[str]:
|
|
158
|
+
# Extract parameter list logic
|
|
159
|
+
return []
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
Then register it:
|
|
163
|
+
|
|
164
|
+
```python
|
|
165
|
+
# In registry.py _register_defaults()
|
|
166
|
+
self.register_language(
|
|
167
|
+
language_id="opencl",
|
|
168
|
+
display_name="OpenCL",
|
|
169
|
+
parser=OpenCLParser(),
|
|
170
|
+
extensions=[".cl"],
|
|
171
|
+
file_patterns=["*.cl"]
|
|
172
|
+
)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
## Benefits of This Architecture
|
|
176
|
+
|
|
177
|
+
1. **Modular**: Each language parser is independent
|
|
178
|
+
2. **Extensible**: Easy to add new languages without modifying existing code
|
|
179
|
+
3. **Type-safe**: Common types (`KernelInfo`, `LayoutInfo`) ensure consistency
|
|
180
|
+
4. **Language-agnostic handlers**: Handlers work with any registered language
|
|
181
|
+
5. **Automatic detection**: Language is detected from file extension automatically
|
|
182
|
+
|
|
183
|
+
## Testing
|
|
184
|
+
|
|
185
|
+
Add tests for your parser:
|
|
186
|
+
|
|
187
|
+
```python
|
|
188
|
+
# tests/test_my_language_parser.py
|
|
189
|
+
|
|
190
|
+
def test_parse_kernel():
|
|
191
|
+
parser = MyLanguageParser()
|
|
192
|
+
result = parser.parse_file("kernel my_kernel() { }")
|
|
193
|
+
assert len(result["kernels"]) == 1
|
|
194
|
+
assert result["kernels"][0].name == "my_kernel"
|
|
195
|
+
```
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .converter import ParserResultConverter
|
|
2
|
+
from .detector import LanguageDetector
|
|
3
|
+
from .parser_manager import ParserManager
|
|
4
|
+
from .registry import LanguageRegistry, get_language_registry
|
|
5
|
+
from .types import KernelInfo, LanguageInfo, LayoutInfo, StructInfo
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"KernelInfo",
|
|
9
|
+
"LanguageDetector",
|
|
10
|
+
"LanguageInfo",
|
|
11
|
+
"LanguageRegistry",
|
|
12
|
+
"LayoutInfo",
|
|
13
|
+
"ParserManager",
|
|
14
|
+
"ParserResultConverter",
|
|
15
|
+
"StructInfo",
|
|
16
|
+
"get_language_registry",
|
|
17
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from .types import KernelInfo, LanguageInfo, LayoutInfo, StructInfo
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ParserResultConverter:
|
|
7
|
+
|
|
8
|
+
def convert(
|
|
9
|
+
self,
|
|
10
|
+
parsed_data: dict[str, Any],
|
|
11
|
+
language_id: str
|
|
12
|
+
) -> LanguageInfo:
|
|
13
|
+
kernels = self._convert_kernels(
|
|
14
|
+
parsed_data.get("kernels", []),
|
|
15
|
+
language_id
|
|
16
|
+
)
|
|
17
|
+
layouts = self._convert_layouts(
|
|
18
|
+
parsed_data.get("layouts", []),
|
|
19
|
+
language_id
|
|
20
|
+
)
|
|
21
|
+
structs = self._convert_structs(
|
|
22
|
+
parsed_data.get("structs", []),
|
|
23
|
+
language_id
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
return LanguageInfo(
|
|
27
|
+
kernels=kernels,
|
|
28
|
+
layouts=layouts,
|
|
29
|
+
structs=structs,
|
|
30
|
+
language=language_id,
|
|
31
|
+
raw_data=parsed_data
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def _convert_kernels(
|
|
35
|
+
self,
|
|
36
|
+
kernels: list[Any],
|
|
37
|
+
language_id: str
|
|
38
|
+
) -> list[KernelInfo]:
|
|
39
|
+
result: list[KernelInfo] = []
|
|
40
|
+
|
|
41
|
+
for kernel in kernels:
|
|
42
|
+
if hasattr(kernel, "name") and hasattr(kernel, "line"):
|
|
43
|
+
result.append(KernelInfo(
|
|
44
|
+
name=kernel.name,
|
|
45
|
+
line=kernel.line,
|
|
46
|
+
parameters=getattr(kernel, "parameters", []),
|
|
47
|
+
docstring=getattr(kernel, "docstring", None),
|
|
48
|
+
language=language_id
|
|
49
|
+
))
|
|
50
|
+
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
def _convert_layouts(
|
|
54
|
+
self,
|
|
55
|
+
layouts: list[Any],
|
|
56
|
+
language_id: str
|
|
57
|
+
) -> list[LayoutInfo]:
|
|
58
|
+
result: list[LayoutInfo] = []
|
|
59
|
+
|
|
60
|
+
for layout in layouts:
|
|
61
|
+
if hasattr(layout, "name") and hasattr(layout, "line"):
|
|
62
|
+
result.append(LayoutInfo(
|
|
63
|
+
name=layout.name,
|
|
64
|
+
line=layout.line,
|
|
65
|
+
shape=getattr(layout, "shape", None),
|
|
66
|
+
stride=getattr(layout, "stride", None),
|
|
67
|
+
language=language_id
|
|
68
|
+
))
|
|
69
|
+
|
|
70
|
+
return result
|
|
71
|
+
|
|
72
|
+
def _convert_structs(
|
|
73
|
+
self,
|
|
74
|
+
structs: list[Any],
|
|
75
|
+
language_id: str
|
|
76
|
+
) -> list[StructInfo]:
|
|
77
|
+
result: list[StructInfo] = []
|
|
78
|
+
|
|
79
|
+
for struct in structs:
|
|
80
|
+
if hasattr(struct, "name") and hasattr(struct, "line"):
|
|
81
|
+
result.append(StructInfo(
|
|
82
|
+
name=struct.name,
|
|
83
|
+
line=struct.line,
|
|
84
|
+
docstring=getattr(struct, "docstring", None),
|
|
85
|
+
language=language_id
|
|
86
|
+
))
|
|
87
|
+
|
|
88
|
+
return result
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LanguageDetector:
|
|
5
|
+
"""Detects language based on file extension and content markers.
|
|
6
|
+
|
|
7
|
+
Supports both extension-based detection (fast) and content-based detection
|
|
8
|
+
(for files that share extensions, e.g., .cpp files that could be HIP or CUDA).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
self._extensions: dict[str, str] = {}
|
|
13
|
+
self._content_markers: dict[str, list[str]] = {} # language_id -> markers
|
|
14
|
+
# Compound extensions like .hip.cpp need special handling
|
|
15
|
+
self._compound_extensions: dict[str, str] = {}
|
|
16
|
+
|
|
17
|
+
def register_extension(self, extension: str, language_id: str):
|
|
18
|
+
normalized_ext = extension if extension.startswith(".") else f".{extension}"
|
|
19
|
+
|
|
20
|
+
# Check if this is a compound extension (e.g., .hip.cpp)
|
|
21
|
+
if normalized_ext.count(".") > 1:
|
|
22
|
+
self._compound_extensions[normalized_ext] = language_id
|
|
23
|
+
else:
|
|
24
|
+
self._extensions[normalized_ext] = language_id
|
|
25
|
+
|
|
26
|
+
def register_content_markers(self, language_id: str, markers: list[str]):
|
|
27
|
+
"""Register content markers for content-based language detection."""
|
|
28
|
+
self._content_markers[language_id] = markers
|
|
29
|
+
|
|
30
|
+
def detect_from_uri(self, uri: str, content: str | None = None) -> str | None:
|
|
31
|
+
"""Detect language from URI and optionally content.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
uri: File URI or path
|
|
35
|
+
content: Optional file content for content-based detection
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Language ID or None
|
|
39
|
+
"""
|
|
40
|
+
if uri.startswith("file://"):
|
|
41
|
+
file_path = uri[7:]
|
|
42
|
+
else:
|
|
43
|
+
file_path = uri
|
|
44
|
+
|
|
45
|
+
return self.detect_from_path(file_path, content)
|
|
46
|
+
|
|
47
|
+
def detect_from_path(self, file_path: str, content: str | None = None) -> str | None:
|
|
48
|
+
"""Detect language from file path and optionally content.
|
|
49
|
+
|
|
50
|
+
Order of detection:
|
|
51
|
+
1. Compound extensions (e.g., .hip.cpp) - most specific
|
|
52
|
+
2. Content markers (for shared extensions like .cpp)
|
|
53
|
+
3. Simple extension
|
|
54
|
+
"""
|
|
55
|
+
path = Path(file_path)
|
|
56
|
+
|
|
57
|
+
# 1. Check compound extensions first
|
|
58
|
+
# Get the last two suffixes for compound extension detection
|
|
59
|
+
suffixes = path.suffixes
|
|
60
|
+
if len(suffixes) >= 2:
|
|
61
|
+
compound_ext = "".join(suffixes[-2:]).lower()
|
|
62
|
+
if compound_ext in self._compound_extensions:
|
|
63
|
+
return self._compound_extensions[compound_ext]
|
|
64
|
+
|
|
65
|
+
# 2. If content is provided, check content markers
|
|
66
|
+
if content:
|
|
67
|
+
content_lang = self._detect_from_content(content)
|
|
68
|
+
if content_lang:
|
|
69
|
+
return content_lang
|
|
70
|
+
|
|
71
|
+
# 3. Fall back to simple extension
|
|
72
|
+
ext = path.suffix.lower()
|
|
73
|
+
return self._extensions.get(ext)
|
|
74
|
+
|
|
75
|
+
def _detect_from_content(self, content: str) -> str | None:
|
|
76
|
+
"""Detect language based on content markers.
|
|
77
|
+
|
|
78
|
+
Returns the language with the most matching markers.
|
|
79
|
+
"""
|
|
80
|
+
best_match: str | None = None
|
|
81
|
+
best_count = 0
|
|
82
|
+
|
|
83
|
+
for language_id, markers in self._content_markers.items():
|
|
84
|
+
match_count = sum(1 for marker in markers if marker in content)
|
|
85
|
+
if match_count > best_count:
|
|
86
|
+
best_count = match_count
|
|
87
|
+
best_match = language_id
|
|
88
|
+
|
|
89
|
+
# Require at least one marker match
|
|
90
|
+
return best_match if best_count > 0 else None
|
|
91
|
+
|
|
92
|
+
def detect_from_extension(self, extension: str) -> str | None:
|
|
93
|
+
normalized_ext = extension if extension.startswith(".") else f".{extension}"
|
|
94
|
+
normalized_ext = normalized_ext.lower() # Case insensitive
|
|
95
|
+
return self._extensions.get(normalized_ext)
|
|
96
|
+
|
|
97
|
+
def get_supported_extensions(self) -> list[str]:
|
|
98
|
+
all_extensions = list(self._extensions.keys())
|
|
99
|
+
all_extensions.extend(self._compound_extensions.keys())
|
|
100
|
+
return all_extensions
|
|
101
|
+
|
|
102
|
+
def is_supported(self, uri: str, content: str | None = None) -> bool:
|
|
103
|
+
return self.detect_from_uri(uri, content) is not None
|
|
104
|
+
|
|
105
|
+
def get_content_markers(self, language_id: str) -> list[str]:
|
|
106
|
+
"""Get content markers for a language."""
|
|
107
|
+
return self._content_markers.get(language_id, [])
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
|
|
2
|
+
from ..parsers.base_parser import BaseParser
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ParserManager:
|
|
6
|
+
|
|
7
|
+
def __init__(self):
|
|
8
|
+
self._parsers: dict[str, BaseParser] = {}
|
|
9
|
+
self._language_names: dict[str, str] = {}
|
|
10
|
+
|
|
11
|
+
def register_parser(
|
|
12
|
+
self,
|
|
13
|
+
language_id: str,
|
|
14
|
+
display_name: str,
|
|
15
|
+
parser: BaseParser
|
|
16
|
+
):
|
|
17
|
+
assert language_id not in self._parsers, \
|
|
18
|
+
f"Language {language_id} already registered"
|
|
19
|
+
|
|
20
|
+
self._parsers[language_id] = parser
|
|
21
|
+
self._language_names[language_id] = display_name
|
|
22
|
+
|
|
23
|
+
def get_parser(self, language_id: str) -> BaseParser | None:
|
|
24
|
+
return self._parsers.get(language_id)
|
|
25
|
+
|
|
26
|
+
def get_language_name(self, language_id: str) -> str | None:
|
|
27
|
+
return self._language_names.get(language_id)
|
|
28
|
+
|
|
29
|
+
def has_parser(self, language_id: str) -> bool:
|
|
30
|
+
return language_id in self._parsers
|
|
31
|
+
|
|
32
|
+
def list_languages(self) -> list[str]:
|
|
33
|
+
return list(self._parsers.keys())
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
|
|
2
|
+
from ..parsers.cuda_parser import CUDAParser
|
|
3
|
+
from ..parsers.cutedsl_parser import CuTeDSLParser
|
|
4
|
+
from ..parsers.hip_parser import HIPParser
|
|
5
|
+
from .converter import ParserResultConverter
|
|
6
|
+
from .detector import LanguageDetector
|
|
7
|
+
from .parser_manager import ParserManager
|
|
8
|
+
from .types import LanguageInfo
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LanguageRegistry:
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self._detector = LanguageDetector()
|
|
15
|
+
self._parser_manager = ParserManager()
|
|
16
|
+
self._converter = ParserResultConverter()
|
|
17
|
+
|
|
18
|
+
self._register_defaults()
|
|
19
|
+
|
|
20
|
+
def _register_defaults(self):
|
|
21
|
+
self.register_language(
|
|
22
|
+
language_id="cutedsl",
|
|
23
|
+
display_name="CuTeDSL",
|
|
24
|
+
parser=CuTeDSLParser(),
|
|
25
|
+
extensions=[".py"],
|
|
26
|
+
file_patterns=["*.py"]
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
self.register_language(
|
|
30
|
+
language_id="cuda",
|
|
31
|
+
display_name="CUDA",
|
|
32
|
+
parser=CUDAParser(),
|
|
33
|
+
extensions=[".cu", ".cuh"],
|
|
34
|
+
file_patterns=["*.cu", "*.cuh"]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# HIP (AMD GPU) - Register before cpp so .hip.cpp files get detected as HIP
|
|
38
|
+
self.register_language(
|
|
39
|
+
language_id="hip",
|
|
40
|
+
display_name="HIP (AMD GPU)",
|
|
41
|
+
parser=HIPParser(),
|
|
42
|
+
extensions=[".hip", ".hip.cpp", ".hip.hpp", ".hipcc"],
|
|
43
|
+
file_patterns=["*.hip", "*.hip.cpp", "*.hip.hpp", "*.hipcc"],
|
|
44
|
+
content_markers=[
|
|
45
|
+
"#include <hip/hip_runtime.h>",
|
|
46
|
+
"#include \"hip/hip_runtime.h\"",
|
|
47
|
+
"hipMalloc",
|
|
48
|
+
"hipLaunchKernelGGL",
|
|
49
|
+
"__HIP_PLATFORM_AMD__",
|
|
50
|
+
]
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
self.register_language(
|
|
54
|
+
language_id="cpp",
|
|
55
|
+
display_name="C++",
|
|
56
|
+
parser=CUDAParser(),
|
|
57
|
+
extensions=[".cpp", ".hpp", ".cc", ".cxx"],
|
|
58
|
+
file_patterns=["*.cpp", "*.hpp", "*.cc", "*.cxx"]
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def register_language(
|
|
62
|
+
self,
|
|
63
|
+
language_id: str,
|
|
64
|
+
display_name: str,
|
|
65
|
+
parser,
|
|
66
|
+
extensions: list[str],
|
|
67
|
+
file_patterns: list[str] | None = None,
|
|
68
|
+
content_markers: list[str] | None = None
|
|
69
|
+
):
|
|
70
|
+
self._parser_manager.register_parser(language_id, display_name, parser)
|
|
71
|
+
|
|
72
|
+
for ext in extensions:
|
|
73
|
+
self._detector.register_extension(ext, language_id)
|
|
74
|
+
|
|
75
|
+
if content_markers:
|
|
76
|
+
self._detector.register_content_markers(language_id, content_markers)
|
|
77
|
+
|
|
78
|
+
def detect_language(self, uri: str) -> str | None:
|
|
79
|
+
return self._detector.detect_from_uri(uri)
|
|
80
|
+
|
|
81
|
+
def get_parser(self, language_id: str):
|
|
82
|
+
return self._parser_manager.get_parser(language_id)
|
|
83
|
+
|
|
84
|
+
def parse_file(self, uri: str, content: str) -> LanguageInfo | None:
|
|
85
|
+
language_id = self.detect_language(uri)
|
|
86
|
+
if not language_id:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
parser = self.get_parser(language_id)
|
|
90
|
+
if not parser:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
parsed_data = parser.parse_file(content)
|
|
95
|
+
except Exception:
|
|
96
|
+
return LanguageInfo(
|
|
97
|
+
kernels=[],
|
|
98
|
+
layouts=[],
|
|
99
|
+
structs=[],
|
|
100
|
+
language=language_id,
|
|
101
|
+
raw_data={}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return self._converter.convert(parsed_data, language_id)
|
|
105
|
+
|
|
106
|
+
def get_supported_extensions(self) -> list[str]:
|
|
107
|
+
return self._detector.get_supported_extensions()
|
|
108
|
+
|
|
109
|
+
def get_language_name(self, language_id: str) -> str | None:
|
|
110
|
+
return self._parser_manager.get_language_name(language_id)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
_registry: LanguageRegistry | None = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_language_registry() -> LanguageRegistry:
|
|
117
|
+
global _registry
|
|
118
|
+
if _registry is None:
|
|
119
|
+
_registry = LanguageRegistry()
|
|
120
|
+
return _registry
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class KernelInfo:
|
|
7
|
+
name: str
|
|
8
|
+
line: int
|
|
9
|
+
parameters: list[str]
|
|
10
|
+
docstring: str | None = None
|
|
11
|
+
language: str = ""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class LayoutInfo:
|
|
16
|
+
name: str
|
|
17
|
+
line: int
|
|
18
|
+
shape: str | None = None
|
|
19
|
+
stride: str | None = None
|
|
20
|
+
language: str = ""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class StructInfo:
|
|
25
|
+
name: str
|
|
26
|
+
line: int
|
|
27
|
+
docstring: str | None = None
|
|
28
|
+
language: str = ""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LanguageInfo:
|
|
33
|
+
kernels: list[KernelInfo]
|
|
34
|
+
layouts: list[LayoutInfo]
|
|
35
|
+
structs: list[StructInfo]
|
|
36
|
+
language: str
|
|
37
|
+
raw_data: dict[str, Any]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from .base_parser import BaseParser
|
|
2
|
+
from .cuda_parser import CUDAKernel, CUDAParser
|
|
3
|
+
from .cutedsl_parser import (
|
|
4
|
+
CuTeDSLKernel,
|
|
5
|
+
CuTeDSLLayout,
|
|
6
|
+
CuTeDSLParser,
|
|
7
|
+
CuTeDSLStruct,
|
|
8
|
+
)
|
|
9
|
+
from .hip_parser import (
|
|
10
|
+
HIPKernel,
|
|
11
|
+
HIPDeviceFunction,
|
|
12
|
+
HIPParameter,
|
|
13
|
+
HIPParser,
|
|
14
|
+
KernelLaunchSite,
|
|
15
|
+
SharedMemoryAllocation,
|
|
16
|
+
WavefrontPattern,
|
|
17
|
+
is_hip_file,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"BaseParser",
|
|
22
|
+
"CUDAKernel",
|
|
23
|
+
"CUDAParser",
|
|
24
|
+
"CuTeDSLKernel",
|
|
25
|
+
"CuTeDSLLayout",
|
|
26
|
+
"CuTeDSLParser",
|
|
27
|
+
"CuTeDSLStruct",
|
|
28
|
+
"HIPKernel",
|
|
29
|
+
"HIPDeviceFunction",
|
|
30
|
+
"HIPParameter",
|
|
31
|
+
"HIPParser",
|
|
32
|
+
"KernelLaunchSite",
|
|
33
|
+
"SharedMemoryAllocation",
|
|
34
|
+
"WavefrontPattern",
|
|
35
|
+
"is_hip_file",
|
|
36
|
+
]
|