sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -121,6 +121,10 @@ from sglang.srt.utils import (
121
121
  set_cpu_offload_max_bytes,
122
122
  set_cuda_arch,
123
123
  )
124
+ from sglang.srt.weight_sync.tensor_bucket import (
125
+ FlattenedTensorBucket,
126
+ FlattenedTensorMetadata,
127
+ )
124
128
 
125
129
  _is_hip = is_hip()
126
130
  _is_npu = is_npu()
@@ -378,6 +382,25 @@ class ModelRunner:
378
382
  )
379
383
  server_args.attention_backend = "torch_native"
380
384
 
385
+ if server_args.prefill_attention_backend is not None and (
386
+ server_args.prefill_attention_backend
387
+ == server_args.decode_attention_backend
388
+ ): # override the default attention backend
389
+ server_args.attention_backend = server_args.prefill_attention_backend
390
+
391
+ if (
392
+ getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
393
+ is not None
394
+ ):
395
+ if server_args.attention_backend is None:
396
+ server_args.attention_backend = "dual_chunk_flash_attn"
397
+ logger.info("Dual chunk attention is turned on by default.")
398
+ elif server_args.attention_backend != "dual_chunk_flash_attn":
399
+ raise ValueError(
400
+ "Dual chunk attention is enabled, but attention backend is set to "
401
+ f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
402
+ )
403
+
381
404
  if server_args.attention_backend is None:
382
405
  """
383
406
  Auto select the fastest attention backend.
@@ -397,7 +420,6 @@ class ModelRunner:
397
420
  is_hopper_with_cuda_12_3()
398
421
  and is_no_spec_infer_or_topk_one(server_args)
399
422
  and is_fa3_default_architecture(self.model_config.hf_config)
400
- and (not server_args.enable_hierarchical_cache)
401
423
  ):
402
424
  server_args.attention_backend = "fa3"
403
425
  elif _is_hip:
@@ -410,9 +432,7 @@ class ModelRunner:
410
432
  )
411
433
  else:
412
434
  # MLA architecture
413
- if is_hopper_with_cuda_12_3() and (
414
- not server_args.enable_hierarchical_cache
415
- ):
435
+ if is_hopper_with_cuda_12_3():
416
436
  server_args.attention_backend = "fa3"
417
437
  elif is_sm100_supported():
418
438
  server_args.attention_backend = "flashinfer"
@@ -500,6 +520,27 @@ class ModelRunner:
500
520
  if self.model_config.context_len > 8192:
501
521
  self.mem_fraction_static *= 0.85
502
522
 
523
+ if (
524
+ server_args.enable_hierarchical_cache
525
+ and server_args.hicache_io_backend == "kernel"
526
+ ):
527
+ # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
528
+ if server_args.decode_attention_backend is None:
529
+ if not self.use_mla_backend:
530
+ server_args.decode_attention_backend = (
531
+ "flashinfer" if is_flashinfer_available() else "triton"
532
+ )
533
+ else:
534
+ server_args.decode_attention_backend = (
535
+ "flashinfer" if is_sm100_supported() else "triton"
536
+ )
537
+ elif server_args.decode_attention_backend == "fa3":
538
+ server_args.hicache_io_backend = "direct"
539
+ logger.warning(
540
+ "FlashAttention3 decode backend is not compatible with hierarchical cache. "
541
+ f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
542
+ )
543
+
503
544
  def init_torch_distributed(self):
504
545
  logger.info("Init torch distributed begin.")
505
546
 
@@ -871,8 +912,18 @@ class ModelRunner:
871
912
  named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
872
913
  load_format: Optional[str] = None,
873
914
  ):
915
+ monkey_patch_torch_reductions()
916
+ if load_format == "flattened_bucket":
917
+ # Handle flattened bucket format
918
+ return self._update_weights_from_flattened_bucket(
919
+ flattened_tensor_bucket_dict=named_tensors
920
+ )
921
+
922
+ # We need to get device after patch otherwise the device would be wrong
923
+ infered_device = torch.cuda.current_device()
924
+
874
925
  named_tensors = [
875
- (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
926
+ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
876
927
  for name, tensor in named_tensors
877
928
  ]
878
929
  if load_format == "direct":
@@ -886,6 +937,38 @@ class ModelRunner:
886
937
  raise NotImplementedError(f"Unknown load_format={load_format}")
887
938
  return True, "Success"
888
939
 
940
+ def _update_weights_from_flattened_bucket(
941
+ self,
942
+ flattened_tensor_bucket_dict,
943
+ ):
944
+ """Handle flattened bucket format for weight updates"""
945
+ flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
946
+ metadata = flattened_tensor_bucket_dict["metadata"]
947
+
948
+ # Convert metadata dict to our format
949
+ converted_metadata = []
950
+ for meta in metadata:
951
+ converted_meta = FlattenedTensorMetadata(
952
+ name=meta.name,
953
+ shape=meta.shape,
954
+ dtype=meta.dtype,
955
+ start_idx=meta.start_idx,
956
+ end_idx=meta.end_idx,
957
+ numel=meta.numel,
958
+ )
959
+ converted_metadata.append(converted_meta)
960
+
961
+ # Create bucket and reconstruct tensors
962
+ bucket = FlattenedTensorBucket(
963
+ flattened_tensor=flattened_tensor, metadata=converted_metadata
964
+ )
965
+ reconstructed_tensors = bucket.reconstruct_tensors()
966
+
967
+ # Load the reconstructed tensors using the standard method
968
+ self.model.load_weights(reconstructed_tensors)
969
+
970
+ return True, "Success"
971
+
889
972
  def get_weights_by_name(
890
973
  self, name: str, truncate_size: int = 100
891
974
  ) -> Optional[torch.Tensor]:
@@ -1181,30 +1264,33 @@ class ModelRunner:
1181
1264
  # Draft worker shares req_to_token_pool with the target worker.
1182
1265
  assert self.is_draft_worker
1183
1266
 
1184
- if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
1185
- self.token_to_kv_pool = AscendTokenToKVPool(
1186
- self.max_total_num_tokens,
1187
- page_size=self.page_size,
1188
- dtype=self.kv_cache_dtype,
1189
- head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
1190
- head_dim=self.model_config.head_dim,
1191
- layer_num=self.model_config.num_hidden_layers,
1192
- device=self.device,
1193
- enable_memory_saver=self.server_args.enable_memory_saver,
1194
- )
1195
- elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
1196
- self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
1197
- self.max_total_num_tokens,
1198
- page_size=self.page_size,
1199
- dtype=self.kv_cache_dtype,
1200
- kv_lora_rank=self.model_config.kv_lora_rank,
1201
- qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1202
- layer_num=self.num_effective_layers,
1203
- device=self.device,
1204
- enable_memory_saver=self.server_args.enable_memory_saver,
1205
- start_layer=self.start_layer,
1206
- end_layer=self.end_layer,
1207
- )
1267
+ if self.server_args.attention_backend == "ascend":
1268
+ if self.use_mla_backend:
1269
+ self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
1270
+ self.max_total_num_tokens,
1271
+ page_size=self.page_size,
1272
+ dtype=self.kv_cache_dtype,
1273
+ kv_lora_rank=self.model_config.kv_lora_rank,
1274
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1275
+ layer_num=self.num_effective_layers,
1276
+ device=self.device,
1277
+ enable_memory_saver=self.server_args.enable_memory_saver,
1278
+ start_layer=self.start_layer,
1279
+ end_layer=self.end_layer,
1280
+ )
1281
+ else:
1282
+ self.token_to_kv_pool = AscendTokenToKVPool(
1283
+ self.max_total_num_tokens,
1284
+ page_size=self.page_size,
1285
+ dtype=self.kv_cache_dtype,
1286
+ head_num=self.model_config.get_num_kv_heads(
1287
+ get_attention_tp_size()
1288
+ ),
1289
+ head_dim=self.model_config.head_dim,
1290
+ layer_num=self.model_config.num_hidden_layers,
1291
+ device=self.device,
1292
+ enable_memory_saver=self.server_args.enable_memory_saver,
1293
+ )
1208
1294
  elif self.use_mla_backend:
1209
1295
  self.token_to_kv_pool = MLATokenToKVPool(
1210
1296
  self.max_total_num_tokens,
@@ -1263,6 +1349,7 @@ class ModelRunner:
1263
1349
  end_layer=self.end_layer,
1264
1350
  )
1265
1351
 
1352
+ need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1266
1353
  if self.token_to_kv_pool_allocator is None:
1267
1354
  if self.page_size == 1:
1268
1355
  if self.is_hybrid:
@@ -1272,6 +1359,7 @@ class ModelRunner:
1272
1359
  dtype=self.kv_cache_dtype,
1273
1360
  device=self.device,
1274
1361
  kvcache=self.token_to_kv_pool,
1362
+ need_sort=need_sort,
1275
1363
  )
1276
1364
  else:
1277
1365
  self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
@@ -1279,23 +1367,26 @@ class ModelRunner:
1279
1367
  dtype=self.kv_cache_dtype,
1280
1368
  device=self.device,
1281
1369
  kvcache=self.token_to_kv_pool,
1370
+ need_sort=need_sort,
1282
1371
  )
1283
1372
  else:
1284
- if _is_npu:
1285
- self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1373
+ if not _is_npu:
1374
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1286
1375
  self.max_total_num_tokens,
1287
1376
  page_size=self.page_size,
1288
1377
  dtype=self.kv_cache_dtype,
1289
1378
  device=self.device,
1290
1379
  kvcache=self.token_to_kv_pool,
1380
+ need_sort=need_sort,
1291
1381
  )
1292
1382
  else:
1293
- self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1383
+ self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1294
1384
  self.max_total_num_tokens,
1295
1385
  page_size=self.page_size,
1296
1386
  dtype=self.kv_cache_dtype,
1297
1387
  device=self.device,
1298
1388
  kvcache=self.token_to_kv_pool,
1389
+ need_sort=need_sort,
1299
1390
  )
1300
1391
  else:
1301
1392
  assert self.is_draft_worker
@@ -1396,6 +1487,10 @@ class ModelRunner:
1396
1487
  from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1397
1488
 
1398
1489
  return AiterAttnBackend(self)
1490
+ elif self.server_args.attention_backend == "wave":
1491
+ from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
1492
+
1493
+ return WaveAttnBackend(self)
1399
1494
  elif backend_str == "ascend":
1400
1495
  from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1401
1496
 
@@ -1443,19 +1538,36 @@ class ModelRunner:
1443
1538
  )
1444
1539
 
1445
1540
  return CutlassMLABackend(self)
1446
- elif self.server_args.attention_backend == "trtllm_mla":
1541
+ elif backend_str == "trtllm_mla":
1447
1542
  if not self.use_mla_backend:
1448
1543
  raise ValueError("trtllm_mla backend can only be used with MLA models.")
1449
1544
  from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1450
1545
 
1451
1546
  return TRTLLMMLABackend(self)
1452
- elif self.server_args.attention_backend == "intel_amx":
1547
+ elif backend_str == "trtllm_mha":
1548
+ if self.use_mla_backend:
1549
+ raise ValueError(
1550
+ "trtllm_mha backend can only be used with non-MLA models."
1551
+ )
1552
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
1553
+ TRTLLMHAAttnBackend,
1554
+ )
1555
+
1556
+ return TRTLLMHAAttnBackend(self)
1557
+
1558
+ elif backend_str == "intel_amx":
1453
1559
  from sglang.srt.layers.attention.intel_amx_backend import (
1454
1560
  IntelAMXAttnBackend,
1455
1561
  )
1456
1562
 
1457
1563
  logger.info(f"Intel AMX attention backend is enabled.")
1458
1564
  return IntelAMXAttnBackend(self)
1565
+ elif self.server_args.attention_backend == "dual_chunk_flash_attn":
1566
+ from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1567
+ DualChunkFlashAttentionBackend,
1568
+ )
1569
+
1570
+ return DualChunkFlashAttentionBackend(self)
1459
1571
  else:
1460
1572
  raise ValueError(f"Invalid attention backend: {backend_str}")
1461
1573
 
@@ -1768,11 +1880,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
1768
1880
  default_weight_loader(params_dict[name], tensor)
1769
1881
 
1770
1882
 
1771
- def _unwrap_tensor(tensor, tp_rank):
1883
+ def _unwrap_tensor(tensor, tp_rank, device):
1772
1884
  if isinstance(tensor, LocalSerializedTensor):
1773
- monkey_patch_torch_reductions()
1774
1885
  tensor = tensor.get(tp_rank)
1775
- return tensor.to(torch.cuda.current_device())
1886
+ return tensor.to(device)
1776
1887
 
1777
1888
 
1778
1889
  @dataclass
@@ -162,12 +162,24 @@ def _initialize_model(
162
162
  model_class, _ = get_model_architecture(model_config)
163
163
  packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
164
164
  if _is_npu:
165
- packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
166
- "q_a_proj",
167
- "kv_a_proj_with_mqa",
168
- ]
169
- packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
170
- packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
165
+ packed_modules_mapping.update(
166
+ {
167
+ "visual": {"qkv_proj": ["qkv"]},
168
+ "vision_model": {
169
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
170
+ "proj": ["out_proj"],
171
+ },
172
+ "model": {
173
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
174
+ "gate_up_proj": ["gate_proj", "up_proj"],
175
+ "fused_qkv_a_proj_with_mqa": [
176
+ "q_a_proj",
177
+ "kv_a_proj_with_mqa",
178
+ ],
179
+ },
180
+ }
181
+ )
182
+
171
183
  quant_config = _get_quantization_config(
172
184
  model_config, load_config, packed_modules_mapping
173
185
  )
@@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
843
843
  return None
844
844
  return remapped_name
845
845
 
846
+ quark_scale_names = {
847
+ ".q_proj.output_scale": ".attn.q_scale",
848
+ ".k_proj.output_scale": ".attn.k_scale",
849
+ ".v_proj.output_scale": ".attn.v_scale",
850
+ "self_attn.prob_output_scale": ".attn.prob_scale",
851
+ }
852
+ for quark_scale_name, sglang_scale_name in quark_scale_names.items():
853
+ if name.endswith(quark_scale_name):
854
+ return name.replace(quark_scale_name, sglang_scale_name)
855
+
846
856
  # If there were no matches, return the untouched param name
847
857
  return name
848
858