sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +4 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  16. sglang/srt/function_call/ebnf_composer.py +10 -3
  17. sglang/srt/function_call/function_call_parser.py +2 -0
  18. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  19. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  20. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  21. sglang/srt/layers/attention/vision.py +56 -8
  22. sglang/srt/layers/layernorm.py +26 -1
  23. sglang/srt/layers/logits_processor.py +14 -3
  24. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  27. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  28. sglang/srt/layers/moe/topk.py +84 -22
  29. sglang/srt/layers/multimodal.py +11 -8
  30. sglang/srt/layers/quantization/fp8.py +25 -247
  31. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  32. sglang/srt/layers/quantization/modelopt_quant.py +25 -10
  33. sglang/srt/layers/quantization/unquant.py +24 -76
  34. sglang/srt/layers/quantization/w4afp8.py +68 -17
  35. sglang/srt/lora/lora_registry.py +93 -29
  36. sglang/srt/managers/cache_controller.py +9 -7
  37. sglang/srt/managers/mm_utils.py +154 -35
  38. sglang/srt/managers/multimodal_processor.py +3 -14
  39. sglang/srt/managers/schedule_batch.py +14 -8
  40. sglang/srt/managers/scheduler.py +35 -1
  41. sglang/srt/managers/tokenizer_manager.py +37 -6
  42. sglang/srt/managers/tp_worker.py +3 -0
  43. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  44. sglang/srt/model_executor/model_runner.py +68 -14
  45. sglang/srt/models/deepseek_v2.py +62 -28
  46. sglang/srt/models/glm4_moe.py +1035 -0
  47. sglang/srt/models/glm4_moe_nextn.py +167 -0
  48. sglang/srt/models/interns1.py +328 -0
  49. sglang/srt/models/internvl.py +143 -47
  50. sglang/srt/models/llava.py +9 -5
  51. sglang/srt/models/minicpmo.py +4 -1
  52. sglang/srt/models/qwen2_moe.py +2 -2
  53. sglang/srt/models/qwen3_moe.py +5 -2
  54. sglang/srt/multimodal/processors/base_processor.py +20 -6
  55. sglang/srt/multimodal/processors/clip.py +2 -2
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  57. sglang/srt/multimodal/processors/gemma3.py +2 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  59. sglang/srt/multimodal/processors/internvl.py +21 -8
  60. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  61. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  62. sglang/srt/multimodal/processors/llava.py +4 -4
  63. sglang/srt/multimodal/processors/minicpm.py +2 -3
  64. sglang/srt/multimodal/processors/mlama.py +2 -2
  65. sglang/srt/multimodal/processors/mllama4.py +18 -111
  66. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  67. sglang/srt/multimodal/processors/pixtral.py +2 -2
  68. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  69. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  70. sglang/srt/multimodal/processors/vila.py +3 -1
  71. sglang/srt/reasoning_parser.py +2 -1
  72. sglang/srt/server_args.py +57 -6
  73. sglang/srt/utils.py +96 -1
  74. sglang/srt/weight_sync/utils.py +119 -0
  75. sglang/test/runners.py +4 -0
  76. sglang/test/test_utils.py +65 -5
  77. sglang/utils.py +19 -0
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
  80. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
  81. sglang/srt/debug_utils.py +0 -74
  82. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -18,16 +18,16 @@ from sglang.srt.multimodal.processors.base_processor import (
18
18
  class Mllama4ImageProcessor(BaseMultimodalProcessor):
19
19
  models = [Llama4ForConditionalGeneration]
20
20
 
21
- def __init__(self, hf_config, server_args, _processor):
22
- super().__init__(hf_config, server_args, _processor)
21
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
22
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
23
23
  self.vision_config = hf_config.vision_config
24
24
  self.text_config = hf_config.text_config
25
- self.boi_token_index = hf_config.boi_token_index
26
- self.eoi_token_index = hf_config.eoi_token_index
27
- self.image_token_index = hf_config.image_token_index
28
- self.multimodal_tokens = MultimodalSpecialTokens(
25
+ self.IM_START_TOKEN_ID = hf_config.boi_token_index
26
+ self.IM_END_TOKEN_ID = hf_config.eoi_token_index
27
+ self.IM_TOKEN_ID = hf_config.image_token_index
28
+ self.mm_tokens = MultimodalSpecialTokens(
29
29
  image_token=_processor.image_token,
30
- image_token_id=self.image_token_index,
30
+ image_token_id=self.IM_TOKEN_ID,
31
31
  ).build(_processor)
32
32
 
33
33
  async def process_mm_data_async(
@@ -37,114 +37,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
37
37
  *args,
38
38
  **kwargs,
39
39
  ):
40
- if isinstance(input_text, list):
41
- assert len(input_text) and isinstance(input_text[0], int)
42
- input_text = self._processor.tokenizer.decode(input_text)
43
-
44
- # Process images and text using the base processor's load_mm_data method
45
- processed_data = self.load_mm_data(
40
+ base_output = self.load_mm_data(
46
41
  prompt=input_text,
47
- multimodal_tokens=self.multimodal_tokens,
48
42
  image_data=image_data,
49
- return_text=True,
43
+ multimodal_tokens=self.mm_tokens,
50
44
  )
51
45
 
52
- # Process the images using the processor
53
- processor = self._processor
54
-
55
46
  # Process the prompt and images
56
- processor_output = self.process_mm_data(
57
- input_text=processed_data.input_text,
58
- images=processed_data.images,
59
- )
60
-
61
- # Handle image resolutions and aspect ratios
62
- if "pixel_values" not in processor_output: # no image processed
63
- return None
64
-
65
- image_processor = processor.image_processor
66
- tokenizer = self._processor.tokenizer
67
-
68
- # Calculate tile size and find supported resolutions
69
- tile_size = self.vision_config.image_size
70
- max_num_tiles = getattr(self.vision_config, "max_patches", 1)
71
-
72
- possible_resolutions = find_supported_resolutions(
73
- max_num_chunks=max_num_tiles,
74
- patch_size=SizeDict(height=tile_size, width=tile_size),
47
+ mm_items, input_ids, _ = self.process_and_combine_mm_data(
48
+ base_output, self.mm_tokens
75
49
  )
76
50
 
77
- # Find best fit for each image
78
- best_fit_sizes = [
79
- get_best_fit(
80
- (image.size[1], image.size[0]), # (height, width)
81
- torch.tensor(possible_resolutions),
82
- resize_to_max_canvas=image_processor.resize_to_max_canvas,
83
- )
84
- for image in processed_data.images
85
- ]
86
-
87
- # Calculate aspect ratios and patches per image
88
- aspect_ratios = [
89
- (image_size[0] // tile_size, image_size[1] // tile_size)
90
- for image_size in best_fit_sizes
91
- ]
92
-
93
- patches_per_image = [
94
- 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
95
- ]
96
-
97
- # Add to image_inputs
98
- processor_output["aspect_ratios"] = aspect_ratios
99
- processor_output["patches_per_image"] = torch.tensor(patches_per_image)
100
-
101
- # Process embed_is_patch
102
- vocab = tokenizer.get_vocab()
103
- patch_id = vocab.get(processor.img_patch_token, -1)
104
- image_end_id = vocab.get(processor.end_of_img_token, -1)
105
-
106
- if patch_id != -1 and image_end_id != -1:
107
- input_ids = processor_output["input_ids"].view(-1)
108
-
109
- # Remove BOS token if present
110
- if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
111
- input_ids = input_ids[1:]
112
-
113
- # Find image end indices and split input_ids
114
- image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
115
-
116
- if image_end_indices.size(0) > 0:
117
- # Split at image boundaries
118
- split_indices = (image_end_indices + 1)[:-1]
119
- split_input_ids = torch.tensor_split(input_ids, split_indices)
120
- split_input_ids = [x for x in split_input_ids if x.numel() > 0]
121
-
122
- # Create embed_is_patch for each image
123
- embed_is_patch = []
124
- for per_image_input_ids in split_input_ids:
125
- embed_is_patch.append(per_image_input_ids == patch_id)
126
-
127
- processor_output["embed_is_patch"] = embed_is_patch
128
-
129
- # Convert to the format expected by SGLang
130
- processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
131
-
132
- processor_output["im_start_id"] = self.boi_token_index
133
- processor_output["im_end_id"] = self.eoi_token_index
134
- processor_output["im_token_id"] = self.image_token_index
135
-
136
- image_offsets = self.get_mm_items_offset(
137
- input_ids=torch.tensor(processor_output["input_ids"]),
138
- mm_token_id=self.image_token_index,
139
- )
140
-
141
- # Add metadata for image processing
142
- processor_output["mm_items"] = [
143
- MultimodalDataItem(
144
- feature=processor_output["pixel_values"],
145
- modality=Modality.IMAGE,
146
- offsets=image_offsets,
147
- )
148
- ]
149
-
150
- return processor_output
51
+ return {
52
+ "input_ids": input_ids.tolist(),
53
+ "mm_items": mm_items,
54
+ "im_start_id": self.IM_START_TOKEN_ID,
55
+ "im_end_id": self.IM_END_TOKEN_ID,
56
+ "im_token_id": self.IM_TOKEN_ID,
57
+ }
@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
47
47
  class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
48
48
  models = [Phi4MMForCausalLM]
49
49
 
50
- def __init__(self, hf_config, server_args, _processor):
50
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
51
51
  self.processor = Phi4MMProcessorAdapter(_processor)
52
- super().__init__(hf_config, server_args, self.processor)
52
+ super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
53
53
 
54
54
  # the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
55
55
  # ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
42
42
 
43
43
  return ncols, nrows
44
44
 
45
- def __init__(self, hf_config, server_args, _processor):
46
- super().__init__(hf_config, server_args, _processor)
45
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
46
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
47
47
  self.IM_TOKEN_ID = getattr(
48
48
  hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
49
49
  )
@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
11
11
  class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
12
12
  models = [Qwen2AudioForConditionalGeneration]
13
13
 
14
- def __init__(self, hf_config, server_args, _processor):
15
- super().__init__(hf_config, server_args, _processor)
14
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
15
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
16
16
  self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
17
17
  self.AUDIO_TOKEN_REGEX = re.compile(
18
18
  r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
@@ -201,8 +201,8 @@ async def preprocess_video(
201
201
  class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
202
202
  models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
203
203
 
204
- def __init__(self, hf_config, server_args, _processor):
205
- super().__init__(hf_config, server_args, _processor)
204
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
205
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
206
206
  # The regex that matches expanded image tokens.
207
207
  self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
208
208
  self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
34
34
  hf_config: PretrainedConfig,
35
35
  server_args: ServerArgs,
36
36
  _processor: VILAProcessor,
37
+ *args,
38
+ **kwargs,
37
39
  ) -> None:
38
- super().__init__(hf_config, server_args, _processor)
40
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
39
41
  self.mm_tokens = MultimodalSpecialTokens(
40
42
  image_token=self._processor.tokenizer.image_token,
41
43
  image_token_id=hf_config.image_token_id,
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
32
32
  One-time parsing: Detects and parses reasoning sections in the provided text.
33
33
  Returns both reasoning content and normal text separately.
34
34
  """
35
- in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
35
+ in_reasoning = self._in_reasoning or self.think_start_token in text
36
36
 
37
37
  if not in_reasoning:
38
38
  return StreamingParseResult(normal_text=text)
@@ -231,6 +231,7 @@ class ReasoningParser:
231
231
  "deepseek-r1": DeepSeekR1Detector,
232
232
  "qwen3": Qwen3Detector,
233
233
  "qwen3-thinking": Qwen3ThinkingDetector,
234
+ "glm45": Qwen3Detector,
234
235
  "kimi": KimiDetector,
235
236
  }
236
237
 
sglang/srt/server_args.py CHANGED
@@ -151,6 +151,8 @@ class ServerArgs:
151
151
 
152
152
  # Kernel backend
153
153
  attention_backend: Optional[str] = None
154
+ decode_attention_backend: Optional[str] = None
155
+ prefill_attention_backend: Optional[str] = None
154
156
  sampling_backend: Optional[str] = None
155
157
  grammar_backend: Optional[str] = None
156
158
  mm_attention_backend: Optional[str] = None
@@ -169,7 +171,8 @@ class ServerArgs:
169
171
  ep_size: int = 1
170
172
  enable_ep_moe: bool = False
171
173
  enable_deepep_moe: bool = False
172
- enable_flashinfer_moe: bool = False
174
+ enable_flashinfer_cutlass_moe: bool = False
175
+ enable_flashinfer_trtllm_moe: bool = False
173
176
  enable_flashinfer_allreduce_fusion: bool = False
174
177
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
175
178
  ep_num_redundant_experts: int = 0
@@ -386,13 +389,19 @@ class ServerArgs:
386
389
  )
387
390
  self.page_size = 128
388
391
 
389
- if self.attention_backend == "flashmla":
392
+ if (
393
+ self.attention_backend == "flashmla"
394
+ or self.decode_attention_backend == "flashmla"
395
+ ):
390
396
  logger.warning(
391
397
  "FlashMLA only supports a page_size of 64, change page_size to 64."
392
398
  )
393
399
  self.page_size = 64
394
400
 
395
- if self.attention_backend == "cutlass_mla":
401
+ if (
402
+ self.attention_backend == "cutlass_mla"
403
+ or self.decode_attention_backend == "cutlass_mla"
404
+ ):
396
405
  logger.warning(
397
406
  "Cutlass MLA only supports a page_size of 128, change page_size to 128."
398
407
  )
@@ -428,12 +437,16 @@ class ServerArgs:
428
437
  ), "Please enable dp attention when setting enable_dp_lm_head. "
429
438
 
430
439
  # MoE kernel
431
- if self.enable_flashinfer_moe:
440
+ if self.enable_flashinfer_cutlass_moe:
432
441
  assert (
433
442
  self.quantization == "modelopt_fp4"
434
443
  ), "modelopt_fp4 quantization is required for Flashinfer MOE"
435
444
  os.environ["TRTLLM_ENABLE_PDL"] = "1"
436
445
 
446
+ if self.enable_flashinfer_trtllm_moe:
447
+ assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
448
+ logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
449
+
437
450
  # DeepEP MoE
438
451
  if self.enable_deepep_moe:
439
452
  if self.deepep_mode == "normal":
@@ -458,6 +471,9 @@ class ServerArgs:
458
471
  "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
459
472
  )
460
473
 
474
+ if self.enable_eplb:
475
+ assert self.enable_ep_moe or self.enable_deepep_moe
476
+
461
477
  if self.enable_expert_distribution_metrics and (
462
478
  self.expert_distribution_recorder_mode is None
463
479
  ):
@@ -497,7 +513,7 @@ class ServerArgs:
497
513
  )
498
514
 
499
515
  model_arch = self.get_hf_config().architectures[0]
500
- if model_arch == "DeepseekV3ForCausalLM":
516
+ if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
501
517
  # Auto set draft_model_path DeepSeek-V3/R1
502
518
  if self.speculative_draft_model_path is None:
503
519
  self.speculative_draft_model_path = self.model_path
@@ -1092,6 +1108,7 @@ class ServerArgs:
1092
1108
  "pythonic",
1093
1109
  "kimi_k2",
1094
1110
  "qwen3_coder",
1111
+ "glm45",
1095
1112
  ],
1096
1113
  default=ServerArgs.tool_call_parser,
1097
1114
  help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
@@ -1205,6 +1222,35 @@ class ServerArgs:
1205
1222
  default=ServerArgs.attention_backend,
1206
1223
  help="Choose the kernels for attention layers.",
1207
1224
  )
1225
+ parser.add_argument(
1226
+ "--decode-attention-backend",
1227
+ type=str,
1228
+ choices=[
1229
+ "flashinfer",
1230
+ "triton",
1231
+ "torch_native",
1232
+ "fa3",
1233
+ "flashmla",
1234
+ "cutlass_mla",
1235
+ ],
1236
+ default=ServerArgs.decode_attention_backend,
1237
+ help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
1238
+ )
1239
+
1240
+ parser.add_argument(
1241
+ "--prefill-attention-backend",
1242
+ type=str,
1243
+ choices=[
1244
+ "flashinfer",
1245
+ "triton",
1246
+ "torch_native",
1247
+ "fa3",
1248
+ "flashmla",
1249
+ "cutlass_mla",
1250
+ ],
1251
+ default=ServerArgs.prefill_attention_backend,
1252
+ help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
1253
+ )
1208
1254
  parser.add_argument(
1209
1255
  "--sampling-backend",
1210
1256
  type=str,
@@ -1290,10 +1336,15 @@ class ServerArgs:
1290
1336
  help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
1291
1337
  )
1292
1338
  parser.add_argument(
1293
- "--enable-flashinfer-moe",
1339
+ "--enable-flashinfer-cutlass-moe",
1294
1340
  action="store_true",
1295
1341
  help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1296
1342
  )
1343
+ parser.add_argument(
1344
+ "--enable-flashinfer-trtllm-moe",
1345
+ action="store_true",
1346
+ help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
1347
+ )
1297
1348
  parser.add_argument(
1298
1349
  "--enable-flashinfer-allreduce-fusion",
1299
1350
  action="store_true",
sglang/srt/utils.py CHANGED
@@ -15,6 +15,7 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import asyncio
18
19
  import builtins
19
20
  import ctypes
20
21
  import dataclasses
@@ -85,6 +86,8 @@ from torch.profiler import ProfilerActivity, profile, record_function
85
86
  from torch.utils._contextlib import _DecoratorContextManager
86
87
  from triton.runtime.cache import FileCacheManager
87
88
 
89
+ from sglang.srt.metrics.func_timer import enable_func_timer
90
+
88
91
  logger = logging.getLogger(__name__)
89
92
 
90
93
  show_time_cost = False
@@ -2049,7 +2052,7 @@ def rank0_log(msg: str):
2049
2052
  logger.info(msg)
2050
2053
 
2051
2054
 
2052
- def launch_dummy_health_check_server(host, port):
2055
+ def launch_dummy_health_check_server(host, port, enable_metrics):
2053
2056
  import asyncio
2054
2057
 
2055
2058
  import uvicorn
@@ -2067,6 +2070,11 @@ def launch_dummy_health_check_server(host, port):
2067
2070
  """Check the health of the http server."""
2068
2071
  return Response(status_code=200)
2069
2072
 
2073
+ # Add prometheus middleware
2074
+ if enable_metrics:
2075
+ add_prometheus_middleware(app)
2076
+ enable_func_timer()
2077
+
2070
2078
  config = uvicorn.Config(
2071
2079
  app,
2072
2080
  host=host,
@@ -2335,6 +2343,7 @@ def is_fa3_default_architecture(hf_config):
2335
2343
  "Gemma3ForConditionalGeneration",
2336
2344
  "Qwen3ForCausalLM",
2337
2345
  "Qwen3MoeForCausalLM",
2346
+ "Glm4MoeForCausalLM",
2338
2347
  }
2339
2348
  return architectures[0] in default_archs
2340
2349
 
@@ -2855,3 +2864,89 @@ SUPPORTED_LORA_TARGET_MODULES = [
2855
2864
  ]
2856
2865
 
2857
2866
  LORA_TARGET_ALL_MODULES = "all"
2867
+
2868
+
2869
+ class ConcurrentCounter:
2870
+ """
2871
+ An asynchronous counter for managing concurrent tasks that need
2872
+ coordinated increments, decrements, and waiting until the count reaches zero.
2873
+
2874
+ This class is useful for scenarios like tracking the number of in-flight tasks
2875
+ and waiting for them to complete.
2876
+ """
2877
+
2878
+ def __init__(self, initial: int = 0):
2879
+ """
2880
+ Initialize the counter with an optional initial value.
2881
+
2882
+ Args:
2883
+ initial (int): The initial value of the counter. Default is 0.
2884
+ """
2885
+ self._count = initial
2886
+ self._condition = asyncio.Condition()
2887
+
2888
+ def value(self) -> int:
2889
+ """
2890
+ Return the current value of the counter.
2891
+
2892
+ Note:
2893
+ This method is not synchronized. It may return a stale value
2894
+ if other coroutines are concurrently modifying the counter.
2895
+
2896
+ Returns:
2897
+ int: The current counter value.
2898
+ """
2899
+ return self._count
2900
+
2901
+ def __repr__(self) -> str:
2902
+ """Return an informative string representation of the counter."""
2903
+ return f"<ConcurrentCounter value={self.value()}>"
2904
+
2905
+ async def increment(self, n: int = 1, notify_all: bool = True):
2906
+ """
2907
+ Atomically increment the counter by a given amount and notify all waiters.
2908
+
2909
+ Args:
2910
+ n (int): The amount to increment the counter by. Default is 1.
2911
+ notify_all (bool): Whether to notify all waiters after incrementing. Default is True.
2912
+ """
2913
+ async with self._condition:
2914
+ self._count += n
2915
+ if notify_all:
2916
+ self._condition.notify_all()
2917
+
2918
+ async def decrement(self, n: int = 1, notify_all: bool = True):
2919
+ """
2920
+ Atomically decrement the counter by a given amount and notify all waiters.
2921
+
2922
+ Args:
2923
+ n (int): The amount to decrement the counter by. Default is 1.
2924
+ notify_all (bool): Whether to notify all waiters after decrementing. Default is True.
2925
+ """
2926
+ async with self._condition:
2927
+ self._count -= n
2928
+ if notify_all:
2929
+ self._condition.notify_all()
2930
+
2931
+ async def wait_for(self, condition: Callable[[int], bool]):
2932
+ """
2933
+ Asynchronously wait until the counter satisfies a given condition.
2934
+
2935
+ This suspends the calling coroutine without blocking the thread, allowing
2936
+ other tasks to run while waiting. When the condition is met, the coroutine resumes.
2937
+
2938
+ Args:
2939
+ condition (Callable[[int], bool]): A function that takes the current counter value
2940
+ and returns True when the condition is satisfied.
2941
+ """
2942
+ async with self._condition:
2943
+ await self._condition.wait_for(lambda: condition(self._count))
2944
+
2945
+ async def wait_for_zero(self):
2946
+ """
2947
+ Asynchronously wait until the counter reaches zero.
2948
+
2949
+ This suspends the calling coroutine without blocking the thread, allowing
2950
+ other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
2951
+ """
2952
+ self.wait_for(lambda count: count == 0)
@@ -0,0 +1,119 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.distributed.device_mesh import DeviceMesh
6
+ from torch.distributed.tensor import DTensor
7
+
8
+ from sglang.srt.entrypoints.engine import Engine
9
+ from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
10
+ from sglang.srt.model_executor.model_runner import LocalSerializedTensor
11
+ from sglang.srt.utils import MultiprocessingSerializer
12
+
13
+
14
+ async def update_weights(
15
+ engine: Engine,
16
+ params_batch: list[tuple[str, torch.Tensor]],
17
+ device_mesh_key: str,
18
+ device_mesh: DeviceMesh,
19
+ load_format: Optional[str] = None,
20
+ ):
21
+ """
22
+ Update weights for the inference engine.
23
+ This function is designed to be stateless, so that the caller process could keep the stateful engine.
24
+ Example Use Case:
25
+ - Multiple Producer Process will call this function in a SPMD style
26
+
27
+ Args:
28
+ engine: The inference engine created by the caller process.
29
+ params_batch: A list of (name, tensor) tuples. We batched the tensors to avoid the overhead of cpu call.
30
+ device_mesh_key: The key of the device mesh. Typically "tp" or "infer_tp"
31
+ device_mesh: The device mesh.
32
+ load_format: The format of the weights.
33
+ """
34
+ infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
35
+ infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
36
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
37
+
38
+ monkey_patch_torch_reductions()
39
+
40
+ # [
41
+ # (name0, ipc_tensor0_tp0),
42
+ # (name1, ipc_tensor1_tp0),
43
+ # ]
44
+ named_tensors_batch = [
45
+ (
46
+ name,
47
+ MultiprocessingSerializer.serialize(
48
+ _preprocess_tensor_for_update_weights(tensor)
49
+ ),
50
+ )
51
+ for name, tensor in params_batch
52
+ ]
53
+
54
+ if infer_tp_rank == 0:
55
+ gathered_serialized_batches = [None for _ in range(infer_tp_size)]
56
+ else:
57
+ gathered_serialized_batches = None
58
+
59
+ # [
60
+ # [ (name0, ipc_tensor0_tp0), (name1, ipc_tensor1_tp0) ],
61
+ # [ (name0, ipc_tensor0_tp1), (name1, ipc_tensor1_tp1) ],
62
+ # ]
63
+ dist.gather_object(
64
+ obj=named_tensors_batch,
65
+ object_gather_list=gathered_serialized_batches,
66
+ dst=device_mesh[device_mesh_key].mesh.tolist()[0],
67
+ group=device_mesh[device_mesh_key].get_group(),
68
+ )
69
+
70
+ if infer_tp_rank == 0:
71
+ # Use zip(*) to "transpose" the data structure.
72
+ # After transpose, the data structure is like:
73
+ # [
74
+ # ( (name0, ipc_tensor0_tp0), (name0, ipc_tensor0_tp1) ),
75
+ # ( (name1, ipc_tensor1_tp0), (name1, ipc_tensor1_tp1) ),
76
+ # ]
77
+ logical_tensors = zip(*gathered_serialized_batches, strict=True)
78
+
79
+ named_tensors = [
80
+ # [
81
+ # (name0, LocalSerializedTensor(values=[ipc_tensor0_tp0, ipc_tensor0_tp1])),
82
+ # (name1, LocalSerializedTensor(values=[ipc_tensor1_tp0, ipc_tensor1_tp1])),
83
+ # ]
84
+ (
85
+ tensor_group[0][0],
86
+ LocalSerializedTensor(
87
+ values=[rank_part[1] for rank_part in tensor_group]
88
+ ),
89
+ )
90
+ for tensor_group in logical_tensors
91
+ ]
92
+
93
+ update_weights_request = UpdateWeightsFromTensorReqInput(
94
+ serialized_named_tensors=[
95
+ MultiprocessingSerializer.serialize(named_tensors)
96
+ for _ in range(infer_tp_size)
97
+ ],
98
+ load_format=load_format,
99
+ )
100
+
101
+ return await engine.update_weights_from_tensor(update_weights_request)
102
+
103
+
104
+ def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
105
+ """
106
+ Preprocess the tensor for update weights.
107
+ Example Use Case:
108
+ - FSDP: we gather tensor by calling full_tensor in _preprocess_tensor_for_update_weights
109
+ - Megatron: we do nothing here, assuming it is gathered when feed into this func
110
+
111
+ Args:
112
+ tensor: The tensor to be preprocessed.
113
+
114
+ Returns:
115
+ The full tensor if it is a DTensor, otherwise the original tensor.
116
+ """
117
+ if isinstance(tensor, DTensor):
118
+ return tensor.full_tensor()
119
+ return tensor
sglang/test/runners.py CHANGED
@@ -491,6 +491,8 @@ class SRTRunner:
491
491
  lora_paths: List[str] = None,
492
492
  max_loras_per_batch: int = 4,
493
493
  attention_backend: Optional[str] = None,
494
+ prefill_attention_backend: Optional[str] = None,
495
+ decode_attention_backend: Optional[str] = None,
494
496
  lora_backend: str = "triton",
495
497
  disable_cuda_graph: bool = False,
496
498
  disable_radix_cache: bool = False,
@@ -540,6 +542,8 @@ class SRTRunner:
540
542
  max_loras_per_batch=max_loras_per_batch,
541
543
  lora_backend=lora_backend,
542
544
  attention_backend=attention_backend,
545
+ prefill_attention_backend=prefill_attention_backend,
546
+ decode_attention_backend=decode_attention_backend,
543
547
  disable_cuda_graph=disable_cuda_graph,
544
548
  disable_radix_cache=disable_radix_cache,
545
549
  chunked_prefill_size=chunked_prefill_size,