sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -148,7 +148,11 @@ class PyNcclCommunicator:
|
|
148
148
|
)
|
149
149
|
|
150
150
|
def all_gather(
|
151
|
-
self,
|
151
|
+
self,
|
152
|
+
output_tensor: torch.Tensor,
|
153
|
+
input_tensor: torch.Tensor,
|
154
|
+
stream=None,
|
155
|
+
sizes: Optional[list[int]] = None,
|
152
156
|
):
|
153
157
|
if self.disabled:
|
154
158
|
return
|
@@ -161,14 +165,33 @@ class PyNcclCommunicator:
|
|
161
165
|
)
|
162
166
|
if stream is None:
|
163
167
|
stream = self.stream
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
168
|
+
|
169
|
+
if sizes is not None:
|
170
|
+
split_offset = 0
|
171
|
+
|
172
|
+
self.nccl.ncclGroupStart()
|
173
|
+
for root, split_size in enumerate(sizes):
|
174
|
+
dst_slice = output_tensor[split_offset : split_offset + split_size]
|
175
|
+
self.nccl.ncclBroadcast(
|
176
|
+
buffer_type(input_tensor.data_ptr()),
|
177
|
+
buffer_type(dst_slice.data_ptr()),
|
178
|
+
dst_slice.numel(),
|
179
|
+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
180
|
+
root,
|
181
|
+
self.comm,
|
182
|
+
cudaStream_t(stream.cuda_stream),
|
183
|
+
)
|
184
|
+
split_offset += split_size
|
185
|
+
self.nccl.ncclGroupEnd()
|
186
|
+
else:
|
187
|
+
self.nccl.ncclAllGather(
|
188
|
+
buffer_type(input_tensor.data_ptr()),
|
189
|
+
buffer_type(output_tensor.data_ptr()),
|
190
|
+
input_tensor.numel(),
|
191
|
+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
192
|
+
self.comm,
|
193
|
+
cudaStream_t(stream.cuda_stream),
|
194
|
+
)
|
172
195
|
|
173
196
|
def reduce_scatter(
|
174
197
|
self,
|
@@ -176,6 +199,7 @@ class PyNcclCommunicator:
|
|
176
199
|
input_tensor: torch.Tensor,
|
177
200
|
op: ReduceOp = ReduceOp.SUM,
|
178
201
|
stream=None,
|
202
|
+
sizes: Optional[list[int]] = None,
|
179
203
|
):
|
180
204
|
if self.disabled:
|
181
205
|
return
|
@@ -188,15 +212,35 @@ class PyNcclCommunicator:
|
|
188
212
|
)
|
189
213
|
if stream is None:
|
190
214
|
stream = self.stream
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
215
|
+
|
216
|
+
if sizes is not None:
|
217
|
+
split_offset = 0
|
218
|
+
self.nccl.ncclGroupStart()
|
219
|
+
for root, split_size in enumerate(sizes):
|
220
|
+
chunk = input_tensor[split_offset : split_offset + split_size, ...]
|
221
|
+
|
222
|
+
self.nccl.ncclReduce(
|
223
|
+
buffer_type(chunk.data_ptr()),
|
224
|
+
buffer_type(output_tensor.data_ptr()),
|
225
|
+
chunk.numel(),
|
226
|
+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
227
|
+
ncclRedOpTypeEnum.from_torch(op),
|
228
|
+
root,
|
229
|
+
self.comm,
|
230
|
+
cudaStream_t(stream.cuda_stream),
|
231
|
+
)
|
232
|
+
split_offset += split_size
|
233
|
+
self.nccl.ncclGroupEnd()
|
234
|
+
else:
|
235
|
+
self.nccl.ncclReduceScatter(
|
236
|
+
buffer_type(input_tensor.data_ptr()),
|
237
|
+
buffer_type(output_tensor.data_ptr()),
|
238
|
+
output_tensor.numel(),
|
239
|
+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
240
|
+
ncclRedOpTypeEnum.from_torch(op),
|
241
|
+
self.comm,
|
242
|
+
cudaStream_t(stream.cuda_stream),
|
243
|
+
)
|
200
244
|
|
201
245
|
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
202
246
|
if self.disabled:
|
@@ -266,6 +310,12 @@ class PyNcclCommunicator:
|
|
266
310
|
def deregister_comm_window(self, window):
|
267
311
|
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
268
312
|
|
313
|
+
def group_start(self):
|
314
|
+
self.nccl.ncclGroupStart()
|
315
|
+
|
316
|
+
def group_end(self):
|
317
|
+
self.nccl.ncclGroupEnd()
|
318
|
+
|
269
319
|
@contextmanager
|
270
320
|
def change_state(
|
271
321
|
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
@@ -206,6 +206,26 @@ class NCCLLibrary:
|
|
206
206
|
cudaStream_t,
|
207
207
|
],
|
208
208
|
),
|
209
|
+
# ncclResult_t ncclReduce(
|
210
|
+
# const void* sendbuff, void* recvbuff, size_t count,
|
211
|
+
# ncclDataType_t datatype, ncclRedOp_t op, int root,
|
212
|
+
# ncclComm_t comm, cudaStream_t stream);
|
213
|
+
# note that cudaStream_t is a pointer type, so the last argument
|
214
|
+
# is a pointer
|
215
|
+
Function(
|
216
|
+
"ncclReduce",
|
217
|
+
ncclResult_t,
|
218
|
+
[
|
219
|
+
buffer_type,
|
220
|
+
buffer_type,
|
221
|
+
ctypes.c_size_t,
|
222
|
+
ncclDataType_t,
|
223
|
+
ncclRedOp_t,
|
224
|
+
ctypes.c_int,
|
225
|
+
ncclComm_t,
|
226
|
+
cudaStream_t,
|
227
|
+
],
|
228
|
+
),
|
209
229
|
# ncclResult_t ncclReduceScatter(
|
210
230
|
# const void* sendbuff, void* recvbuff, size_t count,
|
211
231
|
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
@@ -278,6 +298,10 @@ class NCCLLibrary:
|
|
278
298
|
# it is better not to call it at all.
|
279
299
|
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
280
300
|
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
301
|
+
# ncclResult_t ncclGroupStart();
|
302
|
+
Function("ncclGroupStart", ncclResult_t, []),
|
303
|
+
# ncclResult_t ncclGroupEnd();
|
304
|
+
Function("ncclGroupEnd", ncclResult_t, []),
|
281
305
|
]
|
282
306
|
|
283
307
|
exported_functions_symm_mem = [
|
@@ -400,6 +424,28 @@ class NCCLLibrary:
|
|
400
424
|
)
|
401
425
|
)
|
402
426
|
|
427
|
+
def ncclReduce(
|
428
|
+
self,
|
429
|
+
sendbuff: buffer_type,
|
430
|
+
recvbuff: buffer_type,
|
431
|
+
count: int,
|
432
|
+
datatype: int,
|
433
|
+
op: int,
|
434
|
+
root: int,
|
435
|
+
comm: ncclComm_t,
|
436
|
+
stream: cudaStream_t,
|
437
|
+
) -> None:
|
438
|
+
# `datatype` actually should be `ncclDataType_t`
|
439
|
+
# and `op` should be `ncclRedOp_t`
|
440
|
+
# both are aliases of `ctypes.c_int`
|
441
|
+
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
442
|
+
# by ctypes automatically
|
443
|
+
self.NCCL_CHECK(
|
444
|
+
self._funcs["ncclReduce"](
|
445
|
+
sendbuff, recvbuff, count, datatype, op, root, comm, stream
|
446
|
+
)
|
447
|
+
)
|
448
|
+
|
403
449
|
def ncclReduceScatter(
|
404
450
|
self,
|
405
451
|
sendbuff: buffer_type,
|
@@ -499,6 +545,12 @@ class NCCLLibrary:
|
|
499
545
|
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
500
546
|
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
501
547
|
|
548
|
+
def ncclGroupStart(self) -> None:
|
549
|
+
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
|
550
|
+
|
551
|
+
def ncclGroupEnd(self) -> None:
|
552
|
+
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
553
|
+
|
502
554
|
|
503
555
|
__all__ = [
|
504
556
|
"NCCLLibrary",
|
@@ -0,0 +1,112 @@
|
|
1
|
+
import base64
|
2
|
+
import os
|
3
|
+
import pickle
|
4
|
+
import time
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, List, Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.utils import MultiprocessingSerializer
|
11
|
+
|
12
|
+
|
13
|
+
class NaiveDistributed:
|
14
|
+
def __init__(self, rank: int, world_size: int, rendezvous: str):
|
15
|
+
self._rank = rank
|
16
|
+
self._world_size = world_size
|
17
|
+
self._operation_index = 0
|
18
|
+
self._directory = Path(rendezvous)
|
19
|
+
self._directory.mkdir(parents=True, exist_ok=True)
|
20
|
+
assert 0 <= rank < world_size
|
21
|
+
|
22
|
+
# both barrier to be safe, and as a sanity check
|
23
|
+
self.barrier()
|
24
|
+
|
25
|
+
def get_rank(self):
|
26
|
+
return self._rank
|
27
|
+
|
28
|
+
def get_world_size(self):
|
29
|
+
return self._world_size
|
30
|
+
|
31
|
+
def scatter(
|
32
|
+
self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0
|
33
|
+
):
|
34
|
+
if self._rank == src:
|
35
|
+
assert len(scatter_list) == self._world_size
|
36
|
+
else:
|
37
|
+
assert scatter_list is None
|
38
|
+
|
39
|
+
gathered_objects = self.all_gather_object(
|
40
|
+
dict(
|
41
|
+
serialized_scatter_list=[
|
42
|
+
(
|
43
|
+
None
|
44
|
+
if item_rank == src
|
45
|
+
else MultiprocessingSerializer.serialize(item)
|
46
|
+
)
|
47
|
+
for item_rank, item in enumerate(scatter_list)
|
48
|
+
]
|
49
|
+
)
|
50
|
+
if self._rank == src
|
51
|
+
else dict()
|
52
|
+
)
|
53
|
+
|
54
|
+
remote_serialized_tensor = gathered_objects[src]["serialized_scatter_list"][
|
55
|
+
self._rank
|
56
|
+
]
|
57
|
+
if self._rank == src:
|
58
|
+
assert remote_serialized_tensor is None
|
59
|
+
remote_tensor = scatter_list[self._rank]
|
60
|
+
else:
|
61
|
+
remote_tensor = MultiprocessingSerializer.deserialize(
|
62
|
+
remote_serialized_tensor
|
63
|
+
)
|
64
|
+
tensor.copy_(remote_tensor)
|
65
|
+
|
66
|
+
# avoid src tensor be deleted too early
|
67
|
+
self.barrier()
|
68
|
+
|
69
|
+
def all_gather_object(self, obj: Any) -> List[Any]:
|
70
|
+
self._operation_index += 1
|
71
|
+
|
72
|
+
text_postfix = "\n"
|
73
|
+
|
74
|
+
def _get_path(interesting_rank: int):
|
75
|
+
return (
|
76
|
+
self._directory
|
77
|
+
/ f"rank{interesting_rank}_op{self._operation_index}.txt"
|
78
|
+
)
|
79
|
+
|
80
|
+
_get_path(self._rank).write_text(
|
81
|
+
base64.b64encode(pickle.dumps(obj)).decode("utf-8") + text_postfix
|
82
|
+
)
|
83
|
+
|
84
|
+
def _read_one(interesting_rank: int):
|
85
|
+
p = _get_path(interesting_rank)
|
86
|
+
while True:
|
87
|
+
if p.exists() and (text := p.read_text()).endswith(text_postfix):
|
88
|
+
return pickle.loads(base64.b64decode(text[: -len(text_postfix)]))
|
89
|
+
time.sleep(0.001)
|
90
|
+
|
91
|
+
return [
|
92
|
+
_read_one(interesting_rank) for interesting_rank in range(self._world_size)
|
93
|
+
]
|
94
|
+
|
95
|
+
def barrier(self):
|
96
|
+
actual_objs = self.all_gather_object(self._rank)
|
97
|
+
assert actual_objs == list(range(self._world_size)), f"{actual_objs=}"
|
98
|
+
|
99
|
+
|
100
|
+
# Can have multi instances if needed
|
101
|
+
_instance: Optional[NaiveDistributed] = None
|
102
|
+
|
103
|
+
|
104
|
+
def get_naive_distributed():
|
105
|
+
assert _instance is not None
|
106
|
+
return _instance
|
107
|
+
|
108
|
+
|
109
|
+
def set_naive_distributed(instance: NaiveDistributed):
|
110
|
+
global _instance
|
111
|
+
assert _instance is None
|
112
|
+
_instance = instance
|
@@ -55,7 +55,7 @@ _is_npu = is_npu()
|
|
55
55
|
|
56
56
|
@dataclass
|
57
57
|
class GraphCaptureContext:
|
58
|
-
stream: torch.cuda.Stream
|
58
|
+
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
59
59
|
|
60
60
|
|
61
61
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
@@ -252,8 +252,11 @@ class GroupCoordinator:
|
|
252
252
|
|
253
253
|
if is_cuda_alike():
|
254
254
|
self.device = torch.device(f"cuda:{local_rank}")
|
255
|
+
elif _is_npu:
|
256
|
+
self.device = torch.device(f"npu:{local_rank}")
|
255
257
|
else:
|
256
258
|
self.device = torch.device("cpu")
|
259
|
+
self.device_module = torch.get_device_module(self.device)
|
257
260
|
|
258
261
|
self.use_pynccl = use_pynccl
|
259
262
|
self.use_pymscclpp = use_pymscclpp
|
@@ -402,7 +405,7 @@ class GroupCoordinator:
|
|
402
405
|
self, graph_capture_context: Optional[GraphCaptureContext] = None
|
403
406
|
):
|
404
407
|
if graph_capture_context is None:
|
405
|
-
stream =
|
408
|
+
stream = self.device_module.Stream()
|
406
409
|
graph_capture_context = GraphCaptureContext(stream)
|
407
410
|
else:
|
408
411
|
stream = graph_capture_context.stream
|
@@ -413,11 +416,11 @@ class GroupCoordinator:
|
|
413
416
|
|
414
417
|
# ensure all initialization operations complete before attempting to
|
415
418
|
# capture the graph on another stream
|
416
|
-
curr_stream =
|
419
|
+
curr_stream = self.device_module.current_stream()
|
417
420
|
if curr_stream != stream:
|
418
421
|
stream.wait_stream(curr_stream)
|
419
422
|
|
420
|
-
with
|
423
|
+
with self.device_module.stream(stream), maybe_ca_context:
|
421
424
|
# In graph mode, we have to be very careful about the collective
|
422
425
|
# operations. The current status is:
|
423
426
|
# allreduce \ Mode | Eager | Graph |
|
@@ -583,6 +586,39 @@ class GroupCoordinator:
|
|
583
586
|
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
|
584
587
|
return output
|
585
588
|
|
589
|
+
def reduce_scatterv(
|
590
|
+
self,
|
591
|
+
input_: torch.Tensor,
|
592
|
+
output: Optional[torch.Tensor] = None,
|
593
|
+
sizes: Optional[List[int]] = None,
|
594
|
+
) -> torch.Tensor:
|
595
|
+
world_size = self.world_size
|
596
|
+
pynccl_comm = self.pynccl_comm
|
597
|
+
|
598
|
+
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
|
599
|
+
assert (
|
600
|
+
pynccl_comm is not None and not pynccl_comm.disabled
|
601
|
+
), "pynccl is required for reduce_scatterv"
|
602
|
+
|
603
|
+
if sizes is not None:
|
604
|
+
assert len(sizes) == world_size
|
605
|
+
assert input_.shape[0] == sum(sizes)
|
606
|
+
chunk_size = sizes[self.rank_in_group]
|
607
|
+
else:
|
608
|
+
assert input_.shape[0] % world_size == 0
|
609
|
+
chunk_size = input_.shape[0] // world_size
|
610
|
+
output_shape = (chunk_size,) + input_.shape[1:]
|
611
|
+
|
612
|
+
if output is None:
|
613
|
+
output = torch.empty(
|
614
|
+
output_shape, dtype=input_.dtype, device=input_.device
|
615
|
+
)
|
616
|
+
else:
|
617
|
+
assert output.shape == output_shape
|
618
|
+
|
619
|
+
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
|
620
|
+
return output
|
621
|
+
|
586
622
|
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
587
623
|
pynccl_comm = self.pynccl_comm
|
588
624
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
@@ -673,6 +709,54 @@ class GroupCoordinator:
|
|
673
709
|
)
|
674
710
|
return output_tensor
|
675
711
|
|
712
|
+
def all_gatherv(
|
713
|
+
self,
|
714
|
+
input_: Union[torch.Tensor, List[torch.Tensor]],
|
715
|
+
sizes: Optional[List[int]] = None,
|
716
|
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
717
|
+
"""
|
718
|
+
Supports varying sizes per rank and input tensor list.
|
719
|
+
`sizes`: a list of len(world_size) with the number of items per rank to gather.
|
720
|
+
"""
|
721
|
+
world_size = self.world_size
|
722
|
+
pynccl_comm = self.pynccl_comm
|
723
|
+
|
724
|
+
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
|
725
|
+
assert (
|
726
|
+
pynccl_comm is not None and not pynccl_comm.disabled
|
727
|
+
), "pynccl is required for all_gatherv"
|
728
|
+
|
729
|
+
def _all_gather_single(
|
730
|
+
input_: torch.Tensor, sizes: Optional[List[int]] = None
|
731
|
+
):
|
732
|
+
input_size = input_.size()
|
733
|
+
if sizes is not None:
|
734
|
+
assert len(sizes) == world_size
|
735
|
+
assert input_.shape[0] == sizes[self.rank_in_group]
|
736
|
+
output_size = (sum(sizes),) + input_size[1:]
|
737
|
+
# 'sizes' is not needed if all inputs in the same group have the same shape
|
738
|
+
if all(s == sizes[0] for s in sizes):
|
739
|
+
sizes = None
|
740
|
+
else:
|
741
|
+
output_size = (input_size[0] * world_size,) + input_size[1:]
|
742
|
+
# Allocate output tensor.
|
743
|
+
output_tensor = torch.empty(
|
744
|
+
output_size, dtype=input_.dtype, device=input_.device
|
745
|
+
)
|
746
|
+
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
|
747
|
+
return output_tensor
|
748
|
+
|
749
|
+
if isinstance(input_, torch.Tensor):
|
750
|
+
return _all_gather_single(input_, sizes)
|
751
|
+
|
752
|
+
output_list = []
|
753
|
+
pynccl_comm.group_start()
|
754
|
+
for inp in input_:
|
755
|
+
output_list.append(_all_gather_single(inp, sizes=sizes))
|
756
|
+
pynccl_comm.group_end()
|
757
|
+
|
758
|
+
return output_list
|
759
|
+
|
676
760
|
def gather(
|
677
761
|
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
678
762
|
) -> Optional[torch.Tensor]:
|
@@ -1560,6 +1644,8 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
|
1560
1644
|
)
|
1561
1645
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
1562
1646
|
torch.xpu.empty_cache()
|
1647
|
+
elif hasattr(torch, "npu") and torch.npu.is_available():
|
1648
|
+
torch.npu.empty_cache()
|
1563
1649
|
|
1564
1650
|
|
1565
1651
|
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
-
# Copied from vLLM
|
2
|
+
# Copied from vLLM
|
3
3
|
import json
|
4
4
|
import logging
|
5
5
|
from abc import ABC, abstractmethod
|
@@ -83,6 +83,14 @@ class HarmonyContext(ConversationContext):
|
|
83
83
|
if isinstance(output, dict) and "output_ids" in output:
|
84
84
|
output_token_ids = output["output_ids"]
|
85
85
|
|
86
|
+
# TODO: REMOVE here:
|
87
|
+
# Very hacky, find the first occurrence of token 200006 and cut from there
|
88
|
+
try:
|
89
|
+
start_index = output_token_ids.index(200006)
|
90
|
+
output_token_ids = output_token_ids[start_index:]
|
91
|
+
except ValueError:
|
92
|
+
pass
|
93
|
+
|
86
94
|
for token_id in output_token_ids:
|
87
95
|
self.parser.process(token_id)
|
88
96
|
output_msgs = self.parser.messages
|
@@ -107,6 +115,8 @@ class HarmonyContext(ConversationContext):
|
|
107
115
|
return self._messages
|
108
116
|
|
109
117
|
def need_builtin_tool_call(self) -> bool:
|
118
|
+
if not self.messages:
|
119
|
+
return False
|
110
120
|
last_msg = self.messages[-1]
|
111
121
|
recipient = last_msg.recipient
|
112
122
|
return recipient is not None and (
|
@@ -188,6 +198,15 @@ class StreamingHarmonyContext(HarmonyContext):
|
|
188
198
|
# RequestOutput from SGLang with outputs
|
189
199
|
output_token_ids = output["output_ids"]
|
190
200
|
|
201
|
+
# TODO: REMOVE here:
|
202
|
+
# Very hacky, find the first occurrence of token 200006 and cut from there
|
203
|
+
# Find the first occurrence of token 200006 and cut from there
|
204
|
+
try:
|
205
|
+
start_index = output_token_ids.index(200006)
|
206
|
+
output_token_ids = output_token_ids[start_index:]
|
207
|
+
except ValueError:
|
208
|
+
pass
|
209
|
+
|
191
210
|
for token_id in output_token_ids:
|
192
211
|
self.parser.process(token_id)
|
193
212
|
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -23,8 +23,10 @@ import dataclasses
|
|
23
23
|
import logging
|
24
24
|
import multiprocessing as mp
|
25
25
|
import os
|
26
|
+
import random
|
26
27
|
import signal
|
27
28
|
import threading
|
29
|
+
import time
|
28
30
|
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
29
31
|
|
30
32
|
import zmq
|
@@ -94,8 +96,8 @@ class Engine(EngineBase):
|
|
94
96
|
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
95
97
|
|
96
98
|
Note:
|
97
|
-
1. The HTTP server, Engine, and TokenizerManager
|
98
|
-
2. Inter-process communication is
|
99
|
+
1. The HTTP server, Engine, and TokenizerManager all run in the main process.
|
100
|
+
2. Inter-process communication (IPC) is handled via the ZMQ library, with each process using a different port.
|
99
101
|
"""
|
100
102
|
|
101
103
|
def __init__(self, **kwargs):
|
@@ -536,6 +538,22 @@ class Engine(EngineBase):
|
|
536
538
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
537
539
|
)
|
538
540
|
|
541
|
+
def freeze_gc(self):
|
542
|
+
"""
|
543
|
+
To maintain a high performance server with low latency, we want to reduce the
|
544
|
+
stalls caused by the garbage collector scanning through a large number of objects.
|
545
|
+
|
546
|
+
It is usually helpful to start the server and warm it up with real requests to
|
547
|
+
initialize many of the long-lived objects that do not need to be garbage collected.
|
548
|
+
|
549
|
+
After sufficient warmup, we can call this function to freeze the garbage collector
|
550
|
+
so that all objects created before this point are considered out of scope for garbage
|
551
|
+
collection.
|
552
|
+
"""
|
553
|
+
|
554
|
+
loop = asyncio.get_event_loop()
|
555
|
+
loop.run_until_complete(self.tokenizer_manager.freeze_gc())
|
556
|
+
|
539
557
|
"""
|
540
558
|
Execute an RPC call on all scheduler processes.
|
541
559
|
"""
|
@@ -635,6 +653,13 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
635
653
|
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
636
654
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
637
655
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
656
|
+
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
657
|
+
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
658
|
+
|
659
|
+
# Can also be passed as argument
|
660
|
+
os.environ["SGLANG_RUN_ID"] = (
|
661
|
+
f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
|
662
|
+
)
|
638
663
|
|
639
664
|
# Set prometheus env vars
|
640
665
|
if server_args.enable_metrics:
|
@@ -647,7 +672,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
647
672
|
if server_args.attention_backend == "flashinfer":
|
648
673
|
assert_pkg_version(
|
649
674
|
"flashinfer_python",
|
650
|
-
"0.2.11.
|
675
|
+
"0.2.11.post3",
|
651
676
|
"Please uninstall the old version and "
|
652
677
|
"reinstall the latest version by following the instructions "
|
653
678
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -655,7 +680,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
655
680
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
656
681
|
assert_pkg_version(
|
657
682
|
"sgl-kernel",
|
658
|
-
"0.3.
|
683
|
+
"0.3.5",
|
659
684
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
660
685
|
)
|
661
686
|
|