wafer-lsp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_lsp/__init__.py +1 -0
- wafer_lsp/__main__.py +9 -0
- wafer_lsp/analyzers/__init__.py +0 -0
- wafer_lsp/analyzers/compiler_integration.py +16 -0
- wafer_lsp/analyzers/docs_index.py +36 -0
- wafer_lsp/handlers/__init__.py +0 -0
- wafer_lsp/handlers/code_action.py +48 -0
- wafer_lsp/handlers/code_lens.py +48 -0
- wafer_lsp/handlers/completion.py +6 -0
- wafer_lsp/handlers/diagnostics.py +16 -0
- wafer_lsp/handlers/document_symbol.py +87 -0
- wafer_lsp/handlers/hover.py +215 -0
- wafer_lsp/handlers/inlay_hint.py +65 -0
- wafer_lsp/handlers/semantic_tokens.py +124 -0
- wafer_lsp/handlers/workspace_symbol.py +87 -0
- wafer_lsp/languages/README.md +195 -0
- wafer_lsp/languages/__init__.py +17 -0
- wafer_lsp/languages/converter.py +88 -0
- wafer_lsp/languages/detector.py +34 -0
- wafer_lsp/languages/parser_manager.py +33 -0
- wafer_lsp/languages/registry.py +99 -0
- wafer_lsp/languages/types.py +37 -0
- wafer_lsp/parsers/__init__.py +18 -0
- wafer_lsp/parsers/base_parser.py +9 -0
- wafer_lsp/parsers/cuda_parser.py +95 -0
- wafer_lsp/parsers/cutedsl_parser.py +114 -0
- wafer_lsp/server.py +58 -0
- wafer_lsp/services/__init__.py +21 -0
- wafer_lsp/services/analysis_service.py +22 -0
- wafer_lsp/services/docs_service.py +40 -0
- wafer_lsp/services/document_service.py +20 -0
- wafer_lsp/services/hover_service.py +237 -0
- wafer_lsp/services/language_registry_service.py +26 -0
- wafer_lsp/services/position_service.py +77 -0
- wafer_lsp/utils/__init__.py +0 -0
- wafer_lsp/utils/lsp_helpers.py +79 -0
- wafer_lsp-0.1.0.dist-info/METADATA +57 -0
- wafer_lsp-0.1.0.dist-info/RECORD +40 -0
- wafer_lsp-0.1.0.dist-info/WHEEL +4 -0
- wafer_lsp-0.1.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
|
|
2
|
+
from lsprotocol.types import Location, Position, Range, SymbolKind, WorkspaceSymbol
|
|
3
|
+
|
|
4
|
+
from ..languages.registry import get_language_registry
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _matches_query(name: str, query: str) -> bool:
|
|
8
|
+
if not query:
|
|
9
|
+
return True
|
|
10
|
+
|
|
11
|
+
name_lower = name.lower()
|
|
12
|
+
query_lower = query.lower()
|
|
13
|
+
|
|
14
|
+
query_idx = 0
|
|
15
|
+
for char in name_lower:
|
|
16
|
+
if query_idx < len(query_lower) and char == query_lower[query_idx]:
|
|
17
|
+
query_idx += 1
|
|
18
|
+
|
|
19
|
+
return query_idx == len(query_lower)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def handle_workspace_symbol(query: str) -> list[WorkspaceSymbol]:
|
|
23
|
+
registry = get_language_registry()
|
|
24
|
+
symbols: list[WorkspaceSymbol] = []
|
|
25
|
+
|
|
26
|
+
return symbols
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def handle_workspace_symbol_with_documents(
|
|
30
|
+
query: str,
|
|
31
|
+
document_contents: dict[str, str]
|
|
32
|
+
) -> list[WorkspaceSymbol]:
|
|
33
|
+
registry = get_language_registry()
|
|
34
|
+
symbols: list[WorkspaceSymbol] = []
|
|
35
|
+
|
|
36
|
+
for uri, content in document_contents.items():
|
|
37
|
+
language_info = registry.parse_file(uri, content)
|
|
38
|
+
|
|
39
|
+
if not language_info:
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
for kernel in language_info.kernels:
|
|
43
|
+
if _matches_query(kernel.name, query):
|
|
44
|
+
symbols.append(WorkspaceSymbol(
|
|
45
|
+
name=kernel.name,
|
|
46
|
+
kind=SymbolKind.Function,
|
|
47
|
+
location=Location(
|
|
48
|
+
uri=uri,
|
|
49
|
+
range=Range(
|
|
50
|
+
start=Position(line=kernel.line, character=0),
|
|
51
|
+
end=Position(line=kernel.line, character=0)
|
|
52
|
+
)
|
|
53
|
+
),
|
|
54
|
+
container_name=f"GPU Kernel ({registry.get_language_name(kernel.language)})"
|
|
55
|
+
))
|
|
56
|
+
|
|
57
|
+
for layout in language_info.layouts:
|
|
58
|
+
if _matches_query(layout.name, query):
|
|
59
|
+
symbols.append(WorkspaceSymbol(
|
|
60
|
+
name=layout.name,
|
|
61
|
+
kind=SymbolKind.Variable,
|
|
62
|
+
location=Location(
|
|
63
|
+
uri=uri,
|
|
64
|
+
range=Range(
|
|
65
|
+
start=Position(line=layout.line, character=0),
|
|
66
|
+
end=Position(line=layout.line, character=0)
|
|
67
|
+
)
|
|
68
|
+
),
|
|
69
|
+
container_name="Layout"
|
|
70
|
+
))
|
|
71
|
+
|
|
72
|
+
for struct in language_info.structs:
|
|
73
|
+
if _matches_query(struct.name, query):
|
|
74
|
+
symbols.append(WorkspaceSymbol(
|
|
75
|
+
name=struct.name,
|
|
76
|
+
kind=SymbolKind.Struct,
|
|
77
|
+
location=Location(
|
|
78
|
+
uri=uri,
|
|
79
|
+
range=Range(
|
|
80
|
+
start=Position(line=struct.line, character=0),
|
|
81
|
+
end=Position(line=struct.line, character=0)
|
|
82
|
+
)
|
|
83
|
+
),
|
|
84
|
+
container_name=f"Struct ({registry.get_language_name(struct.language)})"
|
|
85
|
+
))
|
|
86
|
+
|
|
87
|
+
return symbols
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Adding a New Language to Wafer LSP
|
|
2
|
+
|
|
3
|
+
This guide explains how to add support for a new GPU programming language to the LSP server.
|
|
4
|
+
|
|
5
|
+
## Architecture Overview
|
|
6
|
+
|
|
7
|
+
The LSP server uses a modular, component-based architecture:
|
|
8
|
+
|
|
9
|
+
- **`LanguageDetector`**: Detects language from file extensions/URIs
|
|
10
|
+
- **`ParserManager`**: Manages parser instances for each language
|
|
11
|
+
- **`ParserResultConverter`**: Converts parser-specific types to common types
|
|
12
|
+
- **`LanguageRegistry`**: Coordinates all components (facade pattern)
|
|
13
|
+
|
|
14
|
+
Each component is small, focused, and independently testable.
|
|
15
|
+
|
|
16
|
+
## Overview
|
|
17
|
+
|
|
18
|
+
To add a new language, you need to:
|
|
19
|
+
|
|
20
|
+
1. Create a parser that implements `BaseParser`
|
|
21
|
+
2. Register the parser with the language registry
|
|
22
|
+
3. (Optional) Add language-specific handlers
|
|
23
|
+
|
|
24
|
+
## Step 1: Create a Parser
|
|
25
|
+
|
|
26
|
+
Create a new parser file in `wafer_lsp/parsers/`:
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
# wafer_lsp/parsers/my_language_parser.py
|
|
30
|
+
|
|
31
|
+
from typing import List, Dict, Any, Optional
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
from .base_parser import BaseParser
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class MyLanguageKernel:
|
|
37
|
+
name: str
|
|
38
|
+
line: int
|
|
39
|
+
parameters: List[str]
|
|
40
|
+
docstring: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
class MyLanguageParser(BaseParser):
|
|
43
|
+
"""Parser for MyLanguage GPU code."""
|
|
44
|
+
|
|
45
|
+
def parse_file(self, content: str) -> Dict[str, Any]:
|
|
46
|
+
"""Parse file and extract kernels, layouts, etc."""
|
|
47
|
+
kernels: List[MyLanguageKernel] = []
|
|
48
|
+
|
|
49
|
+
# Your parsing logic here
|
|
50
|
+
# Example: find kernel definitions
|
|
51
|
+
for match in re.finditer(r'kernel\s+(\w+)', content):
|
|
52
|
+
kernels.append(MyLanguageKernel(
|
|
53
|
+
name=match.group(1),
|
|
54
|
+
line=content[:match.start()].count('\n'),
|
|
55
|
+
parameters=[],
|
|
56
|
+
))
|
|
57
|
+
|
|
58
|
+
return {
|
|
59
|
+
"kernels": kernels,
|
|
60
|
+
"layouts": [], # If your language has layouts
|
|
61
|
+
"structs": [], # If your language has structs
|
|
62
|
+
}
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
## Step 2: Register the Parser
|
|
66
|
+
|
|
67
|
+
Update `wafer_lsp/languages/registry.py` to register your parser:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
from ..parsers.my_language_parser import MyLanguageParser
|
|
71
|
+
|
|
72
|
+
class LanguageRegistry:
|
|
73
|
+
def _register_defaults(self):
|
|
74
|
+
# ... existing registrations ...
|
|
75
|
+
|
|
76
|
+
# Register your new language
|
|
77
|
+
self.register_language(
|
|
78
|
+
language_id="mylang",
|
|
79
|
+
display_name="My Language",
|
|
80
|
+
parser=MyLanguageParser(),
|
|
81
|
+
extensions=[".mylang", ".ml"],
|
|
82
|
+
file_patterns=["*.mylang", "*.ml"]
|
|
83
|
+
)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Step 3: Update VS Code Extension (if needed)
|
|
87
|
+
|
|
88
|
+
If your language needs special handling in the VS Code extension, update `lspClient.ts`:
|
|
89
|
+
|
|
90
|
+
```typescript
|
|
91
|
+
const clientOptions: LanguageClientOptions = {
|
|
92
|
+
documentSelector: [
|
|
93
|
+
// ... existing selectors ...
|
|
94
|
+
{ scheme: 'file', language: 'mylang' }, // Your language
|
|
95
|
+
],
|
|
96
|
+
// ...
|
|
97
|
+
};
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
## Parser Interface
|
|
101
|
+
|
|
102
|
+
Your parser must implement `BaseParser`:
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
class BaseParser(ABC):
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def parse_file(self, content: str) -> Dict[str, Any]:
|
|
108
|
+
"""Parse file content and extract language-specific constructs.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Dictionary with keys:
|
|
112
|
+
- "kernels": List of kernel objects (must have .name, .line, .parameters)
|
|
113
|
+
- "layouts": List of layout objects (must have .name, .line)
|
|
114
|
+
- "structs": List of struct objects (must have .name, .line)
|
|
115
|
+
"""
|
|
116
|
+
pass
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
## Example: Adding OpenCL Support
|
|
120
|
+
|
|
121
|
+
Here's a complete example for adding OpenCL support:
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
# wafer_lsp/parsers/opencl_parser.py
|
|
125
|
+
|
|
126
|
+
import re
|
|
127
|
+
from typing import List, Dict, Any
|
|
128
|
+
from dataclasses import dataclass
|
|
129
|
+
from .base_parser import BaseParser
|
|
130
|
+
|
|
131
|
+
@dataclass
|
|
132
|
+
class OpenCLKernel:
|
|
133
|
+
name: str
|
|
134
|
+
line: int
|
|
135
|
+
parameters: List[str]
|
|
136
|
+
|
|
137
|
+
class OpenCLParser(BaseParser):
|
|
138
|
+
def parse_file(self, content: str) -> Dict[str, Any]:
|
|
139
|
+
kernels: List[OpenCLKernel] = []
|
|
140
|
+
|
|
141
|
+
# Pattern: __kernel void kernel_name(...)
|
|
142
|
+
pattern = r'__kernel\s+(?:__global\s+)?(?:void|.*?)\s+(\w+)\s*\('
|
|
143
|
+
|
|
144
|
+
for match in re.finditer(pattern, content):
|
|
145
|
+
line = content[:match.start()].count('\n')
|
|
146
|
+
kernel_name = match.group(1)
|
|
147
|
+
params = self._extract_parameters(content, match.end())
|
|
148
|
+
|
|
149
|
+
kernels.append(OpenCLKernel(
|
|
150
|
+
name=kernel_name,
|
|
151
|
+
line=line,
|
|
152
|
+
parameters=params
|
|
153
|
+
))
|
|
154
|
+
|
|
155
|
+
return {"kernels": kernels, "layouts": [], "structs": []}
|
|
156
|
+
|
|
157
|
+
def _extract_parameters(self, content: str, start: int) -> List[str]:
|
|
158
|
+
# Extract parameter list logic
|
|
159
|
+
return []
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
Then register it:
|
|
163
|
+
|
|
164
|
+
```python
|
|
165
|
+
# In registry.py _register_defaults()
|
|
166
|
+
self.register_language(
|
|
167
|
+
language_id="opencl",
|
|
168
|
+
display_name="OpenCL",
|
|
169
|
+
parser=OpenCLParser(),
|
|
170
|
+
extensions=[".cl"],
|
|
171
|
+
file_patterns=["*.cl"]
|
|
172
|
+
)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
## Benefits of This Architecture
|
|
176
|
+
|
|
177
|
+
1. **Modular**: Each language parser is independent
|
|
178
|
+
2. **Extensible**: Easy to add new languages without modifying existing code
|
|
179
|
+
3. **Type-safe**: Common types (`KernelInfo`, `LayoutInfo`) ensure consistency
|
|
180
|
+
4. **Language-agnostic handlers**: Handlers work with any registered language
|
|
181
|
+
5. **Automatic detection**: Language is detected from file extension automatically
|
|
182
|
+
|
|
183
|
+
## Testing
|
|
184
|
+
|
|
185
|
+
Add tests for your parser:
|
|
186
|
+
|
|
187
|
+
```python
|
|
188
|
+
# tests/test_my_language_parser.py
|
|
189
|
+
|
|
190
|
+
def test_parse_kernel():
|
|
191
|
+
parser = MyLanguageParser()
|
|
192
|
+
result = parser.parse_file("kernel my_kernel() { }")
|
|
193
|
+
assert len(result["kernels"]) == 1
|
|
194
|
+
assert result["kernels"][0].name == "my_kernel"
|
|
195
|
+
```
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .converter import ParserResultConverter
|
|
2
|
+
from .detector import LanguageDetector
|
|
3
|
+
from .parser_manager import ParserManager
|
|
4
|
+
from .registry import LanguageRegistry, get_language_registry
|
|
5
|
+
from .types import KernelInfo, LanguageInfo, LayoutInfo, StructInfo
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"KernelInfo",
|
|
9
|
+
"LanguageDetector",
|
|
10
|
+
"LanguageInfo",
|
|
11
|
+
"LanguageRegistry",
|
|
12
|
+
"LayoutInfo",
|
|
13
|
+
"ParserManager",
|
|
14
|
+
"ParserResultConverter",
|
|
15
|
+
"StructInfo",
|
|
16
|
+
"get_language_registry",
|
|
17
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from .types import KernelInfo, LanguageInfo, LayoutInfo, StructInfo
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ParserResultConverter:
|
|
7
|
+
|
|
8
|
+
def convert(
|
|
9
|
+
self,
|
|
10
|
+
parsed_data: dict[str, Any],
|
|
11
|
+
language_id: str
|
|
12
|
+
) -> LanguageInfo:
|
|
13
|
+
kernels = self._convert_kernels(
|
|
14
|
+
parsed_data.get("kernels", []),
|
|
15
|
+
language_id
|
|
16
|
+
)
|
|
17
|
+
layouts = self._convert_layouts(
|
|
18
|
+
parsed_data.get("layouts", []),
|
|
19
|
+
language_id
|
|
20
|
+
)
|
|
21
|
+
structs = self._convert_structs(
|
|
22
|
+
parsed_data.get("structs", []),
|
|
23
|
+
language_id
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
return LanguageInfo(
|
|
27
|
+
kernels=kernels,
|
|
28
|
+
layouts=layouts,
|
|
29
|
+
structs=structs,
|
|
30
|
+
language=language_id,
|
|
31
|
+
raw_data=parsed_data
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def _convert_kernels(
|
|
35
|
+
self,
|
|
36
|
+
kernels: list[Any],
|
|
37
|
+
language_id: str
|
|
38
|
+
) -> list[KernelInfo]:
|
|
39
|
+
result: list[KernelInfo] = []
|
|
40
|
+
|
|
41
|
+
for kernel in kernels:
|
|
42
|
+
if hasattr(kernel, "name") and hasattr(kernel, "line"):
|
|
43
|
+
result.append(KernelInfo(
|
|
44
|
+
name=kernel.name,
|
|
45
|
+
line=kernel.line,
|
|
46
|
+
parameters=getattr(kernel, "parameters", []),
|
|
47
|
+
docstring=getattr(kernel, "docstring", None),
|
|
48
|
+
language=language_id
|
|
49
|
+
))
|
|
50
|
+
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
def _convert_layouts(
|
|
54
|
+
self,
|
|
55
|
+
layouts: list[Any],
|
|
56
|
+
language_id: str
|
|
57
|
+
) -> list[LayoutInfo]:
|
|
58
|
+
result: list[LayoutInfo] = []
|
|
59
|
+
|
|
60
|
+
for layout in layouts:
|
|
61
|
+
if hasattr(layout, "name") and hasattr(layout, "line"):
|
|
62
|
+
result.append(LayoutInfo(
|
|
63
|
+
name=layout.name,
|
|
64
|
+
line=layout.line,
|
|
65
|
+
shape=getattr(layout, "shape", None),
|
|
66
|
+
stride=getattr(layout, "stride", None),
|
|
67
|
+
language=language_id
|
|
68
|
+
))
|
|
69
|
+
|
|
70
|
+
return result
|
|
71
|
+
|
|
72
|
+
def _convert_structs(
|
|
73
|
+
self,
|
|
74
|
+
structs: list[Any],
|
|
75
|
+
language_id: str
|
|
76
|
+
) -> list[StructInfo]:
|
|
77
|
+
result: list[StructInfo] = []
|
|
78
|
+
|
|
79
|
+
for struct in structs:
|
|
80
|
+
if hasattr(struct, "name") and hasattr(struct, "line"):
|
|
81
|
+
result.append(StructInfo(
|
|
82
|
+
name=struct.name,
|
|
83
|
+
line=struct.line,
|
|
84
|
+
docstring=getattr(struct, "docstring", None),
|
|
85
|
+
language=language_id
|
|
86
|
+
))
|
|
87
|
+
|
|
88
|
+
return result
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LanguageDetector:
|
|
5
|
+
|
|
6
|
+
def __init__(self):
|
|
7
|
+
self._extensions: dict[str, str] = {}
|
|
8
|
+
|
|
9
|
+
def register_extension(self, extension: str, language_id: str):
|
|
10
|
+
normalized_ext = extension if extension.startswith(".") else f".{ext}"
|
|
11
|
+
self._extensions[normalized_ext] = language_id
|
|
12
|
+
|
|
13
|
+
def detect_from_uri(self, uri: str) -> str | None:
|
|
14
|
+
if uri.startswith("file://"):
|
|
15
|
+
file_path = uri[7:]
|
|
16
|
+
else:
|
|
17
|
+
file_path = uri
|
|
18
|
+
|
|
19
|
+
return self.detect_from_path(file_path)
|
|
20
|
+
|
|
21
|
+
def detect_from_path(self, file_path: str) -> str | None:
|
|
22
|
+
path = Path(file_path)
|
|
23
|
+
ext = path.suffix.lower()
|
|
24
|
+
return self._extensions.get(ext)
|
|
25
|
+
|
|
26
|
+
def detect_from_extension(self, extension: str) -> str | None:
|
|
27
|
+
normalized_ext = extension if extension.startswith(".") else f".{ext}"
|
|
28
|
+
return self._extensions.get(normalized_ext)
|
|
29
|
+
|
|
30
|
+
def get_supported_extensions(self) -> list[str]:
|
|
31
|
+
return list(self._extensions.keys())
|
|
32
|
+
|
|
33
|
+
def is_supported(self, uri: str) -> bool:
|
|
34
|
+
return self.detect_from_uri(uri) is not None
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
|
|
2
|
+
from ..parsers.base_parser import BaseParser
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ParserManager:
|
|
6
|
+
|
|
7
|
+
def __init__(self):
|
|
8
|
+
self._parsers: dict[str, BaseParser] = {}
|
|
9
|
+
self._language_names: dict[str, str] = {}
|
|
10
|
+
|
|
11
|
+
def register_parser(
|
|
12
|
+
self,
|
|
13
|
+
language_id: str,
|
|
14
|
+
display_name: str,
|
|
15
|
+
parser: BaseParser
|
|
16
|
+
):
|
|
17
|
+
assert language_id not in self._parsers, \
|
|
18
|
+
f"Language {language_id} already registered"
|
|
19
|
+
|
|
20
|
+
self._parsers[language_id] = parser
|
|
21
|
+
self._language_names[language_id] = display_name
|
|
22
|
+
|
|
23
|
+
def get_parser(self, language_id: str) -> BaseParser | None:
|
|
24
|
+
return self._parsers.get(language_id)
|
|
25
|
+
|
|
26
|
+
def get_language_name(self, language_id: str) -> str | None:
|
|
27
|
+
return self._language_names.get(language_id)
|
|
28
|
+
|
|
29
|
+
def has_parser(self, language_id: str) -> bool:
|
|
30
|
+
return language_id in self._parsers
|
|
31
|
+
|
|
32
|
+
def list_languages(self) -> list[str]:
|
|
33
|
+
return list(self._parsers.keys())
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
|
|
2
|
+
from ..parsers.cuda_parser import CUDAParser
|
|
3
|
+
from ..parsers.cutedsl_parser import CuTeDSLParser
|
|
4
|
+
from .converter import ParserResultConverter
|
|
5
|
+
from .detector import LanguageDetector
|
|
6
|
+
from .parser_manager import ParserManager
|
|
7
|
+
from .types import LanguageInfo
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LanguageRegistry:
|
|
11
|
+
|
|
12
|
+
def __init__(self):
|
|
13
|
+
self._detector = LanguageDetector()
|
|
14
|
+
self._parser_manager = ParserManager()
|
|
15
|
+
self._converter = ParserResultConverter()
|
|
16
|
+
|
|
17
|
+
self._register_defaults()
|
|
18
|
+
|
|
19
|
+
def _register_defaults(self):
|
|
20
|
+
self.register_language(
|
|
21
|
+
language_id="cutedsl",
|
|
22
|
+
display_name="CuTeDSL",
|
|
23
|
+
parser=CuTeDSLParser(),
|
|
24
|
+
extensions=[".py"],
|
|
25
|
+
file_patterns=["*.py"]
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
self.register_language(
|
|
29
|
+
language_id="cuda",
|
|
30
|
+
display_name="CUDA",
|
|
31
|
+
parser=CUDAParser(),
|
|
32
|
+
extensions=[".cu", ".cuh"],
|
|
33
|
+
file_patterns=["*.cu", "*.cuh"]
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
self.register_language(
|
|
37
|
+
language_id="cpp",
|
|
38
|
+
display_name="C++",
|
|
39
|
+
parser=CUDAParser(),
|
|
40
|
+
extensions=[".cpp", ".hpp", ".cc", ".cxx"],
|
|
41
|
+
file_patterns=["*.cpp", "*.hpp", "*.cc", "*.cxx"]
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def register_language(
|
|
45
|
+
self,
|
|
46
|
+
language_id: str,
|
|
47
|
+
display_name: str,
|
|
48
|
+
parser,
|
|
49
|
+
extensions: list[str],
|
|
50
|
+
file_patterns: list[str] | None = None
|
|
51
|
+
):
|
|
52
|
+
self._parser_manager.register_parser(language_id, display_name, parser)
|
|
53
|
+
|
|
54
|
+
for ext in extensions:
|
|
55
|
+
self._detector.register_extension(ext, language_id)
|
|
56
|
+
|
|
57
|
+
def detect_language(self, uri: str) -> str | None:
|
|
58
|
+
return self._detector.detect_from_uri(uri)
|
|
59
|
+
|
|
60
|
+
def get_parser(self, language_id: str):
|
|
61
|
+
return self._parser_manager.get_parser(language_id)
|
|
62
|
+
|
|
63
|
+
def parse_file(self, uri: str, content: str) -> LanguageInfo | None:
|
|
64
|
+
language_id = self.detect_language(uri)
|
|
65
|
+
if not language_id:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
parser = self.get_parser(language_id)
|
|
69
|
+
if not parser:
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
parsed_data = parser.parse_file(content)
|
|
74
|
+
except Exception:
|
|
75
|
+
return LanguageInfo(
|
|
76
|
+
kernels=[],
|
|
77
|
+
layouts=[],
|
|
78
|
+
structs=[],
|
|
79
|
+
language=language_id,
|
|
80
|
+
raw_data={}
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return self._converter.convert(parsed_data, language_id)
|
|
84
|
+
|
|
85
|
+
def get_supported_extensions(self) -> list[str]:
|
|
86
|
+
return self._detector.get_supported_extensions()
|
|
87
|
+
|
|
88
|
+
def get_language_name(self, language_id: str) -> str | None:
|
|
89
|
+
return self._parser_manager.get_language_name(language_id)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
_registry: LanguageRegistry | None = None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_language_registry() -> LanguageRegistry:
|
|
96
|
+
global _registry
|
|
97
|
+
if _registry is None:
|
|
98
|
+
_registry = LanguageRegistry()
|
|
99
|
+
return _registry
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class KernelInfo:
|
|
7
|
+
name: str
|
|
8
|
+
line: int
|
|
9
|
+
parameters: list[str]
|
|
10
|
+
docstring: str | None = None
|
|
11
|
+
language: str = ""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class LayoutInfo:
|
|
16
|
+
name: str
|
|
17
|
+
line: int
|
|
18
|
+
shape: str | None = None
|
|
19
|
+
stride: str | None = None
|
|
20
|
+
language: str = ""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class StructInfo:
|
|
25
|
+
name: str
|
|
26
|
+
line: int
|
|
27
|
+
docstring: str | None = None
|
|
28
|
+
language: str = ""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LanguageInfo:
|
|
33
|
+
kernels: list[KernelInfo]
|
|
34
|
+
layouts: list[LayoutInfo]
|
|
35
|
+
structs: list[StructInfo]
|
|
36
|
+
language: str
|
|
37
|
+
raw_data: dict[str, Any]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .base_parser import BaseParser
|
|
2
|
+
from .cuda_parser import CUDAKernel, CUDAParser
|
|
3
|
+
from .cutedsl_parser import (
|
|
4
|
+
CuTeDSLKernel,
|
|
5
|
+
CuTeDSLLayout,
|
|
6
|
+
CuTeDSLParser,
|
|
7
|
+
CuTeDSLStruct,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BaseParser",
|
|
12
|
+
"CUDAKernel",
|
|
13
|
+
"CUDAParser",
|
|
14
|
+
"CuTeDSLKernel",
|
|
15
|
+
"CuTeDSLLayout",
|
|
16
|
+
"CuTeDSLParser",
|
|
17
|
+
"CuTeDSLStruct",
|
|
18
|
+
]
|