sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -25,17 +25,19 @@ import time
25
25
  from collections import defaultdict
26
26
  from dataclasses import dataclass
27
27
  from typing import List, Optional, Tuple, Union
28
- from urllib.parse import urlparse
29
28
 
30
- import requests
31
29
  import torch
32
30
  import torch.distributed as dist
33
31
 
34
32
  from sglang.srt.configs.device_config import DeviceConfig
35
33
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
36
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
34
+ from sglang.srt.configs.model_config import (
35
+ AttentionArch,
36
+ ModelConfig,
37
+ get_nsa_index_head_dim,
38
+ is_deepseek_nsa,
39
+ )
37
40
  from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
38
- from sglang.srt.connector import ConnectorType
39
41
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
40
42
  from sglang.srt.distributed import (
41
43
  get_pp_group,
@@ -45,6 +47,7 @@ from sglang.srt.distributed import (
45
47
  initialize_model_parallel,
46
48
  set_custom_all_reduce,
47
49
  set_mscclpp_all_reduce,
50
+ set_symm_mem_all_reduce,
48
51
  )
49
52
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
50
53
  from sglang.srt.eplb.eplb_manager import EPLBManager
@@ -60,6 +63,10 @@ from sglang.srt.eplb.expert_location import (
60
63
  set_global_expert_location_metadata,
61
64
  )
62
65
  from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
66
+ from sglang.srt.layers.attention.attention_registry import (
67
+ ATTENTION_BACKENDS,
68
+ attn_backend_wrapper,
69
+ )
63
70
  from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
64
71
  from sglang.srt.layers.dp_attention import (
65
72
  get_attention_tp_group,
@@ -94,6 +101,7 @@ from sglang.srt.mem_cache.memory_pool import (
94
101
  HybridReqToTokenPool,
95
102
  MHATokenToKVPool,
96
103
  MLATokenToKVPool,
104
+ NSATokenToKVPool,
97
105
  ReqToTokenPool,
98
106
  SWAKVPool,
99
107
  )
@@ -103,6 +111,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
103
111
  from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
104
112
  from sglang.srt.model_loader import get_model
105
113
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
114
+ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
115
+ trigger_init_weights_send_group_for_remote_instance_request,
116
+ )
106
117
  from sglang.srt.model_loader.utils import set_default_torch_dtype
107
118
  from sglang.srt.model_loader.weight_utils import default_weight_loader
108
119
  from sglang.srt.offloader import (
@@ -110,10 +121,6 @@ from sglang.srt.offloader import (
110
121
  get_offloader,
111
122
  set_offloader,
112
123
  )
113
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
114
- from sglang.srt.remote_instance_weight_loader_utils import (
115
- trigger_init_weights_send_group_for_remote_instance_request,
116
- )
117
124
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
118
125
  from sglang.srt.server_args import ServerArgs
119
126
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -127,7 +134,6 @@ from sglang.srt.utils import (
127
134
  get_bool_env_var,
128
135
  get_cpu_ids_by_node,
129
136
  init_custom_process_group,
130
- is_blackwell,
131
137
  is_fa3_default_architecture,
132
138
  is_flashinfer_available,
133
139
  is_hip,
@@ -135,16 +141,38 @@ from sglang.srt.utils import (
135
141
  is_no_spec_infer_or_topk_one,
136
142
  is_npu,
137
143
  is_sm100_supported,
144
+ log_info_on_rank0,
138
145
  monkey_patch_p2p_access_check,
139
146
  monkey_patch_vllm_gguf_config,
140
- parse_connector_type,
141
147
  set_cuda_arch,
148
+ slow_rank_detector,
142
149
  )
150
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
143
151
  from sglang.srt.weight_sync.tensor_bucket import (
144
152
  FlattenedTensorBucket,
145
153
  FlattenedTensorMetadata,
146
154
  )
147
155
 
156
+ MLA_ATTENTION_BACKENDS = [
157
+ "aiter",
158
+ "flashinfer",
159
+ "fa3",
160
+ "fa4",
161
+ "triton",
162
+ "flashmla",
163
+ "cutlass_mla",
164
+ "trtllm_mla",
165
+ "ascend",
166
+ "nsa",
167
+ ]
168
+
169
+
170
+ def add_mla_attention_backend(backend_name):
171
+ if backend_name not in MLA_ATTENTION_BACKENDS:
172
+ MLA_ATTENTION_BACKENDS.append(backend_name)
173
+ logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
174
+
175
+
148
176
  _is_hip = is_hip()
149
177
  _is_npu = is_npu()
150
178
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -158,6 +186,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
158
186
  logger = logging.getLogger(__name__)
159
187
 
160
188
 
189
+ if _is_npu:
190
+ import torch_npu
191
+
192
+ torch.npu.config.allow_internal_format = True
193
+ torch_npu.npu.set_compile_mode(jit_compile=False)
194
+
195
+
161
196
  class RankZeroFilter(logging.Filter):
162
197
  """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
163
198
 
@@ -252,6 +287,9 @@ class ModelRunner:
252
287
  # CPU offload
253
288
  set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
254
289
 
290
+ if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
291
+ slow_rank_detector.execute()
292
+
255
293
  # Update deep gemm configure
256
294
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
257
295
  deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
@@ -319,7 +357,6 @@ class ModelRunner:
319
357
  if self.is_hybrid_gdn:
320
358
  logger.warning("Hybrid GDN model detected, disable radix cache")
321
359
  self.server_args.disable_radix_cache = True
322
- self.server_args.attention_backend = "hybrid_linear_attn"
323
360
  if self.server_args.max_mamba_cache_size is None:
324
361
  if self.server_args.max_running_requests is not None:
325
362
  self.server_args.max_mamba_cache_size = (
@@ -385,6 +422,12 @@ class ModelRunner:
385
422
  )
386
423
  self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
387
424
 
425
+ # Enable batch invariant mode
426
+ if server_args.enable_deterministic_inference:
427
+ from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
428
+
429
+ enable_batch_invariant_mode()
430
+
388
431
  # Init memory pool and attention backends
389
432
  self.init_memory_pool(
390
433
  min_per_gpu_memory,
@@ -496,9 +539,7 @@ class ModelRunner:
496
539
  elif _is_hip:
497
540
  head_num = self.model_config.get_num_kv_heads(self.tp_size)
498
541
  # TODO current aiter only support head number 16 or 128 head number
499
- if (
500
- head_num == 128 or head_num == 16
501
- ) and self.spec_algorithm.is_none():
542
+ if head_num == 128 or head_num == 16:
502
543
  server_args.attention_backend = "aiter"
503
544
  else:
504
545
  server_args.attention_backend = "triton"
@@ -511,16 +552,7 @@ class ModelRunner:
511
552
  )
512
553
  elif self.use_mla_backend:
513
554
  if server_args.device != "cpu":
514
- if server_args.attention_backend in [
515
- "aiter",
516
- "flashinfer",
517
- "fa3",
518
- "triton",
519
- "flashmla",
520
- "cutlass_mla",
521
- "trtllm_mla",
522
- "ascend",
523
- ]:
555
+ if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
524
556
  logger.info(
525
557
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
526
558
  )
@@ -562,18 +594,6 @@ class ModelRunner:
562
594
  if not self.use_mla_backend:
563
595
  server_args.disable_chunked_prefix_cache = True
564
596
 
565
- # TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
566
- # For more details, see: https://github.com/sgl-project/sglang/issues/8616
567
- elif (
568
- self.dp_size > 1
569
- and is_sm100_supported()
570
- and server_args.attention_backend != "triton"
571
- and server_args.attention_backend == "trtllm_mla"
572
- ):
573
- logger.info(
574
- "Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
575
- )
576
- server_args.disable_chunked_prefix_cache = True
577
597
  if not server_args.disable_chunked_prefix_cache:
578
598
  logger.info("Chunked prefix cache is turned on.")
579
599
 
@@ -599,7 +619,7 @@ class ModelRunner:
599
619
  server_args.hicache_io_backend = "direct"
600
620
  logger.warning(
601
621
  "FlashAttention3 decode backend is not compatible with hierarchical cache. "
602
- f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
622
+ "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
603
623
  )
604
624
 
605
625
  def init_torch_distributed(self):
@@ -634,6 +654,7 @@ class ModelRunner:
634
654
  dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
635
655
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
636
656
  set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
657
+ set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
637
658
 
638
659
  if not self.is_draft_worker:
639
660
  if self.device == "cpu":
@@ -730,6 +751,10 @@ class ModelRunner:
730
751
  load_format=self.server_args.load_format,
731
752
  download_dir=self.server_args.download_dir,
732
753
  model_loader_extra_config=self.server_args.model_loader_extra_config,
754
+ tp_rank=self.tp_rank,
755
+ remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
756
+ remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
757
+ remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
733
758
  )
734
759
  if self.device == "cpu":
735
760
  self.model_config = adjust_config_with_unaligned_cpu_tp(
@@ -757,7 +782,10 @@ class ModelRunner:
757
782
  monkey_patch_vllm_parallel_state()
758
783
  monkey_patch_isinstance_for_vllm_base_layer()
759
784
 
760
- with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
785
+ with self.memory_saver_adapter.region(
786
+ GPU_MEMORY_TYPE_WEIGHTS,
787
+ enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
788
+ ):
761
789
  self.model = get_model(
762
790
  model_config=self.model_config,
763
791
  load_config=self.load_config,
@@ -1035,6 +1063,19 @@ class ModelRunner:
1035
1063
  logger.error(message)
1036
1064
  return False, message
1037
1065
 
1066
+ def destroy_weights_update_group(self, group_name):
1067
+ try:
1068
+ if group_name in self._model_update_group:
1069
+ pg = self._model_update_group.pop(group_name)
1070
+ torch.distributed.destroy_process_group(pg)
1071
+ return True, "Succeeded to destroy custom process group."
1072
+ else:
1073
+ return False, "The group to be destroyed does not exist."
1074
+ except Exception as e:
1075
+ message = f"Failed to destroy custom process group: {e}."
1076
+ logger.error(message)
1077
+ return False, message
1078
+
1038
1079
  def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
1039
1080
  """
1040
1081
  Update specific parameter in the model weights online
@@ -1072,7 +1113,7 @@ class ModelRunner:
1072
1113
  handle.wait()
1073
1114
 
1074
1115
  self.model.load_weights(weights)
1075
- return True, f"Succeeded to update parameter online."
1116
+ return True, "Succeeded to update parameter online."
1076
1117
 
1077
1118
  except Exception as e:
1078
1119
  error_msg = (
@@ -1176,6 +1217,7 @@ class ModelRunner:
1176
1217
  max_lora_rank=self.server_args.max_lora_rank,
1177
1218
  target_modules=self.server_args.lora_target_modules,
1178
1219
  lora_paths=self.server_args.lora_paths,
1220
+ server_args=self.server_args,
1179
1221
  )
1180
1222
 
1181
1223
  def load_lora_adapter(self, lora_ref: LoRARef):
@@ -1260,6 +1302,7 @@ class ModelRunner:
1260
1302
  return self.model_config.hf_config.architectures[0] in [
1261
1303
  "Qwen3NextForCausalLM",
1262
1304
  "Qwen3NextForCausalLMMTP",
1305
+ "FalconH1ForCausalLM",
1263
1306
  ]
1264
1307
 
1265
1308
  def set_num_token_hybrid(self):
@@ -1352,7 +1395,18 @@ class ModelRunner:
1352
1395
  ):
1353
1396
  # Determine the kv cache dtype
1354
1397
  if self.server_args.kv_cache_dtype == "auto":
1355
- self.kv_cache_dtype = self.dtype
1398
+ quant_config = getattr(self.model, "quant_config", None)
1399
+ kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
1400
+ if (
1401
+ isinstance(kv_cache_quant_algo, str)
1402
+ and kv_cache_quant_algo.upper() == "FP8"
1403
+ ):
1404
+ if _is_hip:
1405
+ self.kv_cache_dtype = torch.float8_e4m3fnuz
1406
+ else:
1407
+ self.kv_cache_dtype = torch.float8_e4m3fn
1408
+ else:
1409
+ self.kv_cache_dtype = self.dtype
1356
1410
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1357
1411
  if _is_hip: # Using natively supported format
1358
1412
  self.kv_cache_dtype = torch.float8_e5m2fnuz
@@ -1368,6 +1422,8 @@ class ModelRunner:
1368
1422
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
1369
1423
  )
1370
1424
 
1425
+ log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
1426
+
1371
1427
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1372
1428
  if SGLANG_CI_SMALL_KV_SIZE:
1373
1429
  self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
@@ -1385,7 +1441,7 @@ class ModelRunner:
1385
1441
  if self.is_hybrid_gdn:
1386
1442
  max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
1387
1443
 
1388
- if not self.spec_algorithm.is_none():
1444
+ if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1389
1445
  if self.is_draft_worker:
1390
1446
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1391
1447
  max_num_reqs = self.server_args.max_num_reqs
@@ -1438,7 +1494,8 @@ class ModelRunner:
1438
1494
 
1439
1495
  if self.max_total_num_tokens <= 0:
1440
1496
  raise RuntimeError(
1441
- "Not enough memory. Please try to increase --mem-fraction-static."
1497
+ f"Not enough memory. Please try to increase --mem-fraction-static. "
1498
+ f"Current value: {self.server_args.mem_fraction_static=}"
1442
1499
  )
1443
1500
 
1444
1501
  # Initialize req_to_token_pool
@@ -1497,6 +1554,7 @@ class ModelRunner:
1497
1554
  assert self.is_draft_worker
1498
1555
 
1499
1556
  # Initialize token_to_kv_pool
1557
+ is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
1500
1558
  if self.server_args.attention_backend == "ascend":
1501
1559
  if self.use_mla_backend:
1502
1560
  self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1505,6 +1563,7 @@ class ModelRunner:
1505
1563
  dtype=self.kv_cache_dtype,
1506
1564
  kv_lora_rank=self.model_config.kv_lora_rank,
1507
1565
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1566
+ index_head_dim=self.model_config.index_head_dim,
1508
1567
  layer_num=self.num_effective_layers,
1509
1568
  device=self.device,
1510
1569
  enable_memory_saver=self.server_args.enable_memory_saver,
@@ -1524,7 +1583,22 @@ class ModelRunner:
1524
1583
  device=self.device,
1525
1584
  enable_memory_saver=self.server_args.enable_memory_saver,
1526
1585
  )
1586
+ elif self.use_mla_backend and is_nsa_model:
1587
+ self.token_to_kv_pool = NSATokenToKVPool(
1588
+ self.max_total_num_tokens,
1589
+ page_size=self.page_size,
1590
+ dtype=self.kv_cache_dtype,
1591
+ kv_lora_rank=self.model_config.kv_lora_rank,
1592
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1593
+ layer_num=self.num_effective_layers,
1594
+ device=self.device,
1595
+ enable_memory_saver=self.server_args.enable_memory_saver,
1596
+ start_layer=self.start_layer,
1597
+ end_layer=self.end_layer,
1598
+ index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
1599
+ )
1527
1600
  elif self.use_mla_backend:
1601
+ assert not is_nsa_model
1528
1602
  self.token_to_kv_pool = MLATokenToKVPool(
1529
1603
  self.max_total_num_tokens,
1530
1604
  page_size=self.page_size,
@@ -1568,7 +1642,7 @@ class ModelRunner:
1568
1642
  )
1569
1643
  elif self.is_hybrid_gdn:
1570
1644
  self.token_to_kv_pool = HybridLinearKVPool(
1571
- page_size=self.page_size if _is_npu else 1,
1645
+ page_size=self.page_size,
1572
1646
  size=self.max_total_num_tokens,
1573
1647
  dtype=self.kv_cache_dtype,
1574
1648
  head_num=self.model_config.get_num_kv_heads(
@@ -1603,10 +1677,9 @@ class ModelRunner:
1603
1677
  # Initialize token_to_kv_pool_allocator
1604
1678
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1605
1679
  if self.token_to_kv_pool_allocator is None:
1606
- if _is_npu and self.server_args.attention_backend in [
1607
- "ascend",
1608
- "hybrid_linear_attn",
1609
- ]:
1680
+ if _is_npu and (
1681
+ self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
1682
+ ):
1610
1683
  self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1611
1684
  self.max_total_num_tokens,
1612
1685
  page_size=self.page_size,
@@ -1700,8 +1773,8 @@ class ModelRunner:
1700
1773
  f"prefill_backend={self.prefill_attention_backend_str}."
1701
1774
  )
1702
1775
  logger.warning(
1703
- f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1704
- f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1776
+ "Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1777
+ "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1705
1778
  )
1706
1779
  else:
1707
1780
  attn_backend = self._get_attention_backend_from_str(
@@ -1717,140 +1790,10 @@ class ModelRunner:
1717
1790
  return attn_backend
1718
1791
 
1719
1792
  def _get_attention_backend_from_str(self, backend_str: str):
1720
- if backend_str == "flashinfer":
1721
- if not self.use_mla_backend:
1722
- from sglang.srt.layers.attention.flashinfer_backend import (
1723
- FlashInferAttnBackend,
1724
- )
1725
-
1726
- # Init streams
1727
- if self.server_args.speculative_algorithm == "EAGLE":
1728
- if (
1729
- not hasattr(self, "plan_stream_for_flashinfer")
1730
- or not self.plan_stream_for_flashinfer
1731
- ):
1732
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
1733
- return FlashInferAttnBackend(self)
1734
- else:
1735
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
1736
- FlashInferMLAAttnBackend,
1737
- )
1738
-
1739
- return FlashInferMLAAttnBackend(self)
1740
- elif backend_str == "aiter":
1741
- from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1742
-
1743
- return AiterAttnBackend(self)
1744
- elif self.server_args.attention_backend == "wave":
1745
- from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
1746
-
1747
- return WaveAttnBackend(self)
1748
- elif backend_str == "ascend":
1749
- from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1750
-
1751
- return AscendAttnBackend(self)
1752
- elif backend_str == "triton":
1753
- assert not self.model_config.is_encoder_decoder, (
1754
- "Cross attention is not supported in the triton attention backend. "
1755
- "Please use `--attention-backend flashinfer`."
1756
- )
1757
- if self.server_args.enable_double_sparsity:
1758
- from sglang.srt.layers.attention.double_sparsity_backend import (
1759
- DoubleSparseAttnBackend,
1760
- )
1761
-
1762
- return DoubleSparseAttnBackend(self)
1763
- else:
1764
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
1765
-
1766
- return TritonAttnBackend(self)
1767
- elif backend_str == "torch_native":
1768
- from sglang.srt.layers.attention.torch_native_backend import (
1769
- TorchNativeAttnBackend,
1770
- )
1771
-
1772
- return TorchNativeAttnBackend(self)
1773
- elif backend_str == "flashmla":
1774
- from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
1775
-
1776
- return FlashMLABackend(self)
1777
- elif backend_str == "fa3":
1778
- assert (
1779
- torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
1780
- ) or torch.cuda.get_device_capability()[0] == 9, (
1781
- "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1782
- "Please use `--attention-backend flashinfer`."
1783
- )
1784
- from sglang.srt.layers.attention.flashattention_backend import (
1785
- FlashAttentionBackend,
1786
- )
1787
-
1788
- return FlashAttentionBackend(self)
1789
- elif backend_str == "cutlass_mla":
1790
- from sglang.srt.layers.attention.cutlass_mla_backend import (
1791
- CutlassMLABackend,
1792
- )
1793
-
1794
- return CutlassMLABackend(self)
1795
- elif backend_str == "trtllm_mla":
1796
- if not self.use_mla_backend:
1797
- raise ValueError("trtllm_mla backend can only be used with MLA models.")
1798
- from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1799
-
1800
- return TRTLLMMLABackend(self)
1801
- elif backend_str == "trtllm_mha":
1802
- if self.use_mla_backend:
1803
- raise ValueError(
1804
- "trtllm_mha backend can only be used with non-MLA models."
1805
- )
1806
- from sglang.srt.layers.attention.trtllm_mha_backend import (
1807
- TRTLLMHAAttnBackend,
1808
- )
1809
-
1810
- return TRTLLMHAAttnBackend(self)
1811
- elif backend_str == "intel_amx":
1812
- from sglang.srt.layers.attention.intel_amx_backend import (
1813
- IntelAMXAttnBackend,
1814
- )
1815
-
1816
- return IntelAMXAttnBackend(self)
1817
- elif backend_str == "dual_chunk_flash_attn":
1818
- from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1819
- DualChunkFlashAttentionBackend,
1820
- )
1821
-
1822
- return DualChunkFlashAttentionBackend(self)
1823
- elif backend_str == "hybrid_linear_attn":
1824
- assert (
1825
- self.is_hybrid_gdn
1826
- ), "hybrid_linear_attn backend can only be used with hybrid GDN models."
1827
- from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
1828
- HybridLinearAttnBackend,
1829
- MambaAttnBackend,
1830
- )
1831
-
1832
- if _is_npu:
1833
- from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1834
-
1835
- full_attn_backend = AscendAttnBackend(self)
1836
- elif is_blackwell():
1837
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
1838
-
1839
- full_attn_backend = TritonAttnBackend(self)
1840
- else:
1841
- from sglang.srt.layers.attention.flashattention_backend import (
1842
- FlashAttentionBackend,
1843
- )
1844
-
1845
- full_attn_backend = FlashAttentionBackend(self)
1846
-
1847
- linear_attn_backend = MambaAttnBackend(self)
1848
- full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
1849
- return HybridLinearAttnBackend(
1850
- full_attn_backend, linear_attn_backend, full_attn_layers
1851
- )
1852
- else:
1793
+ if backend_str not in ATTENTION_BACKENDS:
1853
1794
  raise ValueError(f"Invalid attention backend: {backend_str}")
1795
+ full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
1796
+ return attn_backend_wrapper(self, full_attention_backend)
1854
1797
 
1855
1798
  def init_double_sparsity_channel_config(self, selected_channel):
1856
1799
  selected_channel = "." + selected_channel + "_proj"
@@ -2147,7 +2090,6 @@ class ModelRunner:
2147
2090
  )
2148
2091
 
2149
2092
  self._preprocess_logits(logits_output, forward_batch.sampling_info)
2150
-
2151
2093
  # Sample the next tokens
2152
2094
  next_token_ids = self.sampler(
2153
2095
  logits_output,
@@ -2155,6 +2097,12 @@ class ModelRunner:
2155
2097
  forward_batch.return_logprob,
2156
2098
  forward_batch.top_logprobs_nums,
2157
2099
  forward_batch.token_ids_logprobs,
2100
+ # For prefill, we only use the position of the last token.
2101
+ (
2102
+ forward_batch.positions
2103
+ if forward_batch.forward_mode.is_decode()
2104
+ else forward_batch.seq_lens - 1
2105
+ ),
2158
2106
  )
2159
2107
  return next_token_ids
2160
2108
 
@@ -19,8 +19,10 @@ import logging
19
19
  import threading
20
20
  from typing import TYPE_CHECKING, Optional, Union
21
21
 
22
+ import numpy as np
22
23
  import torch
23
24
 
25
+ from sglang.srt.configs.model_config import AttentionArch
24
26
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
25
27
 
26
28
  logger = logging.getLogger(__name__)
@@ -73,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
73
75
  self.positions[: self.raw_num_token].copy_(forward_batch.positions)
74
76
 
75
77
  # Replay
76
- seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
77
- thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
78
- thread.start()
79
- self.graphs[self.bs].replay()
80
- thread.join()
78
+ if self.model_runner.model_config.index_head_dim is None:
79
+ seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
80
+ self.bs - self.raw_bs
81
+ )
82
+ thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
83
+ thread.start()
84
+ self.graphs[self.bs].replay()
85
+ thread.join()
86
+ else:
87
+ self.graphs[self.bs].replay()
81
88
 
82
89
  output = self.output_buffers[self.bs]
83
90
  if isinstance(output, LogitsProcessorOutput):
@@ -54,6 +54,9 @@ from sglang.srt.distributed import (
54
54
  get_tensor_model_parallel_rank,
55
55
  get_tensor_model_parallel_world_size,
56
56
  )
57
+ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
58
+ trigger_transferring_weights_request,
59
+ )
57
60
  from sglang.srt.model_loader.utils import (
58
61
  get_model_architecture,
59
62
  post_load_weights,
@@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import (
77
80
  safetensors_weights_iterator,
78
81
  set_runai_streamer_env,
79
82
  )
80
- from sglang.srt.remote_instance_weight_loader_utils import (
81
- trigger_transferring_weights_request,
82
- )
83
83
  from sglang.srt.utils import (
84
84
  get_bool_env_var,
85
85
  get_device_capability,
@@ -206,7 +206,10 @@ def _initialize_model(
206
206
  if _is_npu:
207
207
  packed_modules_mapping.update(
208
208
  {
209
- "visual": {"qkv_proj": ["qkv"]},
209
+ "visual": {
210
+ "qkv_proj": ["qkv"],
211
+ "gate_up_proj": ["gate_proj", "up_proj"],
212
+ },
210
213
  "vision_model": {
211
214
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
212
215
  "proj": ["out_proj"],
@@ -1417,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
1417
1420
  f"load format {load_config.load_format}"
1418
1421
  )
1419
1422
 
1420
- model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
1423
+ model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
1421
1424
 
1422
1425
  with set_default_torch_dtype(model_config.dtype):
1423
1426
  with torch.device(device_config.device):
@@ -1439,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader):
1439
1442
  def load_model_from_remote_instance(
1440
1443
  self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1441
1444
  ) -> nn.Module:
1445
+ load_config = self.load_config
1442
1446
  instance_ip = socket.gethostbyname(socket.gethostname())
1443
1447
  start_build_group_tic = time.time()
1444
1448
  client.build_group(
1445
1449
  gpu_id=device_config.gpu_id,
1446
- tp_rank=model_config.tp_rank,
1450
+ tp_rank=load_config.tp_rank,
1447
1451
  instance_ip=instance_ip,
1448
1452
  )
1449
1453
  torch.cuda.synchronize()
@@ -1452,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
1452
1456
  f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
1453
1457
  )
1454
1458
 
1455
- if model_config.tp_rank == 0:
1459
+ if load_config.tp_rank == 0:
1456
1460
  t = threading.Thread(
1457
1461
  target=trigger_transferring_weights_request,
1458
1462
  args=(
1459
- model_config.remote_instance_weight_loader_seed_instance_ip,
1460
- model_config.remote_instance_weight_loader_seed_instance_service_port,
1461
- model_config.remote_instance_weight_loader_send_weights_group_ports,
1463
+ load_config.remote_instance_weight_loader_seed_instance_ip,
1464
+ load_config.remote_instance_weight_loader_seed_instance_service_port,
1465
+ load_config.remote_instance_weight_loader_send_weights_group_ports,
1462
1466
  instance_ip,
1463
1467
  ),
1464
1468
  )