sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
|
|
1
|
+
import atexit
|
2
|
+
import concurrent.futures
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import signal
|
7
|
+
import threading
|
8
|
+
from collections import OrderedDict
|
9
|
+
from functools import wraps
|
10
|
+
from typing import List, Optional
|
11
|
+
|
12
|
+
import torch
|
13
|
+
|
14
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
15
|
+
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class AtomicCounter:
|
21
|
+
def __init__(self, n: int):
|
22
|
+
assert n > 0
|
23
|
+
self.n = n
|
24
|
+
self._value = 0
|
25
|
+
self._lock = threading.Lock()
|
26
|
+
|
27
|
+
def next(self) -> int:
|
28
|
+
with self._lock:
|
29
|
+
current = self._value
|
30
|
+
self._value = (current + 1) % self.n
|
31
|
+
return current
|
32
|
+
|
33
|
+
|
34
|
+
def synchronized():
|
35
|
+
def _decorator(func):
|
36
|
+
@wraps(func)
|
37
|
+
def wrapper(self, *args, **kwargs):
|
38
|
+
with self.lock:
|
39
|
+
return func(self, *args, **kwargs)
|
40
|
+
|
41
|
+
return wrapper
|
42
|
+
|
43
|
+
return _decorator
|
44
|
+
|
45
|
+
|
46
|
+
class HiCacheHF3FS(HiCacheStorage):
|
47
|
+
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
file_path: str,
|
52
|
+
file_size: int,
|
53
|
+
numjobs: int,
|
54
|
+
bytes_per_page: int,
|
55
|
+
entries: int,
|
56
|
+
dtype: torch.dtype,
|
57
|
+
):
|
58
|
+
self.file_path = file_path
|
59
|
+
self.file_size = file_size
|
60
|
+
self.numjobs = numjobs
|
61
|
+
self.bytes_per_page = bytes_per_page
|
62
|
+
self.entries = entries
|
63
|
+
self.dtype = dtype
|
64
|
+
|
65
|
+
self.numel = self.bytes_per_page // self.dtype.itemsize
|
66
|
+
|
67
|
+
self.num_pages = self.file_size // self.bytes_per_page
|
68
|
+
|
69
|
+
logger.info(
|
70
|
+
"HiCacheHF3FS "
|
71
|
+
f"file_path = {self.file_path}, "
|
72
|
+
f"file_size = {self.file_size/(2**30):.2f} GB, "
|
73
|
+
f"numjobs = {self.numjobs}, "
|
74
|
+
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
|
75
|
+
f"entries = {self.entries}, "
|
76
|
+
f"num_pages = {self.num_pages}"
|
77
|
+
)
|
78
|
+
|
79
|
+
self.ac = AtomicCounter(self.numjobs)
|
80
|
+
self.clients = [
|
81
|
+
Hf3fsClient(
|
82
|
+
self.file_path, self.file_size, self.bytes_per_page, self.entries
|
83
|
+
)
|
84
|
+
for _ in range(numjobs)
|
85
|
+
]
|
86
|
+
self.executor = concurrent.futures.ThreadPoolExecutor(
|
87
|
+
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
|
88
|
+
)
|
89
|
+
|
90
|
+
# Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
|
91
|
+
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
|
92
|
+
# through centralized metadata orchestration.
|
93
|
+
self.lock = threading.RLock()
|
94
|
+
self.free_pages = list(range(self.num_pages))
|
95
|
+
self.key_to_index = OrderedDict()
|
96
|
+
|
97
|
+
atexit.register(self.close)
|
98
|
+
|
99
|
+
signal.signal(signal.SIGINT, lambda sig, frame: self.close())
|
100
|
+
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
101
|
+
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def from_env_config(
|
105
|
+
rank: int, bytes_per_page: int, dtype: torch.dtype
|
106
|
+
) -> "HiCacheHF3FS":
|
107
|
+
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
108
|
+
if not config_path:
|
109
|
+
return HiCacheHF3FS(
|
110
|
+
file_path=f"/data/hicache.{rank}.bin",
|
111
|
+
file_size=1 << 40,
|
112
|
+
numjobs=16,
|
113
|
+
bytes_per_page=bytes_per_page,
|
114
|
+
entries=8,
|
115
|
+
dtype=dtype,
|
116
|
+
)
|
117
|
+
|
118
|
+
try:
|
119
|
+
with open(config_path, "r") as f:
|
120
|
+
config = json.load(f)
|
121
|
+
except Exception as e:
|
122
|
+
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
123
|
+
|
124
|
+
required_keys = {
|
125
|
+
"file_path_prefix",
|
126
|
+
"file_size",
|
127
|
+
"numjobs",
|
128
|
+
"entries",
|
129
|
+
}
|
130
|
+
missing_keys = required_keys - set(config.keys())
|
131
|
+
if missing_keys:
|
132
|
+
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
133
|
+
|
134
|
+
return HiCacheHF3FS(
|
135
|
+
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
136
|
+
file_size=int(config["file_size"]),
|
137
|
+
numjobs=int(config["numjobs"]),
|
138
|
+
bytes_per_page=bytes_per_page,
|
139
|
+
entries=int(config["entries"]),
|
140
|
+
dtype=dtype,
|
141
|
+
)
|
142
|
+
|
143
|
+
def get(
|
144
|
+
self, key: str, target_location: Optional[torch.Tensor] = None
|
145
|
+
) -> torch.Tensor | None:
|
146
|
+
return self.batch_get([key], target_location)[0]
|
147
|
+
|
148
|
+
@synchronized()
|
149
|
+
def batch_get(
|
150
|
+
self,
|
151
|
+
keys: List[str],
|
152
|
+
target_locations: Optional[List[torch.Tensor]] = None,
|
153
|
+
) -> List[torch.Tensor | None]:
|
154
|
+
batch_indices, file_offsets = [], []
|
155
|
+
for i, key in enumerate(keys):
|
156
|
+
if key not in self.key_to_index:
|
157
|
+
continue
|
158
|
+
batch_indices.append(i)
|
159
|
+
file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
|
160
|
+
self.key_to_index.move_to_end(key)
|
161
|
+
# TODO: target_locations
|
162
|
+
file_results = [
|
163
|
+
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
164
|
+
]
|
165
|
+
|
166
|
+
futures = [
|
167
|
+
self.executor.submit(
|
168
|
+
self.clients[self.ac.next()].batch_read,
|
169
|
+
file_offsets[i : i + self.entries],
|
170
|
+
file_results[i : i + self.entries],
|
171
|
+
)
|
172
|
+
for i in range(0, len(batch_indices), self.entries)
|
173
|
+
]
|
174
|
+
read_results = [result for future in futures for result in future.result()]
|
175
|
+
|
176
|
+
results = [None] * len(keys)
|
177
|
+
for batch_index, file_result, read_result in zip(
|
178
|
+
batch_indices, file_results, read_results
|
179
|
+
):
|
180
|
+
if read_result == self.bytes_per_page:
|
181
|
+
results[batch_index] = file_result
|
182
|
+
else:
|
183
|
+
logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
|
184
|
+
|
185
|
+
return results
|
186
|
+
|
187
|
+
def set(self, key: str, value: torch.Tensor) -> bool:
|
188
|
+
return self.batch_set([key], [value])
|
189
|
+
|
190
|
+
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
191
|
+
indices = self.get_batch_set_indices(keys)
|
192
|
+
batch_indices, file_offsets, file_values = [], [], []
|
193
|
+
for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
|
194
|
+
if is_written or index == -1:
|
195
|
+
continue
|
196
|
+
batch_indices.append(i)
|
197
|
+
file_offsets.append(index * self.bytes_per_page)
|
198
|
+
file_values.append(value.contiguous())
|
199
|
+
|
200
|
+
futures = [
|
201
|
+
self.executor.submit(
|
202
|
+
self.clients[self.ac.next()].batch_write,
|
203
|
+
file_offsets[i : i + self.entries],
|
204
|
+
file_values[i : i + self.entries],
|
205
|
+
)
|
206
|
+
for i in range(0, len(batch_indices), self.entries)
|
207
|
+
]
|
208
|
+
write_results = [
|
209
|
+
result == self.bytes_per_page
|
210
|
+
for future in futures
|
211
|
+
for result in future.result()
|
212
|
+
]
|
213
|
+
|
214
|
+
results = [index[0] for index in indices]
|
215
|
+
for batch_index, write_result in zip(batch_indices, write_results):
|
216
|
+
key = keys[batch_index]
|
217
|
+
index = indices[batch_index][1]
|
218
|
+
if write_result:
|
219
|
+
self.key_to_index[key] = index
|
220
|
+
self.key_to_index.move_to_end(key)
|
221
|
+
else:
|
222
|
+
logger.error(f"HiCacheHF3FS set {key} failed")
|
223
|
+
self.free_pages.append(index)
|
224
|
+
results[batch_index] = write_result
|
225
|
+
return all(results)
|
226
|
+
|
227
|
+
@synchronized()
|
228
|
+
def get_batch_set_indices(self, keys: List[str]) -> list:
|
229
|
+
ionum = len(keys)
|
230
|
+
# results: tuples of (is_written: bool, page_idx: int)
|
231
|
+
# - is_written: True = hit (no I/O), False = write (miss)
|
232
|
+
# - page_idx: page storing data
|
233
|
+
results = [None] * min(ionum, self.num_pages)
|
234
|
+
if ionum > self.num_pages:
|
235
|
+
results.extend([(False, -1)] * (ionum - self.num_pages))
|
236
|
+
|
237
|
+
new_keys = []
|
238
|
+
for batch_index, key in enumerate(keys[: self.num_pages]):
|
239
|
+
if key in self.key_to_index:
|
240
|
+
results[batch_index] = (True, self.key_to_index[key])
|
241
|
+
self.key_to_index.move_to_end(key)
|
242
|
+
else:
|
243
|
+
new_keys.append((batch_index, key))
|
244
|
+
|
245
|
+
for batch_index, _ in new_keys:
|
246
|
+
index = (
|
247
|
+
self.free_pages.pop()
|
248
|
+
if len(self.free_pages) > 0
|
249
|
+
else self.key_to_index.popitem(last=False)[1]
|
250
|
+
)
|
251
|
+
results[batch_index] = (False, index)
|
252
|
+
|
253
|
+
return results
|
254
|
+
|
255
|
+
@synchronized()
|
256
|
+
def delete(self, key: str) -> None:
|
257
|
+
if key not in self.key_to_index:
|
258
|
+
return
|
259
|
+
index = self.key_to_index.pop(key)
|
260
|
+
self.free_pages.append(index)
|
261
|
+
|
262
|
+
@synchronized()
|
263
|
+
def exists(self, key: str) -> bool:
|
264
|
+
return key in self.key_to_index
|
265
|
+
|
266
|
+
@synchronized()
|
267
|
+
def clear(self) -> None:
|
268
|
+
self.free_pages = list(range(self.num_pages))
|
269
|
+
self.key_to_index.clear()
|
270
|
+
|
271
|
+
def close(self) -> None:
|
272
|
+
try:
|
273
|
+
for c in self.clients:
|
274
|
+
c.close()
|
275
|
+
self.executor.shutdown(wait=True)
|
276
|
+
except Exception as e:
|
277
|
+
logger.error(f"close HiCacheHF3FS: {e}")
|
278
|
+
logger.info("close HiCacheHF3FS")
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import multiprocessing.shared_memory
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import pytest
|
5
|
+
import torch
|
6
|
+
from torch.utils.cpp_extension import load
|
7
|
+
from tqdm import tqdm
|
8
|
+
|
9
|
+
root = Path(__file__).parent.resolve()
|
10
|
+
hf3fs_utils = load(
|
11
|
+
name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"], verbose=True
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
def test_rw_shm():
|
16
|
+
numel = 8 << 20
|
17
|
+
dtype = torch.bfloat16
|
18
|
+
page_num = 128
|
19
|
+
page_bytes = numel * dtype.itemsize
|
20
|
+
shm = multiprocessing.shared_memory.SharedMemory(
|
21
|
+
size=page_num * page_bytes, create=True
|
22
|
+
)
|
23
|
+
tshm = torch.frombuffer(shm.buf, dtype=torch.uint8)
|
24
|
+
a = [
|
25
|
+
torch.randn(numel, dtype=dtype)
|
26
|
+
for _ in tqdm(range(page_num), desc="prepare input")
|
27
|
+
]
|
28
|
+
b = [
|
29
|
+
torch.empty(numel, dtype=dtype)
|
30
|
+
for _ in tqdm(range(page_num), desc="prepare output")
|
31
|
+
]
|
32
|
+
hf3fs_utils.write_shm(a, tshm)
|
33
|
+
hf3fs_utils.read_shm(tshm, b)
|
34
|
+
for _a, _b in tqdm(zip(a, b), desc="assert_close"):
|
35
|
+
torch.testing.assert_close(_a, _b)
|
36
|
+
|
37
|
+
del tshm
|
38
|
+
shm.close()
|
39
|
+
shm.unlink()
|
40
|
+
|
41
|
+
|
42
|
+
if __name__ == "__main__":
|
43
|
+
pytest.main([__file__])
|
@@ -16,6 +16,7 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import bisect
|
19
|
+
import gc
|
19
20
|
import inspect
|
20
21
|
import logging
|
21
22
|
import os
|
@@ -75,6 +76,24 @@ def model_capture_mode():
|
|
75
76
|
is_capture_mode = False
|
76
77
|
|
77
78
|
|
79
|
+
@contextmanager
|
80
|
+
def freeze_gc(enable_cudagraph_gc: bool):
|
81
|
+
"""
|
82
|
+
Optimize garbage collection during CUDA graph capture.
|
83
|
+
Clean up, then freeze all remaining objects from being included
|
84
|
+
in future collections if GC is disabled during capture.
|
85
|
+
"""
|
86
|
+
gc.collect()
|
87
|
+
should_freeze = not enable_cudagraph_gc
|
88
|
+
if should_freeze:
|
89
|
+
gc.freeze()
|
90
|
+
try:
|
91
|
+
yield
|
92
|
+
finally:
|
93
|
+
if should_freeze:
|
94
|
+
gc.unfreeze()
|
95
|
+
|
96
|
+
|
78
97
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
79
98
|
for sub in model._modules.values():
|
80
99
|
if isinstance(sub, CustomOp):
|
@@ -423,7 +442,12 @@ class CudaGraphRunner:
|
|
423
442
|
record_shapes=True,
|
424
443
|
)
|
425
444
|
|
426
|
-
|
445
|
+
# Trigger CUDA graph capture for specific shapes.
|
446
|
+
# Capture the large shapes first so that the smaller shapes
|
447
|
+
# can reuse the memory pool allocated for the large shapes.
|
448
|
+
with freeze_gc(
|
449
|
+
self.model_runner.server_args.enable_cudagraph_gc
|
450
|
+
), graph_capture() as graph_capture_context:
|
427
451
|
with profile_context as prof:
|
428
452
|
self.stream = graph_capture_context.stream
|
429
453
|
avail_mem = get_available_gpu_memory(
|
@@ -157,6 +157,8 @@ class ModelRunner:
|
|
157
157
|
gpu_id: int,
|
158
158
|
tp_rank: int,
|
159
159
|
tp_size: int,
|
160
|
+
moe_ep_rank: int,
|
161
|
+
moe_ep_size: int,
|
160
162
|
pp_rank: int,
|
161
163
|
pp_size: int,
|
162
164
|
nccl_port: int,
|
@@ -175,6 +177,8 @@ class ModelRunner:
|
|
175
177
|
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
176
178
|
self.tp_rank = tp_rank
|
177
179
|
self.tp_size = tp_size
|
180
|
+
self.moe_ep_rank = moe_ep_rank
|
181
|
+
self.moe_ep_size = moe_ep_size
|
178
182
|
self.dp_size = server_args.dp_size
|
179
183
|
self.pp_rank = pp_rank
|
180
184
|
self.pp_size = pp_size
|
@@ -432,6 +436,7 @@ class ModelRunner:
|
|
432
436
|
"triton",
|
433
437
|
"flashmla",
|
434
438
|
"cutlass_mla",
|
439
|
+
"trtllm_mla",
|
435
440
|
"ascend",
|
436
441
|
]:
|
437
442
|
logger.info(
|
@@ -549,6 +554,7 @@ class ModelRunner:
|
|
549
554
|
initialize_model_parallel(
|
550
555
|
tensor_model_parallel_size=self.tp_size,
|
551
556
|
pipeline_model_parallel_size=self.pp_size,
|
557
|
+
expert_model_parallel_size=self.moe_ep_size,
|
552
558
|
duplicate_tp_group=self.server_args.enable_pdmux,
|
553
559
|
)
|
554
560
|
initialize_dp_attention(
|
@@ -666,7 +672,7 @@ class ModelRunner:
|
|
666
672
|
self.sliding_window_size = self.model.get_attention_sliding_window_size()
|
667
673
|
elif self.model_config.attention_chunk_size is not None:
|
668
674
|
self.sliding_window_size = self.model_config.attention_chunk_size
|
669
|
-
|
675
|
+
logger.info(
|
670
676
|
f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
|
671
677
|
)
|
672
678
|
|
@@ -1432,6 +1438,12 @@ class ModelRunner:
|
|
1432
1438
|
)
|
1433
1439
|
|
1434
1440
|
return CutlassMLABackend(self)
|
1441
|
+
elif self.server_args.attention_backend == "trtllm_mla":
|
1442
|
+
if not self.use_mla_backend:
|
1443
|
+
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
1444
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
1445
|
+
|
1446
|
+
return TRTLLMMLABackend(self)
|
1435
1447
|
elif self.server_args.attention_backend == "intel_amx":
|
1436
1448
|
from sglang.srt.layers.attention.intel_amx_backend import (
|
1437
1449
|
IntelAMXAttnBackend,
|
@@ -229,6 +229,8 @@ def get_quant_config(
|
|
229
229
|
f"Unsupported quantization config"
|
230
230
|
f" found for {model_config.quantization} in {f}."
|
231
231
|
)
|
232
|
+
elif model_config.quantization == "w8a8_int8":
|
233
|
+
config["packed_modules_mapping"] = packed_modules_mapping
|
232
234
|
|
233
235
|
return quant_cls.from_config(config)
|
234
236
|
|