sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
1
+ import logging
2
+ import os
3
+ from abc import ABC
4
+ from typing import Callable, Generator, List, Optional
5
+
6
+ import torch
7
+ from torch.func import functional_call
8
+
9
+ from sglang.srt.distributed.naive_distributed import (
10
+ NaiveDistributed,
11
+ get_naive_distributed,
12
+ set_naive_distributed,
13
+ )
14
+ from sglang.srt.host_shared_memory import (
15
+ HostSharedMemoryManager,
16
+ get_host_shared_memory_manager,
17
+ set_host_shared_memory_manager,
18
+ )
19
+ from sglang.srt.layers.parameter import ModelWeightParameter
20
+ from sglang.srt.server_args import ServerArgs
21
+ from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ _SubmoduleAccessor = Callable[[torch.nn.Module], torch.nn.Module]
26
+ _WhitelistParamNamesCreator = Callable[[torch.nn.Module], List[str]]
27
+
28
+
29
+ class BaseOffloader(ABC):
30
+ def wrap_modules(
31
+ self,
32
+ all_modules_generator: Generator[torch.nn.Module, None, None],
33
+ submodule_accessor: Optional[_SubmoduleAccessor] = None,
34
+ whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
35
+ ):
36
+ return list(all_modules_generator)
37
+
38
+ def post_init(self):
39
+ pass
40
+
41
+
42
+ class NoopOffloader(BaseOffloader):
43
+ pass
44
+
45
+
46
+ # For simplicity use singleton, but can surely support multi instance
47
+ _instance: Optional[BaseOffloader] = NoopOffloader()
48
+
49
+
50
+ def get_offloader():
51
+ assert _instance is not None
52
+ return _instance
53
+
54
+
55
+ def set_offloader(instance: BaseOffloader):
56
+ global _instance
57
+ _instance = instance
58
+
59
+
60
+ def create_offloader_from_server_args(server_args: ServerArgs, dp_rank: int):
61
+ if server_args.cpu_offload_gb > 0:
62
+ return OffloaderV1(
63
+ cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3)
64
+ )
65
+ if server_args.offload_group_size > 0:
66
+ assert (
67
+ server_args.cpu_offload_gb == 0
68
+ ), "V2 offload does not support cpu_offload_gb yet"
69
+ return OffloaderV2(
70
+ group_size=server_args.offload_group_size,
71
+ num_in_group=server_args.offload_num_in_group,
72
+ prefetch_step=server_args.offload_prefetch_step,
73
+ mode=server_args.offload_mode,
74
+ dp_rank=dp_rank,
75
+ dp_size=server_args.dp_size,
76
+ )
77
+ return NoopOffloader()
78
+
79
+
80
+ class OffloaderV1(BaseOffloader):
81
+ def __init__(self, cpu_offload_max_bytes: int):
82
+ self._cpu_offload_bytes = 0
83
+ self._cpu_offload_max_bytes = cpu_offload_max_bytes
84
+
85
+ def wrap_modules(
86
+ self,
87
+ all_modules_generator: Generator[torch.nn.Module, None, None],
88
+ submodule_accessor: Optional[_SubmoduleAccessor] = None,
89
+ whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
90
+ ):
91
+ return [self.maybe_offload_to_cpu(module) for module in all_modules_generator]
92
+
93
+ def maybe_offload_to_cpu(self, module: torch.nn.Module) -> torch.nn.Module:
94
+ if (params := next(module.parameters(), None)) is None:
95
+ return module
96
+
97
+ device = params.device
98
+
99
+ if device == torch.device("cpu"):
100
+ return module
101
+
102
+ if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
103
+ return module
104
+
105
+ pin_memory = is_pin_memory_available()
106
+ # offload parameters to CPU
107
+ # use pin_memory if possible, which helps cudagraph capture speed
108
+ offloaded_parameters = False
109
+ for p in module.parameters():
110
+ if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
111
+ # we use per-parameter offloading
112
+ # one module might have some parameters offloaded and some not
113
+ break
114
+
115
+ # `torch.empty_like` does not support `pin_memory` argument
116
+ cpu_data = torch.empty_strided(
117
+ size=p.data.size(),
118
+ stride=p.data.stride(),
119
+ dtype=p.data.dtype,
120
+ layout=p.data.layout,
121
+ device="cpu",
122
+ pin_memory=pin_memory,
123
+ )
124
+ cpu_data.copy_(p.data)
125
+ p.data = cpu_data
126
+ self._cpu_offload_bytes += p.data.numel() * p.data.element_size()
127
+ offloaded_parameters = True
128
+
129
+ if offloaded_parameters:
130
+ original_forward = module.forward
131
+
132
+ def forward(*args, **kwargs):
133
+ module.forward = original_forward
134
+ device_state = {
135
+ # here we blindly call `to(device)`
136
+ # if the parameter is already on the device, it will be a no-op
137
+ k: v.to(device, non_blocking=True)
138
+ for k, v in module.state_dict().items()
139
+ }
140
+ output = functional_call(module, device_state, args=args, kwargs=kwargs)
141
+ module.forward = forward
142
+ return output
143
+
144
+ module.forward = forward
145
+
146
+ return module
147
+
148
+
149
+ class OffloaderV2(BaseOffloader):
150
+ def __init__(
151
+ self,
152
+ group_size: int,
153
+ num_in_group: int,
154
+ prefetch_step: int,
155
+ mode: str,
156
+ dp_rank: int,
157
+ dp_size: int,
158
+ ):
159
+ self.group_size = group_size
160
+ self.num_in_group = num_in_group
161
+ self.prefetch_step = prefetch_step
162
+ self.mode = mode
163
+
164
+ run_id = os.environ["SGLANG_RUN_ID"]
165
+
166
+ # Temporarily init inside Offloader, can move if other modules also need this
167
+ if self.mode in {"sharded_gpu", "shm_cpu"}:
168
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
169
+
170
+ assert (
171
+ get_tensor_model_parallel_world_size() == 1
172
+ ), "not yet support tp_size!=1"
173
+ set_naive_distributed(
174
+ NaiveDistributed(
175
+ rank=dp_rank,
176
+ world_size=dp_size,
177
+ rendezvous=f"/tmp/{run_id}",
178
+ )
179
+ )
180
+ if self.mode in {"shm_cpu"}:
181
+ set_host_shared_memory_manager(
182
+ HostSharedMemoryManager(
183
+ base_name=run_id,
184
+ )
185
+ )
186
+
187
+ self.offloaders = []
188
+
189
+ def wrap_modules(
190
+ self,
191
+ all_modules_generator: Generator[torch.nn.Module, None, None],
192
+ submodule_accessor: Optional[_SubmoduleAccessor] = None,
193
+ whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
194
+ ):
195
+ assert len(self.offloaders) == 0, "should only call wrap_modules once"
196
+
197
+ alt_stream = torch.cuda.Stream()
198
+
199
+ all_modules = []
200
+ offload_submodules = []
201
+ for module_index, module in enumerate(all_modules_generator):
202
+ all_modules.append(module)
203
+ if module_index % self.group_size >= self.group_size - self.num_in_group:
204
+ submodule = submodule_accessor(module)
205
+ whitelist_param_names = whitelist_param_names_creator(submodule)
206
+ logger.info(
207
+ f"[offloader] offload {module_index=} submodule={type(submodule)} params={whitelist_param_names} memory_allocated={torch.cuda.memory_allocated()}"
208
+ )
209
+ offload_submodules.append(submodule)
210
+ self.offloaders.append(
211
+ _ModuleOffloader(
212
+ mode=self.mode,
213
+ module=submodule,
214
+ alt_stream=alt_stream,
215
+ whitelist_param_names=whitelist_param_names,
216
+ )
217
+ )
218
+
219
+ for index, module in enumerate(offload_submodules):
220
+ _hook_module_forward_for_offloader(
221
+ index=index,
222
+ module=module,
223
+ offloaders=self.offloaders,
224
+ prefetch_step=self.prefetch_step,
225
+ )
226
+
227
+ return all_modules
228
+
229
+ def post_init(self):
230
+ for offloader in self.offloaders:
231
+ offloader.post_init()
232
+
233
+ for i in range(self.prefetch_step):
234
+ self.offloaders[i].start_onload()
235
+
236
+
237
+ def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
238
+ def _on_forward_end():
239
+ offloaders[(index + prefetch_step) % len(offloaders)].start_onload()
240
+ offloaders[index].offload()
241
+
242
+ _hook_module_forward_raw(
243
+ module,
244
+ on_forward_end=_on_forward_end,
245
+ get_parameter_and_buffer_dicts=lambda: offloaders[
246
+ index
247
+ ].wait_and_get_device_tensors(),
248
+ )
249
+
250
+
251
+ def _hook_module_forward_raw(module, on_forward_end, get_parameter_and_buffer_dicts):
252
+ original_forward = module.forward
253
+
254
+ def forward(*args, **kwargs):
255
+ module.forward = original_forward
256
+ output = functional_call(
257
+ module, get_parameter_and_buffer_dicts(), args=args, kwargs=kwargs
258
+ )
259
+ on_forward_end()
260
+ module.forward = forward
261
+ return output
262
+
263
+ module.forward = forward
264
+
265
+
266
+ class _ModuleOffloader(ABC):
267
+ def __init__(
268
+ self,
269
+ mode: str,
270
+ module: torch.nn.Module,
271
+ alt_stream: torch.cuda.Stream,
272
+ whitelist_param_names: List[str],
273
+ ):
274
+ self.mode = mode
275
+ self.module = module
276
+ self.device = next(module.parameters()).device
277
+ self.alt_stream = alt_stream
278
+
279
+ assert self.device != torch.device(
280
+ "cpu"
281
+ ), "not handled device=cpu case yet (should skip this tensor)"
282
+
283
+ self._device_tensors = None
284
+ self._load_event = None
285
+
286
+ param_dict = dict(self.module.named_parameters())
287
+ assert all(
288
+ name in param_dict for name in whitelist_param_names
289
+ ), f"{whitelist_param_names=} {list(param_dict.keys())=}"
290
+
291
+ self._param_offloaders = {
292
+ name: _BaseParamOffloader.create(mode, module=module, param_name=name)
293
+ for name in whitelist_param_names
294
+ }
295
+
296
+ def post_init(self):
297
+ for name, param_offloader in self._param_offloaders.items():
298
+ param_offloader.post_init()
299
+
300
+ def start_onload(self):
301
+ self.alt_stream.wait_stream(torch.cuda.current_stream())
302
+ with torch.cuda.stream(self.alt_stream):
303
+ self._device_tensors = self._create_device_tensors()
304
+ self._load_event = torch.cuda.Event()
305
+ self._load_event.record()
306
+
307
+ def offload(self):
308
+ self._device_tensors = None
309
+ self._load_event = None
310
+
311
+ def wait_and_get_device_tensors(self):
312
+ assert self._device_tensors is not None
313
+ self._load_event.wait()
314
+ return self._device_tensors
315
+
316
+ def _create_device_tensors(self):
317
+ return {k: v.create_device_tensor() for k, v in self._param_offloaders.items()}
318
+
319
+
320
+ class _BaseParamOffloader(ABC):
321
+ @staticmethod
322
+ def create(mode: str, **kwargs) -> "_BaseParamOffloader":
323
+ return {
324
+ "cpu": _CpuParamOffloader,
325
+ "shm_cpu": _ShmCpuParamOffloader,
326
+ "sharded_gpu": _ShardedGpuParamOffloader,
327
+ }[mode](**kwargs)
328
+
329
+ def __init__(self, module, param_name):
330
+ self._module = module
331
+ self._param_name = param_name
332
+
333
+ @property
334
+ def _param(self):
335
+ return getattr(self._module, self._param_name)
336
+
337
+ def post_init(self):
338
+ pass
339
+
340
+ def create_device_tensor(self):
341
+ raise NotImplementedError
342
+
343
+
344
+ class _CpuParamOffloader(_BaseParamOffloader):
345
+ def __init__(self, module, param_name):
346
+ super().__init__(module, param_name)
347
+ _move_param_to_cpu(self._param, pin_memory=True)
348
+
349
+ def create_device_tensor(self):
350
+ return self._param.to("cuda", non_blocking=True)
351
+
352
+
353
+ class _ShmCpuParamOffloader(_BaseParamOffloader):
354
+ def __init__(self, module, param_name):
355
+ super().__init__(module, param_name)
356
+ self._rank = get_naive_distributed().get_rank()
357
+ self._world_size = get_naive_distributed().get_world_size()
358
+
359
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
360
+
361
+ assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
362
+ assert (
363
+ self._param.data.is_contiguous()
364
+ ), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
365
+
366
+ self.shm_cpu_data = get_host_shared_memory_manager().malloc(
367
+ shape=self._param.shape, dtype=self._param.dtype
368
+ )
369
+
370
+ if self._rank == 0:
371
+ self.shm_cpu_data.copy_(self._param.data.to("cpu"))
372
+ self._param.data = self.shm_cpu_data
373
+ else:
374
+ _move_param_to_meta(self._module, self._param_name)
375
+ get_naive_distributed().barrier()
376
+
377
+ def post_init(self):
378
+ if self._rank == 0:
379
+ assert (
380
+ self.shm_cpu_data.data_ptr() == self._param.data.data_ptr()
381
+ ), f"{self.shm_cpu_data.data_ptr()=} {self._param.data.data_ptr()=} {self.shm_cpu_data=} {self._param.data=}"
382
+
383
+ _move_param_to_meta(self._module, self._param_name)
384
+
385
+ def create_device_tensor(self):
386
+ return self.shm_cpu_data.to("cuda", non_blocking=True)
387
+
388
+
389
+ def _move_param_to_cpu(param, pin_memory: bool):
390
+ cpu_data = _empty_strided_like(
391
+ param.data,
392
+ device="cpu",
393
+ pin_memory=pin_memory,
394
+ )
395
+ cpu_data.copy_(param.data)
396
+ param.data = cpu_data
397
+
398
+
399
+ def _move_param_to_meta(module, param_name):
400
+ old_param = getattr(module, param_name)
401
+ old_param_type = type(old_param)
402
+
403
+ new_data = old_param.data.to("meta")
404
+
405
+ if old_param_type == ModelWeightParameter:
406
+ # manually checked how `w13_weight` and `w2_weight` are constructed
407
+ new_param = ModelWeightParameter(
408
+ data=new_data,
409
+ **{
410
+ k: getattr(old_param, k)
411
+ for k in ["input_dim", "output_dim", "weight_loader"]
412
+ },
413
+ )
414
+ elif old_param_type == torch.nn.Parameter:
415
+ new_param = torch.nn.Parameter(
416
+ data=new_data,
417
+ requires_grad=False,
418
+ )
419
+ else:
420
+ raise ValueError(f"Unknown {old_param_type=} {old_param=}")
421
+
422
+ setattr(module, param_name, new_param)
423
+
424
+
425
+ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
426
+ return torch.empty_strided(
427
+ size=x.size(),
428
+ stride=x.stride(),
429
+ dtype=x.dtype,
430
+ layout=x.layout,
431
+ device=device,
432
+ pin_memory=pin_memory,
433
+ )
sglang/srt/operations.py CHANGED
@@ -84,6 +84,7 @@ class _StageExecutor:
84
84
  forward_batch: ForwardBatch = inputs["forward_batch"]
85
85
  self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
86
86
  self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
87
+ self._global_num_tokens = forward_batch.global_num_tokens_cpu
87
88
 
88
89
  def next(self):
89
90
  assert not self.done
@@ -91,7 +92,11 @@ class _StageExecutor:
91
92
  stage = self._stages[self._index]
92
93
 
93
94
  if self._global_dp_buffer_len is not None:
94
- set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
95
+ set_dp_buffer_len(
96
+ self._global_dp_buffer_len,
97
+ self._local_dp_buffer_len,
98
+ self._global_num_tokens,
99
+ )
95
100
 
96
101
  with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
97
102
  for op in stage:
@@ -513,12 +513,13 @@ class ReasoningParser:
513
513
 
514
514
  DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
515
515
  "deepseek-r1": DeepSeekR1Detector,
516
- "qwen3": Qwen3Detector,
517
- "qwen3-thinking": Qwen3Detector,
516
+ "deepseek-v3": Qwen3Detector,
518
517
  "glm45": Qwen3Detector,
518
+ "gpt-oss": GptOssDetector,
519
519
  "kimi": KimiDetector,
520
+ "qwen3": Qwen3Detector,
521
+ "qwen3-thinking": Qwen3Detector,
520
522
  "step3": DeepSeekR1Detector,
521
- "gpt-oss": GptOssDetector,
522
523
  }
523
524
 
524
525
  def __init__(