sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
sglang/srt/offloader.py
ADDED
@@ -0,0 +1,433 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from abc import ABC
|
4
|
+
from typing import Callable, Generator, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.func import functional_call
|
8
|
+
|
9
|
+
from sglang.srt.distributed.naive_distributed import (
|
10
|
+
NaiveDistributed,
|
11
|
+
get_naive_distributed,
|
12
|
+
set_naive_distributed,
|
13
|
+
)
|
14
|
+
from sglang.srt.host_shared_memory import (
|
15
|
+
HostSharedMemoryManager,
|
16
|
+
get_host_shared_memory_manager,
|
17
|
+
set_host_shared_memory_manager,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.parameter import ModelWeightParameter
|
20
|
+
from sglang.srt.server_args import ServerArgs
|
21
|
+
from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
_SubmoduleAccessor = Callable[[torch.nn.Module], torch.nn.Module]
|
26
|
+
_WhitelistParamNamesCreator = Callable[[torch.nn.Module], List[str]]
|
27
|
+
|
28
|
+
|
29
|
+
class BaseOffloader(ABC):
|
30
|
+
def wrap_modules(
|
31
|
+
self,
|
32
|
+
all_modules_generator: Generator[torch.nn.Module, None, None],
|
33
|
+
submodule_accessor: Optional[_SubmoduleAccessor] = None,
|
34
|
+
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
|
35
|
+
):
|
36
|
+
return list(all_modules_generator)
|
37
|
+
|
38
|
+
def post_init(self):
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
class NoopOffloader(BaseOffloader):
|
43
|
+
pass
|
44
|
+
|
45
|
+
|
46
|
+
# For simplicity use singleton, but can surely support multi instance
|
47
|
+
_instance: Optional[BaseOffloader] = NoopOffloader()
|
48
|
+
|
49
|
+
|
50
|
+
def get_offloader():
|
51
|
+
assert _instance is not None
|
52
|
+
return _instance
|
53
|
+
|
54
|
+
|
55
|
+
def set_offloader(instance: BaseOffloader):
|
56
|
+
global _instance
|
57
|
+
_instance = instance
|
58
|
+
|
59
|
+
|
60
|
+
def create_offloader_from_server_args(server_args: ServerArgs, dp_rank: int):
|
61
|
+
if server_args.cpu_offload_gb > 0:
|
62
|
+
return OffloaderV1(
|
63
|
+
cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3)
|
64
|
+
)
|
65
|
+
if server_args.offload_group_size > 0:
|
66
|
+
assert (
|
67
|
+
server_args.cpu_offload_gb == 0
|
68
|
+
), "V2 offload does not support cpu_offload_gb yet"
|
69
|
+
return OffloaderV2(
|
70
|
+
group_size=server_args.offload_group_size,
|
71
|
+
num_in_group=server_args.offload_num_in_group,
|
72
|
+
prefetch_step=server_args.offload_prefetch_step,
|
73
|
+
mode=server_args.offload_mode,
|
74
|
+
dp_rank=dp_rank,
|
75
|
+
dp_size=server_args.dp_size,
|
76
|
+
)
|
77
|
+
return NoopOffloader()
|
78
|
+
|
79
|
+
|
80
|
+
class OffloaderV1(BaseOffloader):
|
81
|
+
def __init__(self, cpu_offload_max_bytes: int):
|
82
|
+
self._cpu_offload_bytes = 0
|
83
|
+
self._cpu_offload_max_bytes = cpu_offload_max_bytes
|
84
|
+
|
85
|
+
def wrap_modules(
|
86
|
+
self,
|
87
|
+
all_modules_generator: Generator[torch.nn.Module, None, None],
|
88
|
+
submodule_accessor: Optional[_SubmoduleAccessor] = None,
|
89
|
+
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
|
90
|
+
):
|
91
|
+
return [self.maybe_offload_to_cpu(module) for module in all_modules_generator]
|
92
|
+
|
93
|
+
def maybe_offload_to_cpu(self, module: torch.nn.Module) -> torch.nn.Module:
|
94
|
+
if (params := next(module.parameters(), None)) is None:
|
95
|
+
return module
|
96
|
+
|
97
|
+
device = params.device
|
98
|
+
|
99
|
+
if device == torch.device("cpu"):
|
100
|
+
return module
|
101
|
+
|
102
|
+
if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
|
103
|
+
return module
|
104
|
+
|
105
|
+
pin_memory = is_pin_memory_available()
|
106
|
+
# offload parameters to CPU
|
107
|
+
# use pin_memory if possible, which helps cudagraph capture speed
|
108
|
+
offloaded_parameters = False
|
109
|
+
for p in module.parameters():
|
110
|
+
if self._cpu_offload_bytes >= self._cpu_offload_max_bytes:
|
111
|
+
# we use per-parameter offloading
|
112
|
+
# one module might have some parameters offloaded and some not
|
113
|
+
break
|
114
|
+
|
115
|
+
# `torch.empty_like` does not support `pin_memory` argument
|
116
|
+
cpu_data = torch.empty_strided(
|
117
|
+
size=p.data.size(),
|
118
|
+
stride=p.data.stride(),
|
119
|
+
dtype=p.data.dtype,
|
120
|
+
layout=p.data.layout,
|
121
|
+
device="cpu",
|
122
|
+
pin_memory=pin_memory,
|
123
|
+
)
|
124
|
+
cpu_data.copy_(p.data)
|
125
|
+
p.data = cpu_data
|
126
|
+
self._cpu_offload_bytes += p.data.numel() * p.data.element_size()
|
127
|
+
offloaded_parameters = True
|
128
|
+
|
129
|
+
if offloaded_parameters:
|
130
|
+
original_forward = module.forward
|
131
|
+
|
132
|
+
def forward(*args, **kwargs):
|
133
|
+
module.forward = original_forward
|
134
|
+
device_state = {
|
135
|
+
# here we blindly call `to(device)`
|
136
|
+
# if the parameter is already on the device, it will be a no-op
|
137
|
+
k: v.to(device, non_blocking=True)
|
138
|
+
for k, v in module.state_dict().items()
|
139
|
+
}
|
140
|
+
output = functional_call(module, device_state, args=args, kwargs=kwargs)
|
141
|
+
module.forward = forward
|
142
|
+
return output
|
143
|
+
|
144
|
+
module.forward = forward
|
145
|
+
|
146
|
+
return module
|
147
|
+
|
148
|
+
|
149
|
+
class OffloaderV2(BaseOffloader):
|
150
|
+
def __init__(
|
151
|
+
self,
|
152
|
+
group_size: int,
|
153
|
+
num_in_group: int,
|
154
|
+
prefetch_step: int,
|
155
|
+
mode: str,
|
156
|
+
dp_rank: int,
|
157
|
+
dp_size: int,
|
158
|
+
):
|
159
|
+
self.group_size = group_size
|
160
|
+
self.num_in_group = num_in_group
|
161
|
+
self.prefetch_step = prefetch_step
|
162
|
+
self.mode = mode
|
163
|
+
|
164
|
+
run_id = os.environ["SGLANG_RUN_ID"]
|
165
|
+
|
166
|
+
# Temporarily init inside Offloader, can move if other modules also need this
|
167
|
+
if self.mode in {"sharded_gpu", "shm_cpu"}:
|
168
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
169
|
+
|
170
|
+
assert (
|
171
|
+
get_tensor_model_parallel_world_size() == 1
|
172
|
+
), "not yet support tp_size!=1"
|
173
|
+
set_naive_distributed(
|
174
|
+
NaiveDistributed(
|
175
|
+
rank=dp_rank,
|
176
|
+
world_size=dp_size,
|
177
|
+
rendezvous=f"/tmp/{run_id}",
|
178
|
+
)
|
179
|
+
)
|
180
|
+
if self.mode in {"shm_cpu"}:
|
181
|
+
set_host_shared_memory_manager(
|
182
|
+
HostSharedMemoryManager(
|
183
|
+
base_name=run_id,
|
184
|
+
)
|
185
|
+
)
|
186
|
+
|
187
|
+
self.offloaders = []
|
188
|
+
|
189
|
+
def wrap_modules(
|
190
|
+
self,
|
191
|
+
all_modules_generator: Generator[torch.nn.Module, None, None],
|
192
|
+
submodule_accessor: Optional[_SubmoduleAccessor] = None,
|
193
|
+
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
|
194
|
+
):
|
195
|
+
assert len(self.offloaders) == 0, "should only call wrap_modules once"
|
196
|
+
|
197
|
+
alt_stream = torch.cuda.Stream()
|
198
|
+
|
199
|
+
all_modules = []
|
200
|
+
offload_submodules = []
|
201
|
+
for module_index, module in enumerate(all_modules_generator):
|
202
|
+
all_modules.append(module)
|
203
|
+
if module_index % self.group_size >= self.group_size - self.num_in_group:
|
204
|
+
submodule = submodule_accessor(module)
|
205
|
+
whitelist_param_names = whitelist_param_names_creator(submodule)
|
206
|
+
logger.info(
|
207
|
+
f"[offloader] offload {module_index=} submodule={type(submodule)} params={whitelist_param_names} memory_allocated={torch.cuda.memory_allocated()}"
|
208
|
+
)
|
209
|
+
offload_submodules.append(submodule)
|
210
|
+
self.offloaders.append(
|
211
|
+
_ModuleOffloader(
|
212
|
+
mode=self.mode,
|
213
|
+
module=submodule,
|
214
|
+
alt_stream=alt_stream,
|
215
|
+
whitelist_param_names=whitelist_param_names,
|
216
|
+
)
|
217
|
+
)
|
218
|
+
|
219
|
+
for index, module in enumerate(offload_submodules):
|
220
|
+
_hook_module_forward_for_offloader(
|
221
|
+
index=index,
|
222
|
+
module=module,
|
223
|
+
offloaders=self.offloaders,
|
224
|
+
prefetch_step=self.prefetch_step,
|
225
|
+
)
|
226
|
+
|
227
|
+
return all_modules
|
228
|
+
|
229
|
+
def post_init(self):
|
230
|
+
for offloader in self.offloaders:
|
231
|
+
offloader.post_init()
|
232
|
+
|
233
|
+
for i in range(self.prefetch_step):
|
234
|
+
self.offloaders[i].start_onload()
|
235
|
+
|
236
|
+
|
237
|
+
def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
|
238
|
+
def _on_forward_end():
|
239
|
+
offloaders[(index + prefetch_step) % len(offloaders)].start_onload()
|
240
|
+
offloaders[index].offload()
|
241
|
+
|
242
|
+
_hook_module_forward_raw(
|
243
|
+
module,
|
244
|
+
on_forward_end=_on_forward_end,
|
245
|
+
get_parameter_and_buffer_dicts=lambda: offloaders[
|
246
|
+
index
|
247
|
+
].wait_and_get_device_tensors(),
|
248
|
+
)
|
249
|
+
|
250
|
+
|
251
|
+
def _hook_module_forward_raw(module, on_forward_end, get_parameter_and_buffer_dicts):
|
252
|
+
original_forward = module.forward
|
253
|
+
|
254
|
+
def forward(*args, **kwargs):
|
255
|
+
module.forward = original_forward
|
256
|
+
output = functional_call(
|
257
|
+
module, get_parameter_and_buffer_dicts(), args=args, kwargs=kwargs
|
258
|
+
)
|
259
|
+
on_forward_end()
|
260
|
+
module.forward = forward
|
261
|
+
return output
|
262
|
+
|
263
|
+
module.forward = forward
|
264
|
+
|
265
|
+
|
266
|
+
class _ModuleOffloader(ABC):
|
267
|
+
def __init__(
|
268
|
+
self,
|
269
|
+
mode: str,
|
270
|
+
module: torch.nn.Module,
|
271
|
+
alt_stream: torch.cuda.Stream,
|
272
|
+
whitelist_param_names: List[str],
|
273
|
+
):
|
274
|
+
self.mode = mode
|
275
|
+
self.module = module
|
276
|
+
self.device = next(module.parameters()).device
|
277
|
+
self.alt_stream = alt_stream
|
278
|
+
|
279
|
+
assert self.device != torch.device(
|
280
|
+
"cpu"
|
281
|
+
), "not handled device=cpu case yet (should skip this tensor)"
|
282
|
+
|
283
|
+
self._device_tensors = None
|
284
|
+
self._load_event = None
|
285
|
+
|
286
|
+
param_dict = dict(self.module.named_parameters())
|
287
|
+
assert all(
|
288
|
+
name in param_dict for name in whitelist_param_names
|
289
|
+
), f"{whitelist_param_names=} {list(param_dict.keys())=}"
|
290
|
+
|
291
|
+
self._param_offloaders = {
|
292
|
+
name: _BaseParamOffloader.create(mode, module=module, param_name=name)
|
293
|
+
for name in whitelist_param_names
|
294
|
+
}
|
295
|
+
|
296
|
+
def post_init(self):
|
297
|
+
for name, param_offloader in self._param_offloaders.items():
|
298
|
+
param_offloader.post_init()
|
299
|
+
|
300
|
+
def start_onload(self):
|
301
|
+
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
302
|
+
with torch.cuda.stream(self.alt_stream):
|
303
|
+
self._device_tensors = self._create_device_tensors()
|
304
|
+
self._load_event = torch.cuda.Event()
|
305
|
+
self._load_event.record()
|
306
|
+
|
307
|
+
def offload(self):
|
308
|
+
self._device_tensors = None
|
309
|
+
self._load_event = None
|
310
|
+
|
311
|
+
def wait_and_get_device_tensors(self):
|
312
|
+
assert self._device_tensors is not None
|
313
|
+
self._load_event.wait()
|
314
|
+
return self._device_tensors
|
315
|
+
|
316
|
+
def _create_device_tensors(self):
|
317
|
+
return {k: v.create_device_tensor() for k, v in self._param_offloaders.items()}
|
318
|
+
|
319
|
+
|
320
|
+
class _BaseParamOffloader(ABC):
|
321
|
+
@staticmethod
|
322
|
+
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
|
323
|
+
return {
|
324
|
+
"cpu": _CpuParamOffloader,
|
325
|
+
"shm_cpu": _ShmCpuParamOffloader,
|
326
|
+
"sharded_gpu": _ShardedGpuParamOffloader,
|
327
|
+
}[mode](**kwargs)
|
328
|
+
|
329
|
+
def __init__(self, module, param_name):
|
330
|
+
self._module = module
|
331
|
+
self._param_name = param_name
|
332
|
+
|
333
|
+
@property
|
334
|
+
def _param(self):
|
335
|
+
return getattr(self._module, self._param_name)
|
336
|
+
|
337
|
+
def post_init(self):
|
338
|
+
pass
|
339
|
+
|
340
|
+
def create_device_tensor(self):
|
341
|
+
raise NotImplementedError
|
342
|
+
|
343
|
+
|
344
|
+
class _CpuParamOffloader(_BaseParamOffloader):
|
345
|
+
def __init__(self, module, param_name):
|
346
|
+
super().__init__(module, param_name)
|
347
|
+
_move_param_to_cpu(self._param, pin_memory=True)
|
348
|
+
|
349
|
+
def create_device_tensor(self):
|
350
|
+
return self._param.to("cuda", non_blocking=True)
|
351
|
+
|
352
|
+
|
353
|
+
class _ShmCpuParamOffloader(_BaseParamOffloader):
|
354
|
+
def __init__(self, module, param_name):
|
355
|
+
super().__init__(module, param_name)
|
356
|
+
self._rank = get_naive_distributed().get_rank()
|
357
|
+
self._world_size = get_naive_distributed().get_world_size()
|
358
|
+
|
359
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
360
|
+
|
361
|
+
assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
|
362
|
+
assert (
|
363
|
+
self._param.data.is_contiguous()
|
364
|
+
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
|
365
|
+
|
366
|
+
self.shm_cpu_data = get_host_shared_memory_manager().malloc(
|
367
|
+
shape=self._param.shape, dtype=self._param.dtype
|
368
|
+
)
|
369
|
+
|
370
|
+
if self._rank == 0:
|
371
|
+
self.shm_cpu_data.copy_(self._param.data.to("cpu"))
|
372
|
+
self._param.data = self.shm_cpu_data
|
373
|
+
else:
|
374
|
+
_move_param_to_meta(self._module, self._param_name)
|
375
|
+
get_naive_distributed().barrier()
|
376
|
+
|
377
|
+
def post_init(self):
|
378
|
+
if self._rank == 0:
|
379
|
+
assert (
|
380
|
+
self.shm_cpu_data.data_ptr() == self._param.data.data_ptr()
|
381
|
+
), f"{self.shm_cpu_data.data_ptr()=} {self._param.data.data_ptr()=} {self.shm_cpu_data=} {self._param.data=}"
|
382
|
+
|
383
|
+
_move_param_to_meta(self._module, self._param_name)
|
384
|
+
|
385
|
+
def create_device_tensor(self):
|
386
|
+
return self.shm_cpu_data.to("cuda", non_blocking=True)
|
387
|
+
|
388
|
+
|
389
|
+
def _move_param_to_cpu(param, pin_memory: bool):
|
390
|
+
cpu_data = _empty_strided_like(
|
391
|
+
param.data,
|
392
|
+
device="cpu",
|
393
|
+
pin_memory=pin_memory,
|
394
|
+
)
|
395
|
+
cpu_data.copy_(param.data)
|
396
|
+
param.data = cpu_data
|
397
|
+
|
398
|
+
|
399
|
+
def _move_param_to_meta(module, param_name):
|
400
|
+
old_param = getattr(module, param_name)
|
401
|
+
old_param_type = type(old_param)
|
402
|
+
|
403
|
+
new_data = old_param.data.to("meta")
|
404
|
+
|
405
|
+
if old_param_type == ModelWeightParameter:
|
406
|
+
# manually checked how `w13_weight` and `w2_weight` are constructed
|
407
|
+
new_param = ModelWeightParameter(
|
408
|
+
data=new_data,
|
409
|
+
**{
|
410
|
+
k: getattr(old_param, k)
|
411
|
+
for k in ["input_dim", "output_dim", "weight_loader"]
|
412
|
+
},
|
413
|
+
)
|
414
|
+
elif old_param_type == torch.nn.Parameter:
|
415
|
+
new_param = torch.nn.Parameter(
|
416
|
+
data=new_data,
|
417
|
+
requires_grad=False,
|
418
|
+
)
|
419
|
+
else:
|
420
|
+
raise ValueError(f"Unknown {old_param_type=} {old_param=}")
|
421
|
+
|
422
|
+
setattr(module, param_name, new_param)
|
423
|
+
|
424
|
+
|
425
|
+
def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
|
426
|
+
return torch.empty_strided(
|
427
|
+
size=x.size(),
|
428
|
+
stride=x.stride(),
|
429
|
+
dtype=x.dtype,
|
430
|
+
layout=x.layout,
|
431
|
+
device=device,
|
432
|
+
pin_memory=pin_memory,
|
433
|
+
)
|
sglang/srt/operations.py
CHANGED
@@ -84,6 +84,7 @@ class _StageExecutor:
|
|
84
84
|
forward_batch: ForwardBatch = inputs["forward_batch"]
|
85
85
|
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
|
86
86
|
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
|
87
|
+
self._global_num_tokens = forward_batch.global_num_tokens_cpu
|
87
88
|
|
88
89
|
def next(self):
|
89
90
|
assert not self.done
|
@@ -91,7 +92,11 @@ class _StageExecutor:
|
|
91
92
|
stage = self._stages[self._index]
|
92
93
|
|
93
94
|
if self._global_dp_buffer_len is not None:
|
94
|
-
set_dp_buffer_len(
|
95
|
+
set_dp_buffer_len(
|
96
|
+
self._global_dp_buffer_len,
|
97
|
+
self._local_dp_buffer_len,
|
98
|
+
self._global_num_tokens,
|
99
|
+
)
|
95
100
|
|
96
101
|
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
|
97
102
|
for op in stage:
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -513,12 +513,13 @@ class ReasoningParser:
|
|
513
513
|
|
514
514
|
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
|
515
515
|
"deepseek-r1": DeepSeekR1Detector,
|
516
|
-
"
|
517
|
-
"qwen3-thinking": Qwen3Detector,
|
516
|
+
"deepseek-v3": Qwen3Detector,
|
518
517
|
"glm45": Qwen3Detector,
|
518
|
+
"gpt-oss": GptOssDetector,
|
519
519
|
"kimi": KimiDetector,
|
520
|
+
"qwen3": Qwen3Detector,
|
521
|
+
"qwen3-thinking": Qwen3Detector,
|
520
522
|
"step3": DeepSeekR1Detector,
|
521
|
-
"gpt-oss": GptOssDetector,
|
522
523
|
}
|
523
524
|
|
524
525
|
def __init__(
|