sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,422 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
import logging
|
15
|
+
from typing import Dict, List, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.distributed
|
19
|
+
from torch.distributed import P2POp
|
20
|
+
|
21
|
+
from sglang.srt.managers.expert_location import (
|
22
|
+
ExpertLocationMetadata,
|
23
|
+
get_global_expert_location_metadata,
|
24
|
+
)
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
def update_expert_location(
|
30
|
+
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
31
|
+
new_expert_location_metadata: ExpertLocationMetadata,
|
32
|
+
nnodes: int,
|
33
|
+
rank: int,
|
34
|
+
):
|
35
|
+
old_expert_location_metadata = get_global_expert_location_metadata()
|
36
|
+
_update_expert_weights(
|
37
|
+
routed_experts_weights_of_layer,
|
38
|
+
old_expert_location_metadata,
|
39
|
+
new_expert_location_metadata,
|
40
|
+
nnodes,
|
41
|
+
rank,
|
42
|
+
)
|
43
|
+
old_expert_location_metadata.update(new_expert_location_metadata)
|
44
|
+
|
45
|
+
|
46
|
+
def _update_expert_weights(
|
47
|
+
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
48
|
+
old_expert_location_metadata: ExpertLocationMetadata,
|
49
|
+
new_expert_location_metadata: ExpertLocationMetadata,
|
50
|
+
nnodes: int,
|
51
|
+
rank: int,
|
52
|
+
):
|
53
|
+
temp_buffers = create_temp_buffers(
|
54
|
+
next(iter(routed_experts_weights_of_layer.values()))
|
55
|
+
)
|
56
|
+
|
57
|
+
world_size = torch.distributed.get_world_size()
|
58
|
+
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
|
59
|
+
num_gpu_per_node = world_size // nnodes
|
60
|
+
|
61
|
+
old_physical_to_logical_map = (
|
62
|
+
old_expert_location_metadata.physical_to_logical_map.tolist()
|
63
|
+
)
|
64
|
+
new_physical_to_logical_map = (
|
65
|
+
new_expert_location_metadata.physical_to_logical_map.tolist()
|
66
|
+
)
|
67
|
+
|
68
|
+
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
|
69
|
+
update_expert_weights_single_layer(
|
70
|
+
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
|
71
|
+
temp_buffers=temp_buffers,
|
72
|
+
old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
|
73
|
+
new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
|
74
|
+
num_local_physical_experts=num_local_physical_experts,
|
75
|
+
num_gpu_per_node=num_gpu_per_node,
|
76
|
+
rank=rank,
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def create_temp_buffers(sample_tensors):
|
81
|
+
return [torch.empty_like(tensor) for tensor in sample_tensors]
|
82
|
+
|
83
|
+
|
84
|
+
def update_expert_weights_single_layer(
|
85
|
+
routed_experts_weights: List[torch.Tensor],
|
86
|
+
temp_buffers: List[torch.Tensor],
|
87
|
+
old_physical_to_logical_map: List[int], # (num_physical_Experts,)
|
88
|
+
new_physical_to_logical_map: List[int], # (num_physical_Experts,)
|
89
|
+
num_local_physical_experts: int,
|
90
|
+
num_gpu_per_node: int,
|
91
|
+
rank: int,
|
92
|
+
debug: bool = False,
|
93
|
+
):
|
94
|
+
assert all(
|
95
|
+
tensor.shape[0] == num_local_physical_experts
|
96
|
+
for tensor in routed_experts_weights
|
97
|
+
), f"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}"
|
98
|
+
assert isinstance(old_physical_to_logical_map, list)
|
99
|
+
assert isinstance(new_physical_to_logical_map, list)
|
100
|
+
|
101
|
+
output_logs = [] if debug else None
|
102
|
+
|
103
|
+
num_physical_experts = len(old_physical_to_logical_map)
|
104
|
+
num_tensors = len(routed_experts_weights)
|
105
|
+
|
106
|
+
self_node_id = rank // num_gpu_per_node
|
107
|
+
|
108
|
+
local_expert_location_range = (
|
109
|
+
rank * num_local_physical_experts,
|
110
|
+
(rank + 1) * num_local_physical_experts,
|
111
|
+
)
|
112
|
+
|
113
|
+
def _entrypoint():
|
114
|
+
# List[Tuple[logical_expert_id, List[P2POp]]]
|
115
|
+
p2p_op_infos: List[Tuple[int, List[P2POp]]] = []
|
116
|
+
# List[Tuple[temp_buffers_expert_location, routed_experts_weights_expert_location]]
|
117
|
+
buffer2weight_copy_infos: List[Tuple[int, int]] = []
|
118
|
+
|
119
|
+
_handle_recv(buffer2weight_copy_infos, p2p_op_infos)
|
120
|
+
_create_isend_ops(p2p_op_infos)
|
121
|
+
_execute_p2p_ops(p2p_op_infos)
|
122
|
+
_execute_buffer2weight_copies(buffer2weight_copy_infos)
|
123
|
+
|
124
|
+
if debug:
|
125
|
+
output_logs.append(f"{p2p_op_infos=}")
|
126
|
+
output_logs.append(f"{buffer2weight_copy_infos=}")
|
127
|
+
|
128
|
+
def _handle_recv(buffer2weight_copy_infos, p2p_op_infos):
|
129
|
+
for dst_expert_location in range(*local_expert_location_range):
|
130
|
+
_handle_recv_of_dst_expert_location(
|
131
|
+
dst_expert_location, buffer2weight_copy_infos, p2p_op_infos
|
132
|
+
)
|
133
|
+
|
134
|
+
def _handle_recv_of_dst_expert_location(
|
135
|
+
dst_expert_location: int, buffer2weight_copy_infos, p2p_op_infos
|
136
|
+
):
|
137
|
+
logical_expert_id = new_physical_to_logical_map[dst_expert_location]
|
138
|
+
|
139
|
+
# case 1: unchanged
|
140
|
+
if old_physical_to_logical_map[dst_expert_location] == logical_expert_id:
|
141
|
+
if debug:
|
142
|
+
output_logs.append(
|
143
|
+
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=unchanged"
|
144
|
+
)
|
145
|
+
return
|
146
|
+
|
147
|
+
# case 2: same-gpu
|
148
|
+
for src_expert_location in range(*local_expert_location_range):
|
149
|
+
if old_physical_to_logical_map[src_expert_location] == logical_expert_id:
|
150
|
+
for i in range(num_tensors):
|
151
|
+
_get_tensor(temp_buffers, i, dst_expert_location).copy_(
|
152
|
+
_get_tensor(routed_experts_weights, i, src_expert_location)
|
153
|
+
)
|
154
|
+
buffer2weight_copy_infos.append(
|
155
|
+
(dst_expert_location, dst_expert_location)
|
156
|
+
)
|
157
|
+
if debug:
|
158
|
+
output_logs.append(
|
159
|
+
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-gpu {src_expert_location=}"
|
160
|
+
)
|
161
|
+
return
|
162
|
+
|
163
|
+
# case 3: free-rider
|
164
|
+
for src_expert_location in range(
|
165
|
+
rank * num_local_physical_experts, dst_expert_location
|
166
|
+
):
|
167
|
+
if new_physical_to_logical_map[src_expert_location] == logical_expert_id:
|
168
|
+
buffer2weight_copy_infos.append(
|
169
|
+
(src_expert_location, dst_expert_location)
|
170
|
+
)
|
171
|
+
if debug:
|
172
|
+
output_logs.append(
|
173
|
+
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=free-rider {src_expert_location=}"
|
174
|
+
)
|
175
|
+
return
|
176
|
+
|
177
|
+
same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (
|
178
|
+
_compute_comm_info(logical_expert_id=logical_expert_id)
|
179
|
+
)
|
180
|
+
|
181
|
+
# case 4: same-node
|
182
|
+
if rank in need_comm_self_node_dst_ranks:
|
183
|
+
chosen_src_rank = same_node_mapping.chunk_value_from_element_value(
|
184
|
+
element_value=rank
|
185
|
+
)
|
186
|
+
_create_p2p_recv_and_buffer2weight_copy(
|
187
|
+
buffer2weight_copy_infos,
|
188
|
+
p2p_op_infos,
|
189
|
+
src_rank=chosen_src_rank,
|
190
|
+
logical_expert_id=logical_expert_id,
|
191
|
+
dst_expert_location=dst_expert_location,
|
192
|
+
)
|
193
|
+
if debug:
|
194
|
+
output_logs.append(
|
195
|
+
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-node {chosen_src_rank=}"
|
196
|
+
)
|
197
|
+
return
|
198
|
+
|
199
|
+
# case 5: cross-node
|
200
|
+
# Future work: can optimize when there are multiple ranks in the same dst node that uses the same logical expert
|
201
|
+
chosen_src_rank = cross_node_mapping.chunk_value_from_element_value(
|
202
|
+
element_value=rank
|
203
|
+
)
|
204
|
+
_create_p2p_recv_and_buffer2weight_copy(
|
205
|
+
buffer2weight_copy_infos,
|
206
|
+
p2p_op_infos,
|
207
|
+
src_rank=chosen_src_rank,
|
208
|
+
logical_expert_id=logical_expert_id,
|
209
|
+
dst_expert_location=dst_expert_location,
|
210
|
+
)
|
211
|
+
if debug:
|
212
|
+
output_logs.append(
|
213
|
+
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=cross-node {chosen_src_rank=}"
|
214
|
+
)
|
215
|
+
return
|
216
|
+
|
217
|
+
def _create_p2p_recv_and_buffer2weight_copy(
|
218
|
+
buffer2weight_copy_infos,
|
219
|
+
p2p_op_infos,
|
220
|
+
*,
|
221
|
+
logical_expert_id: int,
|
222
|
+
src_rank: int,
|
223
|
+
dst_expert_location: int,
|
224
|
+
):
|
225
|
+
p2p_op_infos.append(
|
226
|
+
(
|
227
|
+
logical_expert_id,
|
228
|
+
[
|
229
|
+
P2POp(
|
230
|
+
op=torch.distributed.irecv,
|
231
|
+
tensor=_get_tensor(temp_buffers, i, dst_expert_location),
|
232
|
+
peer=src_rank,
|
233
|
+
)
|
234
|
+
for i in range(num_tensors)
|
235
|
+
],
|
236
|
+
)
|
237
|
+
)
|
238
|
+
buffer2weight_copy_infos.append((dst_expert_location, dst_expert_location))
|
239
|
+
|
240
|
+
def _create_isend_ops(p2p_op_infos):
|
241
|
+
handled_logical_expert_ids = set()
|
242
|
+
for src_expert_location in range(*local_expert_location_range):
|
243
|
+
logical_expert_id = old_physical_to_logical_map[src_expert_location]
|
244
|
+
|
245
|
+
if logical_expert_id in handled_logical_expert_ids:
|
246
|
+
continue
|
247
|
+
handled_logical_expert_ids.add(logical_expert_id)
|
248
|
+
|
249
|
+
_create_isend_ops_of_logical_expert_id(
|
250
|
+
logical_expert_id, src_expert_location, p2p_op_infos
|
251
|
+
)
|
252
|
+
|
253
|
+
def _create_isend_ops_of_logical_expert_id(
|
254
|
+
logical_expert_id, src_expert_location, p2p_op_infos
|
255
|
+
):
|
256
|
+
same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (
|
257
|
+
_compute_comm_info(logical_expert_id=logical_expert_id)
|
258
|
+
)
|
259
|
+
|
260
|
+
same_node_dst_ranks = same_node_mapping.element_values_from_chunk_value(
|
261
|
+
chunk_value=rank
|
262
|
+
)
|
263
|
+
cross_node_dst_ranks = cross_node_mapping.element_values_from_chunk_value(
|
264
|
+
chunk_value=rank
|
265
|
+
)
|
266
|
+
all_dst_ranks = same_node_dst_ranks + cross_node_dst_ranks
|
267
|
+
|
268
|
+
if debug:
|
269
|
+
output_logs.append(
|
270
|
+
f"create_isend_ops_of_logical_expert_id {logical_expert_id=} {src_expert_location=} {same_node_dst_ranks=} {cross_node_dst_ranks=}"
|
271
|
+
)
|
272
|
+
|
273
|
+
p2p_op_infos.append(
|
274
|
+
(
|
275
|
+
logical_expert_id,
|
276
|
+
[
|
277
|
+
P2POp(
|
278
|
+
op=torch.distributed.isend,
|
279
|
+
tensor=_get_tensor(
|
280
|
+
routed_experts_weights, i, src_expert_location
|
281
|
+
),
|
282
|
+
peer=dst_rank,
|
283
|
+
)
|
284
|
+
for dst_rank in all_dst_ranks
|
285
|
+
for i in range(num_tensors)
|
286
|
+
],
|
287
|
+
)
|
288
|
+
)
|
289
|
+
|
290
|
+
def _compute_comm_info(logical_expert_id: int):
|
291
|
+
all_src_ranks = _deduplicate_ordered(
|
292
|
+
[
|
293
|
+
x // num_local_physical_experts
|
294
|
+
for x in range(num_physical_experts)
|
295
|
+
if old_physical_to_logical_map[x] == logical_expert_id
|
296
|
+
]
|
297
|
+
)
|
298
|
+
all_src_nodes = [x // num_gpu_per_node for x in all_src_ranks]
|
299
|
+
self_node_src_ranks = [
|
300
|
+
x for x in all_src_ranks if x // num_gpu_per_node == self_node_id
|
301
|
+
]
|
302
|
+
|
303
|
+
need_comm_dst_ranks = _deduplicate_ordered(
|
304
|
+
[
|
305
|
+
x // num_local_physical_experts
|
306
|
+
for x in range(num_physical_experts)
|
307
|
+
if new_physical_to_logical_map[x] == logical_expert_id
|
308
|
+
and x // num_local_physical_experts not in all_src_ranks
|
309
|
+
]
|
310
|
+
)
|
311
|
+
need_comm_self_node_dst_ranks = (
|
312
|
+
[x for x in need_comm_dst_ranks if x // num_gpu_per_node == self_node_id]
|
313
|
+
if len(self_node_src_ranks) > 0
|
314
|
+
else []
|
315
|
+
)
|
316
|
+
need_comm_cross_node_dst_ranks = [
|
317
|
+
x
|
318
|
+
for x in need_comm_dst_ranks
|
319
|
+
if (x // num_gpu_per_node) not in all_src_nodes
|
320
|
+
]
|
321
|
+
|
322
|
+
same_node_mapping = _ChunkUtils(
|
323
|
+
chunk_values=self_node_src_ranks,
|
324
|
+
element_values=need_comm_self_node_dst_ranks,
|
325
|
+
)
|
326
|
+
|
327
|
+
cross_node_mapping = _ChunkUtils(
|
328
|
+
chunk_values=all_src_ranks,
|
329
|
+
element_values=need_comm_cross_node_dst_ranks,
|
330
|
+
)
|
331
|
+
|
332
|
+
return same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks
|
333
|
+
|
334
|
+
def _execute_p2p_ops(p2p_op_infos):
|
335
|
+
sorted_infos = sorted(p2p_op_infos, key=lambda info: info[0])
|
336
|
+
p2p_ops = [op for _, ops in sorted_infos for op in ops]
|
337
|
+
if len(p2p_ops) == 0:
|
338
|
+
return
|
339
|
+
|
340
|
+
reqs = torch.distributed.batch_isend_irecv(p2p_ops)
|
341
|
+
for req in reqs:
|
342
|
+
req.wait()
|
343
|
+
|
344
|
+
def _execute_buffer2weight_copies(buffer2weight_copy_infos):
|
345
|
+
for (
|
346
|
+
temp_buffers_expert_location,
|
347
|
+
routed_experts_weights_expert_location,
|
348
|
+
) in buffer2weight_copy_infos:
|
349
|
+
for i in range(num_tensors):
|
350
|
+
_get_tensor(
|
351
|
+
routed_experts_weights, i, routed_experts_weights_expert_location
|
352
|
+
).copy_(_get_tensor(temp_buffers, i, temp_buffers_expert_location))
|
353
|
+
|
354
|
+
def _get_tensor(tensors, tensor_index: int, expert_location: int) -> torch.Tensor:
|
355
|
+
return tensors[tensor_index][_get_local_expert_location(expert_location)]
|
356
|
+
|
357
|
+
def _get_local_expert_location(expert_location: int) -> int:
|
358
|
+
assert (
|
359
|
+
local_expert_location_range[0]
|
360
|
+
<= expert_location
|
361
|
+
< local_expert_location_range[1]
|
362
|
+
)
|
363
|
+
return expert_location % num_local_physical_experts
|
364
|
+
|
365
|
+
_entrypoint()
|
366
|
+
|
367
|
+
return output_logs
|
368
|
+
|
369
|
+
|
370
|
+
class _ChunkUtils:
|
371
|
+
def __init__(self, *, chunk_values: List, element_values: List):
|
372
|
+
self.chunk_values = chunk_values
|
373
|
+
self.element_values = element_values
|
374
|
+
|
375
|
+
def chunk_value_from_element_value(self, element_value):
|
376
|
+
chunk_index = self._chunk_index_from_element_index(
|
377
|
+
num_elements=len(self.element_values),
|
378
|
+
num_chunks=len(self.chunk_values),
|
379
|
+
element_index=self.element_values.index(element_value),
|
380
|
+
)
|
381
|
+
return self.chunk_values[chunk_index]
|
382
|
+
|
383
|
+
def element_values_from_chunk_value(self, chunk_value) -> List:
|
384
|
+
if len(self.element_values) == 0:
|
385
|
+
return []
|
386
|
+
element_slice = self._element_slice_from_chunk_index(
|
387
|
+
num_elements=len(self.element_values),
|
388
|
+
num_chunks=len(self.chunk_values),
|
389
|
+
chunk_index=self.chunk_values.index(chunk_value),
|
390
|
+
)
|
391
|
+
return self.element_values[element_slice]
|
392
|
+
|
393
|
+
@staticmethod
|
394
|
+
def _chunk_index_from_element_index(
|
395
|
+
num_elements: int, num_chunks: int, element_index: int
|
396
|
+
) -> int:
|
397
|
+
short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)
|
398
|
+
num_elements_for_long_chunks = num_long_chunks * (short_chunk_size + 1)
|
399
|
+
if element_index < num_elements_for_long_chunks:
|
400
|
+
return element_index // (short_chunk_size + 1)
|
401
|
+
else:
|
402
|
+
return (
|
403
|
+
num_long_chunks
|
404
|
+
+ (element_index - num_elements_for_long_chunks) // short_chunk_size
|
405
|
+
)
|
406
|
+
|
407
|
+
@staticmethod
|
408
|
+
def _element_slice_from_chunk_index(
|
409
|
+
num_elements: int, num_chunks: int, chunk_index: int
|
410
|
+
) -> slice:
|
411
|
+
short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)
|
412
|
+
start = chunk_index * short_chunk_size + min(chunk_index, num_long_chunks)
|
413
|
+
end = start + short_chunk_size + int(chunk_index < num_long_chunks)
|
414
|
+
return slice(start, end)
|
415
|
+
|
416
|
+
|
417
|
+
def _deduplicate_ordered(arr: List[int]):
|
418
|
+
output = []
|
419
|
+
for item in arr:
|
420
|
+
if len(output) == 0 or item != output[-1]:
|
421
|
+
output.append(item)
|
422
|
+
return output
|
@@ -247,6 +247,7 @@ class ForwardBatch:
|
|
247
247
|
|
248
248
|
# For padding
|
249
249
|
padded_static_len: int = -1 # -1 if not padded
|
250
|
+
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
|
250
251
|
|
251
252
|
# For Qwen2-VL
|
252
253
|
mrope_positions: torch.Tensor = None
|
@@ -290,6 +291,9 @@ class ForwardBatch:
|
|
290
291
|
capture_hidden_mode=batch.capture_hidden_mode,
|
291
292
|
input_embeds=batch.input_embeds,
|
292
293
|
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
294
|
+
num_token_non_padded=torch.tensor(
|
295
|
+
len(batch.input_ids), dtype=torch.int32
|
296
|
+
).to(device, non_blocking=True),
|
293
297
|
)
|
294
298
|
|
295
299
|
# For DP attention
|