sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
|
|
60
60
|
initialize_dp_attention,
|
61
61
|
)
|
62
62
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
63
|
-
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
64
63
|
from sglang.srt.layers.quantization import (
|
65
64
|
deep_gemm_wrapper,
|
66
65
|
monkey_patch_isinstance_for_vllm_base_layer,
|
@@ -92,10 +91,16 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
92
91
|
)
|
93
92
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
94
93
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
94
|
+
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
95
95
|
from sglang.srt.model_loader import get_model
|
96
96
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
97
97
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
98
98
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
99
|
+
from sglang.srt.offloader import (
|
100
|
+
create_offloader_from_server_args,
|
101
|
+
get_offloader,
|
102
|
+
set_offloader,
|
103
|
+
)
|
99
104
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
100
105
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
101
106
|
from sglang.srt.server_args import ServerArgs
|
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
|
|
118
123
|
is_npu,
|
119
124
|
monkey_patch_p2p_access_check,
|
120
125
|
monkey_patch_vllm_gguf_config,
|
121
|
-
set_cpu_offload_max_bytes,
|
122
126
|
set_cuda_arch,
|
123
127
|
)
|
124
128
|
from sglang.srt.weight_sync.tensor_bucket import (
|
@@ -168,6 +172,7 @@ class ModelRunner:
|
|
168
172
|
pp_size: int,
|
169
173
|
nccl_port: int,
|
170
174
|
server_args: ServerArgs,
|
175
|
+
dp_rank: Optional[int] = None,
|
171
176
|
is_draft_worker: bool = False,
|
172
177
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
173
178
|
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
@@ -219,14 +224,9 @@ class ModelRunner:
|
|
219
224
|
# TODO it is indeed not a "server args"
|
220
225
|
"use_mla_backend": self.use_mla_backend,
|
221
226
|
"speculative_algorithm": self.spec_algorithm,
|
222
|
-
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
223
|
-
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
224
227
|
}
|
225
228
|
)
|
226
229
|
|
227
|
-
# CPU offload
|
228
|
-
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
229
|
-
|
230
230
|
# Init OpenMP threads binding for CPU
|
231
231
|
if self.device == "cpu":
|
232
232
|
self.init_threads_binding()
|
@@ -234,6 +234,9 @@ class ModelRunner:
|
|
234
234
|
# Get memory before model loading
|
235
235
|
min_per_gpu_memory = self.init_torch_distributed()
|
236
236
|
|
237
|
+
# CPU offload
|
238
|
+
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
|
239
|
+
|
237
240
|
# Update deep gemm configure
|
238
241
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
239
242
|
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
@@ -309,8 +312,13 @@ class ModelRunner:
|
|
309
312
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
310
313
|
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
311
314
|
self.num_effective_layers = self.end_layer - self.start_layer
|
312
|
-
assert (
|
313
|
-
|
315
|
+
assert (
|
316
|
+
(not model_has_mtp_layers)
|
317
|
+
or (self.spec_algorithm.is_none())
|
318
|
+
or (
|
319
|
+
(not self.spec_algorithm.is_none())
|
320
|
+
and (self.num_effective_layers == model_num_layers)
|
321
|
+
)
|
314
322
|
), "PP is not compatible with MTP models."
|
315
323
|
|
316
324
|
# Apply torchao quantization
|
@@ -339,9 +347,12 @@ class ModelRunner:
|
|
339
347
|
if self.device == "cuda":
|
340
348
|
self.init_cublas()
|
341
349
|
self.init_attention_backend()
|
342
|
-
self.
|
350
|
+
self.init_device_graphs()
|
351
|
+
elif self.device == "npu":
|
352
|
+
self.init_attention_backend()
|
353
|
+
self.init_device_graphs()
|
343
354
|
else:
|
344
|
-
self.
|
355
|
+
self.graph_runner = None
|
345
356
|
self.cuda_graph_mem_usage = 0
|
346
357
|
self.init_attention_backend()
|
347
358
|
|
@@ -508,9 +519,6 @@ class ModelRunner:
|
|
508
519
|
|
509
520
|
if not self.use_mla_backend:
|
510
521
|
server_args.disable_chunked_prefix_cache = True
|
511
|
-
elif self.page_size > 1:
|
512
|
-
logger.info("Disable chunked prefix cache when page size > 1.")
|
513
|
-
server_args.disable_chunked_prefix_cache = True
|
514
522
|
|
515
523
|
if not server_args.disable_chunked_prefix_cache:
|
516
524
|
logger.info("Chunked prefix cache is turned on.")
|
@@ -684,6 +692,8 @@ class ModelRunner:
|
|
684
692
|
monkey_patch_vllm_parallel_state(reverse=True)
|
685
693
|
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
686
694
|
|
695
|
+
get_offloader().post_init()
|
696
|
+
|
687
697
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
688
698
|
if self.server_args.quantization_param_path is not None:
|
689
699
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
@@ -915,7 +925,8 @@ class ModelRunner:
|
|
915
925
|
)
|
916
926
|
|
917
927
|
# We need to get device after patch otherwise the device would be wrong
|
918
|
-
|
928
|
+
self.device_module = torch.get_device_module(self.device)
|
929
|
+
infered_device = self.device_module.current_device()
|
919
930
|
|
920
931
|
named_tensors = [
|
921
932
|
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
@@ -1046,8 +1057,6 @@ class ModelRunner:
|
|
1046
1057
|
else:
|
1047
1058
|
num_layers = self.num_effective_layers
|
1048
1059
|
if self.use_mla_backend:
|
1049
|
-
# FIXME: pipeline parallelism is not compatible with mla backend
|
1050
|
-
assert self.pp_size == 1
|
1051
1060
|
cell_size = (
|
1052
1061
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
1053
1062
|
* num_layers
|
@@ -1236,6 +1245,11 @@ class ModelRunner:
|
|
1236
1245
|
|
1237
1246
|
# Initialize req_to_token_pool
|
1238
1247
|
if self.req_to_token_pool is None:
|
1248
|
+
# FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
|
1249
|
+
extra_max_context_len = 4
|
1250
|
+
if self.server_args.speculative_num_draft_tokens is not None:
|
1251
|
+
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
1252
|
+
|
1239
1253
|
if self.server_args.disaggregation_mode == "decode":
|
1240
1254
|
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
1241
1255
|
|
@@ -1244,7 +1258,8 @@ class ModelRunner:
|
|
1244
1258
|
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
1245
1259
|
self.req_to_token_pool = DecodeReqToTokenPool(
|
1246
1260
|
size=max_num_reqs,
|
1247
|
-
max_context_len=self.model_config.context_len
|
1261
|
+
max_context_len=self.model_config.context_len
|
1262
|
+
+ extra_max_context_len,
|
1248
1263
|
device=self.device,
|
1249
1264
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1250
1265
|
pre_alloc_size=pre_alloc_size,
|
@@ -1252,7 +1267,8 @@ class ModelRunner:
|
|
1252
1267
|
else:
|
1253
1268
|
self.req_to_token_pool = ReqToTokenPool(
|
1254
1269
|
size=max_num_reqs,
|
1255
|
-
max_context_len=self.model_config.context_len
|
1270
|
+
max_context_len=self.model_config.context_len
|
1271
|
+
+ extra_max_context_len,
|
1256
1272
|
device=self.device,
|
1257
1273
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1258
1274
|
)
|
@@ -1348,11 +1364,6 @@ class ModelRunner:
|
|
1348
1364
|
|
1349
1365
|
# Initialize token_to_kv_pool_allocator
|
1350
1366
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1351
|
-
max_num_extend_tokens = (
|
1352
|
-
self.server_args.chunked_prefill_size
|
1353
|
-
if self.server_args.chunked_prefill_size > 0
|
1354
|
-
else self.server_args.max_prefill_tokens
|
1355
|
-
)
|
1356
1367
|
if self.token_to_kv_pool_allocator is None:
|
1357
1368
|
if self.server_args.attention_backend == "ascend":
|
1358
1369
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
@@ -1391,7 +1402,6 @@ class ModelRunner:
|
|
1391
1402
|
device=self.device,
|
1392
1403
|
kvcache=self.token_to_kv_pool,
|
1393
1404
|
need_sort=need_sort,
|
1394
|
-
max_num_extend_tokens=max_num_extend_tokens,
|
1395
1405
|
)
|
1396
1406
|
else:
|
1397
1407
|
assert self.is_draft_worker
|
@@ -1591,9 +1601,9 @@ class ModelRunner:
|
|
1591
1601
|
.cuda()
|
1592
1602
|
)
|
1593
1603
|
|
1594
|
-
def
|
1604
|
+
def init_device_graphs(self):
|
1595
1605
|
"""Capture cuda graphs."""
|
1596
|
-
self.
|
1606
|
+
self.graph_runner = None
|
1597
1607
|
self.cuda_graph_mem_usage = 0
|
1598
1608
|
|
1599
1609
|
if not self.is_generation:
|
@@ -1608,8 +1618,9 @@ class ModelRunner:
|
|
1608
1618
|
logger.info(
|
1609
1619
|
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1610
1620
|
)
|
1611
|
-
self.
|
1612
|
-
|
1621
|
+
self.graph_runner = (
|
1622
|
+
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
|
1623
|
+
)
|
1613
1624
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1614
1625
|
self.cuda_graph_mem_usage = before_mem - after_mem
|
1615
1626
|
logger.info(
|
@@ -1761,11 +1772,11 @@ class ModelRunner:
|
|
1761
1772
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1762
1773
|
can_run_cuda_graph = bool(
|
1763
1774
|
forward_batch.forward_mode.is_cuda_graph()
|
1764
|
-
and self.
|
1765
|
-
and self.
|
1775
|
+
and self.graph_runner
|
1776
|
+
and self.graph_runner.can_run(forward_batch)
|
1766
1777
|
)
|
1767
1778
|
if can_run_cuda_graph:
|
1768
|
-
ret = self.
|
1779
|
+
ret = self.graph_runner.replay(
|
1769
1780
|
forward_batch,
|
1770
1781
|
skip_attn_backend_init=skip_attn_backend_init,
|
1771
1782
|
pp_proxy_tensors=pp_proxy_tensors,
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Run the model with npu graph and torch.compile."""
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import threading
|
20
|
+
from typing import TYPE_CHECKING, Optional, Union
|
21
|
+
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
30
|
+
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
32
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
33
|
+
|
34
|
+
|
35
|
+
class NPUGraphRunner(CudaGraphRunner):
|
36
|
+
"""A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
|
37
|
+
|
38
|
+
def __init__(self, model_runner: ModelRunner):
|
39
|
+
super().__init__(model_runner)
|
40
|
+
|
41
|
+
def _create_device_graph(self):
|
42
|
+
return torch.npu.NPUGraph()
|
43
|
+
|
44
|
+
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
45
|
+
with torch.npu.graph(
|
46
|
+
graph,
|
47
|
+
pool=pool,
|
48
|
+
stream=stream,
|
49
|
+
auto_dispatch_capture=True,
|
50
|
+
):
|
51
|
+
out = run_once_fn()
|
52
|
+
return out
|
53
|
+
|
54
|
+
def _update_inputs(self, seq_lens):
|
55
|
+
self.graphs[self.bs].update(
|
56
|
+
cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}]
|
57
|
+
)
|
58
|
+
|
59
|
+
def _cache_loc_dtype(self):
|
60
|
+
return torch.int32
|
61
|
+
|
62
|
+
def replay(
|
63
|
+
self,
|
64
|
+
forward_batch: ForwardBatch,
|
65
|
+
skip_attn_backend_init: bool = False,
|
66
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
67
|
+
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
68
|
+
if not skip_attn_backend_init:
|
69
|
+
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
70
|
+
else:
|
71
|
+
# In speculative decoding, these two fields are still needed.
|
72
|
+
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
73
|
+
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
74
|
+
|
75
|
+
# Replay
|
76
|
+
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
|
77
|
+
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
78
|
+
thread.start()
|
79
|
+
self.graphs[self.bs].replay()
|
80
|
+
thread.join()
|
81
|
+
|
82
|
+
output = self.output_buffers[self.bs]
|
83
|
+
if isinstance(output, LogitsProcessorOutput):
|
84
|
+
return LogitsProcessorOutput(
|
85
|
+
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
86
|
+
hidden_states=(
|
87
|
+
output.hidden_states[: self.raw_num_token]
|
88
|
+
if output.hidden_states is not None
|
89
|
+
else None
|
90
|
+
),
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
assert isinstance(output, PPProxyTensors)
|
94
|
+
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
|
79
79
|
yield module
|
80
80
|
return
|
81
81
|
|
82
|
-
|
82
|
+
original_infos: Dict[str, Dict] = {}
|
83
83
|
|
84
84
|
# Store original device states and move parameters to GPU if they're on CPU
|
85
85
|
for name, p in module.named_parameters():
|
86
86
|
if p.device.type == "cpu":
|
87
|
-
|
88
|
-
|
87
|
+
original_data = p.data
|
88
|
+
device_data = p.data.to(target_device)
|
89
|
+
original_infos[name] = dict(
|
90
|
+
device=p.device,
|
91
|
+
original_data=original_data,
|
92
|
+
device_data=device_data,
|
93
|
+
)
|
94
|
+
p.data = device_data
|
89
95
|
# Parameters already on target device are not touched
|
90
96
|
|
91
97
|
try:
|
@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
|
95
101
|
# Restore parameters to their original devices, ignoring new parameters
|
96
102
|
pin_memory = is_pin_memory_available()
|
97
103
|
for name, p in module.named_parameters():
|
98
|
-
if name in
|
99
|
-
|
100
|
-
|
104
|
+
if name in original_infos:
|
105
|
+
original_info = original_infos[name]
|
106
|
+
device_data = original_info["device_data"]
|
107
|
+
original_data = original_info["original_data"]
|
108
|
+
original_device: torch.device = original_info["device"]
|
109
|
+
|
110
|
+
if (
|
111
|
+
(device_data.device == p.data.device)
|
112
|
+
and (device_data.data_ptr() == p.data.data_ptr())
|
113
|
+
and (device_data.shape == p.data.shape)
|
114
|
+
and (device_data.dtype == p.data.dtype)
|
115
|
+
):
|
116
|
+
original_data.copy_(p.data.to(original_data.device))
|
117
|
+
p.data = original_data
|
118
|
+
elif original_device.type == "cpu":
|
101
119
|
# `torch.empty_like` does not support `pin_memory` argument
|
102
120
|
cpu_data = torch.empty_strided(
|
103
121
|
size=p.data.size(),
|
sglang/srt/models/dbrx.py
CHANGED
@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
|
|
32
32
|
RowParallelLinear,
|
33
33
|
)
|
34
34
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
35
|
-
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
35
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
36
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
37
|
+
from sglang.srt.layers.moe.topk import TopK
|
36
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
40
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
|
|
104
106
|
self.params_dtype = params_dtype
|
105
107
|
|
106
108
|
self.router = DbrxRouter(config, self.params_dtype)
|
109
|
+
self.topk = TopK(
|
110
|
+
self.top_k,
|
111
|
+
renormalize=True,
|
112
|
+
)
|
113
|
+
self.moe_runner_config = MoeRunnerConfig(inplace=True)
|
107
114
|
self.ws = nn.Parameter(
|
108
115
|
torch.empty(
|
109
116
|
self.num_total_experts,
|
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
|
|
169
176
|
hidden_states = hidden_states.view(-1, self.d_model)
|
170
177
|
# router_logits: (num_tokens, n_experts)
|
171
178
|
router_logits = self.router(hidden_states)
|
179
|
+
topk_output = self.topk(hidden_states, router_logits)
|
172
180
|
final_hidden_states = fused_moe(
|
173
181
|
hidden_states,
|
174
182
|
self.ws,
|
175
183
|
self.w2s,
|
176
|
-
|
177
|
-
self.
|
178
|
-
renormalize=True,
|
179
|
-
inplace=True,
|
184
|
+
topk_output,
|
185
|
+
self.moe_runner_config,
|
180
186
|
)
|
181
187
|
|
182
188
|
if self.tp_size > 1:
|
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
|
|
293
299
|
position_ids: torch.Tensor,
|
294
300
|
hidden_states: torch.Tensor,
|
295
301
|
forward_batch: ForwardBatch,
|
296
|
-
) -> torch.Tensor:
|
302
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
297
303
|
residual = hidden_states
|
298
304
|
hidden_states = self.norm_1(hidden_states)
|
299
305
|
x = self.attn(
|
sglang/srt/models/deepseek.py
CHANGED
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
40
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
40
41
|
from sglang.srt.layers.moe.topk import TopK
|
41
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
|
|
180
181
|
w1=self.w1,
|
181
182
|
w2=self.w2,
|
182
183
|
topk_output=topk_output,
|
183
|
-
inplace=True,
|
184
|
+
moe_runner_config=MoeRunnerConfig(inplace=True),
|
184
185
|
)
|
185
186
|
|
186
187
|
if self.config.n_shared_experts is not None:
|
@@ -20,7 +20,7 @@ import torch
|
|
20
20
|
from torch import nn
|
21
21
|
from transformers import PretrainedConfig
|
22
22
|
|
23
|
-
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
23
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
24
24
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
25
25
|
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
26
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -135,6 +135,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
135
135
|
self.config = config
|
136
136
|
self.tp_size = get_tensor_model_parallel_world_size()
|
137
137
|
self.quant_config = quant_config
|
138
|
+
# if not set, model load will be broken in DeepseekV3ForCausalLM load_weights()
|
139
|
+
self.pp_group = get_pp_group()
|
138
140
|
self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN")
|
139
141
|
|
140
142
|
self.model = DeepseekModelNextN(
|