sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import (
29
29
  )
30
30
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
31
31
  from sglang.srt.models.qwen2 import Qwen2Model
32
+ from sglang.srt.server_args import get_global_server_args
32
33
  from sglang.srt.utils import (
33
34
  add_prefix,
34
35
  get_cmo_stream,
@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module):
88
89
  self.max_position_embeddings = max_position_embeddings
89
90
  self.tp_rank = get_tensor_model_parallel_rank()
90
91
 
91
- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
92
- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
92
+ norm_kwargs = (
93
+ dict(
94
+ weight_dtype=torch.float32,
95
+ cast_x_before_out_mul=True,
96
+ )
97
+ if get_global_server_args().rl_on_policy_target == "fsdp"
98
+ else {}
99
+ )
100
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
101
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
93
102
 
94
103
  self.qkv_proj = QKVParallelLinear(
95
104
  hidden_size,
@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module):
158
167
  hidden_states: torch.Tensor,
159
168
  forward_batch: ForwardBatch,
160
169
  ) -> torch.Tensor:
170
+ if get_global_server_args().rl_on_policy_target == "fsdp":
171
+ hidden_states = hidden_states.bfloat16()
172
+
161
173
  qkv, _ = self.qkv_proj(hidden_states)
162
174
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
163
175
  q, k = self._apply_qk_norm(q, k)
164
176
  q, k = self.rotary_emb(positions, q, k)
177
+
178
+ if get_global_server_args().rl_on_policy_target == "fsdp":
179
+ q = q.to(torch.bfloat16)
180
+ k = k.to(torch.bfloat16)
181
+
165
182
  attn_output = self.attn(q, k, v, forward_batch)
166
183
  output, _ = self.o_proj(attn_output)
167
184
  return output
@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module):
204
221
  quant_config=quant_config,
205
222
  prefix=add_prefix("mlp", prefix),
206
223
  )
207
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
224
+
225
+ norm_kwargs = (
226
+ dict(
227
+ weight_dtype=torch.float32,
228
+ cast_x_before_out_mul=True,
229
+ override_orig_dtype=torch.float32,
230
+ fp32_residual=True,
231
+ )
232
+ if get_global_server_args().rl_on_policy_target == "fsdp"
233
+ else {}
234
+ )
235
+ self.input_layernorm = RMSNorm(
236
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
237
+ )
208
238
  self.post_attention_layernorm = RMSNorm(
209
- config.hidden_size, eps=config.rms_norm_eps
239
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
210
240
  )
211
241
 
212
242
  self.layer_scatter_modes = LayerScatterModes.init_new(
@@ -331,7 +361,7 @@ class Qwen3ForCausalLM(nn.Module):
331
361
  self.pp_group.send(
332
362
  self.model.embed_tokens.weight, dst=self.pp_group.last_rank
333
363
  )
334
- else:
364
+ elif self.pp_group.is_last_rank:
335
365
  emb_token_weight = self.pp_group.recv(
336
366
  size=(config.vocab_size, config.hidden_size),
337
367
  dtype=next(self.model.parameters()).dtype,
@@ -241,16 +241,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
241
241
  )
242
242
 
243
243
  def op_experts(self, state):
244
- state.hidden_states_experts_output = self.experts.run_moe_core(
244
+ state.combine_input = self.experts.run_moe_core(
245
245
  dispatch_output=state.dispatch_output,
246
246
  )
247
247
 
248
248
  def op_combine_a(self, state):
249
249
  if self.ep_size > 1:
250
250
  self.experts.dispatcher.combine_a(
251
- hidden_states=state.pop("hidden_states_experts_output"),
252
- topk_ids=state.dispatch_output.topk_ids,
253
- topk_weights=state.dispatch_output.topk_weights,
251
+ combine_input=state.pop("combine_input"),
254
252
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
255
253
  )
256
254
  state.pop("dispatch_output")
@@ -539,10 +537,16 @@ class Qwen3MoeDecoderLayer(nn.Module):
539
537
  hidden_states: torch.Tensor,
540
538
  forward_batch: ForwardBatch,
541
539
  residual: Optional[torch.Tensor],
540
+ captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
542
541
  ) -> Tuple[torch.Tensor, torch.Tensor]:
543
542
 
544
- hidden_states, residual = self.layer_communicator.prepare_attn(
545
- hidden_states, residual, forward_batch
543
+ hidden_states, residual = (
544
+ self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
545
+ hidden_states,
546
+ residual,
547
+ forward_batch,
548
+ captured_last_layer_outputs=captured_last_layer_outputs,
549
+ )
546
550
  )
547
551
 
548
552
  if hidden_states.shape[0] != 0:
@@ -774,13 +778,15 @@ class Qwen3MoeForCausalLM(nn.Module):
774
778
  self.capture_aux_hidden_states = True
775
779
  if layer_ids is None:
776
780
  num_layers = self.config.num_hidden_layers
777
- self.model.layers_to_capture = [
778
- 2,
779
- num_layers // 2,
780
- num_layers - 3,
781
- ] # Specific layers for EAGLE3 support
781
+ self.model.set_eagle3_layers_to_capture(
782
+ [
783
+ 2,
784
+ num_layers // 2,
785
+ num_layers - 3,
786
+ ]
787
+ ) # Specific layers for EAGLE3 support
782
788
  else:
783
- self.model.layers_to_capture = [val + 1 for val in layer_ids]
789
+ self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
784
790
 
785
791
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
786
792
  stacked_params_mapping = [
@@ -478,6 +478,13 @@ class Qwen3GatedDeltaNet(nn.Module):
478
478
  # reshape input data into 2D tensor
479
479
  core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
480
480
  z = z.reshape(-1, z.shape[-1])
481
+
482
+ # Add padding for DP-Attn
483
+ if is_dp_attention_enabled():
484
+ core_attn_out_pad = torch.zeros_like(z)
485
+ core_attn_out_pad[: core_attn_out.shape[0], :] = core_attn_out
486
+ core_attn_out = core_attn_out_pad
487
+
481
488
  core_attn_out = self.norm(core_attn_out, z)
482
489
  core_attn_out = core_attn_out.reshape(z_shape_og)
483
490
  core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
@@ -0,0 +1,35 @@
1
+ from typing import Dict, Type
2
+
3
+ from transformers import PretrainedConfig, ProcessorMixin
4
+
5
+ # Useful for registering a custom processor different from Hugging Face's default.
6
+ _CUSTOMIZED_MM_PROCESSOR: Dict[str, Type[ProcessorMixin]] = dict()
7
+
8
+
9
+ def register_customized_processor(
10
+ processor_class: Type[ProcessorMixin],
11
+ ):
12
+ """Class decorator that maps a config class's model_type field to a customized processor class.
13
+
14
+ Args:
15
+ processor_class: A processor class that inherits from ProcessorMixin
16
+
17
+ Example:
18
+ ```python
19
+ @register_customized_processor(MyCustomProcessor)
20
+ class MyModelConfig(PretrainedConfig):
21
+ model_type = "my_model"
22
+
23
+ ```
24
+ """
25
+
26
+ def decorator(config_class: PretrainedConfig):
27
+ if not hasattr(config_class, "model_type"):
28
+ raise ValueError(
29
+ f"Class {config_class.__name__} with register_customized_processor should "
30
+ f"have a 'model_type' class attribute."
31
+ )
32
+ _CUSTOMIZED_MM_PROCESSOR[config_class.model_type] = processor_class
33
+ return config_class
34
+
35
+ return decorator
@@ -185,6 +185,7 @@ class BaseMultimodalProcessor(ABC):
185
185
  "aspect_ratio_mask": Modality.IMAGE,
186
186
  "num_patches": Modality.IMAGE,
187
187
  "patch_pixel_values": Modality.IMAGE,
188
+ "block_sizes": Modality.IMAGE,
188
189
  # Audio-related attributes
189
190
  "audio_features": Modality.AUDIO,
190
191
  "audio_feature_lens": Modality.AUDIO,
@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
17
17
  def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
18
18
  super().__init__(hf_config, server_args, _processor, *args, **kwargs)
19
19
 
20
- # GLM-4.1V and GLM-4.5V specific tokens
20
+ # GLM-V specific tokens
21
21
  self.IMAGE_TOKEN = "<|image|>"
22
22
  self.VIDEO_TOKEN = "<|video|>"
23
23
  self.IMAGE_START_TOKEN = "<|begin_of_image|>"
@@ -1,64 +1,72 @@
1
- from typing import Any, Dict, List, Optional, Type
1
+ from typing import Any
2
2
 
3
3
  import torch.nn as nn
4
4
  from transformers.configuration_utils import PretrainedConfig
5
5
  from transformers.processing_utils import ProcessorMixin
6
6
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase
7
7
 
8
- from sglang.srt.managers.io_struct import (
9
- EmbeddingReqInput,
10
- GenerateReqInput,
11
- ImageDataInputItem,
12
- )
13
- from sglang.srt.models.vila import VILAForConditionalGeneration
8
+ from sglang.srt.managers.io_struct import GenerateReqInput
9
+ from sglang.srt.models.nvila import NVILAForConditionalGeneration
10
+ from sglang.srt.models.nvila_lite import NVILALiteForConditionalGeneration
14
11
  from sglang.srt.multimodal.processors.base_processor import (
15
12
  BaseMultimodalProcessor,
16
13
  MultimodalSpecialTokens,
17
14
  )
18
15
  from sglang.srt.server_args import ServerArgs
19
16
 
17
+ NUM_VIDEO_FRAMES = 8
20
18
 
21
- class VILAProcessor(ProcessorMixin):
22
- """A stub class for the VILA processor."""
23
-
24
- tokenizer: PreTrainedTokenizerBase
25
-
26
-
27
- class VILAMultimodalProcessor(BaseMultimodalProcessor):
28
- models: List[Type[nn.Module]] = [VILAForConditionalGeneration]
29
19
 
30
- _processor: VILAProcessor
20
+ class NVILAMultimodalProcessor(BaseMultimodalProcessor):
21
+ models: list[type[nn.Module]] = [
22
+ NVILAForConditionalGeneration,
23
+ NVILALiteForConditionalGeneration,
24
+ ]
31
25
 
32
26
  def __init__(
33
27
  self,
34
28
  hf_config: PretrainedConfig,
35
29
  server_args: ServerArgs,
36
- _processor: VILAProcessor,
30
+ _processor: ProcessorMixin,
37
31
  *args,
38
32
  **kwargs,
39
33
  ) -> None:
40
34
  super().__init__(hf_config, server_args, _processor, *args, **kwargs)
35
+
36
+ self._processor: ProcessorMixin
37
+
38
+ tokenizer: PreTrainedTokenizerBase = getattr(self._processor, "tokenizer")
39
+
41
40
  self.mm_tokens = MultimodalSpecialTokens(
42
- image_token=self._processor.tokenizer.image_token,
41
+ image_token=tokenizer.image_token,
43
42
  image_token_id=hf_config.image_token_id,
43
+ video_token=tokenizer.video_token,
44
44
  video_token_id=hf_config.video_token_id,
45
45
  ).build(_processor)
46
46
 
47
47
  async def process_mm_data_async(
48
48
  self,
49
- image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]],
50
- input_text: str | List[int],
51
- request_obj: GenerateReqInput | EmbeddingReqInput,
49
+ image_data,
50
+ audio_data,
51
+ input_text,
52
+ request_obj: GenerateReqInput,
52
53
  **kwargs,
53
- ) -> Optional[Dict[str, Any]]:
54
+ ) -> dict[str, Any] | None:
54
55
  base_output = self.load_mm_data(
55
56
  prompt=input_text,
56
57
  multimodal_tokens=self.mm_tokens,
57
- image_data=image_data,
58
+ image_data=request_obj.image_data, # type: ignore
59
+ video_data=request_obj.video_data, # type: ignore
58
60
  )
59
61
 
62
+ for i, video in enumerate(base_output.videos): # type: ignore
63
+ base_output.videos[i] = [x.asnumpy() for x in video] # type: ignore
64
+
60
65
  mm_items, input_ids, _ = self.process_and_combine_mm_data(
61
- base_output, self.mm_tokens
66
+ base_output,
67
+ self.mm_tokens,
68
+ do_sample_frames=True,
69
+ num_frames=NUM_VIDEO_FRAMES,
62
70
  )
63
71
 
64
72
  return {
@@ -7,12 +7,12 @@ from PIL import Image
7
7
 
8
8
  from sglang.srt.models.points_v15_chat import POINTSV15ChatModel
9
9
  from sglang.srt.multimodal.processors.qwen_vl import (
10
- Qwen2_5VLImageProcessor,
10
+ QwenVLImageProcessor,
11
11
  resize_image_async,
12
12
  )
13
13
 
14
14
 
15
- class POINTSV15ChatProcessor(Qwen2_5VLImageProcessor):
15
+ class POINTSV15ChatProcessor(QwenVLImageProcessor):
16
16
  models = [POINTSV15ChatModel]
17
17
 
18
18
  def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
@@ -0,0 +1,209 @@
1
+ """
2
+ Mixin class providing multiplexing scheduling logic
3
+ """
4
+
5
+ import logging
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.cuda.streams import ExternalStream
10
+
11
+ from sglang.srt.distributed.parallel_state import set_pdmux_status
12
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
13
+ from sglang.srt.multiplex.pdmux_context import (
14
+ get_current_stream_idx,
15
+ get_sm_counts,
16
+ get_stream_groups,
17
+ initialize_stream_groups,
18
+ load_pdmux_config,
19
+ set_current_stream_idx,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class SchedulerMultiplexMixin:
26
+
27
+ def init_pdmux(self):
28
+ # for pd_multiplexing, Init stream_groups, exclude normal stream for prefill only and decode only
29
+ self.pdmux_config = load_pdmux_config(self.server_args.pdmux_config_path)
30
+ initialize_stream_groups(self.gpu_id, self.pdmux_config)
31
+ self.stream_groups = get_stream_groups()
32
+ self.sm_counts = get_sm_counts()
33
+ self.real_sm_group_num = len(self.stream_groups)
34
+ logger.info(
35
+ f"PD-Multiplexing enabled with {self.real_sm_group_num} stream groups, sm_counts (prefill_sm, decode_sm): {self.sm_counts}"
36
+ )
37
+
38
+ # TODO(jason-fxz): This is a temporary demo
39
+ def adjust_stream_groups(self) -> tuple[int, tuple[ExternalStream, ExternalStream]]:
40
+ if not self.running_batch.is_empty() and self.split_prefill_batch:
41
+ decode_bs = self.running_batch.batch_size()
42
+ manual_divisions = self.pdmux_config.manual_divisions
43
+ if manual_divisions:
44
+ for i in range(len(manual_divisions)):
45
+ _, _, threshold = manual_divisions[i]
46
+ if decode_bs >= threshold:
47
+ stream_idx = i + 1
48
+ else:
49
+ stream_idx = max(
50
+ 1,
51
+ min(
52
+ self.real_sm_group_num - 2,
53
+ decode_bs
54
+ * (self.real_sm_group_num - 2)
55
+ // self.pdmux_config.decode_bs_divisor,
56
+ ),
57
+ )
58
+ set_current_stream_idx(stream_idx)
59
+ elif not self.running_batch.is_empty():
60
+ set_current_stream_idx(self.real_sm_group_num - 1)
61
+ else:
62
+ set_current_stream_idx(0)
63
+
64
+ stream_idx = get_current_stream_idx()
65
+
66
+ self.tp_worker.model_runner.update_decode_attn_backend(stream_idx)
67
+ return stream_idx, self.stream_groups[stream_idx]
68
+
69
+ def update_split_prefill_batch(self, sm_count: int) -> bool:
70
+ if self.split_prefill_batch:
71
+ return False
72
+
73
+ # add new request
74
+ batch = self.get_new_batch_prefill()
75
+ if batch and not batch.is_empty():
76
+ batch.forward_mode = (
77
+ ForwardMode.SPLIT_PREFILL
78
+ ) # Set forward mode for split prefill
79
+ self.split_prefill_batch = batch
80
+ return True
81
+ return False
82
+
83
+ @torch.inference_mode()
84
+ def event_loop_pdmux(self):
85
+ """A scheduler loop for pd multiplexing."""
86
+ decode_done = False
87
+ prefill_done = False
88
+ wait_prefill_kernel_done = False
89
+ adjust_stream_group = False
90
+ stream_idx = get_current_stream_idx()
91
+ stream_group = self.stream_groups[stream_idx]
92
+ prefill_stream = stream_group[0]
93
+ decode_stream = stream_group[1]
94
+ torch.cuda.empty_cache()
95
+
96
+ logger.debug("Starting event loop for pd multiplexing...")
97
+
98
+ while True:
99
+ with torch.cuda.stream(decode_stream):
100
+ set_pdmux_status(False)
101
+ recv_reqs = self.recv_requests()
102
+ self.process_input_requests(recv_reqs)
103
+
104
+ with torch.cuda.stream(prefill_stream):
105
+ set_pdmux_status(True)
106
+ sm_count = self.sm_counts[stream_idx][0]
107
+ if not wait_prefill_kernel_done:
108
+ adjust_stream_group = (
109
+ self.update_split_prefill_batch(sm_count) or adjust_stream_group
110
+ )
111
+
112
+ with torch.cuda.stream(decode_stream):
113
+ set_pdmux_status(False)
114
+ self.running_batch = self.update_running_batch(self.running_batch)
115
+ adjust_stream_group = adjust_stream_group or (
116
+ stream_idx > 0 and self.running_batch.is_empty()
117
+ )
118
+ if self.running_batch.is_empty() and self.split_prefill_batch is None:
119
+ self.check_memory()
120
+ self.check_tree_cache()
121
+ self.new_token_ratio = self.init_new_token_ratio
122
+ self.maybe_sleep_on_idle()
123
+
124
+ if adjust_stream_group:
125
+ prefill_stream.synchronize()
126
+ decode_stream.synchronize()
127
+ stream_idx, stream_group = self.adjust_stream_groups()
128
+ prefill_stream = stream_group[0]
129
+ decode_stream = stream_group[1]
130
+ adjust_stream_group = False
131
+ logger.debug(
132
+ f"Adjusting stream groups: {stream_idx}, prefill sm: {self.sm_counts[stream_idx][0]}, decode sm: {self.sm_counts[stream_idx][1]}"
133
+ )
134
+
135
+ with torch.cuda.stream(decode_stream):
136
+ set_pdmux_status(False)
137
+ # process decode batch
138
+ if self.running_batch and not self.running_batch.is_empty():
139
+ decode_result = self.run_batch(self.running_batch)
140
+ decode_done = True
141
+ else:
142
+ decode_done = False
143
+ with torch.cuda.stream(prefill_stream):
144
+ set_pdmux_status(True)
145
+ if (
146
+ self.split_prefill_batch
147
+ and not self.split_prefill_batch.is_empty()
148
+ and not wait_prefill_kernel_done
149
+ ):
150
+ prefill_done = True
151
+ forward_count = (
152
+ max(
153
+ 1,
154
+ self.pdmux_config.split_forward_token_budget
155
+ // self.split_prefill_batch.extend_num_tokens,
156
+ )
157
+ if self.split_prefill_batch.extend_num_tokens > 0
158
+ else self.model_config.num_hidden_layers
159
+ )
160
+ next_split_index = min(
161
+ self.split_prefill_batch.split_index + forward_count,
162
+ self.model_config.num_hidden_layers,
163
+ )
164
+ forward_count = (
165
+ next_split_index - self.split_prefill_batch.split_index
166
+ )
167
+
168
+ self.split_prefill_batch.split_forward_count = forward_count
169
+ prefill_result = self.run_batch(self.split_prefill_batch)
170
+ if next_split_index == self.model_config.num_hidden_layers:
171
+ self.split_prefill_batch.split_prefill_finished = True
172
+ prefill_exe_done = prefill_stream.record_event()
173
+ self.split_prefill_batch.split_index = next_split_index
174
+
175
+ elif wait_prefill_kernel_done:
176
+ prefill_done = True
177
+ else:
178
+ prefill_done = False
179
+
180
+ with torch.cuda.stream(decode_stream):
181
+ set_pdmux_status(False)
182
+ decode_stream.synchronize()
183
+ if decode_done:
184
+ self.process_batch_result(self.running_batch, decode_result)
185
+
186
+ with torch.cuda.stream(prefill_stream):
187
+ set_pdmux_status(True)
188
+ if prefill_done and self.split_prefill_batch.split_prefill_finished:
189
+ wait_prefill_kernel_done = True
190
+ prefill_exe_done_flag = prefill_exe_done.query()
191
+ flags = (
192
+ torch.ones(1, device="cpu", dtype=torch.int32)
193
+ if prefill_exe_done_flag
194
+ else torch.zeros(1, device="cpu", dtype=torch.int32)
195
+ )
196
+
197
+ self.tp_cpu_group.allreduce(flags, dist.ReduceOp.SUM).wait()
198
+ if flags.item() == self.tp_size:
199
+ self.process_batch_result(
200
+ self.split_prefill_batch, prefill_result
201
+ )
202
+ if self.running_batch and not self.running_batch.is_empty():
203
+ self.running_batch.merge_batch(self.split_prefill_batch)
204
+ else:
205
+ self.running_batch = self.split_prefill_batch
206
+
207
+ self.split_prefill_batch = None
208
+ wait_prefill_kernel_done = False
209
+ adjust_stream_group = True