sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,10 @@ import logging
19
19
  import threading
20
20
  from typing import TYPE_CHECKING, Optional, Union
21
21
 
22
+ import numpy as np
22
23
  import torch
23
24
 
25
+ from sglang.srt.configs.model_config import AttentionArch
24
26
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
25
27
 
26
28
  logger = logging.getLogger(__name__)
@@ -73,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
73
75
  self.positions[: self.raw_num_token].copy_(forward_batch.positions)
74
76
 
75
77
  # Replay
76
- seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
77
- thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
78
- thread.start()
79
- self.graphs[self.bs].replay()
80
- thread.join()
78
+ if self.model_runner.model_config.index_head_dim is None:
79
+ seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
80
+ self.bs - self.raw_bs
81
+ )
82
+ thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
83
+ thread.start()
84
+ self.graphs[self.bs].replay()
85
+ thread.join()
86
+ else:
87
+ self.graphs[self.bs].replay()
81
88
 
82
89
  output = self.output_buffers[self.bs]
83
90
  if isinstance(output, LogitsProcessorOutput):
@@ -1,16 +1,22 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
3
7
  from torch import nn
4
8
 
5
- from sglang.srt.configs.device_config import DeviceConfig
6
- from sglang.srt.configs.load_config import LoadConfig
7
- from sglang.srt.configs.model_config import ModelConfig
8
9
  from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
9
10
  from sglang.srt.model_loader.utils import (
10
11
  get_architecture_class_name,
11
12
  get_model_architecture,
12
13
  )
13
14
 
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.configs.device_config import DeviceConfig
17
+ from sglang.srt.configs.load_config import LoadConfig
18
+ from sglang.srt.configs.model_config import ModelConfig
19
+
14
20
 
15
21
  def get_model(
16
22
  *,
@@ -1,5 +1,7 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  # ruff: noqa: SIM117
4
6
  import collections
5
7
  import concurrent
@@ -10,14 +12,29 @@ import json
10
12
  import logging
11
13
  import math
12
14
  import os
15
+ import re
16
+ import socket
17
+ import threading
13
18
  import time
14
19
  from abc import ABC, abstractmethod
15
20
  from concurrent.futures import ThreadPoolExecutor
16
21
  from contextlib import contextmanager
17
- from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
22
+ from typing import (
23
+ TYPE_CHECKING,
24
+ Any,
25
+ Dict,
26
+ Generator,
27
+ Iterable,
28
+ List,
29
+ Optional,
30
+ Tuple,
31
+ cast,
32
+ )
33
+ from urllib.parse import urlparse
18
34
 
19
35
  import huggingface_hub
20
36
  import numpy as np
37
+ import requests
21
38
  import safetensors.torch
22
39
  import torch
23
40
  from huggingface_hub import HfApi, hf_hub_download
@@ -26,9 +43,7 @@ from tqdm.auto import tqdm
26
43
  from transformers import AutoModelForCausalLM
27
44
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
28
45
 
29
- from sglang.srt.configs.device_config import DeviceConfig
30
46
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
31
- from sglang.srt.configs.model_config import ModelConfig
32
47
  from sglang.srt.connector import (
33
48
  ConnectorType,
34
49
  create_remote_connector,
@@ -39,7 +54,9 @@ from sglang.srt.distributed import (
39
54
  get_tensor_model_parallel_rank,
40
55
  get_tensor_model_parallel_world_size,
41
56
  )
42
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
+ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
58
+ trigger_transferring_weights_request,
59
+ )
43
60
  from sglang.srt.model_loader.utils import (
44
61
  get_model_architecture,
45
62
  post_load_weights,
@@ -47,6 +64,7 @@ from sglang.srt.model_loader.utils import (
47
64
  )
48
65
  from sglang.srt.model_loader.weight_utils import (
49
66
  _BAR_FORMAT,
67
+ default_weight_loader,
50
68
  download_safetensors_index_file_from_hf,
51
69
  download_weights_from_hf,
52
70
  filter_duplicate_safetensors_files,
@@ -70,6 +88,11 @@ from sglang.srt.utils import (
70
88
  set_weight_attrs,
71
89
  )
72
90
 
91
+ if TYPE_CHECKING:
92
+ from sglang.srt.configs.device_config import DeviceConfig
93
+ from sglang.srt.configs.model_config import ModelConfig
94
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
95
+
73
96
  _is_npu = is_npu()
74
97
 
75
98
 
@@ -183,7 +206,10 @@ def _initialize_model(
183
206
  if _is_npu:
184
207
  packed_modules_mapping.update(
185
208
  {
186
- "visual": {"qkv_proj": ["qkv"]},
209
+ "visual": {
210
+ "qkv_proj": ["qkv"],
211
+ "gate_up_proj": ["gate_proj", "up_proj"],
212
+ },
187
213
  "vision_model": {
188
214
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
189
215
  "proj": ["out_proj"],
@@ -1366,6 +1392,105 @@ class GGUFModelLoader(BaseModelLoader):
1366
1392
  return model
1367
1393
 
1368
1394
 
1395
+ class RemoteInstanceModelLoader(BaseModelLoader):
1396
+ """Model loader that can load Tensors from remote sglang instance."""
1397
+
1398
+ def __init__(self, load_config: LoadConfig):
1399
+ super().__init__(load_config)
1400
+ if load_config.model_loader_extra_config:
1401
+ raise ValueError(
1402
+ f"Model loader extra config is not supported for "
1403
+ f"load format {load_config.load_format}"
1404
+ )
1405
+
1406
+ def download_model(self, model_config: ModelConfig) -> None:
1407
+ raise NotImplementedError
1408
+
1409
+ def load_model(
1410
+ self,
1411
+ *,
1412
+ model_config: ModelConfig,
1413
+ device_config: DeviceConfig,
1414
+ ) -> nn.Module:
1415
+ logger.info("Loading weights from remote instance ...")
1416
+ load_config = self.load_config
1417
+
1418
+ assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
1419
+ f"Model loader {self.load_config.load_format} is not supported for "
1420
+ f"load format {load_config.load_format}"
1421
+ )
1422
+
1423
+ model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
1424
+
1425
+ with set_default_torch_dtype(model_config.dtype):
1426
+ with torch.device(device_config.device):
1427
+ model = _initialize_model(model_config, self.load_config)
1428
+
1429
+ with create_remote_connector(model_weights, device_config.device) as client:
1430
+ connector_type = get_connector_type(client)
1431
+ if connector_type == ConnectorType.INSTANCE:
1432
+ self.load_model_from_remote_instance(
1433
+ model, client, model_config, device_config
1434
+ )
1435
+ else:
1436
+ raise ValueError(
1437
+ f"Unsupported connector type {connector_type} for "
1438
+ f"remote tensor model loading."
1439
+ )
1440
+ return model.eval()
1441
+
1442
+ def load_model_from_remote_instance(
1443
+ self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1444
+ ) -> nn.Module:
1445
+ load_config = self.load_config
1446
+ instance_ip = socket.gethostbyname(socket.gethostname())
1447
+ start_build_group_tic = time.time()
1448
+ client.build_group(
1449
+ gpu_id=device_config.gpu_id,
1450
+ tp_rank=load_config.tp_rank,
1451
+ instance_ip=instance_ip,
1452
+ )
1453
+ torch.cuda.synchronize()
1454
+ end_build_group_tic = time.time()
1455
+ logger.debug(
1456
+ f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
1457
+ )
1458
+
1459
+ if load_config.tp_rank == 0:
1460
+ t = threading.Thread(
1461
+ target=trigger_transferring_weights_request,
1462
+ args=(
1463
+ load_config.remote_instance_weight_loader_seed_instance_ip,
1464
+ load_config.remote_instance_weight_loader_seed_instance_service_port,
1465
+ load_config.remote_instance_weight_loader_send_weights_group_ports,
1466
+ instance_ip,
1467
+ ),
1468
+ )
1469
+ t.start()
1470
+
1471
+ start_get_weights_tic = time.time()
1472
+ with set_default_torch_dtype(model_config.dtype):
1473
+ for _, tensor in model.named_parameters():
1474
+ torch.distributed.broadcast(
1475
+ tensor.data,
1476
+ src=0,
1477
+ group=client._model_update_group,
1478
+ )
1479
+ torch.cuda.synchronize()
1480
+
1481
+ if hasattr(model, "post_load_weights"):
1482
+ model.post_load_weights()
1483
+ end_get_weights_tic = time.time()
1484
+ logger.debug(
1485
+ f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
1486
+ )
1487
+ # destroy the process group after loading weights
1488
+ torch.distributed.distributed_c10d.destroy_process_group(
1489
+ client._model_update_group
1490
+ )
1491
+ torch.cuda.empty_cache()
1492
+
1493
+
1369
1494
  class RemoteModelLoader(BaseModelLoader):
1370
1495
  """Model loader that can load Tensors from remote database."""
1371
1496
 
@@ -1567,4 +1692,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1567
1692
  if load_config.load_format == LoadFormat.REMOTE:
1568
1693
  return RemoteModelLoader(load_config)
1569
1694
 
1695
+ if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
1696
+ return RemoteInstanceModelLoader(load_config)
1697
+
1570
1698
  return DefaultModelLoader(load_config)
@@ -0,0 +1,69 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ from typing import List
5
+
6
+ import requests
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def trigger_init_weights_send_group_for_remote_instance_request(
12
+ remote_instance_weight_loader_seed_instance_ip: str,
13
+ remote_instance_weight_loader_seed_instance_service_port: int,
14
+ remote_instance_weight_loader_send_weights_group_ports: List[int],
15
+ remote_instance_weight_loader_client_id: str,
16
+ ):
17
+ seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
18
+ # Only support loading weights from instance with same parallelism strategy.
19
+ # Per TP rank pair between seed and dst instances will build a communication group for sending weights.
20
+ # i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
21
+ # Each communication group will have a world size 2.
22
+ try:
23
+ requests.post(
24
+ f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
25
+ json={
26
+ "master_address": remote_instance_weight_loader_seed_instance_ip,
27
+ "ports": (
28
+ ",".join(
29
+ str(p)
30
+ for p in remote_instance_weight_loader_send_weights_group_ports
31
+ )
32
+ ),
33
+ "group_rank": 0,
34
+ "world_size": 2,
35
+ "group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
36
+ "backend": "nccl",
37
+ },
38
+ )
39
+ except Exception as e:
40
+ logger.error(
41
+ f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
42
+ )
43
+ raise
44
+
45
+
46
+ def trigger_transferring_weights_request(
47
+ remote_instance_weight_loader_seed_instance_ip: str,
48
+ remote_instance_weight_loader_seed_instance_service_port: int,
49
+ remote_instance_weight_loader_send_weights_group_ports: List[int],
50
+ remote_instance_weight_loader_client_id: str,
51
+ ):
52
+ seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
53
+ try:
54
+ requests.post(
55
+ f"{seed_instance_service_url}/send_weights_to_remote_instance",
56
+ json={
57
+ "master_address": remote_instance_weight_loader_seed_instance_ip,
58
+ "ports": (
59
+ ",".join(
60
+ str(p)
61
+ for p in remote_instance_weight_loader_send_weights_group_ports
62
+ )
63
+ ),
64
+ "group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
65
+ },
66
+ )
67
+ except Exception as e:
68
+ logger.error(f"Failed to trigger send weights to remote instance request: {e}")
69
+ raise
@@ -8,7 +8,7 @@ import hashlib
8
8
  import json
9
9
  import logging
10
10
  import os
11
- import queue
11
+ import re
12
12
  import tempfile
13
13
  from collections import defaultdict
14
14
  from typing import (
@@ -35,9 +35,11 @@ from tqdm.auto import tqdm
35
35
  from sglang.srt.configs.load_config import LoadConfig
36
36
  from sglang.srt.configs.model_config import ModelConfig
37
37
  from sglang.srt.distributed import get_tensor_model_parallel_rank
38
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank
38
39
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
39
40
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
40
- from sglang.srt.utils import print_warning_once
41
+ from sglang.srt.utils import find_local_repo_dir, print_warning_once
42
+ from sglang.utils import is_in_ci
41
43
 
42
44
  logger = logging.getLogger(__name__)
43
45
 
@@ -235,6 +237,149 @@ def get_quant_config(
235
237
  return quant_cls.from_config(config)
236
238
 
237
239
 
240
+ def find_local_hf_snapshot_dir(
241
+ model_name_or_path: str,
242
+ cache_dir: Optional[str],
243
+ allow_patterns: List[str],
244
+ revision: Optional[str] = None,
245
+ ) -> Optional[str]:
246
+ """If the weights are already local, skip downloading and returns the path."""
247
+ if os.path.isdir(model_name_or_path):
248
+ return None
249
+
250
+ found_local_snapshot_dir = None
251
+
252
+ # Check custom cache_dir (if provided)
253
+ if cache_dir:
254
+ try:
255
+ repo_folder = os.path.join(
256
+ cache_dir,
257
+ huggingface_hub.constants.REPO_ID_SEPARATOR.join(
258
+ ["models", *model_name_or_path.split("/")]
259
+ ),
260
+ )
261
+ rev_to_use = revision
262
+ if not rev_to_use:
263
+ ref_main = os.path.join(repo_folder, "refs", "main")
264
+ if os.path.isfile(ref_main):
265
+ with open(ref_main) as f:
266
+ rev_to_use = f.read().strip()
267
+ if rev_to_use:
268
+ rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
269
+ if os.path.isdir(rev_dir):
270
+ found_local_snapshot_dir = rev_dir
271
+ except Exception as e:
272
+ logger.warning(
273
+ "Failed to find local snapshot in custom cache_dir %s: %s",
274
+ cache_dir,
275
+ e,
276
+ )
277
+
278
+ # Check default HF cache as well
279
+ if not found_local_snapshot_dir:
280
+ try:
281
+ rev_dir = find_local_repo_dir(model_name_or_path, revision)
282
+ if rev_dir and os.path.isdir(rev_dir):
283
+ found_local_snapshot_dir = rev_dir
284
+ except Exception as e:
285
+ logger.warning("Failed to find local snapshot in default HF cache: %s", e)
286
+
287
+ # if any incomplete file exists, force re-download by returning None
288
+ if found_local_snapshot_dir:
289
+ repo_folder = os.path.abspath(
290
+ os.path.join(found_local_snapshot_dir, "..", "..")
291
+ )
292
+ blobs_dir = os.path.join(repo_folder, "blobs")
293
+ if os.path.isdir(blobs_dir) and glob.glob(
294
+ os.path.join(blobs_dir, "*.incomplete")
295
+ ):
296
+ logger.info(
297
+ "Found .incomplete files in %s for %s. "
298
+ "Considering local snapshot incomplete.",
299
+ blobs_dir,
300
+ model_name_or_path,
301
+ )
302
+ return None
303
+
304
+ # if local snapshot exists, validate it contains at least one weight file
305
+ # matching allow_patterns before skipping download.
306
+ if found_local_snapshot_dir is None:
307
+ return None
308
+
309
+ local_weight_files: List[str] = []
310
+ try:
311
+ for pattern in allow_patterns:
312
+ matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
313
+ for f in matched_files:
314
+ # os.path.exists returns False for broken symlinks.
315
+ if not os.path.exists(f):
316
+ continue
317
+ local_weight_files.append(f)
318
+ except Exception as e:
319
+ logger.warning(
320
+ "Failed to scan local snapshot %s with patterns %s: %s",
321
+ found_local_snapshot_dir,
322
+ allow_patterns,
323
+ e,
324
+ )
325
+ local_weight_files = []
326
+
327
+ # After we have a list of valid files, check for sharded model completeness.
328
+ # Check if all safetensors with name model-{i}-of-{n}.safetensors exists
329
+ checked_sharded_model = False
330
+ for f in local_weight_files:
331
+ if checked_sharded_model:
332
+ break
333
+ base_name = os.path.basename(f)
334
+ # Regex for files like model-00001-of-00009.safetensors
335
+ match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
336
+ if match:
337
+ prefix = match.group(1)
338
+ shard_id_str = match.group(2)
339
+ total_shards_str = match.group(3)
340
+ suffix = match.group(4)
341
+ total_shards = int(total_shards_str)
342
+
343
+ # Check if all shards are present
344
+ missing_shards = []
345
+ for i in range(1, total_shards + 1):
346
+ # Reconstruct shard name, preserving padding of original shard id
347
+ shard_name = (
348
+ f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
349
+ )
350
+ expected_path = os.path.join(found_local_snapshot_dir, shard_name)
351
+ # os.path.exists returns False for broken symlinks, which is desired.
352
+ if not os.path.exists(expected_path):
353
+ missing_shards.append(shard_name)
354
+
355
+ if missing_shards:
356
+ logger.info(
357
+ "Found incomplete sharded model %s. Missing shards: %s. "
358
+ "Will attempt download.",
359
+ model_name_or_path,
360
+ missing_shards,
361
+ )
362
+ return None
363
+
364
+ # If we found and verified one set of shards, we are done.
365
+ checked_sharded_model = True
366
+
367
+ if len(local_weight_files) > 0:
368
+ logger.info(
369
+ "Found local HF snapshot for %s at %s; skipping download.",
370
+ model_name_or_path,
371
+ found_local_snapshot_dir,
372
+ )
373
+ return found_local_snapshot_dir
374
+ else:
375
+ logger.info(
376
+ "Local HF snapshot at %s has no files matching %s; will attempt download.",
377
+ found_local_snapshot_dir,
378
+ allow_patterns,
379
+ )
380
+ return None
381
+
382
+
238
383
  def download_weights_from_hf(
239
384
  model_name_or_path: str,
240
385
  cache_dir: Optional[str],
@@ -259,6 +404,16 @@ def download_weights_from_hf(
259
404
  Returns:
260
405
  str: The path to the downloaded model weights.
261
406
  """
407
+
408
+ if is_in_ci():
409
+ # If the weights are already local, skip downloading and returns the path.
410
+ # This is used to skip too-many Huggingface API calls in CI.
411
+ path = find_local_hf_snapshot_dir(
412
+ model_name_or_path, cache_dir, allow_patterns, revision
413
+ )
414
+ if path is not None:
415
+ return path
416
+
262
417
  if not huggingface_hub.constants.HF_HUB_OFFLINE:
263
418
  # Before we download we look at that is available:
264
419
  fs = HfFileSystem()
@@ -680,7 +835,7 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
680
835
  """Create a weight loader that shards the weights along the given axis"""
681
836
 
682
837
  def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
683
- tp_rank = get_tensor_model_parallel_rank()
838
+ tp_rank = get_attention_tp_rank()
684
839
 
685
840
  shard_size = param.data.shape[shard_axis]
686
841
  start_idx = tp_rank * shard_size