sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
68
68
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
69
69
  from sglang.srt.layers.utils import is_sm100_supported
70
70
  from sglang.srt.lora.lora_manager import LoRAManager
71
+ from sglang.srt.lora.lora_registry import LoRARef
71
72
  from sglang.srt.managers.schedule_batch import (
72
73
  GLOBAL_SERVER_ARGS_KEYS,
73
74
  global_server_args_dict,
@@ -108,7 +109,6 @@ from sglang.srt.utils import (
108
109
  get_bool_env_var,
109
110
  get_cpu_ids_by_node,
110
111
  init_custom_process_group,
111
- is_cuda,
112
112
  is_fa3_default_architecture,
113
113
  is_flashinfer_available,
114
114
  is_hip,
@@ -275,6 +275,16 @@ class ModelRunner:
275
275
  self.sampler = Sampler()
276
276
  self.load_model()
277
277
 
278
+ # Check if the model is using hybrid SWA
279
+ if (
280
+ not self.server_args.disable_hybrid_swa_memory
281
+ and self.sliding_window_size is not None
282
+ and self.sliding_window_size > 0
283
+ ):
284
+ architectures = self.model_config.hf_config.architectures
285
+ if architectures and not any("Llama4" in arch for arch in architectures):
286
+ self.is_hybrid = self.model_config.is_hybrid = True
287
+
278
288
  self.start_layer = getattr(self.model, "start_layer", 0)
279
289
  self.end_layer = getattr(
280
290
  self.model, "end_layer", self.model_config.num_hidden_layers
@@ -295,11 +305,7 @@ class ModelRunner:
295
305
  self.apply_torch_tp()
296
306
 
297
307
  # Init lora
298
- # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
299
- # a new server arg `enable_lora` to control whether to init LoRA manager to be more
300
- # explicit, as it is perfectly valid to start a server with an empty lora_paths and
301
- # load LoRA adapters dynamically later.
302
- if server_args.lora_paths is not None:
308
+ if server_args.enable_lora:
303
309
  self.init_lora_manager()
304
310
 
305
311
  # Init memory pool and attention backends
@@ -372,6 +378,7 @@ class ModelRunner:
372
378
  is_hopper_with_cuda_12_3()
373
379
  and is_no_spec_infer_or_topk_one(server_args)
374
380
  and is_fa3_default_architecture(self.model_config.hf_config)
381
+ and (not server_args.enable_hierarchical_cache)
375
382
  ):
376
383
  server_args.attention_backend = "fa3"
377
384
  elif _is_hip:
@@ -384,7 +391,9 @@ class ModelRunner:
384
391
  )
385
392
  else:
386
393
  # MLA architecture
387
- if is_hopper_with_cuda_12_3():
394
+ if is_hopper_with_cuda_12_3() and (
395
+ not server_args.enable_hierarchical_cache
396
+ ):
388
397
  server_args.attention_backend = "fa3"
389
398
  elif is_sm100_supported():
390
399
  server_args.attention_backend = "flashinfer"
@@ -402,7 +411,7 @@ class ModelRunner:
402
411
  else:
403
412
  server_args.attention_backend = "triton"
404
413
  logger.info(
405
- f"Attention backend not set. Use {server_args.attention_backend} backend by default."
414
+ f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
406
415
  )
407
416
  elif self.use_mla_backend:
408
417
  if server_args.device != "cpu":
@@ -454,7 +463,7 @@ class ModelRunner:
454
463
  if not self.is_multimodal_chunked_prefill_supported:
455
464
  server_args.chunked_prefill_size = -1
456
465
  logger.info(
457
- f"Automatically turn of --chunked-prefill-size as it is not supported for "
466
+ f"Automatically turn off --chunked-prefill-size as it is not supported for "
458
467
  f"{self.model_config.hf_config.model_type}"
459
468
  )
460
469
 
@@ -471,10 +480,6 @@ class ModelRunner:
471
480
  if self.model_config.context_len > 8192:
472
481
  self.mem_fraction_static *= 0.85
473
482
 
474
- if self.is_hybrid and not server_args.disable_radix_cache:
475
- logger.info("Automatically disable radix cache for hybrid cache.")
476
- server_args.disable_radix_cache = True
477
-
478
483
  def init_torch_distributed(self):
479
484
  logger.info("Init torch distributed begin.")
480
485
 
@@ -534,6 +539,7 @@ class ModelRunner:
534
539
  initialize_model_parallel(
535
540
  tensor_model_parallel_size=self.tp_size,
536
541
  pipeline_model_parallel_size=self.pp_size,
542
+ duplicate_tp_group=self.server_args.enable_pdmux,
537
543
  )
538
544
  initialize_dp_attention(
539
545
  enable_dp_attention=self.server_args.enable_dp_attention,
@@ -555,7 +561,7 @@ class ModelRunner:
555
561
 
556
562
  # Check memory for tensor parallelism
557
563
  local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
558
- if self.tp_size > 1:
564
+ if self.tp_size > 1 and not self.is_draft_worker:
559
565
  if min_per_gpu_memory < local_gpu_memory * 0.9:
560
566
  if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
561
567
  logger.warning(
@@ -645,11 +651,15 @@ class ModelRunner:
645
651
  )
646
652
 
647
653
  # Parse other args
648
- self.sliding_window_size = (
649
- self.model.get_attention_sliding_window_size()
650
- if hasattr(self.model, "get_attention_sliding_window_size")
651
- else None
652
- )
654
+ self.sliding_window_size = None
655
+ if hasattr(self.model, "get_attention_sliding_window_size"):
656
+ self.sliding_window_size = self.model.get_attention_sliding_window_size()
657
+ elif self.model_config.attention_chunk_size is not None:
658
+ self.sliding_window_size = self.model_config.attention_chunk_size
659
+ print(
660
+ f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
661
+ )
662
+
653
663
  self.dtype = self.model_config.dtype
654
664
 
655
665
  after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
@@ -882,44 +892,40 @@ class ModelRunner:
882
892
  lora_backend=self.server_args.lora_backend,
883
893
  tp_size=self.tp_size,
884
894
  tp_rank=self.tp_rank,
895
+ max_lora_rank=self.server_args.max_lora_rank,
896
+ target_modules=self.server_args.lora_target_modules,
897
+ lora_paths=self.server_args.lora_paths,
885
898
  )
886
- result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
887
- if result.success:
888
- logger.info(
889
- f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
890
- )
891
- else:
892
- raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
893
899
 
894
- def load_lora_adapter(self, lora_name: str, lora_path: str):
900
+ def load_lora_adapter(self, lora_ref: LoRARef):
895
901
  """Load a new lora adapter from disk or huggingface."""
896
902
 
897
903
  logger.info(
898
- f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
904
+ f"LoRA adapter loading starts: {lora_ref}. "
899
905
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
900
906
  )
901
907
 
902
- result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
908
+ result = self.lora_manager.load_lora_adapter(lora_ref)
903
909
 
904
910
  logger.info(
905
- f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
911
+ f"LoRA adapter loading completes: {lora_ref}. "
906
912
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
907
913
  )
908
914
 
909
915
  return result
910
916
 
911
- def unload_lora_adapter(self, lora_name: str):
917
+ def unload_lora_adapter(self, lora_ref: LoRARef):
912
918
  """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
913
919
 
914
920
  logger.info(
915
- f"LoRA adapter unloading starts: name={lora_name}. "
921
+ f"LoRA adapter unloading starts: {lora_ref}. "
916
922
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
917
923
  )
918
924
 
919
- result = self.lora_manager.unload_lora_adapter(lora_name)
925
+ result = self.lora_manager.unload_lora_adapter(lora_ref)
920
926
 
921
927
  logger.info(
922
- f"LoRA adapter unloading completes: name={lora_name}. "
928
+ f"LoRA adapter unloading completes: {lora_ref}. "
923
929
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
924
930
  )
925
931
 
@@ -992,8 +998,56 @@ class ModelRunner:
992
998
  )
993
999
  self.max_total_num_tokens = self.full_max_total_num_tokens
994
1000
  else:
995
- raise ValueError(
996
- f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
1001
+ assert self.sliding_window_size is not None and self.sliding_window_size > 0
1002
+ full_attention_layer_ids = []
1003
+ swa_attention_layer_ids = []
1004
+
1005
+ try:
1006
+ layers = self.model.model.layers
1007
+ except:
1008
+ try:
1009
+ layers = self.model.language_model.model.layers
1010
+ except:
1011
+ try:
1012
+ layers = self.model.language_model.layers
1013
+ except:
1014
+ self.is_hybrid = False
1015
+ return
1016
+
1017
+ for layer in layers:
1018
+ if (
1019
+ layer.self_attn.attn.sliding_window_size is None
1020
+ or layer.self_attn.attn.sliding_window_size == -1
1021
+ ):
1022
+ full_attention_layer_ids.append(layer.layer_id)
1023
+ else:
1024
+ swa_attention_layer_ids.append(layer.layer_id)
1025
+ self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
1026
+ self.model_config.full_attention_layer_ids = full_attention_layer_ids
1027
+
1028
+ # Algorithm:
1029
+ # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
1030
+ # - Find total # of tokens available across layers.
1031
+ # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
1032
+ total_tokens = (
1033
+ self.max_total_num_tokens * self.model_config.num_hidden_layers
1034
+ )
1035
+ full_layers_num = len(full_attention_layer_ids)
1036
+ swa_layers_num = len(swa_attention_layer_ids)
1037
+ swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio
1038
+
1039
+ # Solve the equations:
1040
+ # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
1041
+ # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
1042
+ denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
1043
+ self.full_max_total_num_tokens = int(total_tokens / denominator)
1044
+ self.swa_max_total_num_tokens = int(
1045
+ self.full_max_total_num_tokens * swa_full_tokens_ratio
1046
+ )
1047
+ self.max_total_num_tokens = self.full_max_total_num_tokens
1048
+
1049
+ logger.info(
1050
+ f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
997
1051
  )
998
1052
 
999
1053
  def init_memory_pool(
@@ -1072,7 +1126,6 @@ class ModelRunner:
1072
1126
  // self.server_args.page_size
1073
1127
  * self.server_args.page_size
1074
1128
  )
1075
-
1076
1129
  # create token size for hybrid cache
1077
1130
  if self.is_hybrid:
1078
1131
  self.set_num_token_hybrid()
@@ -1410,9 +1463,13 @@ class ModelRunner:
1410
1463
  tensor_parallel(self.model, device_mesh)
1411
1464
 
1412
1465
  def forward_decode(
1413
- self, forward_batch: ForwardBatch, pp_proxy_tensors=None
1466
+ self,
1467
+ forward_batch: ForwardBatch,
1468
+ skip_attn_backend_init: bool = False,
1469
+ pp_proxy_tensors=None,
1414
1470
  ) -> LogitsProcessorOutput:
1415
- self.attn_backend.init_forward_metadata(forward_batch)
1471
+ if not skip_attn_backend_init:
1472
+ self.attn_backend.init_forward_metadata(forward_batch)
1416
1473
  # FIXME: add pp_proxy_tensors arg to all models
1417
1474
  kwargs = {}
1418
1475
  if self.support_pp:
@@ -1457,11 +1514,34 @@ class ModelRunner:
1457
1514
  **kwargs,
1458
1515
  )
1459
1516
 
1517
+ def forward_split_prefill(
1518
+ self,
1519
+ forward_batch: ForwardBatch,
1520
+ reinit_attn_backend: bool = False,
1521
+ forward_count: int = 1,
1522
+ ) -> LogitsProcessorOutput:
1523
+ if forward_batch.split_index == 0 or reinit_attn_backend:
1524
+ self.attn_backend.init_forward_metadata(forward_batch)
1525
+ next_split_index = min(
1526
+ forward_batch.split_index + forward_count,
1527
+ self.model_config.num_hidden_layers,
1528
+ )
1529
+ ret = self.model.forward_split_prefill(
1530
+ forward_batch.input_ids,
1531
+ forward_batch.positions,
1532
+ forward_batch,
1533
+ (forward_batch.split_index, next_split_index),
1534
+ )
1535
+ forward_batch.split_index = next_split_index
1536
+ return ret
1537
+
1460
1538
  def forward(
1461
1539
  self,
1462
1540
  forward_batch: ForwardBatch,
1463
1541
  skip_attn_backend_init: bool = False,
1464
1542
  pp_proxy_tensors: Optional[PPProxyTensors] = None,
1543
+ reinit_attn_backend: bool = False,
1544
+ split_forward_count: int = 1,
1465
1545
  ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1466
1546
  self.forward_pass_id += 1
1467
1547
 
@@ -1470,7 +1550,11 @@ class ModelRunner:
1470
1550
  forward_batch,
1471
1551
  ):
1472
1552
  output = self._forward_raw(
1473
- forward_batch, skip_attn_backend_init, pp_proxy_tensors
1553
+ forward_batch,
1554
+ skip_attn_backend_init,
1555
+ pp_proxy_tensors,
1556
+ reinit_attn_backend,
1557
+ split_forward_count,
1474
1558
  )
1475
1559
 
1476
1560
  if self.eplb_manager is not None:
@@ -1483,6 +1567,8 @@ class ModelRunner:
1483
1567
  forward_batch: ForwardBatch,
1484
1568
  skip_attn_backend_init: bool,
1485
1569
  pp_proxy_tensors: Optional[PPProxyTensors],
1570
+ reinit_attn_backend: bool = False,
1571
+ split_forward_count: int = 1,
1486
1572
  ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1487
1573
  can_run_cuda_graph = bool(
1488
1574
  forward_batch.forward_mode.is_cuda_graph()
@@ -1495,19 +1581,38 @@ class ModelRunner:
1495
1581
  skip_attn_backend_init=skip_attn_backend_init,
1496
1582
  pp_proxy_tensors=pp_proxy_tensors,
1497
1583
  )
1498
- elif forward_batch.forward_mode.is_decode():
1499
- ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1584
+ return ret, can_run_cuda_graph
1585
+
1586
+ # For MLP sync
1587
+ if forward_batch.global_num_tokens_cpu is not None:
1588
+ forward_batch.prepare_mlp_sync_batch(self)
1589
+
1590
+ if forward_batch.forward_mode.is_decode():
1591
+ ret = self.forward_decode(
1592
+ forward_batch,
1593
+ skip_attn_backend_init=skip_attn_backend_init,
1594
+ pp_proxy_tensors=pp_proxy_tensors,
1595
+ )
1500
1596
  elif forward_batch.forward_mode.is_extend():
1501
1597
  ret = self.forward_extend(
1502
1598
  forward_batch,
1503
1599
  skip_attn_backend_init=skip_attn_backend_init,
1504
1600
  pp_proxy_tensors=pp_proxy_tensors,
1505
1601
  )
1602
+ elif forward_batch.forward_mode.is_split_prefill():
1603
+ ret = self.forward_split_prefill(
1604
+ forward_batch,
1605
+ reinit_attn_backend=reinit_attn_backend,
1606
+ forward_count=split_forward_count,
1607
+ )
1506
1608
  elif forward_batch.forward_mode.is_idle():
1507
1609
  ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1508
1610
  else:
1509
1611
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1510
1612
 
1613
+ if forward_batch.global_num_tokens_cpu is not None:
1614
+ forward_batch.post_forward_mlp_sync_batch(ret)
1615
+
1511
1616
  return ret, can_run_cuda_graph
1512
1617
 
1513
1618
  def _preprocess_logits(
@@ -575,7 +575,13 @@ class DummyModelLoader(BaseModelLoader):
575
575
  # 2. Post-processing of weights, including assigning specific member variables.
576
576
  # For `dummy_init`, only the second stage is required.
577
577
  if hasattr(model, "post_load_weights"):
578
- model.post_load_weights()
578
+ if (
579
+ model_config.hf_config.architectures[0]
580
+ == "DeepseekV3ForCausalLMNextN"
581
+ ):
582
+ model.post_load_weights(is_nextn=True)
583
+ else:
584
+ model.post_load_weights()
579
585
 
580
586
  return model.eval()
581
587
 
@@ -56,14 +56,14 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str
56
56
  "if the model is custom)."
57
57
  )
58
58
  model_module = auto_modules["AutoModel"]
59
- if model_config.impl == ModelImpl.TRANSFORMERS:
59
+ if model_config.model_impl == ModelImpl.TRANSFORMERS:
60
60
  if not model_module.is_backend_compatible():
61
61
  raise ValueError(
62
62
  f"The Transformers implementation of {arch} is not "
63
- "compatible with vLLM."
63
+ "compatible with SGLang."
64
64
  )
65
65
  architectures[i] = "TransformersForCausalLM"
66
- if model_config.impl == ModelImpl.AUTO:
66
+ if model_config.model_impl == ModelImpl.AUTO:
67
67
  if not model_module.is_backend_compatible():
68
68
  raise ValueError(
69
69
  f"{arch} has no SGlang implementation and the Transformers "
@@ -97,7 +97,7 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
97
97
  supported_archs = ModelRegistry.get_supported_archs()
98
98
  is_native_supported = any(arch in supported_archs for arch in architectures)
99
99
 
100
- if not is_native_supported or model_config.impl == ModelImpl.TRANSFORMERS:
100
+ if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS:
101
101
  architectures = resolve_transformers_arch(model_config, architectures)
102
102
 
103
103
  return ModelRegistry.resolve_model_cls(architectures)
sglang/srt/models/clip.py CHANGED
@@ -463,7 +463,7 @@ class CLIPModel(nn.Module):
463
463
  if forward_batch.mm_inputs is not None:
464
464
  mm_inputs = forward_batch.mm_inputs
465
465
  pixel_values_list = [
466
- item.pixel_values
466
+ item.feature
467
467
  for item in flatten_nested_list(
468
468
  [mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
469
469
  )
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
37
37
  )
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
40
+ from sglang.srt.layers.moe.topk import TopK
40
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
42
  from sglang.srt.layers.radix_attention import RadixAttention
42
43
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -109,7 +110,10 @@ class DeepseekMoE(nn.Module):
109
110
  f"Tensor parallel size {self.tp_size} is greater than "
110
111
  f"the number of experts {self.n_routed_experts}."
111
112
  )
112
-
113
+ self.topk = TopK(
114
+ top_k=self.top_k,
115
+ renormalize=config.norm_topk_prob,
116
+ )
113
117
  self.experts = nn.ModuleList(
114
118
  [
115
119
  DeepseekMLP(
@@ -170,13 +174,12 @@ class DeepseekMoE(nn.Module):
170
174
  shared_output = self.shared_experts(hidden_states)
171
175
  # router_logits: (num_tokens, n_experts)
172
176
  router_logits, _ = self.gate(hidden_states)
177
+ topk_output = self.topk(hidden_states, router_logits)
173
178
  final_hidden_states = fused_moe.fused_moe(
174
179
  hidden_states,
175
- self.w1,
176
- self.w2,
177
- router_logits,
178
- self.top_k,
179
- renormalize=self.config.norm_topk_prob,
180
+ w1=self.w1,
181
+ w2=self.w2,
182
+ topk_output=topk_output,
180
183
  inplace=True,
181
184
  )
182
185
 
@@ -1960,7 +1960,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1960
1960
  self.logits_processor = LogitsProcessor(config)
1961
1961
 
1962
1962
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1963
- pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
1963
+ pixel_values = torch.concat([item.feature for item in items], dim=0)
1964
1964
  bs, n = pixel_values.shape[0:2]
1965
1965
  pixel_values = pixel_values.to(
1966
1966
  device=self.vision_model.device, dtype=self.vision_model.dtype