sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
2
+ import logging
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch.distributed import ProcessGroup
8
+
9
+ from sglang.srt.distributed.device_communicators.all_reduce_utils import (
10
+ SYMM_MEM_ALL_REDUCE_MAX_SIZES,
11
+ )
12
+ from sglang.srt.utils import get_device_capability, is_cuda, is_hip
13
+
14
+ try:
15
+ import torch.distributed._symmetric_memory as torch_symm_mem
16
+
17
+ symm_mem_available = True
18
+ except ImportError:
19
+ symm_mem_available = False
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ _is_cuda = is_cuda()
25
+ _is_hip = is_hip()
26
+
27
+ symm_mem_is_available = False
28
+ if _is_hip:
29
+ symm_mem_is_available = False
30
+ if _is_cuda:
31
+ symm_mem_is_available = True
32
+
33
+
34
+ class SymmMemCommunicator:
35
+ """
36
+ Thin wrapper around symmetric-memory collectives.
37
+
38
+ This communicator:
39
+ - Validates device capability and world size.
40
+ - Allocates a shared symmetric buffer.
41
+ - Chooses between 'multimem' and 'two-shot' all-reduce kernels.
42
+ - Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
43
+
44
+ If any prerequisite is not met, the instance remains disabled and will
45
+ decline to perform symmetric-memory all-reduce.
46
+ """
47
+
48
+ # Mapping: compute capability major -> supported world sizes for multimem
49
+ # If the current (cc_major, world_size) is not listed, we fall back
50
+ # to the two-shot path.
51
+ _WORLD_SIZES_MULTIMEM = {
52
+ 9: [4, 6, 8],
53
+ 10: [6, 8],
54
+ }
55
+
56
+ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
57
+ """
58
+ Args:
59
+ group: Torch process group used for rendezvous and naming.
60
+ device: Target CUDA device (index, 'cuda:X', or torch.device).
61
+ """
62
+
63
+ self.disabled = True
64
+
65
+ if not symm_mem_available:
66
+ return
67
+
68
+ if isinstance(device, int):
69
+ device = torch.device(f"cuda:{device}")
70
+ elif isinstance(device, str):
71
+ device = torch.device(device)
72
+ torch.cuda.set_device(device)
73
+ self.dtype = torch.bfloat16
74
+ self.device = device
75
+ self.group = group
76
+ self.world_size = dist.get_world_size(self.group)
77
+ self.device_capability = torch.cuda.get_device_capability(device)[0]
78
+ if self.device_capability < 9:
79
+ logger.warning(
80
+ "SymmMemCommunicator: Device capability %s not supported, "
81
+ "communicator is not available.",
82
+ self.device_capability,
83
+ )
84
+ return
85
+ if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
86
+ logger.warning(
87
+ "SymmMemCommunicator: World size %d not supported, "
88
+ "communicator is not available.",
89
+ self.world_size,
90
+ )
91
+ return
92
+ self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
93
+ self.world_size
94
+ ]
95
+ self.buffer = torch_symm_mem.empty(
96
+ self.max_size // self.dtype.itemsize,
97
+ device=self.device,
98
+ dtype=self.dtype,
99
+ )
100
+ handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
101
+ if handle.multicast_ptr == 0:
102
+ logger.warning(
103
+ "SymmMemCommunicator: symmetric memory "
104
+ "multicast operations are not supported."
105
+ )
106
+ self.buffer = None
107
+ self.disabled = True
108
+ return
109
+ self.disabled = False
110
+
111
+ def should_symm_mem_allreduce(self, inp: torch.Tensor):
112
+ """
113
+ Fast-path eligibility check for a given tensor.
114
+
115
+ Conditions:
116
+ - Communicator must be enabled.
117
+ - dtype must be bfloat16 (matches kernel + buffer dtype).
118
+ - Total byte size must be 4-byte aligned (hardware requirement).
119
+ - Payload must be smaller than the symmetric-memory max size.
120
+
121
+ Returns:
122
+ True if the symmetric-memory path can handle this tensor.
123
+ """
124
+ if self.disabled:
125
+ return False
126
+ if inp.dtype != self.dtype:
127
+ return False
128
+ inp_size = inp.numel() * inp.element_size()
129
+ # enforce 4-byte alignment
130
+ if inp_size % 4 != 0:
131
+ return False
132
+ return inp_size < self.max_size
133
+
134
+ def all_reduce(
135
+ self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
136
+ ) -> Optional[torch.Tensor]:
137
+ """
138
+ Perform an in-place sum all-reduce via symmetric memory.
139
+
140
+ Args:
141
+ inp: Input tensor on the target CUDA device (bfloat16).
142
+ out: Optional output tensor; if omitted, a new tensor is allocated.
143
+
144
+ Returns:
145
+ The reduced tensor (same shape as inp), or None if disabled.
146
+
147
+ Implementation details:
148
+ - Stages 'inp' into the symmetric buffer.
149
+ - Selects 'multimem' or 'two_shot' kernel based on topology.
150
+ - Writes the result into 'out' and returns it.
151
+ """
152
+ if out is None:
153
+ out = torch.empty_like(inp)
154
+ self.buffer[: inp.numel()].copy_(inp.view(-1))
155
+ if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
156
+ torch.ops.symm_mem.multimem_all_reduce_(
157
+ self.buffer[: inp.numel()], "sum", self.group.group_name
158
+ )
159
+ else:
160
+ torch.ops.symm_mem.two_shot_all_reduce_(
161
+ self.buffer[: inp.numel()], "sum", self.group.group_name
162
+ )
163
+ out.copy_(self.buffer[: inp.numel()].view(out.shape))
164
+ return out
@@ -4,7 +4,7 @@
4
4
  # Adapted from
5
5
  # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
6
6
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
- """vLLM distributed state.
7
+ """Distributed state.
8
8
  It takes over the control of the distributed environment from PyTorch.
9
9
  The typical workflow is:
10
10
 
@@ -53,19 +53,26 @@ from sglang.srt.utils import (
53
53
 
54
54
  _is_npu = is_npu()
55
55
  _is_cpu = is_cpu()
56
+ _supports_custom_op = supports_custom_op()
56
57
 
57
58
  IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
58
59
 
59
60
 
61
+ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
62
+
63
+ # use int value instead of ReduceOp.SUM to support torch compile
64
+ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
65
+
66
+
60
67
  @dataclass
61
68
  class GraphCaptureContext:
62
69
  stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
63
70
 
64
71
 
65
- TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
66
-
67
- # use int value instead of ReduceOp.SUM to support torch compile
68
- REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
72
+ @dataclass
73
+ class P2PWork:
74
+ work: Optional[torch.distributed.Work]
75
+ payload: Optional[torch.Tensor]
69
76
 
70
77
 
71
78
  def _split_tensor_dict(
@@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
117
124
  _groups[group.unique_name] = weakref.ref(group)
118
125
 
119
126
 
120
- if supports_custom_op():
127
+ if _supports_custom_op:
121
128
 
122
129
  def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
123
130
  assert group_name in _groups, f"Group {group_name} is not found."
@@ -208,12 +215,14 @@ class GroupCoordinator:
208
215
  use_pynccl: bool # a hint of whether to use PyNccl
209
216
  use_pymscclpp: bool # a hint of whether to use PyMsccl
210
217
  use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
218
+ use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
211
219
  use_message_queue_broadcaster: (
212
220
  bool # a hint of whether to use message queue broadcaster
213
221
  )
214
222
  # communicators are only created for world size > 1
215
223
  pynccl_comm: Optional[Any] # PyNccl communicator
216
224
  ca_comm: Optional[Any] # Custom allreduce communicator
225
+ symm_mem_comm: Optional[Any] # Symm mem communicator
217
226
  mq_broadcaster: Optional[Any] # shared memory broadcaster
218
227
 
219
228
  def __init__(
@@ -224,6 +233,7 @@ class GroupCoordinator:
224
233
  use_pynccl: bool,
225
234
  use_pymscclpp: bool,
226
235
  use_custom_allreduce: bool,
236
+ use_torch_symm_mem: bool,
227
237
  use_hpu_communicator: bool,
228
238
  use_xpu_communicator: bool,
229
239
  use_npu_communicator: bool,
@@ -272,12 +282,13 @@ class GroupCoordinator:
272
282
  self.use_pynccl = use_pynccl
273
283
  self.use_pymscclpp = use_pymscclpp
274
284
  self.use_custom_allreduce = use_custom_allreduce
285
+ self.use_torch_symm_mem = use_torch_symm_mem
275
286
  self.use_hpu_communicator = use_hpu_communicator
276
287
  self.use_xpu_communicator = use_xpu_communicator
277
288
  self.use_npu_communicator = use_npu_communicator
278
289
  self.use_message_queue_broadcaster = use_message_queue_broadcaster
279
290
 
280
- # lazy import to avoid documentation build error
291
+ # Lazy import to avoid documentation build error
281
292
  from sglang.srt.distributed.device_communicators.custom_all_reduce import (
282
293
  CustomAllreduce,
283
294
  )
@@ -287,6 +298,9 @@ class GroupCoordinator:
287
298
  from sglang.srt.distributed.device_communicators.pynccl import (
288
299
  PyNcclCommunicator,
289
300
  )
301
+ from sglang.srt.distributed.device_communicators.symm_mem import (
302
+ SymmMemCommunicator,
303
+ )
290
304
 
291
305
  if is_hip():
292
306
  from sglang.srt.distributed.device_communicators.quick_all_reduce import (
@@ -335,6 +349,13 @@ class GroupCoordinator:
335
349
  except Exception as e:
336
350
  logger.warning(f"Failed to initialize QuickAllReduce: {e}")
337
351
 
352
+ self.symm_mem_comm: Optional[SymmMemCommunicator] = None
353
+ if self.use_torch_symm_mem and self.world_size > 1:
354
+ self.symm_mem_comm = SymmMemCommunicator(
355
+ group=self.cpu_group,
356
+ device=self.device,
357
+ )
358
+
338
359
  # Create communicator for other hardware backends
339
360
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
340
361
  HpuCommunicator,
@@ -439,6 +460,7 @@ class GroupCoordinator:
439
460
  # custom allreduce | enabled | enabled |
440
461
  # PyNccl | disabled| enabled |
441
462
  # PyMscclpp | disabled| enabled |
463
+ # TorchSymmMem | disabled| enabled |
442
464
  # torch.distributed | enabled | disabled|
443
465
  #
444
466
  # Note: When custom quick allreduce is enabled, a runtime check
@@ -497,7 +519,7 @@ class GroupCoordinator:
497
519
  torch.distributed.all_reduce(input_, group=self.device_group)
498
520
  return input_
499
521
 
500
- if not supports_custom_op():
522
+ if not _supports_custom_op:
501
523
  self._all_reduce_in_place(input_)
502
524
  return input_
503
525
 
@@ -523,23 +545,29 @@ class GroupCoordinator:
523
545
 
524
546
  outplace_all_reduce_method = None
525
547
  if (
526
- self.qr_comm is not None
527
- and not self.qr_comm.disabled
528
- and self.qr_comm.should_quick_allreduce(input_)
529
- ):
530
- outplace_all_reduce_method = "qr"
531
- elif (
532
548
  self.ca_comm is not None
533
549
  and not self.ca_comm.disabled
534
550
  and self.ca_comm.should_custom_ar(input_)
535
551
  ):
536
552
  outplace_all_reduce_method = "ca"
553
+ elif (
554
+ self.qr_comm is not None
555
+ and not self.qr_comm.disabled
556
+ and self.qr_comm.should_quick_allreduce(input_)
557
+ ):
558
+ outplace_all_reduce_method = "qr"
537
559
  elif (
538
560
  self.pymscclpp_comm is not None
539
561
  and not self.pymscclpp_comm.disabled
540
562
  and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
541
563
  ):
542
564
  outplace_all_reduce_method = "pymscclpp"
565
+ elif (
566
+ self.symm_mem_comm is not None
567
+ and not self.symm_mem_comm.disabled
568
+ and self.symm_mem_comm.should_symm_mem_allreduce(input_)
569
+ ):
570
+ outplace_all_reduce_method = "symm_mem"
543
571
  if outplace_all_reduce_method is not None:
544
572
  return torch.ops.sglang.outplace_all_reduce(
545
573
  input_,
@@ -553,16 +581,20 @@ class GroupCoordinator:
553
581
  def _all_reduce_out_place(
554
582
  self, input_: torch.Tensor, outplace_all_reduce_method: str
555
583
  ) -> torch.Tensor:
556
- qr_comm = self.qr_comm
557
584
  ca_comm = self.ca_comm
585
+ qr_comm = self.qr_comm
558
586
  pymscclpp_comm = self.pymscclpp_comm
587
+ symm_mem_comm = self.symm_mem_comm
559
588
  assert any([qr_comm, ca_comm, pymscclpp_comm])
560
- if outplace_all_reduce_method == "qr":
561
- assert not qr_comm.disabled
562
- out = qr_comm.quick_all_reduce(input_)
563
- elif outplace_all_reduce_method == "ca":
589
+ if outplace_all_reduce_method == "ca":
564
590
  assert not ca_comm.disabled
565
591
  out = ca_comm.custom_all_reduce(input_)
592
+ elif outplace_all_reduce_method == "qr":
593
+ assert not qr_comm.disabled
594
+ out = qr_comm.quick_all_reduce(input_)
595
+ elif outplace_all_reduce_method == "symm_mem":
596
+ assert not symm_mem_comm.disabled
597
+ out = symm_mem_comm.all_reduce(input_)
566
598
  else:
567
599
  assert not pymscclpp_comm.disabled
568
600
  out = pymscclpp_comm.all_reduce(input_)
@@ -637,7 +669,7 @@ class GroupCoordinator:
637
669
  )
638
670
 
639
671
  def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
640
- if _is_npu or not supports_custom_op():
672
+ if _is_npu or not _supports_custom_op:
641
673
  self._all_gather_into_tensor(output, input)
642
674
  else:
643
675
  torch.ops.sglang.reg_all_gather_into_tensor(
@@ -697,15 +729,13 @@ class GroupCoordinator:
697
729
  )
698
730
 
699
731
  # All-gather.
700
- if input_.is_cpu and is_shm_available(
701
- input_.dtype, self.world_size, self.local_size
702
- ):
703
- return torch.ops.sgl_kernel.shm_allgather(input_, dim)
704
-
705
732
  if input_.is_cpu:
706
- torch.distributed.all_gather_into_tensor(
707
- output_tensor, input_, group=self.device_group
708
- )
733
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
734
+ return torch.ops.sgl_kernel.shm_allgather(input_, dim)
735
+ else:
736
+ torch.distributed.all_gather_into_tensor(
737
+ output_tensor, input_, group=self.device_group
738
+ )
709
739
  else:
710
740
  self.all_gather_into_tensor(output_tensor, input_)
711
741
 
@@ -861,45 +891,63 @@ class GroupCoordinator:
861
891
  torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
862
892
  return objs
863
893
 
864
- def send_object(self, obj: Any, dst: int) -> None:
865
- """Send the input object list to the destination rank."""
866
- """NOTE: `dst` is the local rank of the destination rank."""
894
+ def send_object(
895
+ self,
896
+ obj: Any,
897
+ dst: int,
898
+ async_send: bool = False,
899
+ ) -> List[P2PWork]:
900
+ """
901
+ Send the input object list to the destination rank.
902
+ This function uses the CPU group for all communications.
867
903
 
868
- assert dst < self.world_size, f"Invalid dst rank ({dst})"
904
+ TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
905
+ use other functions (e.g., send), or implement a new function (e.g., send_object_device).
906
+
907
+ NOTE: `dst` is the local rank of the destination rank.
908
+ """
869
909
 
910
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
870
911
  assert dst != self.rank_in_group, (
871
912
  "Invalid destination rank. Destination rank is the same "
872
913
  "as the current rank."
873
914
  )
915
+ send_func = torch.distributed.isend if async_send else torch.distributed.send
874
916
 
875
917
  # Serialize object to tensor and get the size as well
876
- object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
877
- device=torch.cuda.current_device()
878
- )
879
-
918
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
880
919
  size_tensor = torch.tensor(
881
- [object_tensor.numel()],
882
- dtype=torch.long,
883
- device="cpu",
920
+ [object_tensor.numel()], dtype=torch.long, device="cpu"
884
921
  )
922
+
885
923
  # Send object size
886
- torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
924
+ p2p_work = []
925
+ size_work = send_func(
926
+ size_tensor,
927
+ self.ranks[dst],
928
+ group=self.cpu_group,
929
+ )
930
+ if async_send:
931
+ p2p_work.append(P2PWork(size_work, size_tensor))
887
932
 
888
- # Send object
889
- torch.distributed.send(
933
+ object_work = send_func(
890
934
  object_tensor,
891
- dst=self.ranks[dst],
892
- group=self.device_group,
935
+ self.ranks[dst],
936
+ group=self.cpu_group,
893
937
  )
938
+ if async_send:
939
+ p2p_work.append(P2PWork(object_work, object_tensor))
894
940
 
895
- return None
941
+ return p2p_work
896
942
 
897
- def recv_object(self, src: int) -> Any:
943
+ def recv_object(
944
+ self,
945
+ src: int,
946
+ ) -> Any:
898
947
  """Receive the input object list from the source rank."""
899
948
  """NOTE: `src` is the local rank of the source rank."""
900
949
 
901
950
  assert src < self.world_size, f"Invalid src rank ({src})"
902
-
903
951
  assert (
904
952
  src != self.rank_in_group
905
953
  ), "Invalid source rank. Source rank is the same as the current rank."
@@ -907,27 +955,25 @@ class GroupCoordinator:
907
955
  size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
908
956
 
909
957
  # Receive object size
910
- rank_size = torch.distributed.recv(
958
+ # We have to use irecv here to make it work for both isend and send.
959
+ work = torch.distributed.irecv(
911
960
  size_tensor, src=self.ranks[src], group=self.cpu_group
912
961
  )
962
+ work.wait()
913
963
 
914
964
  # Tensor to receive serialized objects into.
915
- object_tensor = torch.empty( # type: ignore[call-overload]
965
+ object_tensor: Any = torch.empty( # type: ignore[call-overload]
916
966
  size_tensor.item(), # type: ignore[arg-type]
917
967
  dtype=torch.uint8,
918
- device=torch.cuda.current_device(),
968
+ device="cpu",
919
969
  )
920
970
 
921
- rank_object = torch.distributed.recv(
922
- object_tensor, src=self.ranks[src], group=self.device_group
971
+ work = torch.distributed.irecv(
972
+ object_tensor, src=self.ranks[src], group=self.cpu_group
923
973
  )
974
+ work.wait()
924
975
 
925
- assert (
926
- rank_object == rank_size
927
- ), "Received object sender rank does not match the size sender rank."
928
-
929
- obj = pickle.loads(object_tensor.cpu().numpy())
930
-
976
+ obj = pickle.loads(object_tensor.numpy())
931
977
  return obj
932
978
 
933
979
  def broadcast_tensor_dict(
@@ -1017,12 +1063,13 @@ class GroupCoordinator:
1017
1063
  tensor_dict: Dict[str, Union[torch.Tensor, Any]],
1018
1064
  dst: Optional[int] = None,
1019
1065
  all_gather_group: Optional["GroupCoordinator"] = None,
1020
- ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
1066
+ async_send: bool = False,
1067
+ ) -> Optional[List[P2PWork]]:
1021
1068
  """Send the input tensor dictionary.
1022
1069
  NOTE: `dst` is the local rank of the source rank.
1023
1070
  """
1024
1071
  # Bypass the function if we are using only 1 GPU.
1025
- if not torch.distributed.is_initialized() or self.world_size == 1:
1072
+ if self.world_size == 1:
1026
1073
  return tensor_dict
1027
1074
 
1028
1075
  all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
@@ -1047,7 +1094,10 @@ class GroupCoordinator:
1047
1094
  # 1. Superior D2D transfer bandwidth
1048
1095
  # 2. Ability to overlap send and recv operations
1049
1096
  # Thus the net performance gain justifies this approach.
1050
- self.send_object(metadata_list, dst=dst)
1097
+
1098
+ send_func = torch.distributed.isend if async_send else torch.distributed.send
1099
+ p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
1100
+
1051
1101
  for tensor in tensor_list:
1052
1102
  if tensor.numel() == 0:
1053
1103
  # Skip sending empty tensors.
@@ -1057,15 +1107,11 @@ class GroupCoordinator:
1057
1107
  if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
1058
1108
  tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
1059
1109
 
1060
- if tensor.is_cpu:
1061
- # use metadata_group for CPU tensors
1062
- torch.distributed.send(
1063
- tensor, dst=self.ranks[dst], group=metadata_group
1064
- )
1065
- else:
1066
- # use group for GPU tensors
1067
- torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
1068
- return None
1110
+ comm_group = metadata_group if tensor.is_cpu else group
1111
+ work = send_func(tensor, self.ranks[dst], group=comm_group)
1112
+ if async_send:
1113
+ p2p_works.append(P2PWork(work, tensor))
1114
+ return p2p_works
1069
1115
 
1070
1116
  def recv_tensor_dict(
1071
1117
  self,
@@ -1111,17 +1157,15 @@ class GroupCoordinator:
1111
1157
  orig_shape = tensor.shape
1112
1158
  tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
1113
1159
 
1114
- if tensor.is_cpu:
1115
- # use metadata_group for CPU tensors
1116
- torch.distributed.recv(
1117
- tensor, src=self.ranks[src], group=metadata_group
1118
- )
1119
- else:
1120
- # use group for GPU tensors
1121
- torch.distributed.recv(tensor, src=self.ranks[src], group=group)
1160
+ # We have to use irecv here to make it work for both isend and send.
1161
+ comm_group = metadata_group if tensor.is_cpu else group
1162
+ work = torch.distributed.irecv(
1163
+ tensor, src=self.ranks[src], group=comm_group
1164
+ )
1165
+ work.wait()
1166
+
1122
1167
  if use_all_gather:
1123
- # do the allgather
1124
- tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
1168
+ tensor = all_gather_group.all_gather(tensor, dim=0)
1125
1169
  tensor = tensor.reshape(orig_shape)
1126
1170
 
1127
1171
  tensor_dict[key] = tensor
@@ -1199,6 +1243,7 @@ def init_world_group(
1199
1243
  use_pynccl=False,
1200
1244
  use_pymscclpp=False,
1201
1245
  use_custom_allreduce=False,
1246
+ use_torch_symm_mem=False,
1202
1247
  use_hpu_communicator=False,
1203
1248
  use_xpu_communicator=False,
1204
1249
  use_npu_communicator=False,
@@ -1214,11 +1259,14 @@ def init_model_parallel_group(
1214
1259
  use_message_queue_broadcaster: bool = False,
1215
1260
  group_name: Optional[str] = None,
1216
1261
  use_mscclpp_allreduce: Optional[bool] = None,
1262
+ use_symm_mem_allreduce: Optional[bool] = None,
1217
1263
  ) -> GroupCoordinator:
1218
1264
  if use_custom_allreduce is None:
1219
1265
  use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
1220
1266
  if use_mscclpp_allreduce is None:
1221
1267
  use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
1268
+ if use_symm_mem_allreduce is None:
1269
+ use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
1222
1270
  return GroupCoordinator(
1223
1271
  group_ranks=group_ranks,
1224
1272
  local_rank=local_rank,
@@ -1226,6 +1274,7 @@ def init_model_parallel_group(
1226
1274
  use_pynccl=not _is_npu,
1227
1275
  use_pymscclpp=use_mscclpp_allreduce,
1228
1276
  use_custom_allreduce=use_custom_allreduce,
1277
+ use_torch_symm_mem=use_symm_mem_allreduce,
1229
1278
  use_hpu_communicator=True,
1230
1279
  use_xpu_communicator=True,
1231
1280
  use_npu_communicator=True,
@@ -1311,6 +1360,7 @@ logger = logging.getLogger(__name__)
1311
1360
 
1312
1361
  _ENABLE_CUSTOM_ALL_REDUCE = True
1313
1362
  _ENABLE_MSCCLPP_ALL_REDUCE = False
1363
+ _ENABLE_SYMM_MEM_ALL_REDUCE = False
1314
1364
 
1315
1365
 
1316
1366
  def set_custom_all_reduce(enable: bool):
@@ -1323,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
1323
1373
  _ENABLE_MSCCLPP_ALL_REDUCE = enable
1324
1374
 
1325
1375
 
1376
+ def set_symm_mem_all_reduce(enable: bool):
1377
+ global _ENABLE_SYMM_MEM_ALL_REDUCE
1378
+ _ENABLE_SYMM_MEM_ALL_REDUCE = enable
1379
+
1380
+
1326
1381
  def init_distributed_environment(
1327
1382
  world_size: int = -1,
1328
1383
  rank: int = -1,
@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
47
47
  )
48
48
  from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
49
49
  from sglang.srt.managers.io_struct import (
50
+ DestroyWeightsUpdateGroupReqInput,
50
51
  EmbeddingReqInput,
51
52
  GenerateReqInput,
52
53
  GetWeightsByNameReqInput,
@@ -433,6 +434,19 @@ class Engine(EngineBase):
433
434
  self.tokenizer_manager.init_weights_update_group(obj, None)
434
435
  )
435
436
 
437
+ def destroy_weights_update_group(
438
+ self,
439
+ group_name: str,
440
+ ):
441
+ """Destroy parameter update group."""
442
+ obj = DestroyWeightsUpdateGroupReqInput(
443
+ group_name=group_name,
444
+ )
445
+ loop = asyncio.get_event_loop()
446
+ return loop.run_until_complete(
447
+ self.tokenizer_manager.destroy_weights_update_group(obj, None)
448
+ )
449
+
436
450
  def update_weights_from_distributed(
437
451
  self,
438
452
  names: list[str],
@@ -666,6 +680,13 @@ def _set_envs_and_config(server_args: ServerArgs):
666
680
  if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
667
681
  os.environ["TRTLLM_ENABLE_PDL"] = "1"
668
682
 
683
+ if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
684
+ # Default to warning level, to avoid too many logs
685
+ os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
686
+ if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
687
+ # Need to set log to console, otherwise the log level won't take effect
688
+ os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
689
+
669
690
  # Can also be passed as argument
670
691
  os.environ["SGLANG_RUN_ID"] = (
671
692
  f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
@@ -682,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs):
682
703
  if server_args.attention_backend == "flashinfer":
683
704
  assert_pkg_version(
684
705
  "flashinfer_python",
685
- "0.3.1",
706
+ "0.4.0rc3",
686
707
  "Please uninstall the old version and "
687
708
  "reinstall the latest version by following the instructions "
688
709
  "at https://docs.flashinfer.ai/installation.html.",
@@ -690,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
690
711
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
691
712
  assert_pkg_version(
692
713
  "sgl-kernel",
693
- "0.3.9.post2",
714
+ "0.3.14",
694
715
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
695
716
  )
696
717
 
@@ -791,7 +812,6 @@ def _launch_subprocesses(
791
812
  pp_rank,
792
813
  None,
793
814
  writer,
794
- None,
795
815
  ),
796
816
  )
797
817