sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +84 -22
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +25 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +37 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +68 -14
- sglang/srt/models/deepseek_v2.py +62 -28
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +5 -2
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +57 -6
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -18,16 +18,16 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
18
18
|
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
19
19
|
models = [Llama4ForConditionalGeneration]
|
20
20
|
|
21
|
-
def __init__(self, hf_config, server_args, _processor):
|
22
|
-
super().__init__(hf_config, server_args, _processor)
|
21
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
22
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
23
23
|
self.vision_config = hf_config.vision_config
|
24
24
|
self.text_config = hf_config.text_config
|
25
|
-
self.
|
26
|
-
self.
|
27
|
-
self.
|
28
|
-
self.
|
25
|
+
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
26
|
+
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
27
|
+
self.IM_TOKEN_ID = hf_config.image_token_index
|
28
|
+
self.mm_tokens = MultimodalSpecialTokens(
|
29
29
|
image_token=_processor.image_token,
|
30
|
-
image_token_id=self.
|
30
|
+
image_token_id=self.IM_TOKEN_ID,
|
31
31
|
).build(_processor)
|
32
32
|
|
33
33
|
async def process_mm_data_async(
|
@@ -37,114 +37,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|
37
37
|
*args,
|
38
38
|
**kwargs,
|
39
39
|
):
|
40
|
-
|
41
|
-
assert len(input_text) and isinstance(input_text[0], int)
|
42
|
-
input_text = self._processor.tokenizer.decode(input_text)
|
43
|
-
|
44
|
-
# Process images and text using the base processor's load_mm_data method
|
45
|
-
processed_data = self.load_mm_data(
|
40
|
+
base_output = self.load_mm_data(
|
46
41
|
prompt=input_text,
|
47
|
-
multimodal_tokens=self.multimodal_tokens,
|
48
42
|
image_data=image_data,
|
49
|
-
|
43
|
+
multimodal_tokens=self.mm_tokens,
|
50
44
|
)
|
51
45
|
|
52
|
-
# Process the images using the processor
|
53
|
-
processor = self._processor
|
54
|
-
|
55
46
|
# Process the prompt and images
|
56
|
-
|
57
|
-
|
58
|
-
images=processed_data.images,
|
59
|
-
)
|
60
|
-
|
61
|
-
# Handle image resolutions and aspect ratios
|
62
|
-
if "pixel_values" not in processor_output: # no image processed
|
63
|
-
return None
|
64
|
-
|
65
|
-
image_processor = processor.image_processor
|
66
|
-
tokenizer = self._processor.tokenizer
|
67
|
-
|
68
|
-
# Calculate tile size and find supported resolutions
|
69
|
-
tile_size = self.vision_config.image_size
|
70
|
-
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
71
|
-
|
72
|
-
possible_resolutions = find_supported_resolutions(
|
73
|
-
max_num_chunks=max_num_tiles,
|
74
|
-
patch_size=SizeDict(height=tile_size, width=tile_size),
|
47
|
+
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
48
|
+
base_output, self.mm_tokens
|
75
49
|
)
|
76
50
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
for image in processed_data.images
|
85
|
-
]
|
86
|
-
|
87
|
-
# Calculate aspect ratios and patches per image
|
88
|
-
aspect_ratios = [
|
89
|
-
(image_size[0] // tile_size, image_size[1] // tile_size)
|
90
|
-
for image_size in best_fit_sizes
|
91
|
-
]
|
92
|
-
|
93
|
-
patches_per_image = [
|
94
|
-
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
95
|
-
]
|
96
|
-
|
97
|
-
# Add to image_inputs
|
98
|
-
processor_output["aspect_ratios"] = aspect_ratios
|
99
|
-
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
100
|
-
|
101
|
-
# Process embed_is_patch
|
102
|
-
vocab = tokenizer.get_vocab()
|
103
|
-
patch_id = vocab.get(processor.img_patch_token, -1)
|
104
|
-
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
105
|
-
|
106
|
-
if patch_id != -1 and image_end_id != -1:
|
107
|
-
input_ids = processor_output["input_ids"].view(-1)
|
108
|
-
|
109
|
-
# Remove BOS token if present
|
110
|
-
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
111
|
-
input_ids = input_ids[1:]
|
112
|
-
|
113
|
-
# Find image end indices and split input_ids
|
114
|
-
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
115
|
-
|
116
|
-
if image_end_indices.size(0) > 0:
|
117
|
-
# Split at image boundaries
|
118
|
-
split_indices = (image_end_indices + 1)[:-1]
|
119
|
-
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
120
|
-
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
121
|
-
|
122
|
-
# Create embed_is_patch for each image
|
123
|
-
embed_is_patch = []
|
124
|
-
for per_image_input_ids in split_input_ids:
|
125
|
-
embed_is_patch.append(per_image_input_ids == patch_id)
|
126
|
-
|
127
|
-
processor_output["embed_is_patch"] = embed_is_patch
|
128
|
-
|
129
|
-
# Convert to the format expected by SGLang
|
130
|
-
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
131
|
-
|
132
|
-
processor_output["im_start_id"] = self.boi_token_index
|
133
|
-
processor_output["im_end_id"] = self.eoi_token_index
|
134
|
-
processor_output["im_token_id"] = self.image_token_index
|
135
|
-
|
136
|
-
image_offsets = self.get_mm_items_offset(
|
137
|
-
input_ids=torch.tensor(processor_output["input_ids"]),
|
138
|
-
mm_token_id=self.image_token_index,
|
139
|
-
)
|
140
|
-
|
141
|
-
# Add metadata for image processing
|
142
|
-
processor_output["mm_items"] = [
|
143
|
-
MultimodalDataItem(
|
144
|
-
feature=processor_output["pixel_values"],
|
145
|
-
modality=Modality.IMAGE,
|
146
|
-
offsets=image_offsets,
|
147
|
-
)
|
148
|
-
]
|
149
|
-
|
150
|
-
return processor_output
|
51
|
+
return {
|
52
|
+
"input_ids": input_ids.tolist(),
|
53
|
+
"mm_items": mm_items,
|
54
|
+
"im_start_id": self.IM_START_TOKEN_ID,
|
55
|
+
"im_end_id": self.IM_END_TOKEN_ID,
|
56
|
+
"im_token_id": self.IM_TOKEN_ID,
|
57
|
+
}
|
@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
|
|
47
47
|
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
48
48
|
models = [Phi4MMForCausalLM]
|
49
49
|
|
50
|
-
def __init__(self, hf_config, server_args, _processor):
|
50
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
51
51
|
self.processor = Phi4MMProcessorAdapter(_processor)
|
52
|
-
super().__init__(hf_config, server_args, self.processor)
|
52
|
+
super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
|
53
53
|
|
54
54
|
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
|
55
55
|
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
|
@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
|
42
42
|
|
43
43
|
return ncols, nrows
|
44
44
|
|
45
|
-
def __init__(self, hf_config, server_args, _processor):
|
46
|
-
super().__init__(hf_config, server_args, _processor)
|
45
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
46
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
47
47
|
self.IM_TOKEN_ID = getattr(
|
48
48
|
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
49
49
|
)
|
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
|
|
11
11
|
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
12
12
|
models = [Qwen2AudioForConditionalGeneration]
|
13
13
|
|
14
|
-
def __init__(self, hf_config, server_args, _processor):
|
15
|
-
super().__init__(hf_config, server_args, _processor)
|
14
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
15
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
16
16
|
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
17
17
|
self.AUDIO_TOKEN_REGEX = re.compile(
|
18
18
|
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
@@ -201,8 +201,8 @@ async def preprocess_video(
|
|
201
201
|
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
202
202
|
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
203
203
|
|
204
|
-
def __init__(self, hf_config, server_args, _processor):
|
205
|
-
super().__init__(hf_config, server_args, _processor)
|
204
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
205
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
206
206
|
# The regex that matches expanded image tokens.
|
207
207
|
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
208
208
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
|
34
34
|
hf_config: PretrainedConfig,
|
35
35
|
server_args: ServerArgs,
|
36
36
|
_processor: VILAProcessor,
|
37
|
+
*args,
|
38
|
+
**kwargs,
|
37
39
|
) -> None:
|
38
|
-
super().__init__(hf_config, server_args, _processor)
|
40
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
39
41
|
self.mm_tokens = MultimodalSpecialTokens(
|
40
42
|
image_token=self._processor.tokenizer.image_token,
|
41
43
|
image_token_id=hf_config.image_token_id,
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
|
|
32
32
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
33
33
|
Returns both reasoning content and normal text separately.
|
34
34
|
"""
|
35
|
-
in_reasoning = self._in_reasoning or
|
35
|
+
in_reasoning = self._in_reasoning or self.think_start_token in text
|
36
36
|
|
37
37
|
if not in_reasoning:
|
38
38
|
return StreamingParseResult(normal_text=text)
|
@@ -231,6 +231,7 @@ class ReasoningParser:
|
|
231
231
|
"deepseek-r1": DeepSeekR1Detector,
|
232
232
|
"qwen3": Qwen3Detector,
|
233
233
|
"qwen3-thinking": Qwen3ThinkingDetector,
|
234
|
+
"glm45": Qwen3Detector,
|
234
235
|
"kimi": KimiDetector,
|
235
236
|
}
|
236
237
|
|
sglang/srt/server_args.py
CHANGED
@@ -151,6 +151,8 @@ class ServerArgs:
|
|
151
151
|
|
152
152
|
# Kernel backend
|
153
153
|
attention_backend: Optional[str] = None
|
154
|
+
decode_attention_backend: Optional[str] = None
|
155
|
+
prefill_attention_backend: Optional[str] = None
|
154
156
|
sampling_backend: Optional[str] = None
|
155
157
|
grammar_backend: Optional[str] = None
|
156
158
|
mm_attention_backend: Optional[str] = None
|
@@ -169,7 +171,8 @@ class ServerArgs:
|
|
169
171
|
ep_size: int = 1
|
170
172
|
enable_ep_moe: bool = False
|
171
173
|
enable_deepep_moe: bool = False
|
172
|
-
|
174
|
+
enable_flashinfer_cutlass_moe: bool = False
|
175
|
+
enable_flashinfer_trtllm_moe: bool = False
|
173
176
|
enable_flashinfer_allreduce_fusion: bool = False
|
174
177
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
175
178
|
ep_num_redundant_experts: int = 0
|
@@ -386,13 +389,19 @@ class ServerArgs:
|
|
386
389
|
)
|
387
390
|
self.page_size = 128
|
388
391
|
|
389
|
-
if
|
392
|
+
if (
|
393
|
+
self.attention_backend == "flashmla"
|
394
|
+
or self.decode_attention_backend == "flashmla"
|
395
|
+
):
|
390
396
|
logger.warning(
|
391
397
|
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
392
398
|
)
|
393
399
|
self.page_size = 64
|
394
400
|
|
395
|
-
if
|
401
|
+
if (
|
402
|
+
self.attention_backend == "cutlass_mla"
|
403
|
+
or self.decode_attention_backend == "cutlass_mla"
|
404
|
+
):
|
396
405
|
logger.warning(
|
397
406
|
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
398
407
|
)
|
@@ -428,12 +437,16 @@ class ServerArgs:
|
|
428
437
|
), "Please enable dp attention when setting enable_dp_lm_head. "
|
429
438
|
|
430
439
|
# MoE kernel
|
431
|
-
if self.
|
440
|
+
if self.enable_flashinfer_cutlass_moe:
|
432
441
|
assert (
|
433
442
|
self.quantization == "modelopt_fp4"
|
434
443
|
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
435
444
|
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
436
445
|
|
446
|
+
if self.enable_flashinfer_trtllm_moe:
|
447
|
+
assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
|
448
|
+
logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
|
449
|
+
|
437
450
|
# DeepEP MoE
|
438
451
|
if self.enable_deepep_moe:
|
439
452
|
if self.deepep_mode == "normal":
|
@@ -458,6 +471,9 @@ class ServerArgs:
|
|
458
471
|
"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
|
459
472
|
)
|
460
473
|
|
474
|
+
if self.enable_eplb:
|
475
|
+
assert self.enable_ep_moe or self.enable_deepep_moe
|
476
|
+
|
461
477
|
if self.enable_expert_distribution_metrics and (
|
462
478
|
self.expert_distribution_recorder_mode is None
|
463
479
|
):
|
@@ -497,7 +513,7 @@ class ServerArgs:
|
|
497
513
|
)
|
498
514
|
|
499
515
|
model_arch = self.get_hf_config().architectures[0]
|
500
|
-
if model_arch
|
516
|
+
if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
|
501
517
|
# Auto set draft_model_path DeepSeek-V3/R1
|
502
518
|
if self.speculative_draft_model_path is None:
|
503
519
|
self.speculative_draft_model_path = self.model_path
|
@@ -1092,6 +1108,7 @@ class ServerArgs:
|
|
1092
1108
|
"pythonic",
|
1093
1109
|
"kimi_k2",
|
1094
1110
|
"qwen3_coder",
|
1111
|
+
"glm45",
|
1095
1112
|
],
|
1096
1113
|
default=ServerArgs.tool_call_parser,
|
1097
1114
|
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
|
@@ -1205,6 +1222,35 @@ class ServerArgs:
|
|
1205
1222
|
default=ServerArgs.attention_backend,
|
1206
1223
|
help="Choose the kernels for attention layers.",
|
1207
1224
|
)
|
1225
|
+
parser.add_argument(
|
1226
|
+
"--decode-attention-backend",
|
1227
|
+
type=str,
|
1228
|
+
choices=[
|
1229
|
+
"flashinfer",
|
1230
|
+
"triton",
|
1231
|
+
"torch_native",
|
1232
|
+
"fa3",
|
1233
|
+
"flashmla",
|
1234
|
+
"cutlass_mla",
|
1235
|
+
],
|
1236
|
+
default=ServerArgs.decode_attention_backend,
|
1237
|
+
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
1238
|
+
)
|
1239
|
+
|
1240
|
+
parser.add_argument(
|
1241
|
+
"--prefill-attention-backend",
|
1242
|
+
type=str,
|
1243
|
+
choices=[
|
1244
|
+
"flashinfer",
|
1245
|
+
"triton",
|
1246
|
+
"torch_native",
|
1247
|
+
"fa3",
|
1248
|
+
"flashmla",
|
1249
|
+
"cutlass_mla",
|
1250
|
+
],
|
1251
|
+
default=ServerArgs.prefill_attention_backend,
|
1252
|
+
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
|
1253
|
+
)
|
1208
1254
|
parser.add_argument(
|
1209
1255
|
"--sampling-backend",
|
1210
1256
|
type=str,
|
@@ -1290,10 +1336,15 @@ class ServerArgs:
|
|
1290
1336
|
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
1291
1337
|
)
|
1292
1338
|
parser.add_argument(
|
1293
|
-
"--enable-flashinfer-moe",
|
1339
|
+
"--enable-flashinfer-cutlass-moe",
|
1294
1340
|
action="store_true",
|
1295
1341
|
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
1296
1342
|
)
|
1343
|
+
parser.add_argument(
|
1344
|
+
"--enable-flashinfer-trtllm-moe",
|
1345
|
+
action="store_true",
|
1346
|
+
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
|
1347
|
+
)
|
1297
1348
|
parser.add_argument(
|
1298
1349
|
"--enable-flashinfer-allreduce-fusion",
|
1299
1350
|
action="store_true",
|
sglang/srt/utils.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import asyncio
|
18
19
|
import builtins
|
19
20
|
import ctypes
|
20
21
|
import dataclasses
|
@@ -85,6 +86,8 @@ from torch.profiler import ProfilerActivity, profile, record_function
|
|
85
86
|
from torch.utils._contextlib import _DecoratorContextManager
|
86
87
|
from triton.runtime.cache import FileCacheManager
|
87
88
|
|
89
|
+
from sglang.srt.metrics.func_timer import enable_func_timer
|
90
|
+
|
88
91
|
logger = logging.getLogger(__name__)
|
89
92
|
|
90
93
|
show_time_cost = False
|
@@ -2049,7 +2052,7 @@ def rank0_log(msg: str):
|
|
2049
2052
|
logger.info(msg)
|
2050
2053
|
|
2051
2054
|
|
2052
|
-
def launch_dummy_health_check_server(host, port):
|
2055
|
+
def launch_dummy_health_check_server(host, port, enable_metrics):
|
2053
2056
|
import asyncio
|
2054
2057
|
|
2055
2058
|
import uvicorn
|
@@ -2067,6 +2070,11 @@ def launch_dummy_health_check_server(host, port):
|
|
2067
2070
|
"""Check the health of the http server."""
|
2068
2071
|
return Response(status_code=200)
|
2069
2072
|
|
2073
|
+
# Add prometheus middleware
|
2074
|
+
if enable_metrics:
|
2075
|
+
add_prometheus_middleware(app)
|
2076
|
+
enable_func_timer()
|
2077
|
+
|
2070
2078
|
config = uvicorn.Config(
|
2071
2079
|
app,
|
2072
2080
|
host=host,
|
@@ -2335,6 +2343,7 @@ def is_fa3_default_architecture(hf_config):
|
|
2335
2343
|
"Gemma3ForConditionalGeneration",
|
2336
2344
|
"Qwen3ForCausalLM",
|
2337
2345
|
"Qwen3MoeForCausalLM",
|
2346
|
+
"Glm4MoeForCausalLM",
|
2338
2347
|
}
|
2339
2348
|
return architectures[0] in default_archs
|
2340
2349
|
|
@@ -2855,3 +2864,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
|
2855
2864
|
]
|
2856
2865
|
|
2857
2866
|
LORA_TARGET_ALL_MODULES = "all"
|
2867
|
+
|
2868
|
+
|
2869
|
+
class ConcurrentCounter:
|
2870
|
+
"""
|
2871
|
+
An asynchronous counter for managing concurrent tasks that need
|
2872
|
+
coordinated increments, decrements, and waiting until the count reaches zero.
|
2873
|
+
|
2874
|
+
This class is useful for scenarios like tracking the number of in-flight tasks
|
2875
|
+
and waiting for them to complete.
|
2876
|
+
"""
|
2877
|
+
|
2878
|
+
def __init__(self, initial: int = 0):
|
2879
|
+
"""
|
2880
|
+
Initialize the counter with an optional initial value.
|
2881
|
+
|
2882
|
+
Args:
|
2883
|
+
initial (int): The initial value of the counter. Default is 0.
|
2884
|
+
"""
|
2885
|
+
self._count = initial
|
2886
|
+
self._condition = asyncio.Condition()
|
2887
|
+
|
2888
|
+
def value(self) -> int:
|
2889
|
+
"""
|
2890
|
+
Return the current value of the counter.
|
2891
|
+
|
2892
|
+
Note:
|
2893
|
+
This method is not synchronized. It may return a stale value
|
2894
|
+
if other coroutines are concurrently modifying the counter.
|
2895
|
+
|
2896
|
+
Returns:
|
2897
|
+
int: The current counter value.
|
2898
|
+
"""
|
2899
|
+
return self._count
|
2900
|
+
|
2901
|
+
def __repr__(self) -> str:
|
2902
|
+
"""Return an informative string representation of the counter."""
|
2903
|
+
return f"<ConcurrentCounter value={self.value()}>"
|
2904
|
+
|
2905
|
+
async def increment(self, n: int = 1, notify_all: bool = True):
|
2906
|
+
"""
|
2907
|
+
Atomically increment the counter by a given amount and notify all waiters.
|
2908
|
+
|
2909
|
+
Args:
|
2910
|
+
n (int): The amount to increment the counter by. Default is 1.
|
2911
|
+
notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
|
2912
|
+
"""
|
2913
|
+
async with self._condition:
|
2914
|
+
self._count += n
|
2915
|
+
if notify_all:
|
2916
|
+
self._condition.notify_all()
|
2917
|
+
|
2918
|
+
async def decrement(self, n: int = 1, notify_all: bool = True):
|
2919
|
+
"""
|
2920
|
+
Atomically decrement the counter by a given amount and notify all waiters.
|
2921
|
+
|
2922
|
+
Args:
|
2923
|
+
n (int): The amount to decrement the counter by. Default is 1.
|
2924
|
+
notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
|
2925
|
+
"""
|
2926
|
+
async with self._condition:
|
2927
|
+
self._count -= n
|
2928
|
+
if notify_all:
|
2929
|
+
self._condition.notify_all()
|
2930
|
+
|
2931
|
+
async def wait_for(self, condition: Callable[[int], bool]):
|
2932
|
+
"""
|
2933
|
+
Asynchronously wait until the counter satisfies a given condition.
|
2934
|
+
|
2935
|
+
This suspends the calling coroutine without blocking the thread, allowing
|
2936
|
+
other tasks to run while waiting. When the condition is met, the coroutine resumes.
|
2937
|
+
|
2938
|
+
Args:
|
2939
|
+
condition (Callable[[int], bool]): A function that takes the current counter value
|
2940
|
+
and returns True when the condition is satisfied.
|
2941
|
+
"""
|
2942
|
+
async with self._condition:
|
2943
|
+
await self._condition.wait_for(lambda: condition(self._count))
|
2944
|
+
|
2945
|
+
async def wait_for_zero(self):
|
2946
|
+
"""
|
2947
|
+
Asynchronously wait until the counter reaches zero.
|
2948
|
+
|
2949
|
+
This suspends the calling coroutine without blocking the thread, allowing
|
2950
|
+
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
2951
|
+
"""
|
2952
|
+
self.wait_for(lambda count: count == 0)
|
@@ -0,0 +1,119 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.distributed as dist
|
5
|
+
from torch.distributed.device_mesh import DeviceMesh
|
6
|
+
from torch.distributed.tensor import DTensor
|
7
|
+
|
8
|
+
from sglang.srt.entrypoints.engine import Engine
|
9
|
+
from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
|
10
|
+
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
11
|
+
from sglang.srt.utils import MultiprocessingSerializer
|
12
|
+
|
13
|
+
|
14
|
+
async def update_weights(
|
15
|
+
engine: Engine,
|
16
|
+
params_batch: list[tuple[str, torch.Tensor]],
|
17
|
+
device_mesh_key: str,
|
18
|
+
device_mesh: DeviceMesh,
|
19
|
+
load_format: Optional[str] = None,
|
20
|
+
):
|
21
|
+
"""
|
22
|
+
Update weights for the inference engine.
|
23
|
+
This function is designed to be stateless, so that the caller process could keep the stateful engine.
|
24
|
+
Example Use Case:
|
25
|
+
- Multiple Producer Process will call this function in a SPMD style
|
26
|
+
|
27
|
+
Args:
|
28
|
+
engine: The inference engine created by the caller process.
|
29
|
+
params_batch: A list of (name, tensor) tuples. We batched the tensors to avoid the overhead of cpu call.
|
30
|
+
device_mesh_key: The key of the device mesh. Typically "tp" or "infer_tp"
|
31
|
+
device_mesh: The device mesh.
|
32
|
+
load_format: The format of the weights.
|
33
|
+
"""
|
34
|
+
infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
|
35
|
+
infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
|
36
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
37
|
+
|
38
|
+
monkey_patch_torch_reductions()
|
39
|
+
|
40
|
+
# [
|
41
|
+
# (name0, ipc_tensor0_tp0),
|
42
|
+
# (name1, ipc_tensor1_tp0),
|
43
|
+
# ]
|
44
|
+
named_tensors_batch = [
|
45
|
+
(
|
46
|
+
name,
|
47
|
+
MultiprocessingSerializer.serialize(
|
48
|
+
_preprocess_tensor_for_update_weights(tensor)
|
49
|
+
),
|
50
|
+
)
|
51
|
+
for name, tensor in params_batch
|
52
|
+
]
|
53
|
+
|
54
|
+
if infer_tp_rank == 0:
|
55
|
+
gathered_serialized_batches = [None for _ in range(infer_tp_size)]
|
56
|
+
else:
|
57
|
+
gathered_serialized_batches = None
|
58
|
+
|
59
|
+
# [
|
60
|
+
# [ (name0, ipc_tensor0_tp0), (name1, ipc_tensor1_tp0) ],
|
61
|
+
# [ (name0, ipc_tensor0_tp1), (name1, ipc_tensor1_tp1) ],
|
62
|
+
# ]
|
63
|
+
dist.gather_object(
|
64
|
+
obj=named_tensors_batch,
|
65
|
+
object_gather_list=gathered_serialized_batches,
|
66
|
+
dst=device_mesh[device_mesh_key].mesh.tolist()[0],
|
67
|
+
group=device_mesh[device_mesh_key].get_group(),
|
68
|
+
)
|
69
|
+
|
70
|
+
if infer_tp_rank == 0:
|
71
|
+
# Use zip(*) to "transpose" the data structure.
|
72
|
+
# After transpose, the data structure is like:
|
73
|
+
# [
|
74
|
+
# ( (name0, ipc_tensor0_tp0), (name0, ipc_tensor0_tp1) ),
|
75
|
+
# ( (name1, ipc_tensor1_tp0), (name1, ipc_tensor1_tp1) ),
|
76
|
+
# ]
|
77
|
+
logical_tensors = zip(*gathered_serialized_batches, strict=True)
|
78
|
+
|
79
|
+
named_tensors = [
|
80
|
+
# [
|
81
|
+
# (name0, LocalSerializedTensor(values=[ipc_tensor0_tp0, ipc_tensor0_tp1])),
|
82
|
+
# (name1, LocalSerializedTensor(values=[ipc_tensor1_tp0, ipc_tensor1_tp1])),
|
83
|
+
# ]
|
84
|
+
(
|
85
|
+
tensor_group[0][0],
|
86
|
+
LocalSerializedTensor(
|
87
|
+
values=[rank_part[1] for rank_part in tensor_group]
|
88
|
+
),
|
89
|
+
)
|
90
|
+
for tensor_group in logical_tensors
|
91
|
+
]
|
92
|
+
|
93
|
+
update_weights_request = UpdateWeightsFromTensorReqInput(
|
94
|
+
serialized_named_tensors=[
|
95
|
+
MultiprocessingSerializer.serialize(named_tensors)
|
96
|
+
for _ in range(infer_tp_size)
|
97
|
+
],
|
98
|
+
load_format=load_format,
|
99
|
+
)
|
100
|
+
|
101
|
+
return await engine.update_weights_from_tensor(update_weights_request)
|
102
|
+
|
103
|
+
|
104
|
+
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
|
105
|
+
"""
|
106
|
+
Preprocess the tensor for update weights.
|
107
|
+
Example Use Case:
|
108
|
+
- FSDP: we gather tensor by calling full_tensor in _preprocess_tensor_for_update_weights
|
109
|
+
- Megatron: we do nothing here, assuming it is gathered when feed into this func
|
110
|
+
|
111
|
+
Args:
|
112
|
+
tensor: The tensor to be preprocessed.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
The full tensor if it is a DTensor, otherwise the original tensor.
|
116
|
+
"""
|
117
|
+
if isinstance(tensor, DTensor):
|
118
|
+
return tensor.full_tensor()
|
119
|
+
return tensor
|
sglang/test/runners.py
CHANGED
@@ -491,6 +491,8 @@ class SRTRunner:
|
|
491
491
|
lora_paths: List[str] = None,
|
492
492
|
max_loras_per_batch: int = 4,
|
493
493
|
attention_backend: Optional[str] = None,
|
494
|
+
prefill_attention_backend: Optional[str] = None,
|
495
|
+
decode_attention_backend: Optional[str] = None,
|
494
496
|
lora_backend: str = "triton",
|
495
497
|
disable_cuda_graph: bool = False,
|
496
498
|
disable_radix_cache: bool = False,
|
@@ -540,6 +542,8 @@ class SRTRunner:
|
|
540
542
|
max_loras_per_batch=max_loras_per_batch,
|
541
543
|
lora_backend=lora_backend,
|
542
544
|
attention_backend=attention_backend,
|
545
|
+
prefill_attention_backend=prefill_attention_backend,
|
546
|
+
decode_attention_backend=decode_attention_backend,
|
543
547
|
disable_cuda_graph=disable_cuda_graph,
|
544
548
|
disable_radix_cache=disable_radix_cache,
|
545
549
|
chunked_prefill_size=chunked_prefill_size,
|