tile-lsp 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tile_lsp-0.1.0/PKG-INFO +7 -0
- tile_lsp-0.1.0/pyproject.toml +19 -0
- tile_lsp-0.1.0/setup.cfg +4 -0
- tile_lsp-0.1.0/tests/test_completion.py +95 -0
- tile_lsp-0.1.0/tests/test_detection.py +79 -0
- tile_lsp-0.1.0/tests/test_diagnostics.py +158 -0
- tile_lsp-0.1.0/tests/test_hover.py +51 -0
- tile_lsp-0.1.0/tests/test_knowledge.py +94 -0
- tile_lsp-0.1.0/tests/test_server_creation.py +13 -0
- tile_lsp-0.1.0/tile_lsp.egg-info/PKG-INFO +7 -0
- tile_lsp-0.1.0/tile_lsp.egg-info/SOURCES.txt +23 -0
- tile_lsp-0.1.0/tile_lsp.egg-info/dependency_links.txt +1 -0
- tile_lsp-0.1.0/tile_lsp.egg-info/entry_points.txt +2 -0
- tile_lsp-0.1.0/tile_lsp.egg-info/requires.txt +2 -0
- tile_lsp-0.1.0/tile_lsp.egg-info/top_level.txt +1 -0
- tile_lsp-0.1.0/tiled_server/__init__.py +1 -0
- tile_lsp-0.1.0/tiled_server/__main__.py +28 -0
- tile_lsp-0.1.0/tiled_server/completion.py +100 -0
- tile_lsp-0.1.0/tiled_server/detection.py +28 -0
- tile_lsp-0.1.0/tiled_server/diagnostics.py +101 -0
- tile_lsp-0.1.0/tiled_server/hover.py +62 -0
- tile_lsp-0.1.0/tiled_server/knowledge.py +679 -0
- tile_lsp-0.1.0/tiled_server/server.py +88 -0
- tile_lsp-0.1.0/tiled_server/signature.py +45 -0
- tile_lsp-0.1.0/tiled_server/utils.py +35 -0
tile_lsp-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "tile-lsp"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Language server for the TileLang DSL"
|
|
9
|
+
requires-python = ">=3.9"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"pygls>=1.3.0",
|
|
12
|
+
"lsprotocol>=2023.0.0",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
[project.scripts]
|
|
16
|
+
tiled = "tiled_server.__main__:main"
|
|
17
|
+
|
|
18
|
+
[tool.setuptools.packages.find]
|
|
19
|
+
include = ["tiled_server*"]
|
tile_lsp-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Tests for TileLang completion logic."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from lsprotocol import types as lsp
|
|
5
|
+
|
|
6
|
+
from tiled_server.completion import symbol_to_completion
|
|
7
|
+
from tiled_server.knowledge import (
|
|
8
|
+
TileLangSymbol,
|
|
9
|
+
get_t_completions,
|
|
10
|
+
get_tilelang_completions,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestSymbolToCompletion:
|
|
15
|
+
def test_function_symbol(self):
|
|
16
|
+
sym = TileLangSymbol("gemm", "function", "T.gemm(A, B, C)", "Do gemm.")
|
|
17
|
+
item = symbol_to_completion(sym)
|
|
18
|
+
assert item.label == "gemm"
|
|
19
|
+
assert item.kind == lsp.CompletionItemKind.Function
|
|
20
|
+
assert item.detail == "T.gemm(A, B, C)"
|
|
21
|
+
|
|
22
|
+
def test_type_symbol(self):
|
|
23
|
+
sym = TileLangSymbol("Tensor", "type", "T.Tensor", "Tensor type.")
|
|
24
|
+
item = symbol_to_completion(sym)
|
|
25
|
+
assert item.kind == lsp.CompletionItemKind.Class
|
|
26
|
+
|
|
27
|
+
def test_constant_symbol(self):
|
|
28
|
+
sym = TileLangSymbol("float16", "constant", "T.float16", "float16 dtype.")
|
|
29
|
+
item = symbol_to_completion(sym)
|
|
30
|
+
assert item.kind == lsp.CompletionItemKind.Constant
|
|
31
|
+
assert item.sort_text.startswith("1_")
|
|
32
|
+
|
|
33
|
+
def test_function_sort_before_constant(self):
|
|
34
|
+
func = TileLangSymbol("gemm", "function", "T.gemm()", "gemm")
|
|
35
|
+
const = TileLangSymbol("float16", "constant", "T.float16", "f16")
|
|
36
|
+
func_item = symbol_to_completion(func)
|
|
37
|
+
const_item = symbol_to_completion(const)
|
|
38
|
+
assert func_item.sort_text < const_item.sort_text
|
|
39
|
+
|
|
40
|
+
def test_snippet_uses_snippet_format(self):
|
|
41
|
+
sym = TileLangSymbol("alloc_shared", "function", "T.alloc_shared()", "Alloc.",
|
|
42
|
+
snippet="alloc_shared((${1:M}, ${2:N}), ${3:dtype})")
|
|
43
|
+
item = symbol_to_completion(sym)
|
|
44
|
+
assert item.insert_text_format == lsp.InsertTextFormat.Snippet
|
|
45
|
+
assert "${1:" in item.insert_text
|
|
46
|
+
|
|
47
|
+
def test_no_snippet_uses_plain_text(self):
|
|
48
|
+
sym = TileLangSymbol("clear", "function", "T.clear(buf)", "Clear.")
|
|
49
|
+
item = symbol_to_completion(sym)
|
|
50
|
+
assert item.insert_text_format == lsp.InsertTextFormat.PlainText
|
|
51
|
+
assert item.insert_text == "clear"
|
|
52
|
+
|
|
53
|
+
def test_documentation_is_markdown(self):
|
|
54
|
+
sym = TileLangSymbol("Kernel", "class", "T.Kernel()", "Launch a kernel.\n\n```python\n...\n```")
|
|
55
|
+
item = symbol_to_completion(sym)
|
|
56
|
+
assert item.documentation.kind == lsp.MarkupKind.Markdown
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class TestCompletionLogic:
|
|
60
|
+
"""Test the completion filtering logic directly."""
|
|
61
|
+
|
|
62
|
+
def test_t_dot_returns_all_symbols(self):
|
|
63
|
+
syms = get_t_completions()
|
|
64
|
+
func_names = {s.name for s in syms if s.kind == "function"}
|
|
65
|
+
const_names = {s.name for s in syms if s.kind == "constant"}
|
|
66
|
+
assert "alloc_shared" in func_names
|
|
67
|
+
assert "float16" in const_names
|
|
68
|
+
|
|
69
|
+
def test_filter_by_prefix(self):
|
|
70
|
+
prefix = "alloc"
|
|
71
|
+
syms = [s for s in get_t_completions() if s.name.lower().startswith(prefix)]
|
|
72
|
+
names = {s.name for s in syms}
|
|
73
|
+
assert "alloc_shared" in names
|
|
74
|
+
assert "alloc_fragment" in names
|
|
75
|
+
assert "alloc_local" in names
|
|
76
|
+
assert "gemm" not in names
|
|
77
|
+
|
|
78
|
+
def test_filter_by_prefix_reduce(self):
|
|
79
|
+
prefix = "reduce"
|
|
80
|
+
syms = [s for s in get_t_completions() if s.name.lower().startswith(prefix)]
|
|
81
|
+
names = {s.name for s in syms}
|
|
82
|
+
assert "reduce_sum" in names
|
|
83
|
+
assert "reduce_max" in names
|
|
84
|
+
assert "reduce_min" in names
|
|
85
|
+
assert "alloc_shared" not in names
|
|
86
|
+
|
|
87
|
+
def test_tilelang_completions_have_jit(self):
|
|
88
|
+
syms = get_tilelang_completions()
|
|
89
|
+
names = {s.name for s in syms}
|
|
90
|
+
assert "jit" in names
|
|
91
|
+
|
|
92
|
+
def test_decorator_completions(self):
|
|
93
|
+
"""@ trigger should offer T.prim_func and tilelang.jit."""
|
|
94
|
+
line = " @"
|
|
95
|
+
assert line.rstrip().endswith("@")
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Tests for TileLang file detection and regex patterns."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from tiled_server.detection import (
|
|
6
|
+
_RE_T_DOT,
|
|
7
|
+
_RE_TILELANG_DOT,
|
|
8
|
+
is_tilelang_file,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestFileDetection:
|
|
13
|
+
def test_import_as_t(self):
|
|
14
|
+
assert is_tilelang_file("import tilelang.language as T\n")
|
|
15
|
+
|
|
16
|
+
def test_import_tilelang(self):
|
|
17
|
+
assert is_tilelang_file("import tilelang\n")
|
|
18
|
+
|
|
19
|
+
def test_from_import(self):
|
|
20
|
+
assert is_tilelang_file("from tilelang.language import Kernel\n")
|
|
21
|
+
|
|
22
|
+
def test_mixed_imports(self):
|
|
23
|
+
assert is_tilelang_file("import numpy\nimport tilelang\nimport torch\n")
|
|
24
|
+
|
|
25
|
+
def test_not_tilelang_numpy(self):
|
|
26
|
+
assert not is_tilelang_file("import numpy as np\n")
|
|
27
|
+
|
|
28
|
+
def test_not_tilelang_plain(self):
|
|
29
|
+
assert not is_tilelang_file("print('hello')\n")
|
|
30
|
+
|
|
31
|
+
def test_not_tilelang_empty(self):
|
|
32
|
+
assert not is_tilelang_file("")
|
|
33
|
+
|
|
34
|
+
def test_not_tilelang_comment_only(self):
|
|
35
|
+
assert not is_tilelang_file("# import tilelang in comment\nprint(1)\n")
|
|
36
|
+
|
|
37
|
+
def test_import_tilelang_as_alias(self):
|
|
38
|
+
assert is_tilelang_file("import tilelang.language as TL\n")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TestRegexPatterns:
|
|
42
|
+
@pytest.mark.parametrize("line,expected_match", [
|
|
43
|
+
(" T.", True),
|
|
44
|
+
(" T.alloc", True),
|
|
45
|
+
("T.Kernel", True),
|
|
46
|
+
("T.", True),
|
|
47
|
+
("x = T.floa", True),
|
|
48
|
+
("print(T.", True),
|
|
49
|
+
("S.", False),
|
|
50
|
+
("no_dot_t", False),
|
|
51
|
+
("", False),
|
|
52
|
+
])
|
|
53
|
+
def test_t_dot_regex(self, line, expected_match):
|
|
54
|
+
m = _RE_T_DOT.search(line)
|
|
55
|
+
assert bool(m) == expected_match, f"line={line!r} expected={expected_match}"
|
|
56
|
+
|
|
57
|
+
@pytest.mark.parametrize("line,expected_match", [
|
|
58
|
+
("tilelang.", True),
|
|
59
|
+
("tilelang.jit", True),
|
|
60
|
+
(" tilelang.comp", True),
|
|
61
|
+
("x = tilelang.", True),
|
|
62
|
+
("tilelan.", False),
|
|
63
|
+
("", False),
|
|
64
|
+
])
|
|
65
|
+
def test_tilelang_dot_regex(self, line, expected_match):
|
|
66
|
+
m = _RE_TILELANG_DOT.search(line)
|
|
67
|
+
assert bool(m) == expected_match
|
|
68
|
+
|
|
69
|
+
def test_t_dot_captures_prefix(self):
|
|
70
|
+
m = _RE_T_DOT.search(" T.alloc_sh")
|
|
71
|
+
assert m.group(1) == "alloc_sh"
|
|
72
|
+
|
|
73
|
+
def test_t_dot_empty_prefix(self):
|
|
74
|
+
m = _RE_T_DOT.search(" T.")
|
|
75
|
+
assert m.group(1) == ""
|
|
76
|
+
|
|
77
|
+
def test_tilelang_dot_captures_prefix(self):
|
|
78
|
+
m = _RE_TILELANG_DOT.search(" tilelang.ji")
|
|
79
|
+
assert m.group(1) == "ji"
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""Tests for TileLang diagnostics."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from lsprotocol import types as lsp
|
|
5
|
+
|
|
6
|
+
from tiled_server.diagnostics import compute_diagnostics
|
|
7
|
+
from .conftest import GEMM_CODE, ELEMENTWISE_CODE, NON_TILELANG_CODE
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestDiagnostics:
|
|
11
|
+
def test_no_diagnostics_for_clean_gemm(self):
|
|
12
|
+
diags = compute_diagnostics("file:///test.py", GEMM_CODE)
|
|
13
|
+
gemm_warnings = [d for d in diags if "T.clear" in d.message]
|
|
14
|
+
assert len(gemm_warnings) == 0
|
|
15
|
+
|
|
16
|
+
def test_gemm_without_clear_warning(self):
|
|
17
|
+
code = """\
|
|
18
|
+
import tilelang
|
|
19
|
+
import tilelang.language as T
|
|
20
|
+
|
|
21
|
+
@T.prim_func
|
|
22
|
+
def gemm(A: T.Tensor((128, 128), T.float16)):
|
|
23
|
+
with T.Kernel(1, 1, threads=128) as (bx, by):
|
|
24
|
+
C_local = T.alloc_fragment((128, 128), T.float32)
|
|
25
|
+
T.gemm(A, A, C_local)
|
|
26
|
+
"""
|
|
27
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
28
|
+
gemm_warnings = [d for d in diags if "T.clear" in d.message]
|
|
29
|
+
assert len(gemm_warnings) == 1
|
|
30
|
+
assert gemm_warnings[0].severity == lsp.DiagnosticSeverity.Information
|
|
31
|
+
|
|
32
|
+
def test_alloc_buffer_hint(self):
|
|
33
|
+
code = """\
|
|
34
|
+
import tilelang
|
|
35
|
+
import tilelang.language as T
|
|
36
|
+
|
|
37
|
+
@T.prim_func
|
|
38
|
+
def kernel():
|
|
39
|
+
buf = T.alloc_buffer((128, 128), T.float16)
|
|
40
|
+
"""
|
|
41
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
42
|
+
hints = [d for d in diags if "T.alloc_buffer" in d.message]
|
|
43
|
+
assert len(hints) == 1
|
|
44
|
+
assert hints[0].severity == lsp.DiagnosticSeverity.Hint
|
|
45
|
+
|
|
46
|
+
def test_tensor_missing_dtype_warning(self):
|
|
47
|
+
code = """\
|
|
48
|
+
import tilelang
|
|
49
|
+
import tilelang.language as T
|
|
50
|
+
|
|
51
|
+
@T.prim_func
|
|
52
|
+
def kernel(A: T.Tensor((128, 128))):
|
|
53
|
+
pass
|
|
54
|
+
"""
|
|
55
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
56
|
+
dtype_warnings = [d for d in diags if "dtype" in d.message.lower()]
|
|
57
|
+
assert len(dtype_warnings) == 1
|
|
58
|
+
assert dtype_warnings[0].severity == lsp.DiagnosticSeverity.Warning
|
|
59
|
+
|
|
60
|
+
def test_no_diagnostics_for_non_tilelang(self):
|
|
61
|
+
diags = compute_diagnostics("file:///test.py", NON_TILELANG_CODE)
|
|
62
|
+
assert diags == []
|
|
63
|
+
|
|
64
|
+
def test_no_diagnostics_for_correct_elementwise(self):
|
|
65
|
+
diags = compute_diagnostics("file:///test.py", ELEMENTWISE_CODE)
|
|
66
|
+
assert all("T.clear" not in d.message for d in diags)
|
|
67
|
+
|
|
68
|
+
def test_diagnostic_range_is_correct(self):
|
|
69
|
+
code = """\
|
|
70
|
+
import tilelang.language as T
|
|
71
|
+
|
|
72
|
+
@T.prim_func
|
|
73
|
+
def kernel():
|
|
74
|
+
buf = T.alloc_buffer((4,), T.float32)
|
|
75
|
+
"""
|
|
76
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
77
|
+
hints = [d for d in diags if "alloc_buffer" in d.message]
|
|
78
|
+
assert len(hints) == 1
|
|
79
|
+
d = hints[0]
|
|
80
|
+
assert d.range.start.line == 4
|
|
81
|
+
assert d.range.start.character == code.split("\n")[4].index("T.alloc_buffer")
|
|
82
|
+
|
|
83
|
+
def test_diagnostic_source_is_tiled(self):
|
|
84
|
+
code = """\
|
|
85
|
+
import tilelang
|
|
86
|
+
@T.prim_func
|
|
87
|
+
def kernel():
|
|
88
|
+
buf = T.alloc_buffer((4,), T.float32)
|
|
89
|
+
"""
|
|
90
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
91
|
+
for d in diags:
|
|
92
|
+
assert d.source == "tiled"
|
|
93
|
+
|
|
94
|
+
def test_multiple_gemm_no_clear(self):
|
|
95
|
+
code = """\
|
|
96
|
+
import tilelang.language as T
|
|
97
|
+
|
|
98
|
+
@T.prim_func
|
|
99
|
+
def kernel():
|
|
100
|
+
with T.Kernel(1, 1, threads=128) as (bx, by):
|
|
101
|
+
A = T.alloc_shared((128, 128), T.float16)
|
|
102
|
+
C = T.alloc_fragment((128, 128), T.float32)
|
|
103
|
+
T.gemm(A, A, C)
|
|
104
|
+
T.gemm(A, A, C)
|
|
105
|
+
"""
|
|
106
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
107
|
+
gemm_warnings = [d for d in diags if "T.clear" in d.message]
|
|
108
|
+
assert len(gemm_warnings) >= 1
|
|
109
|
+
|
|
110
|
+
def test_gemm_with_clear_no_warning(self):
|
|
111
|
+
code = """\
|
|
112
|
+
import tilelang.language as T
|
|
113
|
+
|
|
114
|
+
@T.prim_func
|
|
115
|
+
def kernel():
|
|
116
|
+
with T.Kernel(1, 1, threads=128) as (bx, by):
|
|
117
|
+
A = T.alloc_shared((128, 128), T.float16)
|
|
118
|
+
C = T.alloc_fragment((128, 128), T.float32)
|
|
119
|
+
T.clear(C)
|
|
120
|
+
T.gemm(A, A, C)
|
|
121
|
+
"""
|
|
122
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
123
|
+
gemm_warnings = [d for d in diags if "T.clear" in d.message]
|
|
124
|
+
assert len(gemm_warnings) == 0
|
|
125
|
+
|
|
126
|
+
def test_empty_source(self):
|
|
127
|
+
diags = compute_diagnostics("file:///test.py", "")
|
|
128
|
+
assert diags == []
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class TestEdgeCases:
|
|
132
|
+
def test_diagnostics_with_only_comments(self):
|
|
133
|
+
code = """\
|
|
134
|
+
# import tilelang
|
|
135
|
+
# This is just a comment file
|
|
136
|
+
"""
|
|
137
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
138
|
+
assert diags == []
|
|
139
|
+
|
|
140
|
+
def test_large_file_no_crash(self):
|
|
141
|
+
lines = ["import tilelang.language as T\n"]
|
|
142
|
+
lines.extend(["x = 1\n"] * 10000)
|
|
143
|
+
code = "".join(lines)
|
|
144
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
145
|
+
assert isinstance(diags, list)
|
|
146
|
+
|
|
147
|
+
def test_unicode_in_source(self):
|
|
148
|
+
code = """\
|
|
149
|
+
import tilelang.language as T
|
|
150
|
+
# 矩阵乘法 — GEMM kernel
|
|
151
|
+
"""
|
|
152
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
153
|
+
assert isinstance(diags, list)
|
|
154
|
+
|
|
155
|
+
def test_tab_indentation(self):
|
|
156
|
+
code = "import tilelang.language as T\n\t\tT.alloc_buffer((4,), T.float32)\n"
|
|
157
|
+
diags = compute_diagnostics("file:///test.py", code)
|
|
158
|
+
assert isinstance(diags, list)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Tests for TileLang hover and signature help logic."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from tiled_server.knowledge import lookup_symbol
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestHoverLogic:
|
|
9
|
+
"""Test the hover lookup logic directly."""
|
|
10
|
+
|
|
11
|
+
def test_hover_t_symbol(self):
|
|
12
|
+
sym = lookup_symbol("alloc_shared")
|
|
13
|
+
assert sym is not None
|
|
14
|
+
assert "shared memory" in sym.documentation.lower()
|
|
15
|
+
|
|
16
|
+
def test_hover_gemm(self):
|
|
17
|
+
sym = lookup_symbol("gemm")
|
|
18
|
+
assert sym is not None
|
|
19
|
+
assert "matrix" in sym.documentation.lower() or "gemm" in sym.documentation.lower()
|
|
20
|
+
|
|
21
|
+
def test_hover_pipelined(self):
|
|
22
|
+
sym = lookup_symbol("Pipelined")
|
|
23
|
+
assert sym is not None
|
|
24
|
+
assert "pipeline" in sym.documentation.lower()
|
|
25
|
+
|
|
26
|
+
def test_hover_kernel(self):
|
|
27
|
+
sym = lookup_symbol("Kernel")
|
|
28
|
+
assert sym is not None
|
|
29
|
+
assert "kernel" in sym.documentation.lower() or "GPU" in sym.documentation
|
|
30
|
+
|
|
31
|
+
def test_hover_tilelang_jit(self):
|
|
32
|
+
sym = lookup_symbol("jit")
|
|
33
|
+
assert sym is not None
|
|
34
|
+
assert "JIT" in sym.documentation or "jit" in sym.documentation.lower()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TestSignatureLogic:
|
|
38
|
+
"""Test that symbols have proper detail strings usable for signature help."""
|
|
39
|
+
|
|
40
|
+
@pytest.mark.parametrize("name,expected_fragment", [
|
|
41
|
+
("alloc_shared", "shape"),
|
|
42
|
+
("Kernel", "threads"),
|
|
43
|
+
("Pipelined", "num_stages"),
|
|
44
|
+
("Tensor", "shape"),
|
|
45
|
+
("gemm", "A"),
|
|
46
|
+
("copy", "src"),
|
|
47
|
+
])
|
|
48
|
+
def test_detail_contains_params(self, name, expected_fragment):
|
|
49
|
+
sym = lookup_symbol(name)
|
|
50
|
+
assert sym is not None
|
|
51
|
+
assert expected_fragment in sym.detail
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Tests for the TileLang knowledge base."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from tiled_server.knowledge import (
|
|
6
|
+
TileLangSymbol,
|
|
7
|
+
get_t_completions,
|
|
8
|
+
get_tilelang_completions,
|
|
9
|
+
get_snippets,
|
|
10
|
+
lookup_symbol,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestKnowledgeBase:
|
|
15
|
+
def test_t_completions_count(self):
|
|
16
|
+
syms = get_t_completions()
|
|
17
|
+
assert len(syms) > 50
|
|
18
|
+
|
|
19
|
+
def test_tilelang_completions_count(self):
|
|
20
|
+
syms = get_tilelang_completions()
|
|
21
|
+
assert len(syms) >= 3
|
|
22
|
+
|
|
23
|
+
@pytest.mark.parametrize("name", [
|
|
24
|
+
"Kernel", "alloc_shared", "alloc_fragment", "alloc_local",
|
|
25
|
+
"gemm", "copy", "async_copy", "clear", "fill",
|
|
26
|
+
"Pipelined", "Parallel", "Persistent", "serial", "unroll", "vectorized",
|
|
27
|
+
"Tensor", "Buffer", "SharedBuffer", "FragmentBuffer",
|
|
28
|
+
"ceildiv", "print", "device_assert",
|
|
29
|
+
"reduce_sum", "reduce_max", "reduce_min",
|
|
30
|
+
"atomic_add", "atomic_max", "atomic_min",
|
|
31
|
+
"use_swizzle", "reshape", "view",
|
|
32
|
+
"Layout", "Fragment", "GemmWarpPolicy",
|
|
33
|
+
"dynamic", "symbolic",
|
|
34
|
+
])
|
|
35
|
+
def test_core_symbol_exists(self, name):
|
|
36
|
+
sym = lookup_symbol(name)
|
|
37
|
+
assert sym is not None, f"Missing symbol: {name}"
|
|
38
|
+
|
|
39
|
+
@pytest.mark.parametrize("name", [
|
|
40
|
+
"Kernel", "alloc_shared", "gemm", "copy", "Pipelined", "Parallel",
|
|
41
|
+
])
|
|
42
|
+
def test_core_symbol_has_documentation(self, name):
|
|
43
|
+
sym = lookup_symbol(name)
|
|
44
|
+
assert sym.documentation, f"Empty docs for: {name}"
|
|
45
|
+
assert len(sym.documentation) > 10
|
|
46
|
+
|
|
47
|
+
@pytest.mark.parametrize("name", [
|
|
48
|
+
"Kernel", "alloc_shared", "alloc_fragment", "gemm", "Pipelined", "Parallel",
|
|
49
|
+
"ceildiv", "Tensor",
|
|
50
|
+
])
|
|
51
|
+
def test_core_symbol_has_snippet(self, name):
|
|
52
|
+
sym = lookup_symbol(name)
|
|
53
|
+
assert sym.snippet is not None, f"No snippet for: {name}"
|
|
54
|
+
|
|
55
|
+
@pytest.mark.parametrize("dtype", [
|
|
56
|
+
"float16", "float32", "float64", "bfloat16",
|
|
57
|
+
"int8", "int16", "int32", "int64",
|
|
58
|
+
"uint8", "uint16", "uint32", "uint64",
|
|
59
|
+
"half", "float", "double",
|
|
60
|
+
"float8_e4m3fn", "float8_e5m2",
|
|
61
|
+
])
|
|
62
|
+
def test_dtype_symbol_exists(self, dtype):
|
|
63
|
+
sym = lookup_symbol(dtype)
|
|
64
|
+
assert sym is not None, f"Missing dtype: {dtype}"
|
|
65
|
+
assert sym.kind == "constant"
|
|
66
|
+
|
|
67
|
+
@pytest.mark.parametrize("name", [
|
|
68
|
+
"jit", "compile", "autotune", "Profiler",
|
|
69
|
+
])
|
|
70
|
+
def test_tilelang_top_level_symbol(self, name):
|
|
71
|
+
sym = lookup_symbol(name)
|
|
72
|
+
assert sym is not None, f"Missing tilelang.{name}"
|
|
73
|
+
|
|
74
|
+
def test_no_duplicate_symbols(self):
|
|
75
|
+
syms = get_t_completions()
|
|
76
|
+
names = [s.name for s in syms]
|
|
77
|
+
assert len(names) == len(set(names)), f"Duplicate symbols: {[n for n in names if names.count(n) > 1]}"
|
|
78
|
+
|
|
79
|
+
def test_symbol_kinds_valid(self):
|
|
80
|
+
for sym in get_t_completions():
|
|
81
|
+
assert sym.kind in ("function", "type", "constant", "decorator", "class"), \
|
|
82
|
+
f"Invalid kind '{sym.kind}' for {sym.name}"
|
|
83
|
+
|
|
84
|
+
def test_lookup_nonexistent_returns_none(self):
|
|
85
|
+
assert lookup_symbol("nonexistent_xyz") is None
|
|
86
|
+
|
|
87
|
+
def test_snippets_dict(self):
|
|
88
|
+
snips = get_snippets()
|
|
89
|
+
assert isinstance(snips, dict)
|
|
90
|
+
assert len(snips) >= 2
|
|
91
|
+
for key, val in snips.items():
|
|
92
|
+
assert "prefix" in val
|
|
93
|
+
assert "body" in val
|
|
94
|
+
assert "description" in val
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Tests for server creation."""
|
|
2
|
+
|
|
3
|
+
from tiled_server.server import create_server
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestServerCreation:
|
|
7
|
+
def test_server_name(self):
|
|
8
|
+
server = create_server()
|
|
9
|
+
assert server.name == "tiled"
|
|
10
|
+
|
|
11
|
+
def test_server_version(self):
|
|
12
|
+
server = create_server()
|
|
13
|
+
assert server.version == "v0.1.0"
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
pyproject.toml
|
|
2
|
+
tests/test_completion.py
|
|
3
|
+
tests/test_detection.py
|
|
4
|
+
tests/test_diagnostics.py
|
|
5
|
+
tests/test_hover.py
|
|
6
|
+
tests/test_knowledge.py
|
|
7
|
+
tests/test_server_creation.py
|
|
8
|
+
tile_lsp.egg-info/PKG-INFO
|
|
9
|
+
tile_lsp.egg-info/SOURCES.txt
|
|
10
|
+
tile_lsp.egg-info/dependency_links.txt
|
|
11
|
+
tile_lsp.egg-info/entry_points.txt
|
|
12
|
+
tile_lsp.egg-info/requires.txt
|
|
13
|
+
tile_lsp.egg-info/top_level.txt
|
|
14
|
+
tiled_server/__init__.py
|
|
15
|
+
tiled_server/__main__.py
|
|
16
|
+
tiled_server/completion.py
|
|
17
|
+
tiled_server/detection.py
|
|
18
|
+
tiled_server/diagnostics.py
|
|
19
|
+
tiled_server/hover.py
|
|
20
|
+
tiled_server/knowledge.py
|
|
21
|
+
tiled_server/server.py
|
|
22
|
+
tiled_server/signature.py
|
|
23
|
+
tiled_server/utils.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
tiled_server
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""tiled — Language server for the TileLang DSL."""
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Entry point for the tiled language server."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from .server import create_server
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def main():
|
|
10
|
+
parser = argparse.ArgumentParser(description="tiled — TileLang Language Server")
|
|
11
|
+
parser.add_argument("--tcp", action="store_true", help="Use TCP transport")
|
|
12
|
+
parser.add_argument("--host", default="127.0.0.1", help="TCP host (default: 127.0.0.1)")
|
|
13
|
+
parser.add_argument("--port", type=int, default=2087, help="TCP port (default: 2087)")
|
|
14
|
+
parser.add_argument("--log-level", default="info", choices=["debug", "info", "warning", "error"])
|
|
15
|
+
args = parser.parse_args()
|
|
16
|
+
|
|
17
|
+
logging.basicConfig(level=getattr(logging, args.log_level.upper()))
|
|
18
|
+
|
|
19
|
+
server = create_server()
|
|
20
|
+
|
|
21
|
+
if args.tcp:
|
|
22
|
+
server.start_tcp(args.host, args.port)
|
|
23
|
+
else:
|
|
24
|
+
server.start_io()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if __name__ == "__main__":
|
|
28
|
+
main()
|