sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -214,6 +214,10 @@ class MultimodalDataItem:
|
|
214
214
|
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
215
215
|
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
216
216
|
|
217
|
+
# gemma3n related
|
218
|
+
input_features: Optional[torch.Tensor] = None
|
219
|
+
input_features_mask: Optional[torch.Tensor] = None
|
220
|
+
|
217
221
|
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
218
222
|
|
219
223
|
@staticmethod
|
@@ -277,7 +281,10 @@ class MultimodalDataItem:
|
|
277
281
|
if self.precomputed_features is not None:
|
278
282
|
self.hash = hash_feature(self.precomputed_features)
|
279
283
|
elif self.is_audio():
|
280
|
-
self.
|
284
|
+
if self.audio_features is not None:
|
285
|
+
self.hash = hash_feature(self.audio_features)
|
286
|
+
elif self.input_features is not None:
|
287
|
+
self.hash = hash_feature(self.input_features)
|
281
288
|
else:
|
282
289
|
self.hash = hash_feature(self.pixel_values)
|
283
290
|
|
@@ -288,6 +295,7 @@ class MultimodalDataItem:
|
|
288
295
|
return (self.modality == Modality.AUDIO) and (
|
289
296
|
self.precomputed_features is not None
|
290
297
|
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
298
|
+
or not MultimodalDataItem.is_empty_list(self.input_features)
|
291
299
|
)
|
292
300
|
|
293
301
|
def is_image(self):
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -182,6 +182,18 @@ class EmbeddingBatchResult:
|
|
182
182
|
bid: int
|
183
183
|
|
184
184
|
|
185
|
+
class KvMetrics:
|
186
|
+
def __init__(self):
|
187
|
+
self.request_active_slots = None
|
188
|
+
self.request_total_slots = None
|
189
|
+
self.kv_active_blocks = None
|
190
|
+
self.kv_total_blocks = None
|
191
|
+
self.num_requests_waiting = None
|
192
|
+
self.gpu_cache_usage_perc = None
|
193
|
+
self.gpu_prefix_cache_hit_rate = None
|
194
|
+
self.data_parallel_rank = None
|
195
|
+
|
196
|
+
|
185
197
|
class IdleSleeper:
|
186
198
|
"""
|
187
199
|
In setups which have long inactivity periods it is desirable to reduce
|
@@ -223,6 +235,7 @@ class Scheduler(
|
|
223
235
|
self.server_args = server_args
|
224
236
|
self.tp_rank = tp_rank
|
225
237
|
self.pp_rank = pp_rank
|
238
|
+
self.dp_rank = dp_rank
|
226
239
|
self.tp_size = server_args.tp_size
|
227
240
|
self.pp_size = server_args.pp_size
|
228
241
|
self.dp_size = server_args.dp_size
|
@@ -261,6 +274,9 @@ class Scheduler(
|
|
261
274
|
self.send_to_tokenizer = get_zmq_socket(
|
262
275
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
263
276
|
)
|
277
|
+
self.send_metrics_from_scheduler = get_zmq_socket(
|
278
|
+
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
279
|
+
)
|
264
280
|
|
265
281
|
if server_args.skip_tokenizer_init:
|
266
282
|
# Directly send to the TokenizerManager
|
@@ -286,6 +302,7 @@ class Scheduler(
|
|
286
302
|
else:
|
287
303
|
self.recv_from_tokenizer = None
|
288
304
|
self.recv_from_rpc = None
|
305
|
+
self.send_metrics_from_scheduler = None
|
289
306
|
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
290
307
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
291
308
|
|
@@ -1239,6 +1256,22 @@ class Scheduler(
|
|
1239
1256
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1240
1257
|
self._add_request_to_queue(req)
|
1241
1258
|
|
1259
|
+
def _emit_kv_metrics(self):
|
1260
|
+
kv_metrics = KvMetrics()
|
1261
|
+
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
1262
|
+
kv_metrics.request_total_slots = self.max_running_requests
|
1263
|
+
kv_metrics.kv_active_blocks = int(
|
1264
|
+
self.stats.token_usage * self.max_total_num_tokens
|
1265
|
+
)
|
1266
|
+
kv_metrics.kv_total_blocks = self.max_total_num_tokens
|
1267
|
+
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
|
1268
|
+
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
|
1269
|
+
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
|
1270
|
+
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
|
1271
|
+
|
1272
|
+
if not self.send_metrics_from_scheduler.closed:
|
1273
|
+
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
1274
|
+
|
1242
1275
|
def log_prefill_stats(
|
1243
1276
|
self,
|
1244
1277
|
adder: PrefillAdder,
|
@@ -1291,6 +1324,7 @@ class Scheduler(
|
|
1291
1324
|
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
1292
1325
|
|
1293
1326
|
self.metrics_collector.log_stats(self.stats)
|
1327
|
+
self._emit_kv_metrics()
|
1294
1328
|
self._publish_kv_events()
|
1295
1329
|
|
1296
1330
|
def log_decode_stats(
|
@@ -1352,6 +1386,7 @@ class Scheduler(
|
|
1352
1386
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1353
1387
|
self.stats.spec_accept_length = spec_accept_length
|
1354
1388
|
self.metrics_collector.log_stats(self.stats)
|
1389
|
+
self._emit_kv_metrics()
|
1355
1390
|
self._publish_kv_events()
|
1356
1391
|
|
1357
1392
|
def check_memory(self):
|
@@ -2201,8 +2236,8 @@ class Scheduler(
|
|
2201
2236
|
"""In-place update of the weights from disk."""
|
2202
2237
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
2203
2238
|
if success:
|
2204
|
-
|
2205
|
-
assert
|
2239
|
+
flush_cache_success = self.flush_cache()
|
2240
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
2206
2241
|
else:
|
2207
2242
|
logger.error(message)
|
2208
2243
|
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
@@ -2219,8 +2254,8 @@ class Scheduler(
|
|
2219
2254
|
"""Update the online model parameter."""
|
2220
2255
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
2221
2256
|
if success:
|
2222
|
-
|
2223
|
-
assert
|
2257
|
+
flush_cache_success = self.flush_cache()
|
2258
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
2224
2259
|
else:
|
2225
2260
|
logger.error(message)
|
2226
2261
|
return UpdateWeightsFromDistributedReqOutput(success, message)
|
@@ -2231,10 +2266,11 @@ class Scheduler(
|
|
2231
2266
|
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
2232
2267
|
if success:
|
2233
2268
|
if recv_req.flush_cache:
|
2234
|
-
|
2235
|
-
assert
|
2269
|
+
flush_cache_success = self.flush_cache()
|
2270
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
2236
2271
|
else:
|
2237
2272
|
logger.error(message)
|
2273
|
+
barrier(group=self.tp_cpu_group)
|
2238
2274
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
2239
2275
|
|
2240
2276
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
@@ -239,7 +239,7 @@ class ModelRunner:
|
|
239
239
|
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
|
240
240
|
):
|
241
241
|
logger.info(
|
242
|
-
f"Initial expert_location_metadata: {get_global_expert_location_metadata()
|
242
|
+
f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
|
243
243
|
)
|
244
244
|
|
245
245
|
set_global_expert_distribution_recorder(
|
@@ -547,6 +547,7 @@ class ModelRunner:
|
|
547
547
|
self.load_config = LoadConfig(
|
548
548
|
load_format=self.server_args.load_format,
|
549
549
|
download_dir=self.server_args.download_dir,
|
550
|
+
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
550
551
|
)
|
551
552
|
if self.server_args.load_format == "gguf":
|
552
553
|
monkey_patch_vllm_gguf_config()
|
@@ -865,7 +866,9 @@ class ModelRunner:
|
|
865
866
|
else:
|
866
867
|
self.kv_cache_dtype = torch.float8_e5m2
|
867
868
|
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
868
|
-
if
|
869
|
+
if _is_hip: # Using natively supported format
|
870
|
+
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
871
|
+
else:
|
869
872
|
self.kv_cache_dtype = torch.float8_e4m3fn
|
870
873
|
else:
|
871
874
|
raise ValueError(
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
# ruff: noqa: SIM117
|
4
4
|
import collections
|
5
|
+
import concurrent
|
5
6
|
import dataclasses
|
6
7
|
import fnmatch
|
7
8
|
import glob
|
@@ -11,14 +12,17 @@ import math
|
|
11
12
|
import os
|
12
13
|
import time
|
13
14
|
from abc import ABC, abstractmethod
|
15
|
+
from concurrent.futures import ThreadPoolExecutor
|
14
16
|
from contextlib import contextmanager
|
15
17
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
16
18
|
|
17
19
|
import huggingface_hub
|
18
20
|
import numpy as np
|
21
|
+
import safetensors.torch
|
19
22
|
import torch
|
20
23
|
from huggingface_hub import HfApi, hf_hub_download
|
21
24
|
from torch import nn
|
25
|
+
from tqdm.auto import tqdm
|
22
26
|
from transformers import AutoModelForCausalLM
|
23
27
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
28
|
|
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
|
|
41
45
|
set_default_torch_dtype,
|
42
46
|
)
|
43
47
|
from sglang.srt.model_loader.weight_utils import (
|
48
|
+
_BAR_FORMAT,
|
44
49
|
download_safetensors_index_file_from_hf,
|
45
50
|
download_weights_from_hf,
|
46
51
|
filter_duplicate_safetensors_files,
|
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
|
|
49
54
|
get_quant_config,
|
50
55
|
gguf_quant_weights_iterator,
|
51
56
|
initialize_dummy_weights,
|
57
|
+
multi_thread_pt_weights_iterator,
|
58
|
+
multi_thread_safetensors_weights_iterator,
|
52
59
|
np_cache_weights_iterator,
|
53
60
|
pt_weights_iterator,
|
54
61
|
safetensors_weights_iterator,
|
@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
|
|
181
188
|
class DefaultModelLoader(BaseModelLoader):
|
182
189
|
"""Model loader that can load different file types from disk."""
|
183
190
|
|
191
|
+
# default number of thread when enable multithread weight loading
|
192
|
+
DEFAULT_NUM_THREADS = 8
|
193
|
+
|
184
194
|
@dataclasses.dataclass
|
185
195
|
class Source:
|
186
196
|
"""A source for weights."""
|
@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|
208
218
|
|
209
219
|
def __init__(self, load_config: LoadConfig):
|
210
220
|
super().__init__(load_config)
|
211
|
-
|
221
|
+
extra_config = load_config.model_loader_extra_config
|
222
|
+
allowed_keys = {"enable_multithread_load", "num_threads"}
|
223
|
+
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
224
|
+
|
225
|
+
if unexpected_keys:
|
212
226
|
raise ValueError(
|
213
|
-
f"
|
214
|
-
f"
|
227
|
+
f"Unexpected extra config keys for load format "
|
228
|
+
f"{load_config.load_format}: "
|
229
|
+
f"{unexpected_keys}"
|
215
230
|
)
|
216
231
|
|
217
232
|
def _maybe_download_from_modelscope(
|
@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
324
339
|
self, source: "Source"
|
325
340
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
326
341
|
"""Get an iterator for the model weights based on the load format."""
|
342
|
+
extra_config = self.load_config.model_loader_extra_config
|
327
343
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
328
344
|
source.model_or_path, source.revision, source.fall_back_to_pt
|
329
345
|
)
|
@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader):
|
|
342
358
|
weight_loader_disable_mmap = global_server_args_dict.get(
|
343
359
|
"weight_loader_disable_mmap"
|
344
360
|
)
|
345
|
-
|
346
|
-
|
347
|
-
|
361
|
+
|
362
|
+
if extra_config.get("enable_multithread_load"):
|
363
|
+
weights_iterator = multi_thread_safetensors_weights_iterator(
|
364
|
+
hf_weights_files,
|
365
|
+
max_workers=extra_config.get(
|
366
|
+
"num_threads", self.DEFAULT_NUM_THREADS
|
367
|
+
),
|
368
|
+
disable_mmap=weight_loader_disable_mmap,
|
369
|
+
)
|
370
|
+
else:
|
371
|
+
weights_iterator = safetensors_weights_iterator(
|
372
|
+
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
373
|
+
)
|
374
|
+
|
348
375
|
else:
|
349
|
-
|
376
|
+
if extra_config.get("enable_multithread_load"):
|
377
|
+
weights_iterator = multi_thread_pt_weights_iterator(
|
378
|
+
hf_weights_files,
|
379
|
+
max_workers=extra_config.get(
|
380
|
+
"num_threads", self.DEFAULT_NUM_THREADS
|
381
|
+
),
|
382
|
+
)
|
383
|
+
else:
|
384
|
+
weights_iterator = pt_weights_iterator(hf_weights_files)
|
350
385
|
|
351
386
|
# Apply the prefix.
|
352
387
|
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|
385
420
|
self.load_config,
|
386
421
|
)
|
387
422
|
|
388
|
-
|
389
|
-
|
390
|
-
|
423
|
+
self.load_weights_and_postprocess(
|
424
|
+
model, self._get_all_weights(model_config, model), target_device
|
425
|
+
)
|
391
426
|
|
392
427
|
return model.eval()
|
393
428
|
|
@@ -1,12 +1,14 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
|
2
2
|
|
3
3
|
"""Utilities for downloading and initializing model weights."""
|
4
|
+
import concurrent.futures
|
4
5
|
import fnmatch
|
5
6
|
import glob
|
6
7
|
import hashlib
|
7
8
|
import json
|
8
9
|
import logging
|
9
10
|
import os
|
11
|
+
import queue
|
10
12
|
import tempfile
|
11
13
|
from collections import defaultdict
|
12
14
|
from typing import (
|
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
|
|
453
455
|
yield name, param
|
454
456
|
|
455
457
|
|
458
|
+
def multi_thread_safetensors_weights_iterator(
|
459
|
+
hf_weights_files: List[str],
|
460
|
+
is_all_weights_sharded: bool = False,
|
461
|
+
decryption_key: Optional[str] = None,
|
462
|
+
max_workers: int = 4,
|
463
|
+
disable_mmap: bool = False,
|
464
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
465
|
+
"""Multi-Thread iterate over the weights in the model safetensor files.
|
466
|
+
|
467
|
+
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
468
|
+
entire file instead of reading each tensor one by one.
|
469
|
+
"""
|
470
|
+
if decryption_key:
|
471
|
+
logger.warning(
|
472
|
+
"Multi-Thread loading is not working for encrypted safetensor weights."
|
473
|
+
)
|
474
|
+
yield from safetensors_encrypted_weights_iterator(
|
475
|
+
hf_weights_files, is_all_weights_sharded, decryption_key
|
476
|
+
)
|
477
|
+
return
|
478
|
+
|
479
|
+
enable_tqdm = (
|
480
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
481
|
+
)
|
482
|
+
|
483
|
+
def _load_file(st_file: str):
|
484
|
+
if disable_mmap:
|
485
|
+
with open(st_file, "rb") as f:
|
486
|
+
result = safetensors.torch.load(f.read())
|
487
|
+
else:
|
488
|
+
result = safetensors.torch.load_file(st_file, device="cpu")
|
489
|
+
|
490
|
+
return result
|
491
|
+
|
492
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
493
|
+
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
|
494
|
+
|
495
|
+
if enable_tqdm:
|
496
|
+
futures_iter = tqdm(
|
497
|
+
concurrent.futures.as_completed(futures),
|
498
|
+
total=len(hf_weights_files),
|
499
|
+
desc="Multi-thread loading shards",
|
500
|
+
disable=not enable_tqdm,
|
501
|
+
bar_format=_BAR_FORMAT,
|
502
|
+
)
|
503
|
+
else:
|
504
|
+
futures_iter = concurrent.futures.as_completed(futures)
|
505
|
+
|
506
|
+
for future in futures_iter:
|
507
|
+
state_dict = future.result()
|
508
|
+
for name, param in state_dict.items():
|
509
|
+
yield name, param
|
510
|
+
|
511
|
+
|
456
512
|
def pt_weights_iterator(
|
457
513
|
hf_weights_files: List[str],
|
458
514
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
@@ -471,6 +527,39 @@ def pt_weights_iterator(
|
|
471
527
|
del state
|
472
528
|
|
473
529
|
|
530
|
+
def multi_thread_pt_weights_iterator(
|
531
|
+
hf_weights_files: List[str],
|
532
|
+
max_workers: int = 4,
|
533
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
534
|
+
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
535
|
+
enable_tqdm = (
|
536
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
537
|
+
)
|
538
|
+
|
539
|
+
def _load_file(bin_file: str):
|
540
|
+
return torch.load(bin_file, map_location="cpu", weights_only=True)
|
541
|
+
|
542
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
543
|
+
futures = [
|
544
|
+
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
|
545
|
+
]
|
546
|
+
|
547
|
+
if enable_tqdm:
|
548
|
+
futures_iter = tqdm(
|
549
|
+
concurrent.futures.as_completed(futures),
|
550
|
+
total=len(hf_weights_files),
|
551
|
+
desc="Multi-thread loading pt checkpoint shards",
|
552
|
+
disable=not enable_tqdm,
|
553
|
+
bar_format=_BAR_FORMAT,
|
554
|
+
)
|
555
|
+
else:
|
556
|
+
futures_iter = concurrent.futures.as_completed(futures)
|
557
|
+
|
558
|
+
for future in futures_iter:
|
559
|
+
state = future.result()
|
560
|
+
yield from state.items()
|
561
|
+
|
562
|
+
|
474
563
|
def get_gguf_extra_tensor_names(
|
475
564
|
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
476
565
|
) -> List[str]:
|
@@ -28,6 +28,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
28
28
|
ParallelLMHead,
|
29
29
|
VocabParallelEmbedding,
|
30
30
|
)
|
31
|
+
from sglang.srt.managers.expert_distribution import (
|
32
|
+
get_global_expert_distribution_recorder,
|
33
|
+
)
|
31
34
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
32
35
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
33
36
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
@@ -82,7 +85,6 @@ class DeepseekModelNextN(nn.Module):
|
|
82
85
|
forward_batch: ForwardBatch,
|
83
86
|
input_embeds: torch.Tensor = None,
|
84
87
|
) -> torch.Tensor:
|
85
|
-
|
86
88
|
zero_allocator = BumpAllocator(
|
87
89
|
buffer_size=2,
|
88
90
|
dtype=torch.float32,
|
@@ -108,9 +110,10 @@ class DeepseekModelNextN(nn.Module):
|
|
108
110
|
)
|
109
111
|
|
110
112
|
residual = None
|
111
|
-
|
112
|
-
|
113
|
-
|
113
|
+
with get_global_expert_distribution_recorder().disable_this_region():
|
114
|
+
hidden_states, residual = self.decoder(
|
115
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
116
|
+
)
|
114
117
|
|
115
118
|
if not forward_batch.forward_mode.is_idle():
|
116
119
|
if residual is not None:
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -93,6 +93,7 @@ from sglang.srt.utils import (
|
|
93
93
|
BumpAllocator,
|
94
94
|
DeepEPMode,
|
95
95
|
LazyValue,
|
96
|
+
PackWeightMethod,
|
96
97
|
add_prefix,
|
97
98
|
bind_or_assign,
|
98
99
|
cpu_has_amx_support,
|
@@ -124,8 +125,6 @@ if _is_hip:
|
|
124
125
|
decode_attention_fwd_grouped_rope,
|
125
126
|
)
|
126
127
|
|
127
|
-
if _use_aiter:
|
128
|
-
from aiter.rotary_embedding import get_rope
|
129
128
|
|
130
129
|
logger = logging.getLogger(__name__)
|
131
130
|
|
@@ -144,6 +143,9 @@ class AttnForwardMethod(IntEnum):
|
|
144
143
|
# Use MLA but with fused RoPE
|
145
144
|
MLA_FUSED_ROPE = auto()
|
146
145
|
|
146
|
+
# Use MLA with fused RoPE kernel for CPU
|
147
|
+
MLA_FUSED_ROPE_CPU = auto()
|
148
|
+
|
147
149
|
|
148
150
|
class DeepseekV2MLP(nn.Module):
|
149
151
|
def __init__(
|
@@ -212,8 +214,18 @@ class MoEGate(nn.Module):
|
|
212
214
|
)
|
213
215
|
else:
|
214
216
|
self.e_score_correction_bias = None
|
217
|
+
if _is_cpu and _is_cpu_amx_available:
|
218
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
215
219
|
|
216
220
|
def forward(self, hidden_states):
|
221
|
+
if getattr(self, "use_intel_amx_backend", False):
|
222
|
+
return torch.ops.sgl_kernel.weight_packed_linear(
|
223
|
+
hidden_states,
|
224
|
+
self.weight,
|
225
|
+
None, # bias
|
226
|
+
True, # is_vnni
|
227
|
+
)
|
228
|
+
|
217
229
|
logits = F.linear(hidden_states, self.weight, None)
|
218
230
|
return logits
|
219
231
|
|
@@ -388,7 +400,8 @@ class DeepseekV2MoE(nn.Module):
|
|
388
400
|
final_hidden_states = self.experts(
|
389
401
|
hidden_states=hidden_states, router_logits=router_logits
|
390
402
|
)
|
391
|
-
if not _is_cuda:
|
403
|
+
if not _is_cuda and not _use_aiter:
|
404
|
+
# fused in biased_grouped_topk so we can skip here
|
392
405
|
final_hidden_states *= self.routed_scaling_factor
|
393
406
|
if shared_output is not None:
|
394
407
|
final_hidden_states = final_hidden_states + shared_output
|
@@ -777,6 +790,37 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
777
790
|
"SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
|
778
791
|
)
|
779
792
|
|
793
|
+
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
|
794
|
+
# which requires self.w_kc and self.w_vc to be packed.
|
795
|
+
# If not, we will use torch.bmm and weight shouldn't be packed in this case
|
796
|
+
if (
|
797
|
+
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
798
|
+
and _is_cpu
|
799
|
+
and _is_cpu_amx_available
|
800
|
+
):
|
801
|
+
self.quant_method = PackWeightMethod(
|
802
|
+
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
803
|
+
)
|
804
|
+
|
805
|
+
self.qkv_proj_with_rope_is_int8 = (
|
806
|
+
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
807
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
808
|
+
)
|
809
|
+
self.qkv_proj_with_rope_is_fp8 = (
|
810
|
+
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
811
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
812
|
+
)
|
813
|
+
|
814
|
+
self.weight_block_size = None
|
815
|
+
if self.qkv_proj_with_rope_is_fp8:
|
816
|
+
assert (
|
817
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
818
|
+
== self.q_b_proj.quant_method.quant_config.weight_block_size
|
819
|
+
)
|
820
|
+
self.weight_block_size = (
|
821
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
822
|
+
)
|
823
|
+
|
780
824
|
def dispatch_attn_forward_method(
|
781
825
|
self, forward_batch: ForwardBatch
|
782
826
|
) -> AttnForwardMethod:
|
@@ -790,7 +834,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
790
834
|
else:
|
791
835
|
return AttnForwardMethod.MLA
|
792
836
|
else:
|
793
|
-
|
837
|
+
if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
|
838
|
+
self, "use_intel_amx_backend", False
|
839
|
+
):
|
840
|
+
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
841
|
+
else:
|
842
|
+
return AttnForwardMethod.MLA
|
794
843
|
|
795
844
|
if self.attention_backend == "flashinfer":
|
796
845
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
@@ -904,6 +953,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
904
953
|
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
905
954
|
positions, hidden_states, forward_batch, zero_allocator
|
906
955
|
)
|
956
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
957
|
+
inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
|
958
|
+
positions, hidden_states, forward_batch, zero_allocator
|
959
|
+
)
|
907
960
|
else:
|
908
961
|
raise NotImplementedError
|
909
962
|
return None, attn_forward_method, forward_batch, inner_state
|
@@ -923,6 +976,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
923
976
|
return self.forward_absorb_core(*inner_state)
|
924
977
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
925
978
|
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
979
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
980
|
+
return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
|
926
981
|
else:
|
927
982
|
raise NotImplementedError
|
928
983
|
|
@@ -1240,6 +1295,57 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1240
1295
|
zero_allocator,
|
1241
1296
|
)
|
1242
1297
|
|
1298
|
+
def forward_absorb_fused_mla_rope_cpu_prepare(
|
1299
|
+
self,
|
1300
|
+
positions: torch.Tensor,
|
1301
|
+
hidden_states: torch.Tensor,
|
1302
|
+
forward_batch: ForwardBatch,
|
1303
|
+
zero_allocator: BumpAllocator,
|
1304
|
+
):
|
1305
|
+
assert self.q_lora_rank is not None and getattr(
|
1306
|
+
self, "use_intel_amx_backend", False
|
1307
|
+
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
|
1308
|
+
|
1309
|
+
q_input, k_input, v_input = (
|
1310
|
+
torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
|
1311
|
+
hidden_states,
|
1312
|
+
self.fused_qkv_a_proj_with_mqa.weight,
|
1313
|
+
self.q_b_proj.weight,
|
1314
|
+
self.w_kc,
|
1315
|
+
self.q_a_layernorm.weight,
|
1316
|
+
self.kv_a_layernorm.weight,
|
1317
|
+
positions,
|
1318
|
+
self.rotary_emb.cos_sin_cache,
|
1319
|
+
self.kv_a_layernorm.variance_epsilon,
|
1320
|
+
self.qkv_proj_with_rope_is_int8,
|
1321
|
+
self.qkv_proj_with_rope_is_fp8,
|
1322
|
+
(
|
1323
|
+
self.fused_qkv_a_proj_with_mqa.weight_scale
|
1324
|
+
if self.qkv_proj_with_rope_is_int8
|
1325
|
+
else (
|
1326
|
+
self.fused_qkv_a_proj_with_mqa.weight_scale_inv
|
1327
|
+
if self.qkv_proj_with_rope_is_fp8
|
1328
|
+
else None
|
1329
|
+
)
|
1330
|
+
),
|
1331
|
+
(
|
1332
|
+
self.q_b_proj.weight_scale
|
1333
|
+
if self.qkv_proj_with_rope_is_int8
|
1334
|
+
else (
|
1335
|
+
self.q_b_proj.weight_scale_inv
|
1336
|
+
if self.qkv_proj_with_rope_is_fp8
|
1337
|
+
else None
|
1338
|
+
)
|
1339
|
+
),
|
1340
|
+
True, # is_vnni
|
1341
|
+
self.weight_block_size,
|
1342
|
+
self.q_lora_rank,
|
1343
|
+
self.kv_lora_rank,
|
1344
|
+
self.qk_rope_head_dim,
|
1345
|
+
)
|
1346
|
+
)
|
1347
|
+
return (q_input, k_input, v_input, forward_batch, zero_allocator)
|
1348
|
+
|
1243
1349
|
def forward_absorb_fused_mla_rope_core(
|
1244
1350
|
self,
|
1245
1351
|
q_input,
|
@@ -1313,6 +1419,43 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1313
1419
|
|
1314
1420
|
return output
|
1315
1421
|
|
1422
|
+
def forward_absorb_fused_mla_rope_cpu_core(
|
1423
|
+
self, q_input, k_input, v_input, forward_batch, zero_allocator
|
1424
|
+
):
|
1425
|
+
assert self.q_lora_rank is not None and getattr(
|
1426
|
+
self, "use_intel_amx_backend", False
|
1427
|
+
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
|
1428
|
+
|
1429
|
+
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
1430
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1431
|
+
|
1432
|
+
# [Note] Align shapes of bmm inputs.
|
1433
|
+
# Shapes of inputs:
|
1434
|
+
# q_nope: [M, B, K]
|
1435
|
+
# original self.w_kc: [B, K, N]
|
1436
|
+
# current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]
|
1437
|
+
|
1438
|
+
# Shapes of inputs to sgl_kernel.cpu.bmm:
|
1439
|
+
# out: [B, M, N]
|
1440
|
+
# mat1: [B, M, K]
|
1441
|
+
# mat2: [B, N, K]
|
1442
|
+
B = self.w_vc.size(0)
|
1443
|
+
N = self.w_vc.size(1)
|
1444
|
+
M = attn_output.size(0)
|
1445
|
+
output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
|
1446
|
+
attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
|
1447
|
+
torch.ops.sgl_kernel.bmm_cpu(
|
1448
|
+
attn_bmm_output,
|
1449
|
+
attn_output.transpose(0, 1),
|
1450
|
+
self.w_vc,
|
1451
|
+
True, # is_vnni
|
1452
|
+
None, # scale
|
1453
|
+
)
|
1454
|
+
attn_output = output
|
1455
|
+
output, _ = self.o_proj(attn_output)
|
1456
|
+
|
1457
|
+
return output
|
1458
|
+
|
1316
1459
|
def _chunked_prefix_attn_mha(
|
1317
1460
|
self,
|
1318
1461
|
q: torch.Tensor,
|