sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import (
22
22
  from sglang.srt.layers.dp_attention import get_attention_tp_size
23
23
  from sglang.srt.managers.schedule_batch import global_server_args_dict
24
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
25
- from sglang.srt.utils import is_flashinfer_available
25
+ from sglang.srt.utils import is_cuda, is_flashinfer_available
26
26
 
27
27
  if is_flashinfer_available():
28
28
  import flashinfer
@@ -30,7 +30,12 @@ if is_flashinfer_available():
30
30
  if TYPE_CHECKING:
31
31
  from sglang.srt.layers.radix_attention import RadixAttention
32
32
  from sglang.srt.model_executor.model_runner import ModelRunner
33
- from sglang.srt.speculative.spec_info import SpecInfo
33
+ from sglang.srt.speculative.spec_info import SpecInput
34
+
35
+ _is_cuda = is_cuda()
36
+
37
+ if _is_cuda:
38
+ from sgl_kernel import concat_mla_absorb_q
34
39
 
35
40
  # Constants
36
41
  DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
@@ -122,6 +127,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
122
127
  "disable_chunked_prefix_cache"
123
128
  ]
124
129
 
130
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
131
+
125
132
  def _calc_padded_blocks(self, max_seq_len: int) -> int:
126
133
  """
127
134
  Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
@@ -207,12 +214,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
207
214
  seq_lens: torch.Tensor,
208
215
  encoder_lens: Optional[torch.Tensor],
209
216
  forward_mode: ForwardMode,
210
- spec_info: Optional[SpecInfo],
217
+ spec_info: Optional[SpecInput],
211
218
  ):
212
219
  """Initialize metadata for CUDA graph capture."""
213
220
 
214
221
  # Delegate to parent for non-decode modes.
215
- if not forward_mode.is_decode_or_idle():
222
+ if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
216
223
  return super().init_forward_metadata_capture_cuda_graph(
217
224
  bs,
218
225
  num_tokens,
@@ -223,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
223
230
  spec_info,
224
231
  )
225
232
 
233
+ if forward_mode.is_target_verify():
234
+ seq_lens = seq_lens + self.num_draft_tokens
235
+
226
236
  # Custom fast-path for decode/idle.
227
237
  # Capture with full width so future longer sequences are safe during replay
228
238
  max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
@@ -260,12 +270,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
260
270
  seq_lens_sum: int,
261
271
  encoder_lens: Optional[torch.Tensor],
262
272
  forward_mode: ForwardMode,
263
- spec_info: Optional[SpecInfo],
273
+ spec_info: Optional[SpecInput],
264
274
  seq_lens_cpu: Optional[torch.Tensor],
265
275
  ):
266
276
  """Replay CUDA graph with new inputs."""
267
277
  # Delegate to parent for non-decode modes.
268
- if not forward_mode.is_decode_or_idle():
278
+ if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
269
279
  return super().init_forward_metadata_replay_cuda_graph(
270
280
  bs,
271
281
  req_pool_indices,
@@ -277,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
277
287
  seq_lens_cpu,
278
288
  )
279
289
 
290
+ if forward_mode.is_target_verify():
291
+ seq_lens = seq_lens + self.num_draft_tokens
292
+ del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
293
+
280
294
  metadata = self.decode_cuda_graph_metadata[bs]
281
295
 
282
296
  # Update block indices for new sequences.
@@ -327,7 +341,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
327
341
  cum_seq_lens_q,
328
342
  seq_lens,
329
343
  )
330
- elif forward_batch.forward_mode.is_decode_or_idle():
344
+ elif (
345
+ forward_batch.forward_mode.is_decode_or_idle()
346
+ or forward_batch.forward_mode.is_target_verify()
347
+ ):
331
348
  bs = forward_batch.batch_size
332
349
 
333
350
  # Get maximum sequence length.
@@ -336,13 +353,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
336
353
  else:
337
354
  max_seq = forward_batch.seq_lens.max().item()
338
355
 
356
+ seq_lens = forward_batch.seq_lens
357
+
358
+ if forward_batch.forward_mode.is_target_verify():
359
+ max_seq = max_seq + self.num_draft_tokens
360
+ seq_lens = seq_lens + self.num_draft_tokens
361
+
339
362
  max_seqlen_pad = self._calc_padded_blocks(max_seq)
340
363
  block_kv_indices = self._create_block_kv_indices(
341
364
  bs,
342
365
  max_seqlen_pad,
343
366
  forward_batch.req_pool_indices,
344
- forward_batch.seq_lens,
345
- forward_batch.seq_lens.device,
367
+ seq_lens,
368
+ seq_lens.device,
346
369
  )
347
370
 
348
371
  max_seq_len_val = int(max_seq)
@@ -482,7 +505,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
482
505
  q_rope_reshaped = q_rope.view(
483
506
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
484
507
  )
485
- query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
508
+ query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
486
509
  else:
487
510
  # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
488
511
  query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
@@ -545,49 +568,163 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
545
568
  save_kv_cache: bool = True,
546
569
  q_rope: Optional[torch.Tensor] = None,
547
570
  k_rope: Optional[torch.Tensor] = None,
571
+ cos_sin_cache: Optional[torch.Tensor] = None,
572
+ is_neox: Optional[bool] = False,
548
573
  ) -> torch.Tensor:
549
- if (
550
- forward_batch.forward_mode.is_target_verify()
551
- or forward_batch.forward_mode.is_draft_extend()
552
- ):
574
+ if forward_batch.forward_mode.is_draft_extend():
553
575
  return super().forward_extend(
554
576
  q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
555
577
  )
556
- # chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
557
- if forward_batch.attn_attend_prefix_cache is None:
558
- return super().forward_extend(
559
- q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
578
+
579
+ # TODO refactor to avoid code duplication
580
+ merge_query = q_rope is not None
581
+ if (
582
+ self.data_type == torch.float8_e4m3fn
583
+ ) and forward_batch.forward_mode.is_target_verify():
584
+ # For FP8 path, we quantize the query and rope parts and merge them into a single tensor
585
+ # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
586
+ assert all(
587
+ x is not None for x in [q_rope, k_rope, cos_sin_cache]
588
+ ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
589
+ q, k, k_rope = self.quantize_and_rope_for_fp8(
590
+ q,
591
+ q_rope,
592
+ k.squeeze(1),
593
+ k_rope.squeeze(1),
594
+ forward_batch,
595
+ cos_sin_cache,
596
+ is_neox,
597
+ )
598
+ merge_query = False
599
+
600
+ # Save KV cache if requested
601
+ if save_kv_cache:
602
+ assert (
603
+ k is not None and k_rope is not None
604
+ ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
605
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
606
+ layer, forward_batch.out_cache_loc, k, k_rope
560
607
  )
561
608
 
562
- if not forward_batch.attn_attend_prefix_cache:
609
+ # TODO refactor to avoid code duplication
610
+ # Prepare query tensor inline
611
+ if merge_query:
612
+ # For FP16 path, we merge the query and rope parts into a single tensor
613
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
614
+ q_rope_reshaped = q_rope.view(
615
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
616
+ )
617
+ q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
618
+ else:
619
+ # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
563
620
  q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
564
- k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
565
- v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
566
- output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
621
+
622
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
623
+
624
+ if k_rope is not None:
625
+ k = torch.cat([k, k_rope], dim=-1)
626
+ k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
627
+
628
+ v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
629
+
630
+ if forward_batch.forward_mode.is_target_verify():
631
+ metadata = (
632
+ getattr(forward_batch, "decode_trtllm_mla_metadata", None)
633
+ or self.forward_decode_metadata
634
+ )
635
+
636
+ # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
637
+ bs = forward_batch.batch_size
638
+ q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
639
+
640
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
641
+ kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
642
+
643
+ q_scale = 1.0
644
+ k_scale = (
645
+ layer.k_scale_float
646
+ if getattr(layer, "k_scale_float", None) is not None
647
+ else 1.0
648
+ )
649
+
650
+ bmm1_scale = q_scale * k_scale * layer.scaling
651
+
652
+ seq_lens = (
653
+ forward_batch.seq_lens.to(torch.int32)
654
+ + forward_batch.spec_info.draft_token_num
655
+ )
656
+ max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
657
+
658
+ # TODO may use `mla_rope_quantize_fp8` fusion
659
+ q = q.to(self.data_type)
660
+ assert kv_cache.dtype == self.data_type
661
+
662
+ raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
663
+ query=q,
664
+ kv_cache=kv_cache,
665
+ workspace_buffer=self.workspace_buffer,
666
+ qk_nope_head_dim=self.qk_nope_head_dim,
667
+ kv_lora_rank=self.kv_lora_rank,
668
+ qk_rope_head_dim=self.qk_rope_head_dim,
669
+ block_tables=metadata.block_kv_indices,
670
+ seq_lens=seq_lens,
671
+ max_seq_len=max_seq_len,
672
+ bmm1_scale=bmm1_scale,
673
+ )
674
+
675
+ # Reshape output directly without slicing
676
+ output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
677
+ return output
678
+
679
+ if forward_batch.attn_attend_prefix_cache:
680
+ # MHA for chunked prefix kv cache when running model with MLA
681
+ assert forward_batch.prefix_chunk_idx is not None
682
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
683
+ assert q_rope is None
684
+ assert k_rope is None
685
+ chunk_idx = forward_batch.prefix_chunk_idx
686
+
687
+ output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
688
+ return flashinfer.prefill.trtllm_ragged_attention_deepseek(
567
689
  query=q,
568
690
  key=k,
569
691
  value=v,
570
692
  workspace_buffer=self.workspace_buffer,
571
- seq_lens=self.forward_prefill_metadata.seq_lens,
693
+ seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
572
694
  max_q_len=self.forward_prefill_metadata.max_seq_len,
573
- max_kv_len=self.forward_prefill_metadata.max_seq_len,
695
+ max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
574
696
  bmm1_scale=layer.scaling,
575
697
  bmm2_scale=1.0,
576
- o_sf_scale=1.0,
698
+ o_sf_scale=-1.0,
577
699
  batch_size=forward_batch.batch_size,
578
700
  window_left=-1,
579
701
  cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
580
- cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
702
+ cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
581
703
  enable_pdl=False,
582
- is_causal=True,
583
- return_lse=forward_batch.mha_return_lse,
704
+ is_causal=False,
705
+ return_lse=True,
706
+ out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
584
707
  )
585
- else:
586
- # replace with trtllm ragged attention once accuracy is resolved.
587
- output = super().forward_extend(
588
- q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
589
- )
590
- return output
708
+
709
+ return flashinfer.prefill.trtllm_ragged_attention_deepseek(
710
+ query=q,
711
+ key=k,
712
+ value=v,
713
+ workspace_buffer=self.workspace_buffer,
714
+ seq_lens=self.forward_prefill_metadata.seq_lens,
715
+ max_q_len=self.forward_prefill_metadata.max_seq_len,
716
+ max_kv_len=self.forward_prefill_metadata.max_seq_len,
717
+ bmm1_scale=layer.scaling,
718
+ bmm2_scale=1.0,
719
+ o_sf_scale=1.0,
720
+ batch_size=forward_batch.batch_size,
721
+ window_left=-1,
722
+ cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
723
+ cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
724
+ enable_pdl=False,
725
+ is_causal=True,
726
+ return_lse=forward_batch.mha_return_lse,
727
+ )
591
728
 
592
729
 
593
730
  class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
@@ -605,3 +742,10 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
605
742
  kv_indptr_buf=self.kv_indptr[i],
606
743
  q_indptr_decode_buf=self.q_indptr_decode,
607
744
  )
745
+
746
+
747
+ def _concat_mla_absorb_q_general(q_nope, q_rope):
748
+ if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
749
+ return concat_mla_absorb_q(q_nope, q_rope)
750
+ else:
751
+ return torch.cat([q_nope, q_rope], dim=-1)
@@ -16,14 +16,19 @@ from sglang.srt.utils import (
16
16
  get_device_capability,
17
17
  is_blackwell,
18
18
  is_cuda,
19
+ is_npu,
19
20
  print_info_once,
20
21
  )
21
22
 
22
23
  _is_cuda = is_cuda()
24
+ _is_npu = is_npu()
23
25
 
24
26
  if _is_cuda:
25
27
  from sgl_kernel.flash_attn import flash_attn_varlen_func
26
28
 
29
+ if _is_npu:
30
+ import torch_npu
31
+
27
32
  from sglang.srt.distributed import (
28
33
  split_tensor_along_last_dim,
29
34
  tensor_model_parallel_all_gather,
@@ -331,10 +336,63 @@ class VisionFlash3Attention(nn.Module):
331
336
  return output
332
337
 
333
338
 
339
+ class VisionAscendAttention(nn.Module):
340
+
341
+ def __init__(
342
+ self,
343
+ **kwargs,
344
+ ):
345
+ if not _is_npu:
346
+ raise Exception("VisionAscendAttention is only available for ascend npu")
347
+ super().__init__()
348
+
349
+ def forward(
350
+ self,
351
+ q: torch.Tensor,
352
+ k: torch.Tensor,
353
+ v: torch.Tensor,
354
+ cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
355
+ bsz: int,
356
+ seq_len: int,
357
+ **kwargs,
358
+ ) -> torch.Tensor:
359
+ r"""
360
+ Args:
361
+ cu_seqlens: [b]
362
+ Returns:
363
+ [b * s, h, head_size]
364
+ """
365
+ if cu_seqlens is None:
366
+ cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
367
+
368
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
369
+ if seq_lens.is_npu:
370
+ # cu_seqlens must be on cpu because of operator restriction
371
+ seq_lens = seq_lens.to("cpu")
372
+ _, num_heads, head_size = q.shape
373
+ num_kv_heads = k.shape[1]
374
+ output = torch.empty_like(q)
375
+
376
+ # operator requires pta version >= 2.5.1
377
+ torch_npu._npu_flash_attention_unpad(
378
+ query=q,
379
+ key=k,
380
+ value=v,
381
+ seq_len=seq_lens.to(torch.int32),
382
+ scale_value=head_size**-0.5,
383
+ num_heads=num_heads,
384
+ num_kv_heads=num_kv_heads,
385
+ out=output,
386
+ )
387
+
388
+ return output
389
+
390
+
334
391
  QKV_BACKEND_IMPL = {
335
392
  "triton_attn": VisionTritonAttention,
336
393
  "sdpa": VisionSdpaAttention,
337
394
  "fa3": VisionFlash3Attention,
395
+ "ascend_attn": VisionAscendAttention,
338
396
  }
339
397
 
340
398
 
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, Optional, Union
5
+ from typing import TYPE_CHECKING, Optional
6
6
 
7
7
  import torch
8
8
  import triton
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.radix_attention import RadixAttention
19
19
  from sglang.srt.model_executor.model_runner import ModelRunner
20
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
20
+ from sglang.srt.speculative.spec_info import SpecInput
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
393
393
  seq_lens: torch.Tensor,
394
394
  encoder_lens: Optional[torch.Tensor],
395
395
  forward_mode: ForwardMode,
396
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
396
+ spec_info: Optional[SpecInput],
397
397
  ):
398
398
  assert encoder_lens is None, "Not supported"
399
399
 
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
477
477
  seq_lens_sum: int,
478
478
  encoder_lens: Optional[torch.Tensor],
479
479
  forward_mode: ForwardMode,
480
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
480
+ spec_info: Optional[SpecInput],
481
481
  seq_lens_cpu: Optional[torch.Tensor],
482
482
  ):
483
483
  # NOTE: encoder_lens expected to be zeros or None
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
50
50
  is_hip,
51
51
  is_sm90_supported,
52
52
  is_sm100_supported,
53
+ prepare_weight_cache,
53
54
  )
54
55
 
55
56
  _is_flashinfer_available = is_flashinfer_available()
@@ -275,7 +276,11 @@ class LayerCommunicator:
275
276
  hidden_states: torch.Tensor,
276
277
  residual: torch.Tensor,
277
278
  forward_batch: ForwardBatch,
279
+ cache=None,
278
280
  ):
281
+ if cache is not None:
282
+ self._context.cache = cache
283
+
279
284
  return self._communicate_with_all_reduce_and_layer_norm_fn(
280
285
  hidden_states=hidden_states,
281
286
  residual=residual,
@@ -349,6 +354,7 @@ class CommunicateContext:
349
354
  attn_tp_size: int
350
355
  attn_dp_size: int
351
356
  tp_size: int
357
+ cache = None
352
358
 
353
359
  def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
354
360
  return self.process_group_sizes[a] == self.process_group_sizes[b]
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
533
539
  )
534
540
  else:
535
541
  hidden_states = tensor_model_parallel_all_reduce(hidden_states)
542
+ if context.cache is not None:
543
+ _ = prepare_weight_cache(hidden_states, context.cache)
536
544
  hidden_states, residual = layernorm(hidden_states, residual)
537
545
  return hidden_states, residual
538
546
 
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
17
17
  get_tp_group,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
+ from sglang.srt.utils import get_bool_env_var, is_hip
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from sglang.srt.configs.model_config import ModelConfig
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
36
37
  _LOCAL_ATTN_DP_RANK: Optional[int] = None
37
38
  _ENABLE_DP_ATTENTION_FLAG: bool = False
38
39
 
40
+ _is_hip = is_hip()
41
+ _USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
42
+
39
43
 
40
44
  class DpPaddingMode(IntEnum):
41
45
 
@@ -67,7 +71,12 @@ class DpPaddingMode(IntEnum):
67
71
 
68
72
  @classmethod
69
73
  def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
70
- return cls.MAX_LEN
74
+ # TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
75
+ # it can be safely removed later, once RCCL fixed
76
+ if _USE_ROCM700A_WA:
77
+ return cls.SUM_LEN
78
+ else:
79
+ return cls.MAX_LEN
71
80
 
72
81
 
73
82
  class _DpGatheredBufferWrapper:
@@ -254,6 +263,7 @@ def initialize_dp_attention(
254
263
  use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
255
264
  use_pymscclpp=False,
256
265
  use_custom_allreduce=False,
266
+ use_torch_symm_mem=False,
257
267
  use_hpu_communicator=False,
258
268
  use_xpu_communicator=False,
259
269
  use_npu_communicator=False,
@@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
187
187
 
188
188
  def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
189
189
  assert len(x.shape) == 2
190
- assert x.shape == residual.shape and x.dtype == residual.dtype
190
+ assert (
191
+ x.shape == residual.shape and x.dtype == residual.dtype
192
+ ), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
191
193
  output, mid = torch.empty_like(x), torch.empty_like(x)
192
194
  bs, hidden_dim = x.shape
193
195
  if autotune:
@@ -80,6 +80,8 @@ class RMSNorm(CustomOp):
80
80
  )
81
81
  if _use_aiter:
82
82
  self._forward_method = self.forward_aiter
83
+ if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
84
+ self._forward_method = self.forward_native
83
85
 
84
86
  def forward_cuda(
85
87
  self,
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
31
31
  _ColumnvLLMParameter,
32
32
  )
33
33
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
+ from sglang.srt.layers.utils import pad_or_narrow_weight
34
35
  from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
35
36
 
36
37
  if TYPE_CHECKING:
@@ -625,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
625
626
  # bitsandbytes loads the weights of the specific portion
626
627
  # no need to narrow here
627
628
  if not use_bitsandbytes_4bit and not self.use_presharded_weights:
628
- loaded_weight = loaded_weight.narrow(
629
- output_dim, start_idx, shard_size
630
- )
629
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
630
+ end_idx = start_idx + shard_size
631
+ if end_idx > loaded_weight.shape[output_dim]:
632
+ loaded_weight = pad_or_narrow_weight(
633
+ loaded_weight, output_dim, start_idx, shard_size
634
+ )
635
+ else:
636
+ loaded_weight = loaded_weight.narrow(
637
+ output_dim, start_idx, shard_size
638
+ )
631
639
 
632
640
  # Special case for AQLM codebooks.
633
641
  elif is_metadata:
@@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase):
1302
1310
  shard_size,
1303
1311
  )
1304
1312
  else:
1305
- loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1313
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
1314
+ end_idx = start_idx + shard_size
1315
+ if end_idx > loaded_weight.shape[input_dim]:
1316
+ loaded_weight = pad_or_narrow_weight(
1317
+ loaded_weight, input_dim, start_idx, shard_size
1318
+ )
1319
+ else:
1320
+ loaded_weight = loaded_weight.narrow(
1321
+ input_dim, start_idx, shard_size
1322
+ )
1306
1323
 
1307
1324
  # Special case for loading scales off disk, which often do not
1308
1325
  # have a shape (such as in the case of AutoFP8).
@@ -220,6 +220,7 @@ class LogitsProcessor(nn.Module):
220
220
  self.config = config
221
221
  self.logit_scale = logit_scale
222
222
  self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
223
+ self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
223
224
  if self.use_attn_tp_group:
224
225
  self.attn_tp_size = get_attention_tp_size()
225
226
  self.do_tensor_parallel_all_gather = (
@@ -461,7 +462,11 @@ class LogitsProcessor(nn.Module):
461
462
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
462
463
 
463
464
  if hasattr(lm_head, "weight"):
464
- if use_intel_amx_backend(lm_head):
465
+ if self.use_fp32_lm_head:
466
+ logits = torch.matmul(
467
+ hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
468
+ )
469
+ elif use_intel_amx_backend(lm_head):
465
470
  logits = torch.ops.sgl_kernel.weight_packed_linear(
466
471
  hidden_states.to(lm_head.weight.dtype),
467
472
  lm_head.weight,
@@ -475,7 +480,15 @@ class LogitsProcessor(nn.Module):
475
480
  else:
476
481
  # GGUF models
477
482
  # TODO: use weight_packed_linear for GGUF models
478
- logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
483
+ if self.use_fp32_lm_head:
484
+ with torch.cuda.amp.autocast(enabled=False):
485
+ logits = lm_head.quant_method.apply(
486
+ lm_head, hidden_states.to(torch.float32), embedding_bias
487
+ )
488
+ else:
489
+ logits = lm_head.quant_method.apply(
490
+ lm_head, hidden_states, embedding_bias
491
+ )
479
492
 
480
493
  if self.logit_scale is not None:
481
494
  logits.mul_(self.logit_scale)
@@ -1104,10 +1104,10 @@ def ep_gather(
1104
1104
  input_index: torch.Tensor,
1105
1105
  output_tensor: torch.Tensor,
1106
1106
  ):
1107
- BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
1108
1107
  num_warps = 2
1109
1108
  num_tokens = output_tensor.shape[0]
1110
1109
  hidden_size = input_tensor.shape[1]
1110
+ BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024 # block size of quantization
1111
1111
  assert hidden_size % BLOCK_D == 0
1112
1112
  grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
1113
1113
  _fwd_kernel_ep_gather[grid](