sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
|
2
2
|
|
3
3
|
"""Utilities for downloading and initializing model weights."""
|
4
|
+
import concurrent.futures
|
4
5
|
import fnmatch
|
5
6
|
import glob
|
6
7
|
import hashlib
|
7
8
|
import json
|
8
9
|
import logging
|
9
10
|
import os
|
11
|
+
import queue
|
10
12
|
import tempfile
|
11
13
|
from collections import defaultdict
|
12
14
|
from typing import (
|
@@ -34,6 +36,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
|
34
36
|
from sglang.srt.configs.model_config import ModelConfig
|
35
37
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
36
38
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
39
|
+
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
37
40
|
from sglang.srt.utils import print_warning_once
|
38
41
|
|
39
42
|
logger = logging.getLogger(__name__)
|
@@ -206,7 +209,10 @@ def get_quant_config(
|
|
206
209
|
config["adapter_name_or_path"] = model_name_or_path
|
207
210
|
elif model_config.quantization == "modelopt":
|
208
211
|
if config["producer"]["name"] == "modelopt":
|
209
|
-
|
212
|
+
if "FP4" in config["quantization"]["quant_algo"]:
|
213
|
+
return ModelOptFp4Config.from_config(config)
|
214
|
+
else:
|
215
|
+
return quant_cls.from_config(config)
|
210
216
|
else:
|
211
217
|
raise ValueError(
|
212
218
|
f"Unsupported quantization config"
|
@@ -418,6 +424,7 @@ def safetensors_weights_iterator(
|
|
418
424
|
hf_weights_files: List[str],
|
419
425
|
is_all_weights_sharded: bool = False,
|
420
426
|
decryption_key: Optional[str] = None,
|
427
|
+
disable_mmap: bool = False,
|
421
428
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
422
429
|
"""Iterate over the weights in the model safetensor files.
|
423
430
|
|
@@ -439,11 +446,69 @@ def safetensors_weights_iterator(
|
|
439
446
|
disable=not enable_tqdm,
|
440
447
|
bar_format=_BAR_FORMAT,
|
441
448
|
):
|
442
|
-
|
449
|
+
if disable_mmap:
|
450
|
+
with open(st_file, "rb") as f:
|
451
|
+
result = safetensors.torch.load(f.read())
|
452
|
+
else:
|
453
|
+
result = safetensors.torch.load_file(st_file, device="cpu")
|
443
454
|
for name, param in result.items():
|
444
455
|
yield name, param
|
445
456
|
|
446
457
|
|
458
|
+
def multi_thread_safetensors_weights_iterator(
|
459
|
+
hf_weights_files: List[str],
|
460
|
+
is_all_weights_sharded: bool = False,
|
461
|
+
decryption_key: Optional[str] = None,
|
462
|
+
max_workers: int = 4,
|
463
|
+
disable_mmap: bool = False,
|
464
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
465
|
+
"""Multi-Thread iterate over the weights in the model safetensor files.
|
466
|
+
|
467
|
+
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
468
|
+
entire file instead of reading each tensor one by one.
|
469
|
+
"""
|
470
|
+
if decryption_key:
|
471
|
+
logger.warning(
|
472
|
+
"Multi-Thread loading is not working for encrypted safetensor weights."
|
473
|
+
)
|
474
|
+
yield from safetensors_encrypted_weights_iterator(
|
475
|
+
hf_weights_files, is_all_weights_sharded, decryption_key
|
476
|
+
)
|
477
|
+
return
|
478
|
+
|
479
|
+
enable_tqdm = (
|
480
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
481
|
+
)
|
482
|
+
|
483
|
+
def _load_file(st_file: str):
|
484
|
+
if disable_mmap:
|
485
|
+
with open(st_file, "rb") as f:
|
486
|
+
result = safetensors.torch.load(f.read())
|
487
|
+
else:
|
488
|
+
result = safetensors.torch.load_file(st_file, device="cpu")
|
489
|
+
|
490
|
+
return result
|
491
|
+
|
492
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
493
|
+
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
|
494
|
+
|
495
|
+
if enable_tqdm:
|
496
|
+
futures_iter = tqdm(
|
497
|
+
concurrent.futures.as_completed(futures),
|
498
|
+
total=len(hf_weights_files),
|
499
|
+
desc="Multi-thread loading shards",
|
500
|
+
disable=not enable_tqdm,
|
501
|
+
bar_format=_BAR_FORMAT,
|
502
|
+
)
|
503
|
+
else:
|
504
|
+
futures_iter = concurrent.futures.as_completed(futures)
|
505
|
+
|
506
|
+
for future in futures_iter:
|
507
|
+
state_dict = future.result()
|
508
|
+
for name, param in state_dict.items():
|
509
|
+
yield name, param
|
510
|
+
|
511
|
+
|
447
512
|
def pt_weights_iterator(
|
448
513
|
hf_weights_files: List[str],
|
449
514
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
@@ -462,6 +527,39 @@ def pt_weights_iterator(
|
|
462
527
|
del state
|
463
528
|
|
464
529
|
|
530
|
+
def multi_thread_pt_weights_iterator(
|
531
|
+
hf_weights_files: List[str],
|
532
|
+
max_workers: int = 4,
|
533
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
534
|
+
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
535
|
+
enable_tqdm = (
|
536
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
537
|
+
)
|
538
|
+
|
539
|
+
def _load_file(bin_file: str):
|
540
|
+
return torch.load(bin_file, map_location="cpu", weights_only=True)
|
541
|
+
|
542
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
543
|
+
futures = [
|
544
|
+
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
|
545
|
+
]
|
546
|
+
|
547
|
+
if enable_tqdm:
|
548
|
+
futures_iter = tqdm(
|
549
|
+
concurrent.futures.as_completed(futures),
|
550
|
+
total=len(hf_weights_files),
|
551
|
+
desc="Multi-thread loading pt checkpoint shards",
|
552
|
+
disable=not enable_tqdm,
|
553
|
+
bar_format=_BAR_FORMAT,
|
554
|
+
)
|
555
|
+
else:
|
556
|
+
futures_iter = concurrent.futures.as_completed(futures)
|
557
|
+
|
558
|
+
for future in futures_iter:
|
559
|
+
state = future.result()
|
560
|
+
yield from state.items()
|
561
|
+
|
562
|
+
|
465
563
|
def get_gguf_extra_tensor_names(
|
466
564
|
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
467
565
|
) -> List[str]:
|
@@ -22,13 +22,15 @@ from transformers import PretrainedConfig
|
|
22
22
|
|
23
23
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
|
-
from sglang.srt.layers.linear import ReplicatedLinear
|
26
25
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
26
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
27
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
29
28
|
ParallelLMHead,
|
30
29
|
VocabParallelEmbedding,
|
31
30
|
)
|
31
|
+
from sglang.srt.managers.expert_distribution import (
|
32
|
+
get_global_expert_distribution_recorder,
|
33
|
+
)
|
32
34
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
35
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
34
36
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
@@ -45,6 +47,12 @@ class DeepseekModelNextN(nn.Module):
|
|
45
47
|
prefix: str = "",
|
46
48
|
) -> None:
|
47
49
|
super().__init__()
|
50
|
+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
51
|
+
logger.warning(
|
52
|
+
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
|
53
|
+
)
|
54
|
+
quant_config = None
|
55
|
+
|
48
56
|
self.vocab_size = config.vocab_size
|
49
57
|
|
50
58
|
self.embed_tokens = VocabParallelEmbedding(
|
@@ -90,23 +98,29 @@ class DeepseekModelNextN(nn.Module):
|
|
90
98
|
else:
|
91
99
|
hidden_states = input_embeds
|
92
100
|
|
93
|
-
hidden_states
|
94
|
-
|
95
|
-
(
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
101
|
+
if hidden_states.shape[0] > 0:
|
102
|
+
hidden_states = self.eh_proj(
|
103
|
+
torch.cat(
|
104
|
+
(
|
105
|
+
self.enorm(hidden_states),
|
106
|
+
self.hnorm(forward_batch.spec_info.hidden_states),
|
107
|
+
),
|
108
|
+
dim=-1,
|
109
|
+
)
|
100
110
|
)
|
101
|
-
)
|
102
111
|
|
103
112
|
residual = None
|
104
|
-
|
105
|
-
|
106
|
-
|
113
|
+
with get_global_expert_distribution_recorder().disable_this_region():
|
114
|
+
hidden_states, residual = self.decoder(
|
115
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
116
|
+
)
|
107
117
|
|
108
118
|
if not forward_batch.forward_mode.is_idle():
|
109
|
-
|
119
|
+
if residual is not None:
|
120
|
+
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
|
121
|
+
else:
|
122
|
+
hidden_states = self.shared_head.norm(hidden_states)
|
123
|
+
|
110
124
|
return hidden_states
|
111
125
|
|
112
126
|
|
@@ -127,23 +141,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
127
141
|
self.model = DeepseekModelNextN(
|
128
142
|
config, quant_config, prefix=add_prefix("model", prefix)
|
129
143
|
)
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
139
|
-
else:
|
140
|
-
self.lm_head = ParallelLMHead(
|
141
|
-
config.vocab_size,
|
142
|
-
config.hidden_size,
|
143
|
-
quant_config=quant_config,
|
144
|
-
prefix=add_prefix("model.shared_head.head", prefix),
|
145
|
-
)
|
146
|
-
self.logits_processor = LogitsProcessor(config)
|
144
|
+
self.lm_head = ParallelLMHead(
|
145
|
+
config.vocab_size,
|
146
|
+
config.hidden_size,
|
147
|
+
quant_config=quant_config,
|
148
|
+
prefix=add_prefix("model.shared_head.head", prefix),
|
149
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
150
|
+
)
|
151
|
+
self.logits_processor = LogitsProcessor(config)
|
147
152
|
|
148
153
|
@torch.no_grad()
|
149
154
|
def forward(
|