sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +23 -3
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +98 -603
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +6 -0
- sglang/srt/managers/io_struct.py +12 -2
- sglang/srt/managers/scheduler.py +116 -669
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +166 -83
- sglang/srt/managers/tp_worker.py +9 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +20 -13
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +15 -56
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +18 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
|
|
1
|
+
import atexit
|
2
|
+
import concurrent.futures
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import signal
|
7
|
+
import threading
|
8
|
+
from collections import OrderedDict
|
9
|
+
from functools import wraps
|
10
|
+
from typing import List, Optional
|
11
|
+
|
12
|
+
import torch
|
13
|
+
|
14
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
15
|
+
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class AtomicCounter:
|
21
|
+
def __init__(self, n: int):
|
22
|
+
assert n > 0
|
23
|
+
self.n = n
|
24
|
+
self._value = 0
|
25
|
+
self._lock = threading.Lock()
|
26
|
+
|
27
|
+
def next(self) -> int:
|
28
|
+
with self._lock:
|
29
|
+
current = self._value
|
30
|
+
self._value = (current + 1) % self.n
|
31
|
+
return current
|
32
|
+
|
33
|
+
|
34
|
+
def synchronized():
|
35
|
+
def _decorator(func):
|
36
|
+
@wraps(func)
|
37
|
+
def wrapper(self, *args, **kwargs):
|
38
|
+
with self.lock:
|
39
|
+
return func(self, *args, **kwargs)
|
40
|
+
|
41
|
+
return wrapper
|
42
|
+
|
43
|
+
return _decorator
|
44
|
+
|
45
|
+
|
46
|
+
class HiCacheHF3FS(HiCacheStorage):
|
47
|
+
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
file_path: str,
|
52
|
+
file_size: int,
|
53
|
+
numjobs: int,
|
54
|
+
bytes_per_page: int,
|
55
|
+
entries: int,
|
56
|
+
dtype: torch.dtype,
|
57
|
+
):
|
58
|
+
self.file_path = file_path
|
59
|
+
self.file_size = file_size
|
60
|
+
self.numjobs = numjobs
|
61
|
+
self.bytes_per_page = bytes_per_page
|
62
|
+
self.entries = entries
|
63
|
+
self.dtype = dtype
|
64
|
+
|
65
|
+
self.numel = self.bytes_per_page // self.dtype.itemsize
|
66
|
+
|
67
|
+
self.num_pages = self.file_size // self.bytes_per_page
|
68
|
+
|
69
|
+
logger.info(
|
70
|
+
"HiCacheHF3FS "
|
71
|
+
f"file_path = {self.file_path}, "
|
72
|
+
f"file_size = {self.file_size/(2**30):.2f} GB, "
|
73
|
+
f"numjobs = {self.numjobs}, "
|
74
|
+
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
|
75
|
+
f"entries = {self.entries}, "
|
76
|
+
f"num_pages = {self.num_pages}"
|
77
|
+
)
|
78
|
+
|
79
|
+
self.ac = AtomicCounter(self.numjobs)
|
80
|
+
self.clients = [
|
81
|
+
Hf3fsClient(
|
82
|
+
self.file_path, self.file_size, self.bytes_per_page, self.entries
|
83
|
+
)
|
84
|
+
for _ in range(numjobs)
|
85
|
+
]
|
86
|
+
self.executor = concurrent.futures.ThreadPoolExecutor(
|
87
|
+
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
|
88
|
+
)
|
89
|
+
|
90
|
+
# Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
|
91
|
+
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
|
92
|
+
# through centralized metadata orchestration.
|
93
|
+
self.lock = threading.RLock()
|
94
|
+
self.free_pages = list(range(self.num_pages))
|
95
|
+
self.key_to_index = OrderedDict()
|
96
|
+
|
97
|
+
atexit.register(self.close)
|
98
|
+
|
99
|
+
signal.signal(signal.SIGINT, lambda sig, frame: self.close())
|
100
|
+
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
101
|
+
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def from_env_config(
|
105
|
+
rank: int, bytes_per_page: int, dtype: torch.dtype
|
106
|
+
) -> "HiCacheHF3FS":
|
107
|
+
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
108
|
+
if not config_path:
|
109
|
+
return HiCacheHF3FS(
|
110
|
+
file_path=f"/data/hicache.{rank}.bin",
|
111
|
+
file_size=1 << 40,
|
112
|
+
numjobs=16,
|
113
|
+
bytes_per_page=bytes_per_page,
|
114
|
+
entries=8,
|
115
|
+
dtype=dtype,
|
116
|
+
)
|
117
|
+
|
118
|
+
try:
|
119
|
+
with open(config_path, "r") as f:
|
120
|
+
config = json.load(f)
|
121
|
+
except Exception as e:
|
122
|
+
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
123
|
+
|
124
|
+
required_keys = {
|
125
|
+
"file_path_prefix",
|
126
|
+
"file_size",
|
127
|
+
"numjobs",
|
128
|
+
"entries",
|
129
|
+
}
|
130
|
+
missing_keys = required_keys - set(config.keys())
|
131
|
+
if missing_keys:
|
132
|
+
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
133
|
+
|
134
|
+
return HiCacheHF3FS(
|
135
|
+
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
136
|
+
file_size=int(config["file_size"]),
|
137
|
+
numjobs=int(config["numjobs"]),
|
138
|
+
bytes_per_page=bytes_per_page,
|
139
|
+
entries=int(config["entries"]),
|
140
|
+
dtype=dtype,
|
141
|
+
)
|
142
|
+
|
143
|
+
def get(
|
144
|
+
self, key: str, target_location: Optional[torch.Tensor] = None
|
145
|
+
) -> torch.Tensor | None:
|
146
|
+
return self.batch_get([key], target_location)[0]
|
147
|
+
|
148
|
+
@synchronized()
|
149
|
+
def batch_get(
|
150
|
+
self,
|
151
|
+
keys: List[str],
|
152
|
+
target_locations: Optional[List[torch.Tensor]] = None,
|
153
|
+
) -> List[torch.Tensor | None]:
|
154
|
+
batch_indices, file_offsets = [], []
|
155
|
+
for i, key in enumerate(keys):
|
156
|
+
if key not in self.key_to_index:
|
157
|
+
continue
|
158
|
+
batch_indices.append(i)
|
159
|
+
file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
|
160
|
+
self.key_to_index.move_to_end(key)
|
161
|
+
# TODO: target_locations
|
162
|
+
file_results = [
|
163
|
+
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
164
|
+
]
|
165
|
+
|
166
|
+
futures = [
|
167
|
+
self.executor.submit(
|
168
|
+
self.clients[self.ac.next()].batch_read,
|
169
|
+
file_offsets[i : i + self.entries],
|
170
|
+
file_results[i : i + self.entries],
|
171
|
+
)
|
172
|
+
for i in range(0, len(batch_indices), self.entries)
|
173
|
+
]
|
174
|
+
read_results = [result for future in futures for result in future.result()]
|
175
|
+
|
176
|
+
results = [None] * len(keys)
|
177
|
+
for batch_index, file_result, read_result in zip(
|
178
|
+
batch_indices, file_results, read_results
|
179
|
+
):
|
180
|
+
if read_result == self.bytes_per_page:
|
181
|
+
results[batch_index] = file_result
|
182
|
+
else:
|
183
|
+
logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
|
184
|
+
|
185
|
+
return results
|
186
|
+
|
187
|
+
def set(self, key: str, value: torch.Tensor) -> bool:
|
188
|
+
return self.batch_set([key], [value])
|
189
|
+
|
190
|
+
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
191
|
+
indices = self.get_batch_set_indices(keys)
|
192
|
+
batch_indices, file_offsets, file_values = [], [], []
|
193
|
+
for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
|
194
|
+
if is_written or index == -1:
|
195
|
+
continue
|
196
|
+
batch_indices.append(i)
|
197
|
+
file_offsets.append(index * self.bytes_per_page)
|
198
|
+
file_values.append(value.contiguous())
|
199
|
+
|
200
|
+
futures = [
|
201
|
+
self.executor.submit(
|
202
|
+
self.clients[self.ac.next()].batch_write,
|
203
|
+
file_offsets[i : i + self.entries],
|
204
|
+
file_values[i : i + self.entries],
|
205
|
+
)
|
206
|
+
for i in range(0, len(batch_indices), self.entries)
|
207
|
+
]
|
208
|
+
write_results = [
|
209
|
+
result == self.bytes_per_page
|
210
|
+
for future in futures
|
211
|
+
for result in future.result()
|
212
|
+
]
|
213
|
+
|
214
|
+
results = [index[0] for index in indices]
|
215
|
+
for batch_index, write_result in zip(batch_indices, write_results):
|
216
|
+
key = keys[batch_index]
|
217
|
+
index = indices[batch_index][1]
|
218
|
+
if write_result:
|
219
|
+
self.key_to_index[key] = index
|
220
|
+
self.key_to_index.move_to_end(key)
|
221
|
+
else:
|
222
|
+
logger.error(f"HiCacheHF3FS set {key} failed")
|
223
|
+
self.free_pages.append(index)
|
224
|
+
results[batch_index] = write_result
|
225
|
+
return all(results)
|
226
|
+
|
227
|
+
@synchronized()
|
228
|
+
def get_batch_set_indices(self, keys: List[str]) -> list:
|
229
|
+
ionum = len(keys)
|
230
|
+
# results: tuples of (is_written: bool, page_idx: int)
|
231
|
+
# - is_written: True = hit (no I/O), False = write (miss)
|
232
|
+
# - page_idx: page storing data
|
233
|
+
results = [None] * min(ionum, self.num_pages)
|
234
|
+
if ionum > self.num_pages:
|
235
|
+
results.extend([(False, -1)] * (ionum - self.num_pages))
|
236
|
+
|
237
|
+
new_keys = []
|
238
|
+
for batch_index, key in enumerate(keys[: self.num_pages]):
|
239
|
+
if key in self.key_to_index:
|
240
|
+
results[batch_index] = (True, self.key_to_index[key])
|
241
|
+
self.key_to_index.move_to_end(key)
|
242
|
+
else:
|
243
|
+
new_keys.append((batch_index, key))
|
244
|
+
|
245
|
+
for batch_index, _ in new_keys:
|
246
|
+
index = (
|
247
|
+
self.free_pages.pop()
|
248
|
+
if len(self.free_pages) > 0
|
249
|
+
else self.key_to_index.popitem(last=False)[1]
|
250
|
+
)
|
251
|
+
results[batch_index] = (False, index)
|
252
|
+
|
253
|
+
return results
|
254
|
+
|
255
|
+
@synchronized()
|
256
|
+
def delete(self, key: str) -> None:
|
257
|
+
if key not in self.key_to_index:
|
258
|
+
return
|
259
|
+
index = self.key_to_index.pop(key)
|
260
|
+
self.free_pages.append(index)
|
261
|
+
|
262
|
+
@synchronized()
|
263
|
+
def exists(self, key: str) -> bool:
|
264
|
+
return key in self.key_to_index
|
265
|
+
|
266
|
+
@synchronized()
|
267
|
+
def clear(self) -> None:
|
268
|
+
self.free_pages = list(range(self.num_pages))
|
269
|
+
self.key_to_index.clear()
|
270
|
+
|
271
|
+
def close(self) -> None:
|
272
|
+
try:
|
273
|
+
for c in self.clients:
|
274
|
+
c.close()
|
275
|
+
self.executor.shutdown(wait=True)
|
276
|
+
except Exception as e:
|
277
|
+
logger.error(f"close HiCacheHF3FS: {e}")
|
278
|
+
logger.info("close HiCacheHF3FS")
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import multiprocessing.shared_memory
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import pytest
|
5
|
+
import torch
|
6
|
+
from torch.utils.cpp_extension import load
|
7
|
+
from tqdm import tqdm
|
8
|
+
|
9
|
+
root = Path(__file__).parent.resolve()
|
10
|
+
hf3fs_utils = load(
|
11
|
+
name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"], verbose=True
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
def test_rw_shm():
|
16
|
+
numel = 8 << 20
|
17
|
+
dtype = torch.bfloat16
|
18
|
+
page_num = 128
|
19
|
+
page_bytes = numel * dtype.itemsize
|
20
|
+
shm = multiprocessing.shared_memory.SharedMemory(
|
21
|
+
size=page_num * page_bytes, create=True
|
22
|
+
)
|
23
|
+
tshm = torch.frombuffer(shm.buf, dtype=torch.uint8)
|
24
|
+
a = [
|
25
|
+
torch.randn(numel, dtype=dtype)
|
26
|
+
for _ in tqdm(range(page_num), desc="prepare input")
|
27
|
+
]
|
28
|
+
b = [
|
29
|
+
torch.empty(numel, dtype=dtype)
|
30
|
+
for _ in tqdm(range(page_num), desc="prepare output")
|
31
|
+
]
|
32
|
+
hf3fs_utils.write_shm(a, tshm)
|
33
|
+
hf3fs_utils.read_shm(tshm, b)
|
34
|
+
for _a, _b in tqdm(zip(a, b), desc="assert_close"):
|
35
|
+
torch.testing.assert_close(_a, _b)
|
36
|
+
|
37
|
+
del tshm
|
38
|
+
shm.close()
|
39
|
+
shm.unlink()
|
40
|
+
|
41
|
+
|
42
|
+
if __name__ == "__main__":
|
43
|
+
pytest.main([__file__])
|
@@ -157,6 +157,8 @@ class ModelRunner:
|
|
157
157
|
gpu_id: int,
|
158
158
|
tp_rank: int,
|
159
159
|
tp_size: int,
|
160
|
+
moe_ep_rank: int,
|
161
|
+
moe_ep_size: int,
|
160
162
|
pp_rank: int,
|
161
163
|
pp_size: int,
|
162
164
|
nccl_port: int,
|
@@ -175,6 +177,8 @@ class ModelRunner:
|
|
175
177
|
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
176
178
|
self.tp_rank = tp_rank
|
177
179
|
self.tp_size = tp_size
|
180
|
+
self.moe_ep_rank = moe_ep_rank
|
181
|
+
self.moe_ep_size = moe_ep_size
|
178
182
|
self.dp_size = server_args.dp_size
|
179
183
|
self.pp_rank = pp_rank
|
180
184
|
self.pp_size = pp_size
|
@@ -285,11 +289,21 @@ class ModelRunner:
|
|
285
289
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
286
290
|
self.is_hybrid = self.model_config.is_hybrid = True
|
287
291
|
|
288
|
-
|
289
|
-
|
290
|
-
|
292
|
+
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
293
|
+
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
294
|
+
# determine the number of layers.
|
295
|
+
model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
|
296
|
+
model_num_layers = (
|
297
|
+
self.model_config.num_nextn_predict_layers
|
298
|
+
if self.is_draft_worker and model_has_mtp_layers
|
299
|
+
else self.model_config.num_hidden_layers
|
291
300
|
)
|
301
|
+
self.start_layer = getattr(self.model, "start_layer", 0)
|
302
|
+
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
292
303
|
self.num_effective_layers = self.end_layer - self.start_layer
|
304
|
+
assert (not model_has_mtp_layers) or (
|
305
|
+
self.num_effective_layers == model_num_layers
|
306
|
+
), "PP is not compatible with MTP models."
|
293
307
|
|
294
308
|
# Apply torchao quantization
|
295
309
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
@@ -539,6 +553,7 @@ class ModelRunner:
|
|
539
553
|
initialize_model_parallel(
|
540
554
|
tensor_model_parallel_size=self.tp_size,
|
541
555
|
pipeline_model_parallel_size=self.pp_size,
|
556
|
+
expert_model_parallel_size=self.moe_ep_size,
|
542
557
|
duplicate_tp_group=self.server_args.enable_pdmux,
|
543
558
|
)
|
544
559
|
initialize_dp_attention(
|
@@ -1178,11 +1193,7 @@ class ModelRunner:
|
|
1178
1193
|
dtype=self.kv_cache_dtype,
|
1179
1194
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
1180
1195
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1181
|
-
layer_num=
|
1182
|
-
self.model_config.num_hidden_layers
|
1183
|
-
if not self.is_draft_worker
|
1184
|
-
else self.model_config.hf_config.num_nextn_predict_layers
|
1185
|
-
), # PP is not compatible with mla backend
|
1196
|
+
layer_num=self.num_effective_layers,
|
1186
1197
|
device=self.device,
|
1187
1198
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1188
1199
|
start_layer=self.start_layer,
|
@@ -1195,11 +1206,7 @@ class ModelRunner:
|
|
1195
1206
|
dtype=self.kv_cache_dtype,
|
1196
1207
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
1197
1208
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1198
|
-
layer_num=
|
1199
|
-
self.model_config.num_hidden_layers
|
1200
|
-
if not self.is_draft_worker
|
1201
|
-
else self.model_config.hf_config.num_nextn_predict_layers
|
1202
|
-
), # PP is not compatible with mla backend
|
1209
|
+
layer_num=self.num_effective_layers,
|
1203
1210
|
device=self.device,
|
1204
1211
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1205
1212
|
start_layer=self.start_layer,
|