sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import (
75
75
  global_server_args_dict,
76
76
  )
77
77
  from sglang.srt.mem_cache.allocator import (
78
- AscendPagedTokenToKVPoolAllocator,
79
78
  BaseTokenToKVPoolAllocator,
80
79
  PagedTokenToKVPoolAllocator,
81
80
  SWATokenToKVPoolAllocator,
82
81
  TokenToKVPoolAllocator,
83
82
  )
83
+ from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
84
84
  from sglang.srt.mem_cache.memory_pool import (
85
85
  AscendMLAPagedTokenToKVPool,
86
86
  AscendTokenToKVPool,
@@ -121,6 +121,10 @@ from sglang.srt.utils import (
121
121
  set_cpu_offload_max_bytes,
122
122
  set_cuda_arch,
123
123
  )
124
+ from sglang.srt.weight_sync.tensor_bucket import (
125
+ FlattenedTensorBucket,
126
+ FlattenedTensorMetadata,
127
+ )
124
128
 
125
129
  _is_hip = is_hip()
126
130
  _is_npu = is_npu()
@@ -172,10 +176,6 @@ class ModelRunner:
172
176
  self.mem_fraction_static = mem_fraction_static
173
177
  self.device = server_args.device
174
178
  self.gpu_id = gpu_id
175
-
176
- # Apply the rank zero filter to logger
177
- if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
178
- logger.addFilter(RankZeroFilter(tp_rank == 0))
179
179
  self.tp_rank = tp_rank
180
180
  self.tp_size = tp_size
181
181
  self.moe_ep_rank = moe_ep_rank
@@ -201,15 +201,17 @@ class ModelRunner:
201
201
  self.is_hybrid = model_config.is_hybrid
202
202
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
203
203
  self.attention_chunk_size = model_config.attention_chunk_size
204
-
205
204
  self.forward_pass_id = 0
206
205
 
207
- # Model-specific adjustment
208
- self.model_specific_adjustment()
209
-
206
+ # Apply the rank zero filter to logger
207
+ if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
208
+ logger.addFilter(RankZeroFilter(tp_rank == 0))
210
209
  if server_args.show_time_cost:
211
210
  enable_show_time_cost()
212
211
 
212
+ # Model-specific adjustment
213
+ self.model_specific_adjustment()
214
+
213
215
  # Global vars
214
216
  global_server_args_dict.update(
215
217
  {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@@ -217,8 +219,6 @@ class ModelRunner:
217
219
  # TODO it is indeed not a "server args"
218
220
  "use_mla_backend": self.use_mla_backend,
219
221
  "speculative_algorithm": self.spec_algorithm,
220
- }
221
- | {
222
222
  "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
223
223
  "deepep_mode": DeepEPMode(server_args.deepep_mode),
224
224
  }
@@ -238,13 +238,15 @@ class ModelRunner:
238
238
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
239
239
  deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
240
240
 
241
- # If it is a draft model, tp_group can be different
241
+ # Initialize the model runner
242
242
  self.initialize(min_per_gpu_memory)
243
243
 
244
- # temporary cached values
244
+ # Temporary cached values
245
245
  self.support_pp = (
246
246
  "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
247
247
  )
248
+
249
+ # For weight updates
248
250
  self._model_update_group = {}
249
251
 
250
252
  def initialize(self, min_per_gpu_memory: float):
@@ -273,6 +275,7 @@ class ModelRunner:
273
275
  )
274
276
  )
275
277
 
278
+ # Expert parallelism
276
279
  self.eplb_manager = (
277
280
  EPLBManager(self)
278
281
  if self.server_args.enable_eplb and (not self.is_draft_worker)
@@ -378,6 +381,25 @@ class ModelRunner:
378
381
  )
379
382
  server_args.attention_backend = "torch_native"
380
383
 
384
+ if server_args.prefill_attention_backend is not None and (
385
+ server_args.prefill_attention_backend
386
+ == server_args.decode_attention_backend
387
+ ): # override the default attention backend
388
+ server_args.attention_backend = server_args.prefill_attention_backend
389
+
390
+ if (
391
+ getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
392
+ is not None
393
+ ):
394
+ if server_args.attention_backend is None:
395
+ server_args.attention_backend = "dual_chunk_flash_attn"
396
+ logger.info("Dual chunk attention is turned on by default.")
397
+ elif server_args.attention_backend != "dual_chunk_flash_attn":
398
+ raise ValueError(
399
+ "Dual chunk attention is enabled, but attention backend is set to "
400
+ f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
401
+ )
402
+
381
403
  if server_args.attention_backend is None:
382
404
  """
383
405
  Auto select the fastest attention backend.
@@ -397,7 +419,6 @@ class ModelRunner:
397
419
  is_hopper_with_cuda_12_3()
398
420
  and is_no_spec_infer_or_topk_one(server_args)
399
421
  and is_fa3_default_architecture(self.model_config.hf_config)
400
- and (not server_args.enable_hierarchical_cache)
401
422
  ):
402
423
  server_args.attention_backend = "fa3"
403
424
  elif _is_hip:
@@ -410,9 +431,7 @@ class ModelRunner:
410
431
  )
411
432
  else:
412
433
  # MLA architecture
413
- if is_hopper_with_cuda_12_3() and (
414
- not server_args.enable_hierarchical_cache
415
- ):
434
+ if is_hopper_with_cuda_12_3():
416
435
  server_args.attention_backend = "fa3"
417
436
  elif is_sm100_supported():
418
437
  server_args.attention_backend = "flashinfer"
@@ -500,6 +519,27 @@ class ModelRunner:
500
519
  if self.model_config.context_len > 8192:
501
520
  self.mem_fraction_static *= 0.85
502
521
 
522
+ if (
523
+ server_args.enable_hierarchical_cache
524
+ and server_args.hicache_io_backend == "kernel"
525
+ ):
526
+ # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
527
+ if server_args.decode_attention_backend is None:
528
+ if not self.use_mla_backend:
529
+ server_args.decode_attention_backend = (
530
+ "flashinfer" if is_flashinfer_available() else "triton"
531
+ )
532
+ else:
533
+ server_args.decode_attention_backend = (
534
+ "flashinfer" if is_sm100_supported() else "triton"
535
+ )
536
+ elif server_args.decode_attention_backend == "fa3":
537
+ server_args.hicache_io_backend = "direct"
538
+ logger.warning(
539
+ "FlashAttention3 decode backend is not compatible with hierarchical cache. "
540
+ f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
541
+ )
542
+
503
543
  def init_torch_distributed(self):
504
544
  logger.info("Init torch distributed begin.")
505
545
 
@@ -563,12 +603,8 @@ class ModelRunner:
563
603
  duplicate_tp_group=self.server_args.enable_pdmux,
564
604
  )
565
605
  initialize_dp_attention(
566
- enable_dp_attention=self.server_args.enable_dp_attention,
567
- tp_rank=self.tp_rank,
568
- tp_size=self.tp_size,
569
- dp_size=self.server_args.dp_size,
570
- moe_dense_tp_size=self.server_args.moe_dense_tp_size,
571
- pp_size=self.server_args.pp_size,
606
+ server_args=self.server_args,
607
+ model_config=self.model_config,
572
608
  )
573
609
 
574
610
  min_per_gpu_memory = get_available_gpu_memory(
@@ -871,8 +907,18 @@ class ModelRunner:
871
907
  named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
872
908
  load_format: Optional[str] = None,
873
909
  ):
910
+ monkey_patch_torch_reductions()
911
+ if load_format == "flattened_bucket":
912
+ # Handle flattened bucket format
913
+ return self._update_weights_from_flattened_bucket(
914
+ flattened_tensor_bucket_dict=named_tensors
915
+ )
916
+
917
+ # We need to get device after patch otherwise the device would be wrong
918
+ infered_device = torch.cuda.current_device()
919
+
874
920
  named_tensors = [
875
- (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
921
+ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
876
922
  for name, tensor in named_tensors
877
923
  ]
878
924
  if load_format == "direct":
@@ -886,6 +932,38 @@ class ModelRunner:
886
932
  raise NotImplementedError(f"Unknown load_format={load_format}")
887
933
  return True, "Success"
888
934
 
935
+ def _update_weights_from_flattened_bucket(
936
+ self,
937
+ flattened_tensor_bucket_dict,
938
+ ):
939
+ """Handle flattened bucket format for weight updates"""
940
+ flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
941
+ metadata = flattened_tensor_bucket_dict["metadata"]
942
+
943
+ # Convert metadata dict to our format
944
+ converted_metadata = []
945
+ for meta in metadata:
946
+ converted_meta = FlattenedTensorMetadata(
947
+ name=meta.name,
948
+ shape=meta.shape,
949
+ dtype=meta.dtype,
950
+ start_idx=meta.start_idx,
951
+ end_idx=meta.end_idx,
952
+ numel=meta.numel,
953
+ )
954
+ converted_metadata.append(converted_meta)
955
+
956
+ # Create bucket and reconstruct tensors
957
+ bucket = FlattenedTensorBucket(
958
+ flattened_tensor=flattened_tensor, metadata=converted_metadata
959
+ )
960
+ reconstructed_tensors = bucket.reconstruct_tensors()
961
+
962
+ # Load the reconstructed tensors using the standard method
963
+ self.model.load_weights(reconstructed_tensors)
964
+
965
+ return True, "Success"
966
+
889
967
  def get_weights_by_name(
890
968
  self, name: str, truncate_size: int = 100
891
969
  ) -> Optional[torch.Tensor]:
@@ -1077,6 +1155,7 @@ class ModelRunner:
1077
1155
  max_num_reqs: Optional[int] = None,
1078
1156
  max_total_tokens: Optional[int] = None,
1079
1157
  ):
1158
+ # Determine the kv cache dtype
1080
1159
  if self.server_args.kv_cache_dtype == "auto":
1081
1160
  self.kv_cache_dtype = self.dtype
1082
1161
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
@@ -1095,6 +1174,8 @@ class ModelRunner:
1095
1174
  )
1096
1175
 
1097
1176
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1177
+ if SGLANG_CI_SMALL_KV_SIZE:
1178
+ self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1098
1179
 
1099
1180
  if max_num_reqs is None:
1100
1181
  max_num_reqs = min(
@@ -1107,9 +1188,6 @@ class ModelRunner:
1107
1188
  4096,
1108
1189
  )
1109
1190
 
1110
- if SGLANG_CI_SMALL_KV_SIZE:
1111
- self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1112
-
1113
1191
  if not self.spec_algorithm.is_none():
1114
1192
  if self.is_draft_worker:
1115
1193
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
@@ -1156,6 +1234,7 @@ class ModelRunner:
1156
1234
  "Not enough memory. Please try to increase --mem-fraction-static."
1157
1235
  )
1158
1236
 
1237
+ # Initialize req_to_token_pool
1159
1238
  if self.req_to_token_pool is None:
1160
1239
  if self.server_args.disaggregation_mode == "decode":
1161
1240
  from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
@@ -1181,30 +1260,34 @@ class ModelRunner:
1181
1260
  # Draft worker shares req_to_token_pool with the target worker.
1182
1261
  assert self.is_draft_worker
1183
1262
 
1184
- if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
1185
- self.token_to_kv_pool = AscendTokenToKVPool(
1186
- self.max_total_num_tokens,
1187
- page_size=self.page_size,
1188
- dtype=self.kv_cache_dtype,
1189
- head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
1190
- head_dim=self.model_config.head_dim,
1191
- layer_num=self.model_config.num_hidden_layers,
1192
- device=self.device,
1193
- enable_memory_saver=self.server_args.enable_memory_saver,
1194
- )
1195
- elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
1196
- self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
1197
- self.max_total_num_tokens,
1198
- page_size=self.page_size,
1199
- dtype=self.kv_cache_dtype,
1200
- kv_lora_rank=self.model_config.kv_lora_rank,
1201
- qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1202
- layer_num=self.num_effective_layers,
1203
- device=self.device,
1204
- enable_memory_saver=self.server_args.enable_memory_saver,
1205
- start_layer=self.start_layer,
1206
- end_layer=self.end_layer,
1207
- )
1263
+ # Initialize token_to_kv_pool
1264
+ if self.server_args.attention_backend == "ascend":
1265
+ if self.use_mla_backend:
1266
+ self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
1267
+ self.max_total_num_tokens,
1268
+ page_size=self.page_size,
1269
+ dtype=self.kv_cache_dtype,
1270
+ kv_lora_rank=self.model_config.kv_lora_rank,
1271
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1272
+ layer_num=self.num_effective_layers,
1273
+ device=self.device,
1274
+ enable_memory_saver=self.server_args.enable_memory_saver,
1275
+ start_layer=self.start_layer,
1276
+ end_layer=self.end_layer,
1277
+ )
1278
+ else:
1279
+ self.token_to_kv_pool = AscendTokenToKVPool(
1280
+ self.max_total_num_tokens,
1281
+ page_size=self.page_size,
1282
+ dtype=self.kv_cache_dtype,
1283
+ head_num=self.model_config.get_num_kv_heads(
1284
+ get_attention_tp_size()
1285
+ ),
1286
+ head_dim=self.model_config.head_dim,
1287
+ layer_num=self.model_config.num_hidden_layers,
1288
+ device=self.device,
1289
+ enable_memory_saver=self.server_args.enable_memory_saver,
1290
+ )
1208
1291
  elif self.use_mla_backend:
1209
1292
  self.token_to_kv_pool = MLATokenToKVPool(
1210
1293
  self.max_total_num_tokens,
@@ -1263,39 +1346,52 @@ class ModelRunner:
1263
1346
  end_layer=self.end_layer,
1264
1347
  )
1265
1348
 
1349
+ # Initialize token_to_kv_pool_allocator
1350
+ need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1351
+ max_num_extend_tokens = (
1352
+ self.server_args.chunked_prefill_size
1353
+ if self.server_args.chunked_prefill_size > 0
1354
+ else self.server_args.max_prefill_tokens
1355
+ )
1266
1356
  if self.token_to_kv_pool_allocator is None:
1267
- if self.page_size == 1:
1268
- if self.is_hybrid:
1269
- self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
1270
- self.full_max_total_num_tokens,
1271
- self.swa_max_total_num_tokens,
1272
- dtype=self.kv_cache_dtype,
1273
- device=self.device,
1274
- kvcache=self.token_to_kv_pool,
1275
- )
1276
- else:
1277
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1278
- self.max_total_num_tokens,
1279
- dtype=self.kv_cache_dtype,
1280
- device=self.device,
1281
- kvcache=self.token_to_kv_pool,
1282
- )
1357
+ if self.server_args.attention_backend == "ascend":
1358
+ self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1359
+ self.max_total_num_tokens,
1360
+ page_size=self.page_size,
1361
+ dtype=self.kv_cache_dtype,
1362
+ device=self.device,
1363
+ kvcache=self.token_to_kv_pool,
1364
+ need_sort=need_sort,
1365
+ )
1283
1366
  else:
1284
- if _is_npu:
1285
- self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1286
- self.max_total_num_tokens,
1287
- page_size=self.page_size,
1288
- dtype=self.kv_cache_dtype,
1289
- device=self.device,
1290
- kvcache=self.token_to_kv_pool,
1291
- )
1367
+ if self.page_size == 1:
1368
+ if self.is_hybrid:
1369
+ self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
1370
+ self.full_max_total_num_tokens,
1371
+ self.swa_max_total_num_tokens,
1372
+ dtype=self.kv_cache_dtype,
1373
+ device=self.device,
1374
+ kvcache=self.token_to_kv_pool,
1375
+ need_sort=need_sort,
1376
+ )
1377
+ else:
1378
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1379
+ self.max_total_num_tokens,
1380
+ dtype=self.kv_cache_dtype,
1381
+ device=self.device,
1382
+ kvcache=self.token_to_kv_pool,
1383
+ need_sort=need_sort,
1384
+ )
1292
1385
  else:
1386
+ assert not self.is_hybrid
1293
1387
  self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1294
1388
  self.max_total_num_tokens,
1295
1389
  page_size=self.page_size,
1296
1390
  dtype=self.kv_cache_dtype,
1297
1391
  device=self.device,
1298
1392
  kvcache=self.token_to_kv_pool,
1393
+ need_sort=need_sort,
1394
+ max_num_extend_tokens=max_num_extend_tokens,
1299
1395
  )
1300
1396
  else:
1301
1397
  assert self.is_draft_worker
@@ -1396,6 +1492,10 @@ class ModelRunner:
1396
1492
  from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1397
1493
 
1398
1494
  return AiterAttnBackend(self)
1495
+ elif self.server_args.attention_backend == "wave":
1496
+ from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
1497
+
1498
+ return WaveAttnBackend(self)
1399
1499
  elif backend_str == "ascend":
1400
1500
  from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1401
1501
 
@@ -1459,15 +1559,13 @@ class ModelRunner:
1459
1559
  )
1460
1560
 
1461
1561
  return TRTLLMHAAttnBackend(self)
1462
-
1463
1562
  elif backend_str == "intel_amx":
1464
1563
  from sglang.srt.layers.attention.intel_amx_backend import (
1465
1564
  IntelAMXAttnBackend,
1466
1565
  )
1467
1566
 
1468
- logger.info(f"Intel AMX attention backend is enabled.")
1469
1567
  return IntelAMXAttnBackend(self)
1470
- elif self.server_args.attention_backend == "dual_chunk_flash_attn":
1568
+ elif backend_str == "dual_chunk_flash_attn":
1471
1569
  from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1472
1570
  DualChunkFlashAttentionBackend,
1473
1571
  )
@@ -1511,6 +1609,7 @@ class ModelRunner:
1511
1609
  f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1512
1610
  )
1513
1611
  self.cuda_graph_runner = CudaGraphRunner(self)
1612
+
1514
1613
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1515
1614
  self.cuda_graph_mem_usage = before_mem - after_mem
1516
1615
  logger.info(
@@ -1785,11 +1884,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
1785
1884
  default_weight_loader(params_dict[name], tensor)
1786
1885
 
1787
1886
 
1788
- def _unwrap_tensor(tensor, tp_rank):
1887
+ def _unwrap_tensor(tensor, tp_rank, device):
1789
1888
  if isinstance(tensor, LocalSerializedTensor):
1790
- monkey_patch_torch_reductions()
1791
1889
  tensor = tensor.get(tp_rank)
1792
- return tensor.to(torch.cuda.current_device())
1890
+ return tensor.to(device)
1793
1891
 
1794
1892
 
1795
1893
  @dataclass
@@ -162,12 +162,24 @@ def _initialize_model(
162
162
  model_class, _ = get_model_architecture(model_config)
163
163
  packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
164
164
  if _is_npu:
165
- packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
166
- "q_a_proj",
167
- "kv_a_proj_with_mqa",
168
- ]
169
- packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
170
- packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
165
+ packed_modules_mapping.update(
166
+ {
167
+ "visual": {"qkv_proj": ["qkv"]},
168
+ "vision_model": {
169
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
170
+ "proj": ["out_proj"],
171
+ },
172
+ "model": {
173
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
174
+ "gate_up_proj": ["gate_proj", "up_proj"],
175
+ "fused_qkv_a_proj_with_mqa": [
176
+ "q_a_proj",
177
+ "kv_a_proj_with_mqa",
178
+ ],
179
+ },
180
+ }
181
+ )
182
+
171
183
  quant_config = _get_quantization_config(
172
184
  model_config, load_config, packed_modules_mapping
173
185
  )
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
25
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
25
26
  from sglang.srt.layers.layernorm import RMSNorm
26
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
28
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module):
56
57
  self.embed_tokens = VocabParallelEmbedding(
57
58
  config.vocab_size,
58
59
  config.hidden_size,
59
- enable_tp=not global_server_args_dict["enable_dp_attention"],
60
+ enable_tp=not is_dp_attention_enabled(),
60
61
  prefix=add_prefix("embed_tokens", prefix),
61
62
  )
62
63