sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
1
+ import logging
2
+ import os
3
+ from abc import ABC
4
+ from typing import Callable, Generator, List, Optional
5
+
6
+ import torch
7
+ from torch.func import functional_call
8
+
9
+ from sglang.srt.distributed.naive_distributed import (
10
+ NaiveDistributed,
11
+ get_naive_distributed,
12
+ set_naive_distributed,
13
+ )
14
+ from sglang.srt.host_shared_memory import (
15
+ HostSharedMemoryManager,
16
+ get_host_shared_memory_manager,
17
+ set_host_shared_memory_manager,
18
+ )
19
+ from sglang.srt.layers.parameter import ModelWeightParameter
20
+ from sglang.srt.server_args import ServerArgs
21
+ from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ _SubmoduleAccessor = Callable[[torch.nn.Module], torch.nn.Module]
26
+ _WhitelistParamNamesCreator = Callable[[torch.nn.Module], List[str]]
27
+
28
+
29
+ class BaseOffloader(ABC):
30
+ def wrap_modules(
31
+ self,
32
+ all_modules_generator: Generator[torch.nn.Module, None, None],
33
+ submodule_accessor: Optional[_SubmoduleAccessor] = None,
34
+ whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
35
+ ):
36
+ return list(all_modules_generator)
37
+
38
+ def post_init(self):
39
+ pass
40
+
41
+
42
+ class NoopOffloader(BaseOffloader):
43
+ pass
44
+
45
+
46
+ # For simplicity use singleton, but can surely support multi instance
47
+ _instance: Optional[BaseOffloader] = NoopOffloader()
48
+
49
+
50
+ def get_offloader():
51
+ assert _instance is not None
52
+ return _instance
53
+
54
+
55
+ def set_offloader(instance: BaseOffloader):
56
+ global _instance
57
+ _instance = instance
58
+
59
+
60
+ def create_offloader_from_server_args(server_args: ServerArgs, dp_rank: int):
61
+ if server_args.cpu_offload_gb > 0:
62
+ return OffloaderV1(
63
+ cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3)
64
+ )
65
+ if server_args.offload_group_size > 0:
66
+ assert (
67
+ server_args.cpu_offload_gb == 0
68
+ ), "V2 offload does not support cpu_offload_gb yet"
69
+ return OffloaderV2(
70
+ group_size=server_args.offload_group_size,
71
+ num_in_group=server_args.offload_num_in_group,
72
+ prefetch_step=server_args.offload_prefetch_step,
73
+ mode=server_args.offload_mode,
74
+ dp_rank=dp_rank,
75
+ dp_size=server_args.dp_size,
76
+ )
77
+ return NoopOffloader()
78
+
79
+
80
+ class OffloaderV1(BaseOffloader):
81
+ def __init__(self, cpu_offload_max_bytes: int):
82
+ self._cpu_offload_bytes = 0
83
+ self._cpu_offload_max_bytes = cpu_offload_max_bytes
84
+
85
+ def wrap_modules(
86
+ self,
87
+ all_modules_generator: Generator[torch.nn.Module, None, None],
88
+ submodule_accessor: Optional[_SubmoduleAccessor] = None,
89
+ whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
90
+ ):
91
+ return [self.maybe_offload_to_cpu(module) for module in all_modules_generator]
92
+
93
+ def maybe_offload_to_cpu(self, module: torch.nn.Module) -> torch.nn.Module:
94
+ if (params := next(module.parameters(), None)) is None:
95
+ return module
96
+
97
+ device = params.device
98
+
99
+ if device == torch.device("cpu"):
100
+ return module
101
+
102
+ if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
103
+ return module
104
+
105
+ pin_memory = is_pin_memory_available()
106
+ # offload parameters to CPU
107
+ # use pin_memory if possible, which helps cudagraph capture speed
108
+ offloaded_parameters = False
109
+ for p in module.parameters():
110
+ if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
111
+ # we use per-parameter offloading
112
+ # one module might have some parameters offloaded and some not
113
+ break
114
+
115
+ # `torch.empty_like` does not support `pin_memory` argument
116
+ cpu_data = torch.empty_strided(
117
+ size=p.data.size(),
118
+ stride=p.data.stride(),
119
+ dtype=p.data.dtype,
120
+ layout=p.data.layout,
121
+ device="cpu",
122
+ pin_memory=pin_memory,
123
+ )
124
+ cpu_data.copy_(p.data)
125
+ p.data = cpu_data
126
+ self._cpu_offload_bytes += p.data.numel() * p.data.element_size()
127
+ offloaded_parameters = True
128
+
129
+ if offloaded_parameters:
130
+ original_forward = module.forward
131
+
132
+ def forward(*args, **kwargs):
133
+ module.forward = original_forward
134
+ device_state = {
135
+ # here we blindly call `to(device)`
136
+ # if the parameter is already on the device, it will be a no-op
137
+ k: v.to(device, non_blocking=True)
138
+ for k, v in module.state_dict().items()
139
+ }
140
+ output = functional_call(module, device_state, args=args, kwargs=kwargs)
141
+ module.forward = forward
142
+ return output
143
+
144
+ module.forward = forward
145
+
146
+ return module
147
+
148
+
149
+ class OffloaderV2(BaseOffloader):
150
+ def __init__(
151
+ self,
152
+ group_size: int,
153
+ num_in_group: int,
154
+ prefetch_step: int,
155
+ mode: str,
156
+ dp_rank: int,
157
+ dp_size: int,
158
+ ):
159
+ self.group_size = group_size
160
+ self.num_in_group = num_in_group
161
+ self.prefetch_step = prefetch_step
162
+ self.mode = mode
163
+
164
+ run_id = os.environ["SGLANG_RUN_ID"]
165
+
166
+ # Temporarily init inside Offloader, can move if other modules also need this
167
+ if self.mode in {"sharded_gpu", "shm_cpu"}:
168
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
169
+
170
+ assert (
171
+ get_tensor_model_parallel_world_size() == 1
172
+ ), "not yet support tp_size!=1"
173
+ set_naive_distributed(
174
+ NaiveDistributed(
175
+ rank=dp_rank,
176
+ world_size=dp_size,
177
+ rendezvous=f"/tmp/{run_id}",
178
+ )
179
+ )
180
+ if self.mode in {"shm_cpu"}:
181
+ set_host_shared_memory_manager(
182
+ HostSharedMemoryManager(
183
+ base_name=run_id,
184
+ )
185
+ )
186
+
187
+ self.offloaders = []
188
+
189
+ def wrap_modules(
190
+ self,
191
+ all_modules_generator: Generator[torch.nn.Module, None, None],
192
+ submodule_accessor: Optional[_SubmoduleAccessor] = None,
193
+ whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
194
+ ):
195
+ assert len(self.offloaders) == 0, "should only call wrap_modules once"
196
+
197
+ alt_stream = torch.cuda.Stream()
198
+
199
+ all_modules = []
200
+ offload_submodules = []
201
+ for module_index, module in enumerate(all_modules_generator):
202
+ all_modules.append(module)
203
+ if module_index % self.group_size >= self.group_size - self.num_in_group:
204
+ submodule = submodule_accessor(module)
205
+ whitelist_param_names = whitelist_param_names_creator(submodule)
206
+ logger.info(
207
+ f"[offloader] offload {module_index=} submodule={type(submodule)} params={whitelist_param_names} memory_allocated={torch.cuda.memory_allocated()}"
208
+ )
209
+ offload_submodules.append(submodule)
210
+ self.offloaders.append(
211
+ _ModuleOffloader(
212
+ mode=self.mode,
213
+ module=submodule,
214
+ alt_stream=alt_stream,
215
+ whitelist_param_names=whitelist_param_names,
216
+ )
217
+ )
218
+
219
+ for index, module in enumerate(offload_submodules):
220
+ _hook_module_forward_for_offloader(
221
+ index=index,
222
+ module=module,
223
+ offloaders=self.offloaders,
224
+ prefetch_step=self.prefetch_step,
225
+ )
226
+
227
+ return all_modules
228
+
229
+ def post_init(self):
230
+ for offloader in self.offloaders:
231
+ offloader.post_init()
232
+
233
+ for i in range(self.prefetch_step):
234
+ self.offloaders[i].start_onload()
235
+
236
+
237
+ def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
238
+ def _on_forward_end():
239
+ offloaders[(index + prefetch_step) % len(offloaders)].start_onload()
240
+ offloaders[index].offload()
241
+
242
+ _hook_module_forward_raw(
243
+ module,
244
+ on_forward_end=_on_forward_end,
245
+ get_parameter_and_buffer_dicts=lambda: offloaders[
246
+ index
247
+ ].wait_and_get_device_tensors(),
248
+ )
249
+
250
+
251
+ def _hook_module_forward_raw(module, on_forward_end, get_parameter_and_buffer_dicts):
252
+ original_forward = module.forward
253
+
254
+ def forward(*args, **kwargs):
255
+ module.forward = original_forward
256
+ output = functional_call(
257
+ module, get_parameter_and_buffer_dicts(), args=args, kwargs=kwargs
258
+ )
259
+ on_forward_end()
260
+ module.forward = forward
261
+ return output
262
+
263
+ module.forward = forward
264
+
265
+
266
+ class _ModuleOffloader(ABC):
267
+ def __init__(
268
+ self,
269
+ mode: str,
270
+ module: torch.nn.Module,
271
+ alt_stream: torch.cuda.Stream,
272
+ whitelist_param_names: List[str],
273
+ ):
274
+ self.mode = mode
275
+ self.module = module
276
+ self.device = next(module.parameters()).device
277
+ self.alt_stream = alt_stream
278
+
279
+ assert self.device != torch.device(
280
+ "cpu"
281
+ ), "not handled device=cpu case yet (should skip this tensor)"
282
+
283
+ self._device_tensors = None
284
+ self._load_event = None
285
+
286
+ param_dict = dict(self.module.named_parameters())
287
+ assert all(
288
+ name in param_dict for name in whitelist_param_names
289
+ ), f"{whitelist_param_names=} {list(param_dict.keys())=}"
290
+
291
+ self._param_offloaders = {
292
+ name: _BaseParamOffloader.create(mode, module=module, param_name=name)
293
+ for name in whitelist_param_names
294
+ }
295
+
296
+ def post_init(self):
297
+ for name, param_offloader in self._param_offloaders.items():
298
+ param_offloader.post_init()
299
+
300
+ def start_onload(self):
301
+ self.alt_stream.wait_stream(torch.cuda.current_stream())
302
+ with torch.cuda.stream(self.alt_stream):
303
+ self._device_tensors = self._create_device_tensors()
304
+ self._load_event = torch.cuda.Event()
305
+ self._load_event.record()
306
+
307
+ def offload(self):
308
+ self._device_tensors = None
309
+ self._load_event = None
310
+
311
+ def wait_and_get_device_tensors(self):
312
+ assert self._device_tensors is not None
313
+ self._load_event.wait()
314
+ return self._device_tensors
315
+
316
+ def _create_device_tensors(self):
317
+ return {k: v.create_device_tensor() for k, v in self._param_offloaders.items()}
318
+
319
+
320
+ class _BaseParamOffloader(ABC):
321
+ @staticmethod
322
+ def create(mode: str, **kwargs) -> "_BaseParamOffloader":
323
+ return {
324
+ "cpu": _CpuParamOffloader,
325
+ "shm_cpu": _ShmCpuParamOffloader,
326
+ "sharded_gpu": _ShardedGpuParamOffloader,
327
+ }[mode](**kwargs)
328
+
329
+ def __init__(self, module, param_name):
330
+ self._module = module
331
+ self._param_name = param_name
332
+
333
+ @property
334
+ def _param(self):
335
+ return getattr(self._module, self._param_name)
336
+
337
+ def post_init(self):
338
+ pass
339
+
340
+ def create_device_tensor(self):
341
+ raise NotImplementedError
342
+
343
+
344
+ class _CpuParamOffloader(_BaseParamOffloader):
345
+ def __init__(self, module, param_name):
346
+ super().__init__(module, param_name)
347
+ _move_param_to_cpu(self._param, pin_memory=True)
348
+
349
+ def create_device_tensor(self):
350
+ return self._param.to("cuda", non_blocking=True)
351
+
352
+
353
+ class _ShmCpuParamOffloader(_BaseParamOffloader):
354
+ def __init__(self, module, param_name):
355
+ super().__init__(module, param_name)
356
+ self._rank = get_naive_distributed().get_rank()
357
+ self._world_size = get_naive_distributed().get_world_size()
358
+
359
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
360
+
361
+ assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
362
+ assert (
363
+ self._param.data.is_contiguous()
364
+ ), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
365
+
366
+ self.shm_cpu_data = get_host_shared_memory_manager().malloc(
367
+ shape=self._param.shape, dtype=self._param.dtype
368
+ )
369
+
370
+ if self._rank == 0:
371
+ self.shm_cpu_data.copy_(self._param.data.to("cpu"))
372
+ self._param.data = self.shm_cpu_data
373
+ else:
374
+ _move_param_to_meta(self._module, self._param_name)
375
+ get_naive_distributed().barrier()
376
+
377
+ def post_init(self):
378
+ if self._rank == 0:
379
+ assert (
380
+ self.shm_cpu_data.data_ptr() == self._param.data.data_ptr()
381
+ ), f"{self.shm_cpu_data.data_ptr()=} {self._param.data.data_ptr()=} {self.shm_cpu_data=} {self._param.data=}"
382
+
383
+ _move_param_to_meta(self._module, self._param_name)
384
+
385
+ def create_device_tensor(self):
386
+ return self.shm_cpu_data.to("cuda", non_blocking=True)
387
+
388
+
389
+ def _move_param_to_cpu(param, pin_memory: bool):
390
+ cpu_data = _empty_strided_like(
391
+ param.data,
392
+ device="cpu",
393
+ pin_memory=pin_memory,
394
+ )
395
+ cpu_data.copy_(param.data)
396
+ param.data = cpu_data
397
+
398
+
399
+ def _move_param_to_meta(module, param_name):
400
+ old_param = getattr(module, param_name)
401
+ old_param_type = type(old_param)
402
+
403
+ new_data = old_param.data.to("meta")
404
+
405
+ if old_param_type == ModelWeightParameter:
406
+ # manually checked how `w13_weight` and `w2_weight` are constructed
407
+ new_param = ModelWeightParameter(
408
+ data=new_data,
409
+ **{
410
+ k: getattr(old_param, k)
411
+ for k in ["input_dim", "output_dim", "weight_loader"]
412
+ },
413
+ )
414
+ elif old_param_type == torch.nn.Parameter:
415
+ new_param = torch.nn.Parameter(
416
+ data=new_data,
417
+ requires_grad=False,
418
+ )
419
+ else:
420
+ raise ValueError(f"Unknown {old_param_type=} {old_param=}")
421
+
422
+ setattr(module, param_name, new_param)
423
+
424
+
425
+ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
426
+ return torch.empty_strided(
427
+ size=x.size(),
428
+ stride=x.stride(),
429
+ dtype=x.dtype,
430
+ layout=x.layout,
431
+ device=device,
432
+ pin_memory=pin_memory,
433
+ )
sglang/srt/operations.py CHANGED
@@ -1,10 +1,17 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  from contextlib import contextmanager
3
5
  from dataclasses import dataclass
4
- from typing import Any, Callable, Dict, Generator, List, Sequence, Union
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
5
7
 
6
8
  import torch
7
9
 
10
+ from sglang.srt.layers.dp_attention import set_dp_buffer_len
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
14
+
8
15
  _ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
9
16
 
10
17
  if _ENABLE_PROFILE:
@@ -66,18 +73,31 @@ Stage = List[ExecutionOperation]
66
73
 
67
74
 
68
75
  class _StageExecutor:
69
- def __init__(self, debug_name: str, stages: List[Stage], inputs):
76
+ def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
70
77
  self._debug_name = debug_name
71
78
  self._stages = stages
72
79
  self._index = 0
73
80
  self._stage_state = _StateDict()
74
81
  self._stage_output = inputs
75
82
 
83
+ # handling DP attention
84
+ forward_batch: ForwardBatch = inputs["forward_batch"]
85
+ self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
86
+ self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
87
+ self._global_num_tokens = forward_batch.global_num_tokens_cpu
88
+
76
89
  def next(self):
77
90
  assert not self.done
78
91
 
79
92
  stage = self._stages[self._index]
80
93
 
94
+ if self._global_dp_buffer_len is not None:
95
+ set_dp_buffer_len(
96
+ self._global_dp_buffer_len,
97
+ self._local_dp_buffer_len,
98
+ self._global_num_tokens,
99
+ )
100
+
81
101
  with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
82
102
  for op in stage:
83
103
  with _annotate_region(debug_name=op.debug_name):
@@ -513,12 +513,13 @@ class ReasoningParser:
513
513
 
514
514
  DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
515
515
  "deepseek-r1": DeepSeekR1Detector,
516
- "qwen3": Qwen3Detector,
517
- "qwen3-thinking": Qwen3Detector,
516
+ "deepseek-v3": Qwen3Detector,
518
517
  "glm45": Qwen3Detector,
518
+ "gpt-oss": GptOssDetector,
519
519
  "kimi": KimiDetector,
520
+ "qwen3": Qwen3Detector,
521
+ "qwen3-thinking": Qwen3Detector,
520
522
  "step3": DeepSeekR1Detector,
521
- "gpt-oss": GptOssDetector,
522
523
  }
523
524
 
524
525
  def __init__(
@@ -68,6 +68,8 @@ class SamplingBatchInfo:
68
68
 
69
69
  @classmethod
70
70
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
71
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
72
+
71
73
  reqs = batch.reqs
72
74
  device = batch.device
73
75
  temperatures = (
@@ -97,10 +99,11 @@ class SamplingBatchInfo:
97
99
  logit_bias[i, int(key)] = value
98
100
 
99
101
  # Check if any request has custom logit processor
100
- has_custom_logit_processor = (
101
- batch.enable_custom_logit_processor # check the flag first.
102
- and any(r.custom_logit_processor for r in reqs) # then check the requests.
103
- )
102
+ has_custom_logit_processor = global_server_args_dict[
103
+ "enable_custom_logit_processor"
104
+ ] and any( # check the flag first.
105
+ r.custom_logit_processor for r in reqs
106
+ ) # then check the requests.
104
107
 
105
108
  if has_custom_logit_processor:
106
109
  # Merge the same type of custom logit processors together