sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
60
60
  initialize_dp_attention,
61
61
  )
62
62
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
63
- from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
64
63
  from sglang.srt.layers.quantization import (
65
64
  deep_gemm_wrapper,
66
65
  monkey_patch_isinstance_for_vllm_base_layer,
@@ -92,10 +91,16 @@ from sglang.srt.mem_cache.memory_pool import (
92
91
  )
93
92
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
94
93
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
94
+ from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
95
95
  from sglang.srt.model_loader import get_model
96
96
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
97
97
  from sglang.srt.model_loader.utils import set_default_torch_dtype
98
98
  from sglang.srt.model_loader.weight_utils import default_weight_loader
99
+ from sglang.srt.offloader import (
100
+ create_offloader_from_server_args,
101
+ get_offloader,
102
+ set_offloader,
103
+ )
99
104
  from sglang.srt.patch_torch import monkey_patch_torch_reductions
100
105
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
101
106
  from sglang.srt.server_args import ServerArgs
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
118
123
  is_npu,
119
124
  monkey_patch_p2p_access_check,
120
125
  monkey_patch_vllm_gguf_config,
121
- set_cpu_offload_max_bytes,
122
126
  set_cuda_arch,
123
127
  )
124
128
  from sglang.srt.weight_sync.tensor_bucket import (
@@ -168,6 +172,7 @@ class ModelRunner:
168
172
  pp_size: int,
169
173
  nccl_port: int,
170
174
  server_args: ServerArgs,
175
+ dp_rank: Optional[int] = None,
171
176
  is_draft_worker: bool = False,
172
177
  req_to_token_pool: Optional[ReqToTokenPool] = None,
173
178
  token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
@@ -219,14 +224,9 @@ class ModelRunner:
219
224
  # TODO it is indeed not a "server args"
220
225
  "use_mla_backend": self.use_mla_backend,
221
226
  "speculative_algorithm": self.spec_algorithm,
222
- "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
223
- "deepep_mode": DeepEPMode(server_args.deepep_mode),
224
227
  }
225
228
  )
226
229
 
227
- # CPU offload
228
- set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
229
-
230
230
  # Init OpenMP threads binding for CPU
231
231
  if self.device == "cpu":
232
232
  self.init_threads_binding()
@@ -234,6 +234,9 @@ class ModelRunner:
234
234
  # Get memory before model loading
235
235
  min_per_gpu_memory = self.init_torch_distributed()
236
236
 
237
+ # CPU offload
238
+ set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
239
+
237
240
  # Update deep gemm configure
238
241
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
239
242
  deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
@@ -309,8 +312,13 @@ class ModelRunner:
309
312
  self.start_layer = getattr(self.model, "start_layer", 0)
310
313
  self.end_layer = getattr(self.model, "end_layer", model_num_layers)
311
314
  self.num_effective_layers = self.end_layer - self.start_layer
312
- assert (not model_has_mtp_layers) or (
313
- self.num_effective_layers == model_num_layers
315
+ assert (
316
+ (not model_has_mtp_layers)
317
+ or (self.spec_algorithm.is_none())
318
+ or (
319
+ (not self.spec_algorithm.is_none())
320
+ and (self.num_effective_layers == model_num_layers)
321
+ )
314
322
  ), "PP is not compatible with MTP models."
315
323
 
316
324
  # Apply torchao quantization
@@ -339,9 +347,12 @@ class ModelRunner:
339
347
  if self.device == "cuda":
340
348
  self.init_cublas()
341
349
  self.init_attention_backend()
342
- self.init_cuda_graphs()
350
+ self.init_device_graphs()
351
+ elif self.device == "npu":
352
+ self.init_attention_backend()
353
+ self.init_device_graphs()
343
354
  else:
344
- self.cuda_graph_runner = None
355
+ self.graph_runner = None
345
356
  self.cuda_graph_mem_usage = 0
346
357
  self.init_attention_backend()
347
358
 
@@ -508,9 +519,6 @@ class ModelRunner:
508
519
 
509
520
  if not self.use_mla_backend:
510
521
  server_args.disable_chunked_prefix_cache = True
511
- elif self.page_size > 1:
512
- logger.info("Disable chunked prefix cache when page size > 1.")
513
- server_args.disable_chunked_prefix_cache = True
514
522
 
515
523
  if not server_args.disable_chunked_prefix_cache:
516
524
  logger.info("Chunked prefix cache is turned on.")
@@ -684,6 +692,8 @@ class ModelRunner:
684
692
  monkey_patch_vllm_parallel_state(reverse=True)
685
693
  monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
686
694
 
695
+ get_offloader().post_init()
696
+
687
697
  if self.server_args.kv_cache_dtype == "fp8_e4m3":
688
698
  if self.server_args.quantization_param_path is not None:
689
699
  if callable(getattr(self.model, "load_kv_cache_scales", None)):
@@ -915,7 +925,8 @@ class ModelRunner:
915
925
  )
916
926
 
917
927
  # We need to get device after patch otherwise the device would be wrong
918
- infered_device = torch.cuda.current_device()
928
+ self.device_module = torch.get_device_module(self.device)
929
+ infered_device = self.device_module.current_device()
919
930
 
920
931
  named_tensors = [
921
932
  (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
@@ -1046,8 +1057,6 @@ class ModelRunner:
1046
1057
  else:
1047
1058
  num_layers = self.num_effective_layers
1048
1059
  if self.use_mla_backend:
1049
- # FIXME: pipeline parallelism is not compatible with mla backend
1050
- assert self.pp_size == 1
1051
1060
  cell_size = (
1052
1061
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
1053
1062
  * num_layers
@@ -1236,6 +1245,11 @@ class ModelRunner:
1236
1245
 
1237
1246
  # Initialize req_to_token_pool
1238
1247
  if self.req_to_token_pool is None:
1248
+ # FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
1249
+ extra_max_context_len = 4
1250
+ if self.server_args.speculative_num_draft_tokens is not None:
1251
+ extra_max_context_len += self.server_args.speculative_num_draft_tokens
1252
+
1239
1253
  if self.server_args.disaggregation_mode == "decode":
1240
1254
  from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
1241
1255
 
@@ -1244,7 +1258,8 @@ class ModelRunner:
1244
1258
  pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
1245
1259
  self.req_to_token_pool = DecodeReqToTokenPool(
1246
1260
  size=max_num_reqs,
1247
- max_context_len=self.model_config.context_len + 4,
1261
+ max_context_len=self.model_config.context_len
1262
+ + extra_max_context_len,
1248
1263
  device=self.device,
1249
1264
  enable_memory_saver=self.server_args.enable_memory_saver,
1250
1265
  pre_alloc_size=pre_alloc_size,
@@ -1252,7 +1267,8 @@ class ModelRunner:
1252
1267
  else:
1253
1268
  self.req_to_token_pool = ReqToTokenPool(
1254
1269
  size=max_num_reqs,
1255
- max_context_len=self.model_config.context_len + 4,
1270
+ max_context_len=self.model_config.context_len
1271
+ + extra_max_context_len,
1256
1272
  device=self.device,
1257
1273
  enable_memory_saver=self.server_args.enable_memory_saver,
1258
1274
  )
@@ -1348,11 +1364,6 @@ class ModelRunner:
1348
1364
 
1349
1365
  # Initialize token_to_kv_pool_allocator
1350
1366
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1351
- max_num_extend_tokens = (
1352
- self.server_args.chunked_prefill_size
1353
- if self.server_args.chunked_prefill_size > 0
1354
- else self.server_args.max_prefill_tokens
1355
- )
1356
1367
  if self.token_to_kv_pool_allocator is None:
1357
1368
  if self.server_args.attention_backend == "ascend":
1358
1369
  self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
@@ -1391,7 +1402,6 @@ class ModelRunner:
1391
1402
  device=self.device,
1392
1403
  kvcache=self.token_to_kv_pool,
1393
1404
  need_sort=need_sort,
1394
- max_num_extend_tokens=max_num_extend_tokens,
1395
1405
  )
1396
1406
  else:
1397
1407
  assert self.is_draft_worker
@@ -1591,9 +1601,9 @@ class ModelRunner:
1591
1601
  .cuda()
1592
1602
  )
1593
1603
 
1594
- def init_cuda_graphs(self):
1604
+ def init_device_graphs(self):
1595
1605
  """Capture cuda graphs."""
1596
- self.cuda_graph_runner = None
1606
+ self.graph_runner = None
1597
1607
  self.cuda_graph_mem_usage = 0
1598
1608
 
1599
1609
  if not self.is_generation:
@@ -1608,8 +1618,9 @@ class ModelRunner:
1608
1618
  logger.info(
1609
1619
  f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1610
1620
  )
1611
- self.cuda_graph_runner = CudaGraphRunner(self)
1612
-
1621
+ self.graph_runner = (
1622
+ CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
1623
+ )
1613
1624
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1614
1625
  self.cuda_graph_mem_usage = before_mem - after_mem
1615
1626
  logger.info(
@@ -1761,11 +1772,11 @@ class ModelRunner:
1761
1772
  ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1762
1773
  can_run_cuda_graph = bool(
1763
1774
  forward_batch.forward_mode.is_cuda_graph()
1764
- and self.cuda_graph_runner
1765
- and self.cuda_graph_runner.can_run(forward_batch)
1775
+ and self.graph_runner
1776
+ and self.graph_runner.can_run(forward_batch)
1766
1777
  )
1767
1778
  if can_run_cuda_graph:
1768
- ret = self.cuda_graph_runner.replay(
1779
+ ret = self.graph_runner.replay(
1769
1780
  forward_batch,
1770
1781
  skip_attn_backend_init=skip_attn_backend_init,
1771
1782
  pp_proxy_tensors=pp_proxy_tensors,
@@ -0,0 +1,94 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Run the model with npu graph and torch.compile."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ import threading
20
+ from typing import TYPE_CHECKING, Optional, Union
21
+
22
+ import torch
23
+
24
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from sglang.srt.model_executor.model_runner import ModelRunner
30
+
31
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
32
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
33
+
34
+
35
+ class NPUGraphRunner(CudaGraphRunner):
36
+ """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
37
+
38
+ def __init__(self, model_runner: ModelRunner):
39
+ super().__init__(model_runner)
40
+
41
+ def _create_device_graph(self):
42
+ return torch.npu.NPUGraph()
43
+
44
+ def _capture_graph(self, graph, pool, stream, run_once_fn):
45
+ with torch.npu.graph(
46
+ graph,
47
+ pool=pool,
48
+ stream=stream,
49
+ auto_dispatch_capture=True,
50
+ ):
51
+ out = run_once_fn()
52
+ return out
53
+
54
+ def _update_inputs(self, seq_lens):
55
+ self.graphs[self.bs].update(
56
+ cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}]
57
+ )
58
+
59
+ def _cache_loc_dtype(self):
60
+ return torch.int32
61
+
62
+ def replay(
63
+ self,
64
+ forward_batch: ForwardBatch,
65
+ skip_attn_backend_init: bool = False,
66
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
67
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
68
+ if not skip_attn_backend_init:
69
+ self.replay_prepare(forward_batch, pp_proxy_tensors)
70
+ else:
71
+ # In speculative decoding, these two fields are still needed.
72
+ self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
73
+ self.positions[: self.raw_num_token].copy_(forward_batch.positions)
74
+
75
+ # Replay
76
+ seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
77
+ thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
78
+ thread.start()
79
+ self.graphs[self.bs].replay()
80
+ thread.join()
81
+
82
+ output = self.output_buffers[self.bs]
83
+ if isinstance(output, LogitsProcessorOutput):
84
+ return LogitsProcessorOutput(
85
+ next_token_logits=output.next_token_logits[: self.raw_num_token],
86
+ hidden_states=(
87
+ output.hidden_states[: self.raw_num_token]
88
+ if output.hidden_states is not None
89
+ else None
90
+ ),
91
+ )
92
+ else:
93
+ assert isinstance(output, PPProxyTensors)
94
+ return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
79
79
  yield module
80
80
  return
81
81
 
82
- original_device_states: Dict[str, torch.device] = {}
82
+ original_infos: Dict[str, Dict] = {}
83
83
 
84
84
  # Store original device states and move parameters to GPU if they're on CPU
85
85
  for name, p in module.named_parameters():
86
86
  if p.device.type == "cpu":
87
- original_device_states[name] = p.device
88
- p.data = p.data.to(target_device)
87
+ original_data = p.data
88
+ device_data = p.data.to(target_device)
89
+ original_infos[name] = dict(
90
+ device=p.device,
91
+ original_data=original_data,
92
+ device_data=device_data,
93
+ )
94
+ p.data = device_data
89
95
  # Parameters already on target device are not touched
90
96
 
91
97
  try:
@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
95
101
  # Restore parameters to their original devices, ignoring new parameters
96
102
  pin_memory = is_pin_memory_available()
97
103
  for name, p in module.named_parameters():
98
- if name in original_device_states:
99
- original_device: torch.device = original_device_states[name]
100
- if original_device.type == "cpu":
104
+ if name in original_infos:
105
+ original_info = original_infos[name]
106
+ device_data = original_info["device_data"]
107
+ original_data = original_info["original_data"]
108
+ original_device: torch.device = original_info["device"]
109
+
110
+ if (
111
+ (device_data.device == p.data.device)
112
+ and (device_data.data_ptr() == p.data.data_ptr())
113
+ and (device_data.shape == p.data.shape)
114
+ and (device_data.dtype == p.data.dtype)
115
+ ):
116
+ original_data.copy_(p.data.to(original_data.device))
117
+ p.data = original_data
118
+ elif original_device.type == "cpu":
101
119
  # `torch.empty_like` does not support `pin_memory` argument
102
120
  cpu_data = torch.empty_strided(
103
121
  size=p.data.size(),
sglang/srt/models/dbrx.py CHANGED
@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
32
32
  RowParallelLinear,
33
33
  )
34
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
35
- from sglang.srt.layers.moe.fused_moe_triton import fused_moe
35
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
36
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
37
+ from sglang.srt.layers.moe.topk import TopK
36
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
39
  from sglang.srt.layers.radix_attention import RadixAttention
38
40
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
104
106
  self.params_dtype = params_dtype
105
107
 
106
108
  self.router = DbrxRouter(config, self.params_dtype)
109
+ self.topk = TopK(
110
+ self.top_k,
111
+ renormalize=True,
112
+ )
113
+ self.moe_runner_config = MoeRunnerConfig(inplace=True)
107
114
  self.ws = nn.Parameter(
108
115
  torch.empty(
109
116
  self.num_total_experts,
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
169
176
  hidden_states = hidden_states.view(-1, self.d_model)
170
177
  # router_logits: (num_tokens, n_experts)
171
178
  router_logits = self.router(hidden_states)
179
+ topk_output = self.topk(hidden_states, router_logits)
172
180
  final_hidden_states = fused_moe(
173
181
  hidden_states,
174
182
  self.ws,
175
183
  self.w2s,
176
- router_logits,
177
- self.top_k,
178
- renormalize=True,
179
- inplace=True,
184
+ topk_output,
185
+ self.moe_runner_config,
180
186
  )
181
187
 
182
188
  if self.tp_size > 1:
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
293
299
  position_ids: torch.Tensor,
294
300
  hidden_states: torch.Tensor,
295
301
  forward_batch: ForwardBatch,
296
- ) -> torch.Tensor:
302
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
297
303
  residual = hidden_states
298
304
  hidden_states = self.norm_1(hidden_states)
299
305
  x = self.attn(
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
37
37
  )
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
40
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
40
41
  from sglang.srt.layers.moe.topk import TopK
41
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
43
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
180
181
  w1=self.w1,
181
182
  w2=self.w2,
182
183
  topk_output=topk_output,
183
- inplace=True,
184
+ moe_runner_config=MoeRunnerConfig(inplace=True),
184
185
  )
185
186
 
186
187
  if self.config.n_shared_experts is not None:
@@ -20,7 +20,7 @@ import torch
20
20
  from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
 
23
- from sglang.srt.distributed import get_tensor_model_parallel_world_size
23
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
24
24
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
25
25
  from sglang.srt.layers.dp_attention import is_dp_attention_enabled
26
26
  from sglang.srt.layers.layernorm import RMSNorm
@@ -135,6 +135,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
135
135
  self.config = config
136
136
  self.tp_size = get_tensor_model_parallel_world_size()
137
137
  self.quant_config = quant_config
138
+ # if not set, model load will be broken in DeepseekV3ForCausalLM load_weights()
139
+ self.pp_group = get_pp_group()
138
140
  self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN")
139
141
 
140
142
  self.model = DeepseekModelNextN(