sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,22 @@ import torch
9
9
  import triton
10
10
  import triton.language as tl
11
11
 
12
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
13
+ from sglang.srt.utils.common import calc_diff, get_bool_env_var
14
+
15
+ if ENABLE_JIT_DEEPGEMM:
16
+ import deep_gemm
17
+
18
+ _ENABLE_MM_DEEPGEMM = get_bool_env_var(
19
+ "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM", "1"
20
+ )
21
+ _ENABLE_MM_COMPARISON_TEST = get_bool_env_var(
22
+ "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
23
+ )
24
+
25
+ if not _ENABLE_MM_DEEPGEMM:
26
+ print("Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.")
27
+
12
28
  __all__ = [
13
29
  "set_batch_invariant_mode",
14
30
  "is_batch_invariant_mode_enabled",
@@ -140,7 +156,7 @@ def matmul_kernel_persistent(
140
156
  tl.store(c_ptrs, c, mask=c_mask)
141
157
 
142
158
 
143
- def matmul_persistent(
159
+ def _matmul_persistent_triton(
144
160
  a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
145
161
  ):
146
162
  # Check constraints.
@@ -217,6 +233,54 @@ def matmul_persistent(
217
233
  return c
218
234
 
219
235
 
236
+ def _matmul_persistent_deepgemm(
237
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
238
+ ):
239
+ M, K = a.shape
240
+ K, N = b.shape
241
+ dtype = a.dtype
242
+ out = torch.empty((M, N), device=a.device, dtype=dtype)
243
+
244
+ deep_gemm.bf16_gemm_nn(a, b, out)
245
+
246
+ # TODO can this be put in DeepGEMM's `c`?
247
+ if bias is not None:
248
+ out += bias
249
+
250
+ return out
251
+
252
+
253
+ def matmul_persistent(
254
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
255
+ ):
256
+ if (
257
+ _ENABLE_MM_DEEPGEMM
258
+ and ENABLE_JIT_DEEPGEMM
259
+ and (a.dtype == torch.bfloat16)
260
+ and (b.dtype == torch.bfloat16)
261
+ and a.is_contiguous()
262
+ and b.transpose(0, 1).is_contiguous()
263
+ ):
264
+ if _ENABLE_MM_COMPARISON_TEST:
265
+ out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)
266
+ out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
267
+ diff = calc_diff(out_triton, out_deepgemm)
268
+ assert diff < 0.0001, f"{diff=} {out_triton=} {out_deepgemm=}"
269
+ # can be enabled for debugging
270
+ # print(
271
+ # f"{diff=} "
272
+ # f"{(out_triton - out_deepgemm).abs().mean()=} "
273
+ # f"{(out_triton - out_deepgemm).abs().sum()=} "
274
+ # f"{torch.sum(out_triton != out_deepgemm)=} "
275
+ # )
276
+ # print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
277
+ return out_deepgemm
278
+
279
+ return _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
280
+
281
+ return _matmul_persistent_triton(a=a, b=b, bias=bias)
282
+
283
+
220
284
  @triton.jit
221
285
  def _log_softmax_kernel(
222
286
  input_ptr,
@@ -495,16 +559,159 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
495
559
  return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
496
560
 
497
561
 
562
+ def bmm_batch_invariant(a, b, *, out=None):
563
+ # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
564
+ # Process each batch separately with our persistent kernel
565
+ if a.ndim == 3 and b.ndim == 3:
566
+ results = []
567
+ for i in range(a.shape[0]):
568
+ results.append(matmul_persistent(a[i], b[i]))
569
+ result = torch.stack(results, dim=0)
570
+
571
+ if out is not None:
572
+ out.copy_(result)
573
+ return out
574
+ return result
575
+ else:
576
+ raise ValueError(
577
+ f"bmm_batch_invariant expects 3D tensors, "
578
+ f"got shapes {a.shape} and {b.shape}"
579
+ )
580
+
581
+
582
+ @triton.jit
583
+ def _rms_norm_kernel(
584
+ input_ptr,
585
+ weight_ptr,
586
+ output_ptr,
587
+ input_row_stride,
588
+ output_row_stride,
589
+ n_cols,
590
+ eps,
591
+ BLOCK_SIZE: tl.constexpr,
592
+ ):
593
+ """
594
+ Compute RMS normalization along the last dimension of a 2D tensor.
595
+ RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
596
+ Each block handles one row of the input tensor.
597
+ """
598
+ row_idx = tl.program_id(0).to(tl.int64)
599
+ row_start_ptr = input_ptr + row_idx * input_row_stride
600
+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
601
+
602
+ # Step 1: Compute sum of squares in float32 to avoid overflow
603
+ sum_sq = tl.zeros([1], dtype=tl.float32)
604
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
605
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
606
+ mask = col_idx < n_cols
607
+
608
+ vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
609
+ # Convert to float32 for accumulation to prevent overflow
610
+ vals_f32 = vals.to(tl.float32)
611
+ sq_vals = vals_f32 * vals_f32
612
+ sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
613
+
614
+ # Step 2: Compute RMS (root mean square) in float32
615
+ mean_sq = sum_sq / n_cols
616
+ rms = tl.sqrt(mean_sq + eps)
617
+ inv_rms = 1.0 / rms
618
+
619
+ # Step 3: Normalize and apply weight
620
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
621
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
622
+ mask = col_idx < n_cols
623
+ vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
624
+ weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
625
+ # Compute in float32 then convert back to input dtype
626
+ vals_f32 = vals.to(tl.float32)
627
+ weight_f32 = weight.to(tl.float32)
628
+ output_f32 = vals_f32 * inv_rms * weight_f32
629
+ output = output_f32.to(vals.dtype)
630
+ tl.store(output_row_start_ptr + col_idx, output, mask=mask)
631
+
632
+
633
+ def rms_norm(
634
+ input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
635
+ ) -> torch.Tensor:
636
+ """
637
+ Compute RMS normalization using Triton kernel.
638
+
639
+ RMS Norm normalizes the input by the root mean square and scales by weight:
640
+ output = input / sqrt(mean(input^2) + eps) * weight
641
+
642
+ Args:
643
+ input: Input tensor of shape (..., hidden_size)
644
+ weight: Weight tensor of shape (hidden_size,)
645
+ eps: Small constant for numerical stability
646
+
647
+ Returns:
648
+ Tensor with RMS normalization applied along the last dimension
649
+ """
650
+ assert weight.dim() == 1, "Weight must be 1-dimensional"
651
+ assert input.shape[-1] == weight.shape[0], (
652
+ f"Input last dimension ({input.shape[-1]}) must match "
653
+ f"weight dimension ({weight.shape[0]})"
654
+ )
655
+
656
+ # Flatten all dimensions except the last one
657
+ original_shape = input.shape
658
+ input_2d = input.reshape(-1, input.shape[-1])
659
+ input_2d = input_2d.contiguous()
660
+ weight = weight.contiguous()
661
+
662
+ n_rows, n_cols = input_2d.shape
663
+
664
+ output = torch.empty_like(input_2d)
665
+ BLOCK_SIZE = 1024
666
+ grid = (n_rows,)
667
+ _rms_norm_kernel[grid](
668
+ input_2d,
669
+ weight,
670
+ output,
671
+ input_2d.stride(0),
672
+ output.stride(0),
673
+ n_cols,
674
+ eps,
675
+ BLOCK_SIZE=BLOCK_SIZE,
676
+ )
677
+ return output.reshape(original_shape)
678
+
679
+
680
+ def rms_norm_batch_invariant(
681
+ input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
682
+ ) -> torch.Tensor:
683
+ """
684
+ Batch-invariant wrapper for RMS normalization.
685
+
686
+ This function provides a deterministic, batch-invariant implementation
687
+ of RMS normalization for use with the batch_invariant mode.
688
+
689
+ Adapted from @https://github.com/vllm-project/vllm/blob/66a168a197ba214a5b70a74fa2e713c9eeb3251a/vllm/model_executor/layers/batch_invariant.py#L649
690
+
691
+ Args:
692
+ input: Input tensor of shape (..., hidden_size)
693
+ weight: Weight tensor of shape (hidden_size,)
694
+ eps: Small constant for numerical stability
695
+
696
+ Returns:
697
+ RMS normalized tensor
698
+ """
699
+ return rms_norm(input, weight, eps=eps)
700
+
701
+
498
702
  _batch_invariant_MODE = False
499
703
  _batch_invariant_LIB = None
704
+ _original_torch_bmm = None
500
705
 
501
706
 
502
707
  def is_batch_invariant_mode_enabled():
503
708
  return _batch_invariant_MODE
504
709
 
505
710
 
506
- def enable_batch_invariant_mode():
507
- global _batch_invariant_MODE, _batch_invariant_LIB
711
+ def enable_batch_invariant_mode(
712
+ enable_bmm: bool = True,
713
+ ):
714
+ global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
508
715
  if _batch_invariant_MODE:
509
716
  return
510
717
 
@@ -517,11 +724,21 @@ def enable_batch_invariant_mode():
517
724
  )
518
725
  _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
519
726
 
727
+ if enable_bmm:
728
+ _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
729
+
730
+ # Also monkeypatch torch.bmm directly as a fallback
731
+ _original_torch_bmm = torch.bmm
732
+ torch.bmm = bmm_batch_invariant
733
+
520
734
 
521
735
  def disable_batch_invariant_mode():
522
- global _batch_invariant_MODE, _batch_invariant_LIB
736
+ global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
523
737
  if _batch_invariant_LIB is not None:
524
738
  _batch_invariant_LIB._destroy()
739
+ if _original_torch_bmm is not None:
740
+ torch.bmm = _original_torch_bmm
741
+ _original_torch_bmm = None
525
742
  _batch_invariant_MODE = False
526
743
  _batch_invariant_LIB = None
527
744
 
@@ -0,0 +1,9 @@
1
+ """
2
+ Checkpoint engine module for SGLang.
3
+
4
+ This module provides functionality for updating model weights via checkpoint engine.
5
+ """
6
+
7
+ from sglang.srt.checkpoint_engine.update import main
8
+
9
+ __all__ = ["main"]
@@ -0,0 +1,317 @@
1
+ """
2
+ Usage:
3
+ 1) Launch the server with wait-for-initial-weights option in one terminal:
4
+ python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
5
+
6
+ 2) Torchrun this script in another terminal:
7
+ torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
8
+
9
+ Or use the integrated entry point:
10
+ python -m sglang.srt.checkpoint_engine.update --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import os
16
+ import pickle
17
+ import subprocess
18
+ import sys
19
+ import time
20
+ from collections import defaultdict
21
+ from collections.abc import Callable
22
+ from contextlib import contextmanager
23
+ from typing import Literal
24
+
25
+ import httpx
26
+ import torch
27
+ import torch.distributed as dist
28
+ from safetensors import safe_open
29
+
30
+ try:
31
+ from checkpoint_engine.ps import ParameterServer
32
+ from loguru import logger
33
+ except ImportError:
34
+ # Fallback for when checkpoint_engine is not available
35
+ ParameterServer = None
36
+ import logging
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @contextmanager
42
+ def timer(msg: str):
43
+ start = time.perf_counter()
44
+ yield
45
+ end = time.perf_counter()
46
+ logger.info(f"{msg} duration: {end - start:.2f} seconds")
47
+
48
+
49
+ def check_sglang_ready(
50
+ endpoint: str, inference_parallel_size: int, uds: str | None = None
51
+ ):
52
+ rank = int(os.getenv("RANK", 0))
53
+ if rank != rank // inference_parallel_size * inference_parallel_size:
54
+ return
55
+ retry_num = 0
56
+ transport = None
57
+ if uds is not None:
58
+ transport = httpx.HTTPTransport(uds=uds)
59
+ with httpx.Client(transport=transport) as client:
60
+ while True:
61
+ try:
62
+ response = client.get(f"{endpoint}/ping", timeout=10)
63
+ response.raise_for_status()
64
+ break
65
+ except (httpx.ConnectError, httpx.HTTPStatusError) as e:
66
+ if retry_num % 10 == 0:
67
+ logger.warning(
68
+ f"fail to check sglang ready, retry {retry_num} times, error: {e}"
69
+ )
70
+ retry_num += 1
71
+ time.sleep(0.1)
72
+
73
+
74
+ def split_checkpoint_files(
75
+ checkpoint_path: str, rank: int, world_size: int
76
+ ) -> list[str]:
77
+ checkpoint_files = [
78
+ os.path.join(checkpoint_path, f)
79
+ for f in filter(
80
+ lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
81
+ )
82
+ ]
83
+ files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
84
+ return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
85
+
86
+
87
+ def split_tensors(
88
+ checkpoint_path: str, rank: int, world_size: int
89
+ ) -> dict[str, torch.Tensor]:
90
+ index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
91
+ with open(index_fn) as f:
92
+ weight_map: dict[str, str] = json.load(f)["weight_map"]
93
+ weights_per_rank = (len(weight_map) + world_size - 1) // world_size
94
+ fn_tensors: dict[str, list[str]] = defaultdict(list)
95
+ weight_keys = list(weight_map.items())
96
+ for name, file in weight_keys[
97
+ rank * weights_per_rank : (rank + 1) * weights_per_rank
98
+ ]:
99
+ fn_tensors[file].append(name)
100
+ named_tensors = {}
101
+ for file, names in fn_tensors.items():
102
+ with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
103
+ for name in names:
104
+ named_tensors[name] = f.get_tensor(name)
105
+ return named_tensors
106
+
107
+
108
+ def req_inference(
109
+ endpoint: str,
110
+ inference_parallel_size: int,
111
+ timeout: float = 300.0,
112
+ uds: str | None = None,
113
+ weight_version: str | None = None,
114
+ ) -> Callable[[list[tuple[str, str]]], None]:
115
+ rank = int(os.getenv("RANK", 0))
116
+ src = rank // inference_parallel_size * inference_parallel_size
117
+
118
+ def req_func(socket_paths: list[tuple[str, str]]):
119
+ if rank == src:
120
+ with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
121
+ resp = client.post(
122
+ f"{endpoint}/update_weights_from_ipc",
123
+ json={
124
+ "zmq_handles": dict(
125
+ socket_paths[src : src + inference_parallel_size]
126
+ ),
127
+ "flush_cache": True,
128
+ "weight_version": weight_version,
129
+ },
130
+ timeout=timeout,
131
+ )
132
+ resp.raise_for_status()
133
+
134
+ return req_func
135
+
136
+
137
+ def update_weights(
138
+ ps,
139
+ checkpoint_name: str,
140
+ checkpoint_files: list[str],
141
+ named_tensors: dict[str, torch.Tensor],
142
+ req_func: Callable[[list[tuple[str, str]]], None],
143
+ inference_parallel_size: int,
144
+ endpoint: str,
145
+ save_metas_file: str | None = None,
146
+ update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
147
+ uds: str | None = None,
148
+ ):
149
+ ps.register_checkpoint(
150
+ checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
151
+ )
152
+ ps.init_process_group()
153
+ check_sglang_ready(endpoint, inference_parallel_size, uds)
154
+ dist.barrier()
155
+ with timer("Gather metas"):
156
+ ps.gather_metas(checkpoint_name)
157
+ if save_metas_file and int(os.getenv("RANK")) == 0:
158
+ with open(save_metas_file, "wb") as f:
159
+ pickle.dump(ps.get_metas(), f)
160
+
161
+ if update_method == "broadcast" or update_method == "all":
162
+ with timer("Update weights without setting ranks"):
163
+ ps.update(checkpoint_name, req_func)
164
+
165
+ if update_method == "p2p" or update_method == "all":
166
+ if update_method:
167
+ # sleep 2s to wait destroy process group
168
+ time.sleep(2)
169
+ with timer("Update weights with setting ranks"):
170
+ ps.update(
171
+ checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
172
+ )
173
+
174
+
175
+ def join(
176
+ ps: ParameterServer,
177
+ checkpoint_name: str,
178
+ load_metas_file: str,
179
+ req_func: Callable[[list[tuple[str, str]]], None],
180
+ inference_parallel_size: int,
181
+ endpoint: str,
182
+ uds: str | None = None,
183
+ ):
184
+ assert load_metas_file, "load_metas_file is required"
185
+ with open(load_metas_file, "rb") as f:
186
+ metas = pickle.load(f)
187
+ ps.init_process_group()
188
+ check_sglang_ready(endpoint, inference_parallel_size, uds)
189
+ dist.barrier()
190
+ with timer("Gather metas before join"):
191
+ ps.gather_metas(checkpoint_name)
192
+ ps.load_metas(metas)
193
+ with timer(
194
+ f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
195
+ ):
196
+ ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
197
+
198
+
199
+ def run_with_torchrun():
200
+ """Run the update script with torchrun automatically."""
201
+ # Parse inference_parallel_size from command line arguments to determine nproc-per-node
202
+ inference_parallel_size = 8 # default
203
+ args = sys.argv[1:] # Skip the script name
204
+
205
+ # Look for --inference-parallel-size in arguments
206
+ for i, arg in enumerate(args):
207
+ if arg == "--inference-parallel-size" and i + 1 < len(args):
208
+ try:
209
+ inference_parallel_size = int(args[i + 1])
210
+ except ValueError:
211
+ pass
212
+ break
213
+ elif arg.startswith("--inference-parallel-size="):
214
+ try:
215
+ inference_parallel_size = int(arg.split("=", 1)[1])
216
+ except ValueError:
217
+ pass
218
+ break
219
+
220
+ # Build torchrun command
221
+ cmd = ["torchrun", f"--nproc-per-node={inference_parallel_size}", __file__] + args
222
+
223
+ print(f"Running: {' '.join(cmd)}", file=sys.stderr)
224
+
225
+ # Execute torchrun with the original script
226
+ try:
227
+ result = subprocess.run(cmd, check=False)
228
+ sys.exit(result.returncode)
229
+ except FileNotFoundError:
230
+ print(
231
+ "Error: torchrun command not found. Please ensure PyTorch is installed.",
232
+ file=sys.stderr,
233
+ )
234
+ sys.exit(1)
235
+ except KeyboardInterrupt:
236
+ print("\nInterrupted by user", file=sys.stderr)
237
+ sys.exit(130)
238
+
239
+
240
+ def main():
241
+ # Check if we're running under torchrun or need to invoke it
242
+ if os.getenv("RANK") is None:
243
+ # Not running under torchrun, so invoke it
244
+ run_with_torchrun()
245
+ return
246
+
247
+ # Running under torchrun, proceed with normal execution
248
+ parser = argparse.ArgumentParser(description="Update weights example")
249
+ parser.add_argument("--checkpoint-path", type=str, default=None)
250
+ parser.add_argument("--save-metas-file", type=str, default=None)
251
+ parser.add_argument("--load-metas-file", type=str, default=None)
252
+ parser.add_argument("--sleep-time", type=int, default=0)
253
+ parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
254
+ parser.add_argument("--inference-parallel-size", type=int, default=8)
255
+ parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
256
+ parser.add_argument("--update-method", type=str, default="broadcast")
257
+ parser.add_argument("--uds", type=str, default=None)
258
+ parser.add_argument("--weight-version", type=str, default=None)
259
+ args = parser.parse_args()
260
+
261
+ # Get rank and world_size from environment (set by torchrun)
262
+ rank = int(os.getenv("RANK", 0))
263
+ world_size = int(os.getenv("WORLD_SIZE", 1))
264
+
265
+ req_func = req_inference(
266
+ args.endpoint,
267
+ args.inference_parallel_size,
268
+ uds=args.uds,
269
+ weight_version=args.weight_version,
270
+ )
271
+
272
+ if ParameterServer is None:
273
+ print("Error: checkpoint_engine package not available", file=sys.stderr)
274
+ sys.exit(1)
275
+
276
+ ps = ParameterServer(auto_pg=True)
277
+ ps._p2p_store = None
278
+ if args.load_metas_file:
279
+ join(
280
+ ps,
281
+ args.checkpoint_name,
282
+ args.load_metas_file,
283
+ req_func,
284
+ args.inference_parallel_size,
285
+ args.endpoint,
286
+ args.uds,
287
+ )
288
+ else:
289
+ if args.checkpoint_path and os.path.exists(
290
+ os.path.join(args.checkpoint_path, "model.safetensors.index.json")
291
+ ):
292
+ named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
293
+ checkpoint_files = []
294
+ else:
295
+ checkpoint_files = (
296
+ split_checkpoint_files(args.checkpoint_path, rank, world_size)
297
+ if args.checkpoint_path
298
+ else []
299
+ )
300
+ named_tensors = {}
301
+ update_weights(
302
+ ps,
303
+ args.checkpoint_name,
304
+ checkpoint_files,
305
+ named_tensors,
306
+ req_func,
307
+ args.inference_parallel_size,
308
+ args.endpoint,
309
+ args.save_metas_file,
310
+ args.update_method,
311
+ args.uds,
312
+ )
313
+ time.sleep(args.sleep_time)
314
+
315
+
316
+ if __name__ == "__main__":
317
+ main()
@@ -392,7 +392,7 @@ class SGLangBackend:
392
392
  self.configure_post_pass()
393
393
 
394
394
  self.split_gm, self.piecewise_graphs = split_graph(
395
- graph, ["sglang.unified_attention_with_output"]
395
+ graph, ["sglang.unified_attention_with_output", "sglang.inplace_all_reduce"]
396
396
  )
397
397
 
398
398
  from torch._dynamo.utils import lazy_format_graph_code
@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig
6
6
  from sglang.srt.configs.exaone import ExaoneConfig
7
7
  from sglang.srt.configs.falcon_h1 import FalconH1Config
8
8
  from sglang.srt.configs.janus_pro import MultiModalityConfig
9
+ from sglang.srt.configs.kimi_linear import KimiLinearConfig
9
10
  from sglang.srt.configs.kimi_vl import KimiVLConfig
10
11
  from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
11
12
  from sglang.srt.configs.longcat_flash import LongcatFlashConfig
@@ -31,6 +32,7 @@ __all__ = [
31
32
  "Step3TextConfig",
32
33
  "Step3VisionEncoderConfig",
33
34
  "Olmo3Config",
35
+ "KimiLinearConfig",
34
36
  "Qwen3NextConfig",
35
37
  "DotsVLMConfig",
36
38
  "DotsOCRConfig",