sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -148,7 +148,11 @@ class PyNcclCommunicator:
148
148
  )
149
149
 
150
150
  def all_gather(
151
- self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
151
+ self,
152
+ output_tensor: torch.Tensor,
153
+ input_tensor: torch.Tensor,
154
+ stream=None,
155
+ sizes: Optional[list[int]] = None,
152
156
  ):
153
157
  if self.disabled:
154
158
  return
@@ -161,14 +165,33 @@ class PyNcclCommunicator:
161
165
  )
162
166
  if stream is None:
163
167
  stream = self.stream
164
- self.nccl.ncclAllGather(
165
- buffer_type(input_tensor.data_ptr()),
166
- buffer_type(output_tensor.data_ptr()),
167
- input_tensor.numel(),
168
- ncclDataTypeEnum.from_torch(input_tensor.dtype),
169
- self.comm,
170
- cudaStream_t(stream.cuda_stream),
171
- )
168
+
169
+ if sizes is not None:
170
+ split_offset = 0
171
+
172
+ self.nccl.ncclGroupStart()
173
+ for root, split_size in enumerate(sizes):
174
+ dst_slice = output_tensor[split_offset : split_offset + split_size]
175
+ self.nccl.ncclBroadcast(
176
+ buffer_type(input_tensor.data_ptr()),
177
+ buffer_type(dst_slice.data_ptr()),
178
+ dst_slice.numel(),
179
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
180
+ root,
181
+ self.comm,
182
+ cudaStream_t(stream.cuda_stream),
183
+ )
184
+ split_offset += split_size
185
+ self.nccl.ncclGroupEnd()
186
+ else:
187
+ self.nccl.ncclAllGather(
188
+ buffer_type(input_tensor.data_ptr()),
189
+ buffer_type(output_tensor.data_ptr()),
190
+ input_tensor.numel(),
191
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
192
+ self.comm,
193
+ cudaStream_t(stream.cuda_stream),
194
+ )
172
195
 
173
196
  def reduce_scatter(
174
197
  self,
@@ -176,6 +199,7 @@ class PyNcclCommunicator:
176
199
  input_tensor: torch.Tensor,
177
200
  op: ReduceOp = ReduceOp.SUM,
178
201
  stream=None,
202
+ sizes: Optional[list[int]] = None,
179
203
  ):
180
204
  if self.disabled:
181
205
  return
@@ -188,15 +212,35 @@ class PyNcclCommunicator:
188
212
  )
189
213
  if stream is None:
190
214
  stream = self.stream
191
- self.nccl.ncclReduceScatter(
192
- buffer_type(input_tensor.data_ptr()),
193
- buffer_type(output_tensor.data_ptr()),
194
- output_tensor.numel(),
195
- ncclDataTypeEnum.from_torch(input_tensor.dtype),
196
- ncclRedOpTypeEnum.from_torch(op),
197
- self.comm,
198
- cudaStream_t(stream.cuda_stream),
199
- )
215
+
216
+ if sizes is not None:
217
+ split_offset = 0
218
+ self.nccl.ncclGroupStart()
219
+ for root, split_size in enumerate(sizes):
220
+ chunk = input_tensor[split_offset : split_offset + split_size, ...]
221
+
222
+ self.nccl.ncclReduce(
223
+ buffer_type(chunk.data_ptr()),
224
+ buffer_type(output_tensor.data_ptr()),
225
+ chunk.numel(),
226
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
227
+ ncclRedOpTypeEnum.from_torch(op),
228
+ root,
229
+ self.comm,
230
+ cudaStream_t(stream.cuda_stream),
231
+ )
232
+ split_offset += split_size
233
+ self.nccl.ncclGroupEnd()
234
+ else:
235
+ self.nccl.ncclReduceScatter(
236
+ buffer_type(input_tensor.data_ptr()),
237
+ buffer_type(output_tensor.data_ptr()),
238
+ output_tensor.numel(),
239
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
240
+ ncclRedOpTypeEnum.from_torch(op),
241
+ self.comm,
242
+ cudaStream_t(stream.cuda_stream),
243
+ )
200
244
 
201
245
  def send(self, tensor: torch.Tensor, dst: int, stream=None):
202
246
  if self.disabled:
@@ -266,6 +310,12 @@ class PyNcclCommunicator:
266
310
  def deregister_comm_window(self, window):
267
311
  return self.nccl.ncclCommWindowDeregister(self.comm, window)
268
312
 
313
+ def group_start(self):
314
+ self.nccl.ncclGroupStart()
315
+
316
+ def group_end(self):
317
+ self.nccl.ncclGroupEnd()
318
+
269
319
  @contextmanager
270
320
  def change_state(
271
321
  self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
@@ -206,6 +206,26 @@ class NCCLLibrary:
206
206
  cudaStream_t,
207
207
  ],
208
208
  ),
209
+ # ncclResult_t ncclReduce(
210
+ # const void* sendbuff, void* recvbuff, size_t count,
211
+ # ncclDataType_t datatype, ncclRedOp_t op, int root,
212
+ # ncclComm_t comm, cudaStream_t stream);
213
+ # note that cudaStream_t is a pointer type, so the last argument
214
+ # is a pointer
215
+ Function(
216
+ "ncclReduce",
217
+ ncclResult_t,
218
+ [
219
+ buffer_type,
220
+ buffer_type,
221
+ ctypes.c_size_t,
222
+ ncclDataType_t,
223
+ ncclRedOp_t,
224
+ ctypes.c_int,
225
+ ncclComm_t,
226
+ cudaStream_t,
227
+ ],
228
+ ),
209
229
  # ncclResult_t ncclReduceScatter(
210
230
  # const void* sendbuff, void* recvbuff, size_t count,
211
231
  # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
@@ -278,6 +298,10 @@ class NCCLLibrary:
278
298
  # it is better not to call it at all.
279
299
  # ncclResult_t ncclCommDestroy(ncclComm_t comm);
280
300
  Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
301
+ # ncclResult_t ncclGroupStart();
302
+ Function("ncclGroupStart", ncclResult_t, []),
303
+ # ncclResult_t ncclGroupEnd();
304
+ Function("ncclGroupEnd", ncclResult_t, []),
281
305
  ]
282
306
 
283
307
  exported_functions_symm_mem = [
@@ -400,6 +424,28 @@ class NCCLLibrary:
400
424
  )
401
425
  )
402
426
 
427
+ def ncclReduce(
428
+ self,
429
+ sendbuff: buffer_type,
430
+ recvbuff: buffer_type,
431
+ count: int,
432
+ datatype: int,
433
+ op: int,
434
+ root: int,
435
+ comm: ncclComm_t,
436
+ stream: cudaStream_t,
437
+ ) -> None:
438
+ # `datatype` actually should be `ncclDataType_t`
439
+ # and `op` should be `ncclRedOp_t`
440
+ # both are aliases of `ctypes.c_int`
441
+ # when we pass int to a function, it will be converted to `ctypes.c_int`
442
+ # by ctypes automatically
443
+ self.NCCL_CHECK(
444
+ self._funcs["ncclReduce"](
445
+ sendbuff, recvbuff, count, datatype, op, root, comm, stream
446
+ )
447
+ )
448
+
403
449
  def ncclReduceScatter(
404
450
  self,
405
451
  sendbuff: buffer_type,
@@ -499,6 +545,12 @@ class NCCLLibrary:
499
545
  def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
500
546
  self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
501
547
 
548
+ def ncclGroupStart(self) -> None:
549
+ self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
550
+
551
+ def ncclGroupEnd(self) -> None:
552
+ self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
553
+
502
554
 
503
555
  __all__ = [
504
556
  "NCCLLibrary",
@@ -0,0 +1,112 @@
1
+ import base64
2
+ import os
3
+ import pickle
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any, List, Optional
7
+
8
+ import torch
9
+
10
+ from sglang.srt.utils import MultiprocessingSerializer
11
+
12
+
13
+ class NaiveDistributed:
14
+ def __init__(self, rank: int, world_size: int, rendezvous: str):
15
+ self._rank = rank
16
+ self._world_size = world_size
17
+ self._operation_index = 0
18
+ self._directory = Path(rendezvous)
19
+ self._directory.mkdir(parents=True, exist_ok=True)
20
+ assert 0 <= rank < world_size
21
+
22
+ # both barrier to be safe, and as a sanity check
23
+ self.barrier()
24
+
25
+ def get_rank(self):
26
+ return self._rank
27
+
28
+ def get_world_size(self):
29
+ return self._world_size
30
+
31
+ def scatter(
32
+ self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0
33
+ ):
34
+ if self._rank == src:
35
+ assert len(scatter_list) == self._world_size
36
+ else:
37
+ assert scatter_list is None
38
+
39
+ gathered_objects = self.all_gather_object(
40
+ dict(
41
+ serialized_scatter_list=[
42
+ (
43
+ None
44
+ if item_rank == src
45
+ else MultiprocessingSerializer.serialize(item)
46
+ )
47
+ for item_rank, item in enumerate(scatter_list)
48
+ ]
49
+ )
50
+ if self._rank == src
51
+ else dict()
52
+ )
53
+
54
+ remote_serialized_tensor = gathered_objects[src]["serialized_scatter_list"][
55
+ self._rank
56
+ ]
57
+ if self._rank == src:
58
+ assert remote_serialized_tensor is None
59
+ remote_tensor = scatter_list[self._rank]
60
+ else:
61
+ remote_tensor = MultiprocessingSerializer.deserialize(
62
+ remote_serialized_tensor
63
+ )
64
+ tensor.copy_(remote_tensor)
65
+
66
+ # avoid src tensor be deleted too early
67
+ self.barrier()
68
+
69
+ def all_gather_object(self, obj: Any) -> List[Any]:
70
+ self._operation_index += 1
71
+
72
+ text_postfix = "\n"
73
+
74
+ def _get_path(interesting_rank: int):
75
+ return (
76
+ self._directory
77
+ / f"rank{interesting_rank}_op{self._operation_index}.txt"
78
+ )
79
+
80
+ _get_path(self._rank).write_text(
81
+ base64.b64encode(pickle.dumps(obj)).decode("utf-8") + text_postfix
82
+ )
83
+
84
+ def _read_one(interesting_rank: int):
85
+ p = _get_path(interesting_rank)
86
+ while True:
87
+ if p.exists() and (text := p.read_text()).endswith(text_postfix):
88
+ return pickle.loads(base64.b64decode(text[: -len(text_postfix)]))
89
+ time.sleep(0.001)
90
+
91
+ return [
92
+ _read_one(interesting_rank) for interesting_rank in range(self._world_size)
93
+ ]
94
+
95
+ def barrier(self):
96
+ actual_objs = self.all_gather_object(self._rank)
97
+ assert actual_objs == list(range(self._world_size)), f"{actual_objs=}"
98
+
99
+
100
+ # Can have multi instances if needed
101
+ _instance: Optional[NaiveDistributed] = None
102
+
103
+
104
+ def get_naive_distributed():
105
+ assert _instance is not None
106
+ return _instance
107
+
108
+
109
+ def set_naive_distributed(instance: NaiveDistributed):
110
+ global _instance
111
+ assert _instance is None
112
+ _instance = instance
@@ -55,7 +55,7 @@ _is_npu = is_npu()
55
55
 
56
56
  @dataclass
57
57
  class GraphCaptureContext:
58
- stream: torch.cuda.Stream
58
+ stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
59
59
 
60
60
 
61
61
  TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
@@ -252,8 +252,11 @@ class GroupCoordinator:
252
252
 
253
253
  if is_cuda_alike():
254
254
  self.device = torch.device(f"cuda:{local_rank}")
255
+ elif _is_npu:
256
+ self.device = torch.device(f"npu:{local_rank}")
255
257
  else:
256
258
  self.device = torch.device("cpu")
259
+ self.device_module = torch.get_device_module(self.device)
257
260
 
258
261
  self.use_pynccl = use_pynccl
259
262
  self.use_pymscclpp = use_pymscclpp
@@ -402,7 +405,7 @@ class GroupCoordinator:
402
405
  self, graph_capture_context: Optional[GraphCaptureContext] = None
403
406
  ):
404
407
  if graph_capture_context is None:
405
- stream = torch.cuda.Stream()
408
+ stream = self.device_module.Stream()
406
409
  graph_capture_context = GraphCaptureContext(stream)
407
410
  else:
408
411
  stream = graph_capture_context.stream
@@ -413,11 +416,11 @@ class GroupCoordinator:
413
416
 
414
417
  # ensure all initialization operations complete before attempting to
415
418
  # capture the graph on another stream
416
- curr_stream = torch.cuda.current_stream()
419
+ curr_stream = self.device_module.current_stream()
417
420
  if curr_stream != stream:
418
421
  stream.wait_stream(curr_stream)
419
422
 
420
- with torch.cuda.stream(stream), maybe_ca_context:
423
+ with self.device_module.stream(stream), maybe_ca_context:
421
424
  # In graph mode, we have to be very careful about the collective
422
425
  # operations. The current status is:
423
426
  # allreduce \ Mode | Eager | Graph |
@@ -583,6 +586,39 @@ class GroupCoordinator:
583
586
  torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
584
587
  return output
585
588
 
589
+ def reduce_scatterv(
590
+ self,
591
+ input_: torch.Tensor,
592
+ output: Optional[torch.Tensor] = None,
593
+ sizes: Optional[List[int]] = None,
594
+ ) -> torch.Tensor:
595
+ world_size = self.world_size
596
+ pynccl_comm = self.pynccl_comm
597
+
598
+ with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
599
+ assert (
600
+ pynccl_comm is not None and not pynccl_comm.disabled
601
+ ), "pynccl is required for reduce_scatterv"
602
+
603
+ if sizes is not None:
604
+ assert len(sizes) == world_size
605
+ assert input_.shape[0] == sum(sizes)
606
+ chunk_size = sizes[self.rank_in_group]
607
+ else:
608
+ assert input_.shape[0] % world_size == 0
609
+ chunk_size = input_.shape[0] // world_size
610
+ output_shape = (chunk_size,) + input_.shape[1:]
611
+
612
+ if output is None:
613
+ output = torch.empty(
614
+ output_shape, dtype=input_.dtype, device=input_.device
615
+ )
616
+ else:
617
+ assert output.shape == output_shape
618
+
619
+ pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
620
+ return output
621
+
586
622
  def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
587
623
  pynccl_comm = self.pynccl_comm
588
624
  if pynccl_comm is not None and not pynccl_comm.disabled:
@@ -673,6 +709,54 @@ class GroupCoordinator:
673
709
  )
674
710
  return output_tensor
675
711
 
712
+ def all_gatherv(
713
+ self,
714
+ input_: Union[torch.Tensor, List[torch.Tensor]],
715
+ sizes: Optional[List[int]] = None,
716
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
717
+ """
718
+ Supports varying sizes per rank and input tensor list.
719
+ `sizes`: a list of len(world_size) with the number of items per rank to gather.
720
+ """
721
+ world_size = self.world_size
722
+ pynccl_comm = self.pynccl_comm
723
+
724
+ with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
725
+ assert (
726
+ pynccl_comm is not None and not pynccl_comm.disabled
727
+ ), "pynccl is required for all_gatherv"
728
+
729
+ def _all_gather_single(
730
+ input_: torch.Tensor, sizes: Optional[List[int]] = None
731
+ ):
732
+ input_size = input_.size()
733
+ if sizes is not None:
734
+ assert len(sizes) == world_size
735
+ assert input_.shape[0] == sizes[self.rank_in_group]
736
+ output_size = (sum(sizes),) + input_size[1:]
737
+ # 'sizes' is not needed if all inputs in the same group have the same shape
738
+ if all(s == sizes[0] for s in sizes):
739
+ sizes = None
740
+ else:
741
+ output_size = (input_size[0] * world_size,) + input_size[1:]
742
+ # Allocate output tensor.
743
+ output_tensor = torch.empty(
744
+ output_size, dtype=input_.dtype, device=input_.device
745
+ )
746
+ pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
747
+ return output_tensor
748
+
749
+ if isinstance(input_, torch.Tensor):
750
+ return _all_gather_single(input_, sizes)
751
+
752
+ output_list = []
753
+ pynccl_comm.group_start()
754
+ for inp in input_:
755
+ output_list.append(_all_gather_single(inp, sizes=sizes))
756
+ pynccl_comm.group_end()
757
+
758
+ return output_list
759
+
676
760
  def gather(
677
761
  self, input_: torch.Tensor, dst: int = 0, dim: int = -1
678
762
  ) -> Optional[torch.Tensor]:
@@ -1560,6 +1644,8 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1560
1644
  )
1561
1645
  elif hasattr(torch, "xpu") and torch.xpu.is_available():
1562
1646
  torch.xpu.empty_cache()
1647
+ elif hasattr(torch, "npu") and torch.npu.is_available():
1648
+ torch.npu.empty_cache()
1563
1649
 
1564
1650
 
1565
1651
  def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
@@ -1,5 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
- # Copied from vLLM: https://github.com/zyongye/vllm/blob/6a70830065701b163e36a86fd331b41b5feac401/vllm/entrypoints/context.py
2
+ # Copied from vLLM
3
3
  import json
4
4
  import logging
5
5
  from abc import ABC, abstractmethod
@@ -83,6 +83,14 @@ class HarmonyContext(ConversationContext):
83
83
  if isinstance(output, dict) and "output_ids" in output:
84
84
  output_token_ids = output["output_ids"]
85
85
 
86
+ # TODO: REMOVE here:
87
+ # Very hacky, find the first occurrence of token 200006 and cut from there
88
+ try:
89
+ start_index = output_token_ids.index(200006)
90
+ output_token_ids = output_token_ids[start_index:]
91
+ except ValueError:
92
+ pass
93
+
86
94
  for token_id in output_token_ids:
87
95
  self.parser.process(token_id)
88
96
  output_msgs = self.parser.messages
@@ -107,6 +115,8 @@ class HarmonyContext(ConversationContext):
107
115
  return self._messages
108
116
 
109
117
  def need_builtin_tool_call(self) -> bool:
118
+ if not self.messages:
119
+ return False
110
120
  last_msg = self.messages[-1]
111
121
  recipient = last_msg.recipient
112
122
  return recipient is not None and (
@@ -188,6 +198,15 @@ class StreamingHarmonyContext(HarmonyContext):
188
198
  # RequestOutput from SGLang with outputs
189
199
  output_token_ids = output["output_ids"]
190
200
 
201
+ # TODO: REMOVE here:
202
+ # Very hacky, find the first occurrence of token 200006 and cut from there
203
+ # Find the first occurrence of token 200006 and cut from there
204
+ try:
205
+ start_index = output_token_ids.index(200006)
206
+ output_token_ids = output_token_ids[start_index:]
207
+ except ValueError:
208
+ pass
209
+
191
210
  for token_id in output_token_ids:
192
211
  self.parser.process(token_id)
193
212
 
@@ -23,8 +23,10 @@ import dataclasses
23
23
  import logging
24
24
  import multiprocessing as mp
25
25
  import os
26
+ import random
26
27
  import signal
27
28
  import threading
29
+ import time
28
30
  from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
29
31
 
30
32
  import zmq
@@ -94,8 +96,8 @@ class Engine(EngineBase):
94
96
  3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
95
97
 
96
98
  Note:
97
- 1. The HTTP server, Engine, and TokenizerManager both run in the main process.
98
- 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
99
+ 1. The HTTP server, Engine, and TokenizerManager all run in the main process.
100
+ 2. Inter-process communication (IPC) is handled via the ZMQ library, with each process using a different port.
99
101
  """
100
102
 
101
103
  def __init__(self, **kwargs):
@@ -536,6 +538,22 @@ class Engine(EngineBase):
536
538
  self.tokenizer_manager.resume_memory_occupation(obj, None)
537
539
  )
538
540
 
541
+ def freeze_gc(self):
542
+ """
543
+ To maintain a high performance server with low latency, we want to reduce the
544
+ stalls caused by the garbage collector scanning through a large number of objects.
545
+
546
+ It is usually helpful to start the server and warm it up with real requests to
547
+ initialize many of the long-lived objects that do not need to be garbage collected.
548
+
549
+ After sufficient warmup, we can call this function to freeze the garbage collector
550
+ so that all objects created before this point are considered out of scope for garbage
551
+ collection.
552
+ """
553
+
554
+ loop = asyncio.get_event_loop()
555
+ loop.run_until_complete(self.tokenizer_manager.freeze_gc())
556
+
539
557
  """
540
558
  Execute an RPC call on all scheduler processes.
541
559
  """
@@ -635,6 +653,13 @@ def _set_envs_and_config(server_args: ServerArgs):
635
653
  os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
636
654
  os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
637
655
  os.environ["CUDA_MODULE_LOADING"] = "AUTO"
656
+ # flashinfer uses this environment variable for various kernels from MoE to quant kernels
657
+ os.environ["TRTLLM_ENABLE_PDL"] = "1"
658
+
659
+ # Can also be passed as argument
660
+ os.environ["SGLANG_RUN_ID"] = (
661
+ f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
662
+ )
638
663
 
639
664
  # Set prometheus env vars
640
665
  if server_args.enable_metrics:
@@ -647,7 +672,7 @@ def _set_envs_and_config(server_args: ServerArgs):
647
672
  if server_args.attention_backend == "flashinfer":
648
673
  assert_pkg_version(
649
674
  "flashinfer_python",
650
- "0.2.11.post1",
675
+ "0.2.11.post3",
651
676
  "Please uninstall the old version and "
652
677
  "reinstall the latest version by following the instructions "
653
678
  "at https://docs.flashinfer.ai/installation.html.",
@@ -655,7 +680,7 @@ def _set_envs_and_config(server_args: ServerArgs):
655
680
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
656
681
  assert_pkg_version(
657
682
  "sgl-kernel",
658
- "0.3.4",
683
+ "0.3.5",
659
684
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
660
685
  )
661
686