sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -40,9 +40,10 @@ import triton.language as tl
40
40
 
41
41
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
42
42
  from sglang.srt.layers.dp_attention import (
43
- DPPaddingMode,
43
+ DpPaddingMode,
44
44
  get_attention_dp_rank,
45
45
  get_attention_tp_size,
46
+ set_dp_buffer_len,
46
47
  )
47
48
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
48
49
  from sglang.srt.utils import (
@@ -240,6 +241,9 @@ class ForwardBatch:
240
241
  prefix_chunk_num_tokens: Optional[List[int]] = None
241
242
  # KV Indices for each chunk
242
243
  prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
244
+ # For MLA chunked prefix cache used in chunked prefill
245
+ # Tell attention backend whether lse needs to be returned
246
+ mha_return_lse: Optional[bool] = None
243
247
 
244
248
  # For multimodal
245
249
  mm_inputs: Optional[List[MultimodalInputs]] = None
@@ -274,13 +278,13 @@ class ForwardBatch:
274
278
  global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
275
279
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
276
280
  # The padding mode for DP attention
277
- dp_padding_mode: Optional[DPPaddingMode] = None
281
+ dp_padding_mode: Optional[DpPaddingMode] = None
278
282
  # for extend, local start pos and num tokens is different in logits processor
279
283
  # this will be computed in get_dp_local_info
280
284
  # this will be recomputed in LogitsMetadata.from_forward_batch
281
285
  dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
282
286
  dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
283
- gathered_buffer: Optional[torch.Tensor] = None
287
+ global_dp_buffer_len: Optional[int] = None
284
288
  is_extend_in_batch: bool = False
285
289
  can_run_dp_cuda_graph: bool = False
286
290
  global_forward_mode: Optional[ForwardMode] = None
@@ -628,7 +632,7 @@ class ForwardBatch:
628
632
  (global_num_tokens[i] - 1) // attn_tp_size + 1
629
633
  ) * attn_tp_size
630
634
 
631
- dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
635
+ dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
632
636
  self.dp_padding_mode = dp_padding_mode
633
637
 
634
638
  if dp_padding_mode.is_max_len():
@@ -642,17 +646,14 @@ class ForwardBatch:
642
646
  else:
643
647
  buffer_len = sum(global_num_tokens)
644
648
 
645
- self.gathered_buffer = torch.zeros(
646
- (buffer_len, model_runner.model_config.hidden_size),
647
- dtype=model_runner.dtype,
648
- device=model_runner.device,
649
- )
650
-
651
649
  if len(global_num_tokens) > 1:
652
650
  num_tokens = global_num_tokens[get_attention_dp_rank()]
653
651
  else:
654
652
  num_tokens = global_num_tokens[0]
655
653
 
654
+ self.global_dp_buffer_len = buffer_len
655
+ set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
656
+
656
657
  bs = self.batch_size
657
658
 
658
659
  if self.forward_mode.is_decode():
@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
60
60
  initialize_dp_attention,
61
61
  )
62
62
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
63
- from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
64
63
  from sglang.srt.layers.quantization import (
65
64
  deep_gemm_wrapper,
66
65
  monkey_patch_isinstance_for_vllm_base_layer,
@@ -75,12 +74,12 @@ from sglang.srt.managers.schedule_batch import (
75
74
  global_server_args_dict,
76
75
  )
77
76
  from sglang.srt.mem_cache.allocator import (
78
- AscendPagedTokenToKVPoolAllocator,
79
77
  BaseTokenToKVPoolAllocator,
80
78
  PagedTokenToKVPoolAllocator,
81
79
  SWATokenToKVPoolAllocator,
82
80
  TokenToKVPoolAllocator,
83
81
  )
82
+ from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
84
83
  from sglang.srt.mem_cache.memory_pool import (
85
84
  AscendMLAPagedTokenToKVPool,
86
85
  AscendTokenToKVPool,
@@ -92,10 +91,16 @@ from sglang.srt.mem_cache.memory_pool import (
92
91
  )
93
92
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
94
93
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
94
+ from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
95
95
  from sglang.srt.model_loader import get_model
96
96
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
97
97
  from sglang.srt.model_loader.utils import set_default_torch_dtype
98
98
  from sglang.srt.model_loader.weight_utils import default_weight_loader
99
+ from sglang.srt.offloader import (
100
+ create_offloader_from_server_args,
101
+ get_offloader,
102
+ set_offloader,
103
+ )
99
104
  from sglang.srt.patch_torch import monkey_patch_torch_reductions
100
105
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
101
106
  from sglang.srt.server_args import ServerArgs
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
118
123
  is_npu,
119
124
  monkey_patch_p2p_access_check,
120
125
  monkey_patch_vllm_gguf_config,
121
- set_cpu_offload_max_bytes,
122
126
  set_cuda_arch,
123
127
  )
124
128
  from sglang.srt.weight_sync.tensor_bucket import (
@@ -168,6 +172,7 @@ class ModelRunner:
168
172
  pp_size: int,
169
173
  nccl_port: int,
170
174
  server_args: ServerArgs,
175
+ dp_rank: Optional[int] = None,
171
176
  is_draft_worker: bool = False,
172
177
  req_to_token_pool: Optional[ReqToTokenPool] = None,
173
178
  token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
@@ -176,10 +181,6 @@ class ModelRunner:
176
181
  self.mem_fraction_static = mem_fraction_static
177
182
  self.device = server_args.device
178
183
  self.gpu_id = gpu_id
179
-
180
- # Apply the rank zero filter to logger
181
- if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
182
- logger.addFilter(RankZeroFilter(tp_rank == 0))
183
184
  self.tp_rank = tp_rank
184
185
  self.tp_size = tp_size
185
186
  self.moe_ep_rank = moe_ep_rank
@@ -205,15 +206,17 @@ class ModelRunner:
205
206
  self.is_hybrid = model_config.is_hybrid
206
207
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
207
208
  self.attention_chunk_size = model_config.attention_chunk_size
208
-
209
209
  self.forward_pass_id = 0
210
210
 
211
- # Model-specific adjustment
212
- self.model_specific_adjustment()
213
-
211
+ # Apply the rank zero filter to logger
212
+ if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
213
+ logger.addFilter(RankZeroFilter(tp_rank == 0))
214
214
  if server_args.show_time_cost:
215
215
  enable_show_time_cost()
216
216
 
217
+ # Model-specific adjustment
218
+ self.model_specific_adjustment()
219
+
217
220
  # Global vars
218
221
  global_server_args_dict.update(
219
222
  {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@@ -222,15 +225,8 @@ class ModelRunner:
222
225
  "use_mla_backend": self.use_mla_backend,
223
226
  "speculative_algorithm": self.spec_algorithm,
224
227
  }
225
- | {
226
- "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
227
- "deepep_mode": DeepEPMode(server_args.deepep_mode),
228
- }
229
228
  )
230
229
 
231
- # CPU offload
232
- set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
233
-
234
230
  # Init OpenMP threads binding for CPU
235
231
  if self.device == "cpu":
236
232
  self.init_threads_binding()
@@ -238,17 +234,22 @@ class ModelRunner:
238
234
  # Get memory before model loading
239
235
  min_per_gpu_memory = self.init_torch_distributed()
240
236
 
237
+ # CPU offload
238
+ set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
239
+
241
240
  # Update deep gemm configure
242
241
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
243
242
  deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
244
243
 
245
- # If it is a draft model, tp_group can be different
244
+ # Initialize the model runner
246
245
  self.initialize(min_per_gpu_memory)
247
246
 
248
- # temporary cached values
247
+ # Temporary cached values
249
248
  self.support_pp = (
250
249
  "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
251
250
  )
251
+
252
+ # For weight updates
252
253
  self._model_update_group = {}
253
254
 
254
255
  def initialize(self, min_per_gpu_memory: float):
@@ -277,6 +278,7 @@ class ModelRunner:
277
278
  )
278
279
  )
279
280
 
281
+ # Expert parallelism
280
282
  self.eplb_manager = (
281
283
  EPLBManager(self)
282
284
  if self.server_args.enable_eplb and (not self.is_draft_worker)
@@ -310,8 +312,13 @@ class ModelRunner:
310
312
  self.start_layer = getattr(self.model, "start_layer", 0)
311
313
  self.end_layer = getattr(self.model, "end_layer", model_num_layers)
312
314
  self.num_effective_layers = self.end_layer - self.start_layer
313
- assert (not model_has_mtp_layers) or (
314
- self.num_effective_layers == model_num_layers
315
+ assert (
316
+ (not model_has_mtp_layers)
317
+ or (self.spec_algorithm.is_none())
318
+ or (
319
+ (not self.spec_algorithm.is_none())
320
+ and (self.num_effective_layers == model_num_layers)
321
+ )
315
322
  ), "PP is not compatible with MTP models."
316
323
 
317
324
  # Apply torchao quantization
@@ -340,9 +347,12 @@ class ModelRunner:
340
347
  if self.device == "cuda":
341
348
  self.init_cublas()
342
349
  self.init_attention_backend()
343
- self.init_cuda_graphs()
350
+ self.init_device_graphs()
351
+ elif self.device == "npu":
352
+ self.init_attention_backend()
353
+ self.init_device_graphs()
344
354
  else:
345
- self.cuda_graph_runner = None
355
+ self.graph_runner = None
346
356
  self.cuda_graph_mem_usage = 0
347
357
  self.init_attention_backend()
348
358
 
@@ -509,9 +519,6 @@ class ModelRunner:
509
519
 
510
520
  if not self.use_mla_backend:
511
521
  server_args.disable_chunked_prefix_cache = True
512
- elif self.page_size > 1:
513
- logger.info("Disable chunked prefix cache when page size > 1.")
514
- server_args.disable_chunked_prefix_cache = True
515
522
 
516
523
  if not server_args.disable_chunked_prefix_cache:
517
524
  logger.info("Chunked prefix cache is turned on.")
@@ -604,12 +611,8 @@ class ModelRunner:
604
611
  duplicate_tp_group=self.server_args.enable_pdmux,
605
612
  )
606
613
  initialize_dp_attention(
607
- enable_dp_attention=self.server_args.enable_dp_attention,
608
- tp_rank=self.tp_rank,
609
- tp_size=self.tp_size,
610
- dp_size=self.server_args.dp_size,
611
- moe_dense_tp_size=self.server_args.moe_dense_tp_size,
612
- pp_size=self.server_args.pp_size,
614
+ server_args=self.server_args,
615
+ model_config=self.model_config,
613
616
  )
614
617
 
615
618
  min_per_gpu_memory = get_available_gpu_memory(
@@ -689,6 +692,8 @@ class ModelRunner:
689
692
  monkey_patch_vllm_parallel_state(reverse=True)
690
693
  monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
691
694
 
695
+ get_offloader().post_init()
696
+
692
697
  if self.server_args.kv_cache_dtype == "fp8_e4m3":
693
698
  if self.server_args.quantization_param_path is not None:
694
699
  if callable(getattr(self.model, "load_kv_cache_scales", None)):
@@ -920,7 +925,8 @@ class ModelRunner:
920
925
  )
921
926
 
922
927
  # We need to get device after patch otherwise the device would be wrong
923
- infered_device = torch.cuda.current_device()
928
+ self.device_module = torch.get_device_module(self.device)
929
+ infered_device = self.device_module.current_device()
924
930
 
925
931
  named_tensors = [
926
932
  (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
@@ -1051,8 +1057,6 @@ class ModelRunner:
1051
1057
  else:
1052
1058
  num_layers = self.num_effective_layers
1053
1059
  if self.use_mla_backend:
1054
- # FIXME: pipeline parallelism is not compatible with mla backend
1055
- assert self.pp_size == 1
1056
1060
  cell_size = (
1057
1061
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
1058
1062
  * num_layers
@@ -1160,6 +1164,7 @@ class ModelRunner:
1160
1164
  max_num_reqs: Optional[int] = None,
1161
1165
  max_total_tokens: Optional[int] = None,
1162
1166
  ):
1167
+ # Determine the kv cache dtype
1163
1168
  if self.server_args.kv_cache_dtype == "auto":
1164
1169
  self.kv_cache_dtype = self.dtype
1165
1170
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
@@ -1178,6 +1183,8 @@ class ModelRunner:
1178
1183
  )
1179
1184
 
1180
1185
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1186
+ if SGLANG_CI_SMALL_KV_SIZE:
1187
+ self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1181
1188
 
1182
1189
  if max_num_reqs is None:
1183
1190
  max_num_reqs = min(
@@ -1190,9 +1197,6 @@ class ModelRunner:
1190
1197
  4096,
1191
1198
  )
1192
1199
 
1193
- if SGLANG_CI_SMALL_KV_SIZE:
1194
- self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1195
-
1196
1200
  if not self.spec_algorithm.is_none():
1197
1201
  if self.is_draft_worker:
1198
1202
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
@@ -1239,7 +1243,13 @@ class ModelRunner:
1239
1243
  "Not enough memory. Please try to increase --mem-fraction-static."
1240
1244
  )
1241
1245
 
1246
+ # Initialize req_to_token_pool
1242
1247
  if self.req_to_token_pool is None:
1248
+ # FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
1249
+ extra_max_context_len = 4
1250
+ if self.server_args.speculative_num_draft_tokens is not None:
1251
+ extra_max_context_len += self.server_args.speculative_num_draft_tokens
1252
+
1243
1253
  if self.server_args.disaggregation_mode == "decode":
1244
1254
  from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
1245
1255
 
@@ -1248,7 +1258,8 @@ class ModelRunner:
1248
1258
  pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
1249
1259
  self.req_to_token_pool = DecodeReqToTokenPool(
1250
1260
  size=max_num_reqs,
1251
- max_context_len=self.model_config.context_len + 4,
1261
+ max_context_len=self.model_config.context_len
1262
+ + extra_max_context_len,
1252
1263
  device=self.device,
1253
1264
  enable_memory_saver=self.server_args.enable_memory_saver,
1254
1265
  pre_alloc_size=pre_alloc_size,
@@ -1256,7 +1267,8 @@ class ModelRunner:
1256
1267
  else:
1257
1268
  self.req_to_token_pool = ReqToTokenPool(
1258
1269
  size=max_num_reqs,
1259
- max_context_len=self.model_config.context_len + 4,
1270
+ max_context_len=self.model_config.context_len
1271
+ + extra_max_context_len,
1260
1272
  device=self.device,
1261
1273
  enable_memory_saver=self.server_args.enable_memory_saver,
1262
1274
  )
@@ -1264,6 +1276,7 @@ class ModelRunner:
1264
1276
  # Draft worker shares req_to_token_pool with the target worker.
1265
1277
  assert self.is_draft_worker
1266
1278
 
1279
+ # Initialize token_to_kv_pool
1267
1280
  if self.server_args.attention_backend == "ascend":
1268
1281
  if self.use_mla_backend:
1269
1282
  self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1349,38 +1362,40 @@ class ModelRunner:
1349
1362
  end_layer=self.end_layer,
1350
1363
  )
1351
1364
 
1365
+ # Initialize token_to_kv_pool_allocator
1352
1366
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1353
1367
  if self.token_to_kv_pool_allocator is None:
1354
- if self.page_size == 1:
1355
- if self.is_hybrid:
1356
- self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
1357
- self.full_max_total_num_tokens,
1358
- self.swa_max_total_num_tokens,
1359
- dtype=self.kv_cache_dtype,
1360
- device=self.device,
1361
- kvcache=self.token_to_kv_pool,
1362
- need_sort=need_sort,
1363
- )
1364
- else:
1365
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1366
- self.max_total_num_tokens,
1367
- dtype=self.kv_cache_dtype,
1368
- device=self.device,
1369
- kvcache=self.token_to_kv_pool,
1370
- need_sort=need_sort,
1371
- )
1368
+ if self.server_args.attention_backend == "ascend":
1369
+ self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1370
+ self.max_total_num_tokens,
1371
+ page_size=self.page_size,
1372
+ dtype=self.kv_cache_dtype,
1373
+ device=self.device,
1374
+ kvcache=self.token_to_kv_pool,
1375
+ need_sort=need_sort,
1376
+ )
1372
1377
  else:
1373
- if not _is_npu:
1374
- self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1375
- self.max_total_num_tokens,
1376
- page_size=self.page_size,
1377
- dtype=self.kv_cache_dtype,
1378
- device=self.device,
1379
- kvcache=self.token_to_kv_pool,
1380
- need_sort=need_sort,
1381
- )
1378
+ if self.page_size == 1:
1379
+ if self.is_hybrid:
1380
+ self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
1381
+ self.full_max_total_num_tokens,
1382
+ self.swa_max_total_num_tokens,
1383
+ dtype=self.kv_cache_dtype,
1384
+ device=self.device,
1385
+ kvcache=self.token_to_kv_pool,
1386
+ need_sort=need_sort,
1387
+ )
1388
+ else:
1389
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1390
+ self.max_total_num_tokens,
1391
+ dtype=self.kv_cache_dtype,
1392
+ device=self.device,
1393
+ kvcache=self.token_to_kv_pool,
1394
+ need_sort=need_sort,
1395
+ )
1382
1396
  else:
1383
- self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1397
+ assert not self.is_hybrid
1398
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1384
1399
  self.max_total_num_tokens,
1385
1400
  page_size=self.page_size,
1386
1401
  dtype=self.kv_cache_dtype,
@@ -1554,15 +1569,13 @@ class ModelRunner:
1554
1569
  )
1555
1570
 
1556
1571
  return TRTLLMHAAttnBackend(self)
1557
-
1558
1572
  elif backend_str == "intel_amx":
1559
1573
  from sglang.srt.layers.attention.intel_amx_backend import (
1560
1574
  IntelAMXAttnBackend,
1561
1575
  )
1562
1576
 
1563
- logger.info(f"Intel AMX attention backend is enabled.")
1564
1577
  return IntelAMXAttnBackend(self)
1565
- elif self.server_args.attention_backend == "dual_chunk_flash_attn":
1578
+ elif backend_str == "dual_chunk_flash_attn":
1566
1579
  from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1567
1580
  DualChunkFlashAttentionBackend,
1568
1581
  )
@@ -1588,9 +1601,9 @@ class ModelRunner:
1588
1601
  .cuda()
1589
1602
  )
1590
1603
 
1591
- def init_cuda_graphs(self):
1604
+ def init_device_graphs(self):
1592
1605
  """Capture cuda graphs."""
1593
- self.cuda_graph_runner = None
1606
+ self.graph_runner = None
1594
1607
  self.cuda_graph_mem_usage = 0
1595
1608
 
1596
1609
  if not self.is_generation:
@@ -1605,7 +1618,9 @@ class ModelRunner:
1605
1618
  logger.info(
1606
1619
  f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1607
1620
  )
1608
- self.cuda_graph_runner = CudaGraphRunner(self)
1621
+ self.graph_runner = (
1622
+ CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
1623
+ )
1609
1624
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1610
1625
  self.cuda_graph_mem_usage = before_mem - after_mem
1611
1626
  logger.info(
@@ -1757,11 +1772,11 @@ class ModelRunner:
1757
1772
  ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1758
1773
  can_run_cuda_graph = bool(
1759
1774
  forward_batch.forward_mode.is_cuda_graph()
1760
- and self.cuda_graph_runner
1761
- and self.cuda_graph_runner.can_run(forward_batch)
1775
+ and self.graph_runner
1776
+ and self.graph_runner.can_run(forward_batch)
1762
1777
  )
1763
1778
  if can_run_cuda_graph:
1764
- ret = self.cuda_graph_runner.replay(
1779
+ ret = self.graph_runner.replay(
1765
1780
  forward_batch,
1766
1781
  skip_attn_backend_init=skip_attn_backend_init,
1767
1782
  pp_proxy_tensors=pp_proxy_tensors,
@@ -0,0 +1,94 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Run the model with npu graph and torch.compile."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ import threading
20
+ from typing import TYPE_CHECKING, Optional, Union
21
+
22
+ import torch
23
+
24
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from sglang.srt.model_executor.model_runner import ModelRunner
30
+
31
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
32
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
33
+
34
+
35
+ class NPUGraphRunner(CudaGraphRunner):
36
+ """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
37
+
38
+ def __init__(self, model_runner: ModelRunner):
39
+ super().__init__(model_runner)
40
+
41
+ def _create_device_graph(self):
42
+ return torch.npu.NPUGraph()
43
+
44
+ def _capture_graph(self, graph, pool, stream, run_once_fn):
45
+ with torch.npu.graph(
46
+ graph,
47
+ pool=pool,
48
+ stream=stream,
49
+ auto_dispatch_capture=True,
50
+ ):
51
+ out = run_once_fn()
52
+ return out
53
+
54
+ def _update_inputs(self, seq_lens):
55
+ self.graphs[self.bs].update(
56
+ cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}]
57
+ )
58
+
59
+ def _cache_loc_dtype(self):
60
+ return torch.int32
61
+
62
+ def replay(
63
+ self,
64
+ forward_batch: ForwardBatch,
65
+ skip_attn_backend_init: bool = False,
66
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
67
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
68
+ if not skip_attn_backend_init:
69
+ self.replay_prepare(forward_batch, pp_proxy_tensors)
70
+ else:
71
+ # In speculative decoding, these two fields are still needed.
72
+ self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
73
+ self.positions[: self.raw_num_token].copy_(forward_batch.positions)
74
+
75
+ # Replay
76
+ seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
77
+ thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
78
+ thread.start()
79
+ self.graphs[self.bs].replay()
80
+ thread.join()
81
+
82
+ output = self.output_buffers[self.bs]
83
+ if isinstance(output, LogitsProcessorOutput):
84
+ return LogitsProcessorOutput(
85
+ next_token_logits=output.next_token_logits[: self.raw_num_token],
86
+ hidden_states=(
87
+ output.hidden_states[: self.raw_num_token]
88
+ if output.hidden_states is not None
89
+ else None
90
+ ),
91
+ )
92
+ else:
93
+ assert isinstance(output, PPProxyTensors)
94
+ return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
79
79
  yield module
80
80
  return
81
81
 
82
- original_device_states: Dict[str, torch.device] = {}
82
+ original_infos: Dict[str, Dict] = {}
83
83
 
84
84
  # Store original device states and move parameters to GPU if they're on CPU
85
85
  for name, p in module.named_parameters():
86
86
  if p.device.type == "cpu":
87
- original_device_states[name] = p.device
88
- p.data = p.data.to(target_device)
87
+ original_data = p.data
88
+ device_data = p.data.to(target_device)
89
+ original_infos[name] = dict(
90
+ device=p.device,
91
+ original_data=original_data,
92
+ device_data=device_data,
93
+ )
94
+ p.data = device_data
89
95
  # Parameters already on target device are not touched
90
96
 
91
97
  try:
@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
95
101
  # Restore parameters to their original devices, ignoring new parameters
96
102
  pin_memory = is_pin_memory_available()
97
103
  for name, p in module.named_parameters():
98
- if name in original_device_states:
99
- original_device: torch.device = original_device_states[name]
100
- if original_device.type == "cpu":
104
+ if name in original_infos:
105
+ original_info = original_infos[name]
106
+ device_data = original_info["device_data"]
107
+ original_data = original_info["original_data"]
108
+ original_device: torch.device = original_info["device"]
109
+
110
+ if (
111
+ (device_data.device == p.data.device)
112
+ and (device_data.data_ptr() == p.data.data_ptr())
113
+ and (device_data.shape == p.data.shape)
114
+ and (device_data.dtype == p.data.dtype)
115
+ ):
116
+ original_data.copy_(p.data.to(original_data.device))
117
+ p.data = original_data
118
+ elif original_device.type == "cpu":
101
119
  # `torch.empty_like` does not support `pin_memory` argument
102
120
  cpu_data = torch.empty_strided(
103
121
  size=p.data.size(),
sglang/srt/models/dbrx.py CHANGED
@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
32
32
  RowParallelLinear,
33
33
  )
34
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
35
- from sglang.srt.layers.moe.fused_moe_triton import fused_moe
35
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
36
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
37
+ from sglang.srt.layers.moe.topk import TopK
36
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
39
  from sglang.srt.layers.radix_attention import RadixAttention
38
40
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
104
106
  self.params_dtype = params_dtype
105
107
 
106
108
  self.router = DbrxRouter(config, self.params_dtype)
109
+ self.topk = TopK(
110
+ self.top_k,
111
+ renormalize=True,
112
+ )
113
+ self.moe_runner_config = MoeRunnerConfig(inplace=True)
107
114
  self.ws = nn.Parameter(
108
115
  torch.empty(
109
116
  self.num_total_experts,
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
169
176
  hidden_states = hidden_states.view(-1, self.d_model)
170
177
  # router_logits: (num_tokens, n_experts)
171
178
  router_logits = self.router(hidden_states)
179
+ topk_output = self.topk(hidden_states, router_logits)
172
180
  final_hidden_states = fused_moe(
173
181
  hidden_states,
174
182
  self.ws,
175
183
  self.w2s,
176
- router_logits,
177
- self.top_k,
178
- renormalize=True,
179
- inplace=True,
184
+ topk_output,
185
+ self.moe_runner_config,
180
186
  )
181
187
 
182
188
  if self.tp_size > 1:
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
293
299
  position_ids: torch.Tensor,
294
300
  hidden_states: torch.Tensor,
295
301
  forward_batch: ForwardBatch,
296
- ) -> torch.Tensor:
302
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
297
303
  residual = hidden_states
298
304
  hidden_states = self.norm_1(hidden_states)
299
305
  x = self.attn(
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
37
37
  )
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
40
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
40
41
  from sglang.srt.layers.moe.topk import TopK
41
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
43
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
180
181
  w1=self.w1,
181
182
  w2=self.w2,
182
183
  topk_output=topk_output,
183
- inplace=True,
184
+ moe_runner_config=MoeRunnerConfig(inplace=True),
184
185
  )
185
186
 
186
187
  if self.config.n_shared_experts is not None: