sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,422 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import logging
15
+ from typing import Dict, List, Tuple
16
+
17
+ import torch
18
+ import torch.distributed
19
+ from torch.distributed import P2POp
20
+
21
+ from sglang.srt.managers.expert_location import (
22
+ ExpertLocationMetadata,
23
+ get_global_expert_location_metadata,
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def update_expert_location(
30
+ routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
31
+ new_expert_location_metadata: ExpertLocationMetadata,
32
+ nnodes: int,
33
+ rank: int,
34
+ ):
35
+ old_expert_location_metadata = get_global_expert_location_metadata()
36
+ _update_expert_weights(
37
+ routed_experts_weights_of_layer,
38
+ old_expert_location_metadata,
39
+ new_expert_location_metadata,
40
+ nnodes,
41
+ rank,
42
+ )
43
+ old_expert_location_metadata.update(new_expert_location_metadata)
44
+
45
+
46
+ def _update_expert_weights(
47
+ routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
48
+ old_expert_location_metadata: ExpertLocationMetadata,
49
+ new_expert_location_metadata: ExpertLocationMetadata,
50
+ nnodes: int,
51
+ rank: int,
52
+ ):
53
+ temp_buffers = create_temp_buffers(
54
+ next(iter(routed_experts_weights_of_layer.values()))
55
+ )
56
+
57
+ world_size = torch.distributed.get_world_size()
58
+ num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
59
+ num_gpu_per_node = world_size // nnodes
60
+
61
+ old_physical_to_logical_map = (
62
+ old_expert_location_metadata.physical_to_logical_map.tolist()
63
+ )
64
+ new_physical_to_logical_map = (
65
+ new_expert_location_metadata.physical_to_logical_map.tolist()
66
+ )
67
+
68
+ for layer_id in sorted(routed_experts_weights_of_layer.keys()):
69
+ update_expert_weights_single_layer(
70
+ routed_experts_weights=routed_experts_weights_of_layer[layer_id],
71
+ temp_buffers=temp_buffers,
72
+ old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
73
+ new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
74
+ num_local_physical_experts=num_local_physical_experts,
75
+ num_gpu_per_node=num_gpu_per_node,
76
+ rank=rank,
77
+ )
78
+
79
+
80
+ def create_temp_buffers(sample_tensors):
81
+ return [torch.empty_like(tensor) for tensor in sample_tensors]
82
+
83
+
84
+ def update_expert_weights_single_layer(
85
+ routed_experts_weights: List[torch.Tensor],
86
+ temp_buffers: List[torch.Tensor],
87
+ old_physical_to_logical_map: List[int], # (num_physical_Experts,)
88
+ new_physical_to_logical_map: List[int], # (num_physical_Experts,)
89
+ num_local_physical_experts: int,
90
+ num_gpu_per_node: int,
91
+ rank: int,
92
+ debug: bool = False,
93
+ ):
94
+ assert all(
95
+ tensor.shape[0] == num_local_physical_experts
96
+ for tensor in routed_experts_weights
97
+ ), f"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}"
98
+ assert isinstance(old_physical_to_logical_map, list)
99
+ assert isinstance(new_physical_to_logical_map, list)
100
+
101
+ output_logs = [] if debug else None
102
+
103
+ num_physical_experts = len(old_physical_to_logical_map)
104
+ num_tensors = len(routed_experts_weights)
105
+
106
+ self_node_id = rank // num_gpu_per_node
107
+
108
+ local_expert_location_range = (
109
+ rank * num_local_physical_experts,
110
+ (rank + 1) * num_local_physical_experts,
111
+ )
112
+
113
+ def _entrypoint():
114
+ # List[Tuple[logical_expert_id, List[P2POp]]]
115
+ p2p_op_infos: List[Tuple[int, List[P2POp]]] = []
116
+ # List[Tuple[temp_buffers_expert_location, routed_experts_weights_expert_location]]
117
+ buffer2weight_copy_infos: List[Tuple[int, int]] = []
118
+
119
+ _handle_recv(buffer2weight_copy_infos, p2p_op_infos)
120
+ _create_isend_ops(p2p_op_infos)
121
+ _execute_p2p_ops(p2p_op_infos)
122
+ _execute_buffer2weight_copies(buffer2weight_copy_infos)
123
+
124
+ if debug:
125
+ output_logs.append(f"{p2p_op_infos=}")
126
+ output_logs.append(f"{buffer2weight_copy_infos=}")
127
+
128
+ def _handle_recv(buffer2weight_copy_infos, p2p_op_infos):
129
+ for dst_expert_location in range(*local_expert_location_range):
130
+ _handle_recv_of_dst_expert_location(
131
+ dst_expert_location, buffer2weight_copy_infos, p2p_op_infos
132
+ )
133
+
134
+ def _handle_recv_of_dst_expert_location(
135
+ dst_expert_location: int, buffer2weight_copy_infos, p2p_op_infos
136
+ ):
137
+ logical_expert_id = new_physical_to_logical_map[dst_expert_location]
138
+
139
+ # case 1: unchanged
140
+ if old_physical_to_logical_map[dst_expert_location] == logical_expert_id:
141
+ if debug:
142
+ output_logs.append(
143
+ f"handle_recv_of_dst_expert_location {dst_expert_location=} case=unchanged"
144
+ )
145
+ return
146
+
147
+ # case 2: same-gpu
148
+ for src_expert_location in range(*local_expert_location_range):
149
+ if old_physical_to_logical_map[src_expert_location] == logical_expert_id:
150
+ for i in range(num_tensors):
151
+ _get_tensor(temp_buffers, i, dst_expert_location).copy_(
152
+ _get_tensor(routed_experts_weights, i, src_expert_location)
153
+ )
154
+ buffer2weight_copy_infos.append(
155
+ (dst_expert_location, dst_expert_location)
156
+ )
157
+ if debug:
158
+ output_logs.append(
159
+ f"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-gpu {src_expert_location=}"
160
+ )
161
+ return
162
+
163
+ # case 3: free-rider
164
+ for src_expert_location in range(
165
+ rank * num_local_physical_experts, dst_expert_location
166
+ ):
167
+ if new_physical_to_logical_map[src_expert_location] == logical_expert_id:
168
+ buffer2weight_copy_infos.append(
169
+ (src_expert_location, dst_expert_location)
170
+ )
171
+ if debug:
172
+ output_logs.append(
173
+ f"handle_recv_of_dst_expert_location {dst_expert_location=} case=free-rider {src_expert_location=}"
174
+ )
175
+ return
176
+
177
+ same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (
178
+ _compute_comm_info(logical_expert_id=logical_expert_id)
179
+ )
180
+
181
+ # case 4: same-node
182
+ if rank in need_comm_self_node_dst_ranks:
183
+ chosen_src_rank = same_node_mapping.chunk_value_from_element_value(
184
+ element_value=rank
185
+ )
186
+ _create_p2p_recv_and_buffer2weight_copy(
187
+ buffer2weight_copy_infos,
188
+ p2p_op_infos,
189
+ src_rank=chosen_src_rank,
190
+ logical_expert_id=logical_expert_id,
191
+ dst_expert_location=dst_expert_location,
192
+ )
193
+ if debug:
194
+ output_logs.append(
195
+ f"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-node {chosen_src_rank=}"
196
+ )
197
+ return
198
+
199
+ # case 5: cross-node
200
+ # Future work: can optimize when there are multiple ranks in the same dst node that uses the same logical expert
201
+ chosen_src_rank = cross_node_mapping.chunk_value_from_element_value(
202
+ element_value=rank
203
+ )
204
+ _create_p2p_recv_and_buffer2weight_copy(
205
+ buffer2weight_copy_infos,
206
+ p2p_op_infos,
207
+ src_rank=chosen_src_rank,
208
+ logical_expert_id=logical_expert_id,
209
+ dst_expert_location=dst_expert_location,
210
+ )
211
+ if debug:
212
+ output_logs.append(
213
+ f"handle_recv_of_dst_expert_location {dst_expert_location=} case=cross-node {chosen_src_rank=}"
214
+ )
215
+ return
216
+
217
+ def _create_p2p_recv_and_buffer2weight_copy(
218
+ buffer2weight_copy_infos,
219
+ p2p_op_infos,
220
+ *,
221
+ logical_expert_id: int,
222
+ src_rank: int,
223
+ dst_expert_location: int,
224
+ ):
225
+ p2p_op_infos.append(
226
+ (
227
+ logical_expert_id,
228
+ [
229
+ P2POp(
230
+ op=torch.distributed.irecv,
231
+ tensor=_get_tensor(temp_buffers, i, dst_expert_location),
232
+ peer=src_rank,
233
+ )
234
+ for i in range(num_tensors)
235
+ ],
236
+ )
237
+ )
238
+ buffer2weight_copy_infos.append((dst_expert_location, dst_expert_location))
239
+
240
+ def _create_isend_ops(p2p_op_infos):
241
+ handled_logical_expert_ids = set()
242
+ for src_expert_location in range(*local_expert_location_range):
243
+ logical_expert_id = old_physical_to_logical_map[src_expert_location]
244
+
245
+ if logical_expert_id in handled_logical_expert_ids:
246
+ continue
247
+ handled_logical_expert_ids.add(logical_expert_id)
248
+
249
+ _create_isend_ops_of_logical_expert_id(
250
+ logical_expert_id, src_expert_location, p2p_op_infos
251
+ )
252
+
253
+ def _create_isend_ops_of_logical_expert_id(
254
+ logical_expert_id, src_expert_location, p2p_op_infos
255
+ ):
256
+ same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (
257
+ _compute_comm_info(logical_expert_id=logical_expert_id)
258
+ )
259
+
260
+ same_node_dst_ranks = same_node_mapping.element_values_from_chunk_value(
261
+ chunk_value=rank
262
+ )
263
+ cross_node_dst_ranks = cross_node_mapping.element_values_from_chunk_value(
264
+ chunk_value=rank
265
+ )
266
+ all_dst_ranks = same_node_dst_ranks + cross_node_dst_ranks
267
+
268
+ if debug:
269
+ output_logs.append(
270
+ f"create_isend_ops_of_logical_expert_id {logical_expert_id=} {src_expert_location=} {same_node_dst_ranks=} {cross_node_dst_ranks=}"
271
+ )
272
+
273
+ p2p_op_infos.append(
274
+ (
275
+ logical_expert_id,
276
+ [
277
+ P2POp(
278
+ op=torch.distributed.isend,
279
+ tensor=_get_tensor(
280
+ routed_experts_weights, i, src_expert_location
281
+ ),
282
+ peer=dst_rank,
283
+ )
284
+ for dst_rank in all_dst_ranks
285
+ for i in range(num_tensors)
286
+ ],
287
+ )
288
+ )
289
+
290
+ def _compute_comm_info(logical_expert_id: int):
291
+ all_src_ranks = _deduplicate_ordered(
292
+ [
293
+ x // num_local_physical_experts
294
+ for x in range(num_physical_experts)
295
+ if old_physical_to_logical_map[x] == logical_expert_id
296
+ ]
297
+ )
298
+ all_src_nodes = [x // num_gpu_per_node for x in all_src_ranks]
299
+ self_node_src_ranks = [
300
+ x for x in all_src_ranks if x // num_gpu_per_node == self_node_id
301
+ ]
302
+
303
+ need_comm_dst_ranks = _deduplicate_ordered(
304
+ [
305
+ x // num_local_physical_experts
306
+ for x in range(num_physical_experts)
307
+ if new_physical_to_logical_map[x] == logical_expert_id
308
+ and x // num_local_physical_experts not in all_src_ranks
309
+ ]
310
+ )
311
+ need_comm_self_node_dst_ranks = (
312
+ [x for x in need_comm_dst_ranks if x // num_gpu_per_node == self_node_id]
313
+ if len(self_node_src_ranks) > 0
314
+ else []
315
+ )
316
+ need_comm_cross_node_dst_ranks = [
317
+ x
318
+ for x in need_comm_dst_ranks
319
+ if (x // num_gpu_per_node) not in all_src_nodes
320
+ ]
321
+
322
+ same_node_mapping = _ChunkUtils(
323
+ chunk_values=self_node_src_ranks,
324
+ element_values=need_comm_self_node_dst_ranks,
325
+ )
326
+
327
+ cross_node_mapping = _ChunkUtils(
328
+ chunk_values=all_src_ranks,
329
+ element_values=need_comm_cross_node_dst_ranks,
330
+ )
331
+
332
+ return same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks
333
+
334
+ def _execute_p2p_ops(p2p_op_infos):
335
+ sorted_infos = sorted(p2p_op_infos, key=lambda info: info[0])
336
+ p2p_ops = [op for _, ops in sorted_infos for op in ops]
337
+ if len(p2p_ops) == 0:
338
+ return
339
+
340
+ reqs = torch.distributed.batch_isend_irecv(p2p_ops)
341
+ for req in reqs:
342
+ req.wait()
343
+
344
+ def _execute_buffer2weight_copies(buffer2weight_copy_infos):
345
+ for (
346
+ temp_buffers_expert_location,
347
+ routed_experts_weights_expert_location,
348
+ ) in buffer2weight_copy_infos:
349
+ for i in range(num_tensors):
350
+ _get_tensor(
351
+ routed_experts_weights, i, routed_experts_weights_expert_location
352
+ ).copy_(_get_tensor(temp_buffers, i, temp_buffers_expert_location))
353
+
354
+ def _get_tensor(tensors, tensor_index: int, expert_location: int) -> torch.Tensor:
355
+ return tensors[tensor_index][_get_local_expert_location(expert_location)]
356
+
357
+ def _get_local_expert_location(expert_location: int) -> int:
358
+ assert (
359
+ local_expert_location_range[0]
360
+ <= expert_location
361
+ < local_expert_location_range[1]
362
+ )
363
+ return expert_location % num_local_physical_experts
364
+
365
+ _entrypoint()
366
+
367
+ return output_logs
368
+
369
+
370
+ class _ChunkUtils:
371
+ def __init__(self, *, chunk_values: List, element_values: List):
372
+ self.chunk_values = chunk_values
373
+ self.element_values = element_values
374
+
375
+ def chunk_value_from_element_value(self, element_value):
376
+ chunk_index = self._chunk_index_from_element_index(
377
+ num_elements=len(self.element_values),
378
+ num_chunks=len(self.chunk_values),
379
+ element_index=self.element_values.index(element_value),
380
+ )
381
+ return self.chunk_values[chunk_index]
382
+
383
+ def element_values_from_chunk_value(self, chunk_value) -> List:
384
+ if len(self.element_values) == 0:
385
+ return []
386
+ element_slice = self._element_slice_from_chunk_index(
387
+ num_elements=len(self.element_values),
388
+ num_chunks=len(self.chunk_values),
389
+ chunk_index=self.chunk_values.index(chunk_value),
390
+ )
391
+ return self.element_values[element_slice]
392
+
393
+ @staticmethod
394
+ def _chunk_index_from_element_index(
395
+ num_elements: int, num_chunks: int, element_index: int
396
+ ) -> int:
397
+ short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)
398
+ num_elements_for_long_chunks = num_long_chunks * (short_chunk_size + 1)
399
+ if element_index < num_elements_for_long_chunks:
400
+ return element_index // (short_chunk_size + 1)
401
+ else:
402
+ return (
403
+ num_long_chunks
404
+ + (element_index - num_elements_for_long_chunks) // short_chunk_size
405
+ )
406
+
407
+ @staticmethod
408
+ def _element_slice_from_chunk_index(
409
+ num_elements: int, num_chunks: int, chunk_index: int
410
+ ) -> slice:
411
+ short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)
412
+ start = chunk_index * short_chunk_size + min(chunk_index, num_long_chunks)
413
+ end = start + short_chunk_size + int(chunk_index < num_long_chunks)
414
+ return slice(start, end)
415
+
416
+
417
+ def _deduplicate_ordered(arr: List[int]):
418
+ output = []
419
+ for item in arr:
420
+ if len(output) == 0 or item != output[-1]:
421
+ output.append(item)
422
+ return output
@@ -58,7 +58,7 @@ class ForwardMode(IntEnum):
58
58
  DECODE = auto()
59
59
  # Contains both EXTEND and DECODE when doing chunked prefill.
60
60
  MIXED = auto()
61
- # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
61
+ # No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
62
62
  IDLE = auto()
63
63
 
64
64
  # Used in speculative decoding: verify a batch in the target model.
@@ -247,6 +247,7 @@ class ForwardBatch:
247
247
 
248
248
  # For padding
249
249
  padded_static_len: int = -1 # -1 if not padded
250
+ num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
250
251
 
251
252
  # For Qwen2-VL
252
253
  mrope_positions: torch.Tensor = None
@@ -290,6 +291,9 @@ class ForwardBatch:
290
291
  capture_hidden_mode=batch.capture_hidden_mode,
291
292
  input_embeds=batch.input_embeds,
292
293
  extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
294
+ num_token_non_padded=torch.tensor(
295
+ len(batch.input_ids), dtype=torch.int32
296
+ ).to(device, non_blocking=True),
293
297
  )
294
298
 
295
299
  # For DP attention