sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -19,19 +19,26 @@ import inspect
|
|
19
19
|
import json
|
20
20
|
import logging
|
21
21
|
import os
|
22
|
+
import socket
|
23
|
+
import threading
|
22
24
|
import time
|
25
|
+
from collections import defaultdict
|
23
26
|
from dataclasses import dataclass
|
24
27
|
from typing import List, Optional, Tuple, Union
|
28
|
+
from urllib.parse import urlparse
|
25
29
|
|
30
|
+
import requests
|
26
31
|
import torch
|
27
32
|
import torch.distributed as dist
|
28
33
|
|
29
34
|
from sglang.srt.configs.device_config import DeviceConfig
|
30
|
-
from sglang.srt.configs.load_config import LoadConfig
|
35
|
+
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
31
36
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
32
37
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
38
|
+
from sglang.srt.connector import ConnectorType
|
33
39
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
34
40
|
from sglang.srt.distributed import (
|
41
|
+
get_pp_group,
|
35
42
|
get_tp_group,
|
36
43
|
get_world_group,
|
37
44
|
init_distributed_environment,
|
@@ -83,11 +90,14 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
83
90
|
AscendMLAPagedTokenToKVPool,
|
84
91
|
AscendTokenToKVPool,
|
85
92
|
DoubleSparseTokenToKVPool,
|
93
|
+
HybridLinearKVPool,
|
94
|
+
HybridReqToTokenPool,
|
86
95
|
MHATokenToKVPool,
|
87
96
|
MLATokenToKVPool,
|
88
97
|
ReqToTokenPool,
|
89
98
|
SWAKVPool,
|
90
99
|
)
|
100
|
+
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
91
101
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
92
102
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
93
103
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
@@ -101,6 +111,9 @@ from sglang.srt.offloader import (
|
|
101
111
|
set_offloader,
|
102
112
|
)
|
103
113
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
114
|
+
from sglang.srt.remote_instance_weight_loader_utils import (
|
115
|
+
trigger_init_weights_send_group_for_remote_instance_request,
|
116
|
+
)
|
104
117
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
105
118
|
from sglang.srt.server_args import ServerArgs
|
106
119
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -114,6 +127,7 @@ from sglang.srt.utils import (
|
|
114
127
|
get_bool_env_var,
|
115
128
|
get_cpu_ids_by_node,
|
116
129
|
init_custom_process_group,
|
130
|
+
is_blackwell,
|
117
131
|
is_fa3_default_architecture,
|
118
132
|
is_flashinfer_available,
|
119
133
|
is_hip,
|
@@ -123,6 +137,7 @@ from sglang.srt.utils import (
|
|
123
137
|
is_sm100_supported,
|
124
138
|
monkey_patch_p2p_access_check,
|
125
139
|
monkey_patch_vllm_gguf_config,
|
140
|
+
parse_connector_type,
|
126
141
|
set_cuda_arch,
|
127
142
|
)
|
128
143
|
from sglang.srt.weight_sync.tensor_bucket import (
|
@@ -251,6 +266,7 @@ class ModelRunner:
|
|
251
266
|
|
252
267
|
# For weight updates
|
253
268
|
self._model_update_group = {}
|
269
|
+
self._weights_send_group = {}
|
254
270
|
|
255
271
|
def initialize(self, min_per_gpu_memory: float):
|
256
272
|
server_args = self.server_args
|
@@ -300,6 +316,26 @@ class ModelRunner:
|
|
300
316
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
301
317
|
self.is_hybrid = self.model_config.is_hybrid = True
|
302
318
|
|
319
|
+
if self.is_hybrid_gdn:
|
320
|
+
logger.warning("Hybrid GDN model detected, disable radix cache")
|
321
|
+
self.server_args.disable_radix_cache = True
|
322
|
+
self.server_args.attention_backend = "hybrid_linear_attn"
|
323
|
+
if self.server_args.max_mamba_cache_size is None:
|
324
|
+
if self.server_args.max_running_requests is not None:
|
325
|
+
self.server_args.max_mamba_cache_size = (
|
326
|
+
self.server_args.max_running_requests
|
327
|
+
)
|
328
|
+
else:
|
329
|
+
self.server_args.max_mamba_cache_size = 512
|
330
|
+
self.server_args.max_mamba_cache_size = (
|
331
|
+
self.server_args.max_mamba_cache_size
|
332
|
+
// (
|
333
|
+
self.server_args.dp_size
|
334
|
+
if self.server_args.enable_dp_attention
|
335
|
+
else 1
|
336
|
+
)
|
337
|
+
)
|
338
|
+
|
303
339
|
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
304
340
|
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
305
341
|
# determine the number of layers.
|
@@ -341,6 +377,14 @@ class ModelRunner:
|
|
341
377
|
if server_args.enable_lora:
|
342
378
|
self.init_lora_manager()
|
343
379
|
|
380
|
+
# Init Double Sparsity
|
381
|
+
if server_args.enable_double_sparsity:
|
382
|
+
if server_args.ds_heavy_channel_type is None:
|
383
|
+
raise ValueError(
|
384
|
+
"Please specify the heavy channel type for double sparsity optimization."
|
385
|
+
)
|
386
|
+
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
387
|
+
|
344
388
|
# Init memory pool and attention backends
|
345
389
|
self.init_memory_pool(
|
346
390
|
min_per_gpu_memory,
|
@@ -351,12 +395,12 @@ class ModelRunner:
|
|
351
395
|
self.init_cublas()
|
352
396
|
self.init_attention_backend()
|
353
397
|
self.init_device_graphs()
|
354
|
-
elif self.device
|
398
|
+
elif self.device in ["npu", "cpu"]:
|
355
399
|
self.init_attention_backend()
|
356
400
|
self.init_device_graphs()
|
357
401
|
else:
|
358
402
|
self.graph_runner = None
|
359
|
-
self.
|
403
|
+
self.graph_mem_usage = 0
|
360
404
|
self.init_attention_backend()
|
361
405
|
|
362
406
|
# auxiliary hidden capture mode. TODO: expose this to server args?
|
@@ -506,11 +550,6 @@ class ModelRunner:
|
|
506
550
|
)
|
507
551
|
server_args.attention_backend = "triton"
|
508
552
|
server_args.disable_cuda_graph = True
|
509
|
-
if server_args.ds_heavy_channel_type is None:
|
510
|
-
raise ValueError(
|
511
|
-
"Please specify the heavy channel type for double sparsity optimization."
|
512
|
-
)
|
513
|
-
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
514
553
|
|
515
554
|
if self.is_multimodal:
|
516
555
|
if not self.is_multimodal_chunked_prefill_supported:
|
@@ -523,6 +562,18 @@ class ModelRunner:
|
|
523
562
|
if not self.use_mla_backend:
|
524
563
|
server_args.disable_chunked_prefix_cache = True
|
525
564
|
|
565
|
+
# TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
|
566
|
+
# For more details, see: https://github.com/sgl-project/sglang/issues/8616
|
567
|
+
elif (
|
568
|
+
self.dp_size > 1
|
569
|
+
and is_sm100_supported()
|
570
|
+
and server_args.attention_backend != "triton"
|
571
|
+
and server_args.attention_backend == "trtllm_mla"
|
572
|
+
):
|
573
|
+
logger.info(
|
574
|
+
"Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
|
575
|
+
)
|
576
|
+
server_args.disable_chunked_prefix_cache = True
|
526
577
|
if not server_args.disable_chunked_prefix_cache:
|
527
578
|
logger.info("Chunked prefix cache is turned on.")
|
528
579
|
|
@@ -593,6 +644,11 @@ class ModelRunner:
|
|
593
644
|
# Set local size to hint SGLang to use shared memory based AllReduce
|
594
645
|
os.environ["LOCAL_SIZE"] = str(self.tp_size)
|
595
646
|
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
|
647
|
+
|
648
|
+
@torch.library.register_fake("sgl_kernel::shm_allgather")
|
649
|
+
def _(data, dim):
|
650
|
+
return torch.cat([data] * self.tp_size, dim=dim)
|
651
|
+
|
596
652
|
else:
|
597
653
|
logger.warning(
|
598
654
|
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
|
@@ -625,6 +681,7 @@ class ModelRunner:
|
|
625
681
|
cpu_group=get_world_group().cpu_group,
|
626
682
|
)
|
627
683
|
self.tp_group = get_tp_group()
|
684
|
+
self.pp_group = get_pp_group()
|
628
685
|
self.attention_tp_group = get_attention_tp_group()
|
629
686
|
|
630
687
|
# Check memory for tensor parallelism
|
@@ -681,6 +738,20 @@ class ModelRunner:
|
|
681
738
|
if self.server_args.load_format == "gguf":
|
682
739
|
monkey_patch_vllm_gguf_config()
|
683
740
|
|
741
|
+
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
|
742
|
+
if self.tp_rank == 0:
|
743
|
+
instance_ip = socket.gethostbyname(socket.gethostname())
|
744
|
+
t = threading.Thread(
|
745
|
+
target=trigger_init_weights_send_group_for_remote_instance_request,
|
746
|
+
args=(
|
747
|
+
self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
748
|
+
self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
749
|
+
self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
750
|
+
instance_ip,
|
751
|
+
),
|
752
|
+
)
|
753
|
+
t.start()
|
754
|
+
|
684
755
|
# Load the model
|
685
756
|
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
686
757
|
monkey_patch_vllm_parallel_state()
|
@@ -690,7 +761,7 @@ class ModelRunner:
|
|
690
761
|
self.model = get_model(
|
691
762
|
model_config=self.model_config,
|
692
763
|
load_config=self.load_config,
|
693
|
-
device_config=DeviceConfig(self.device),
|
764
|
+
device_config=DeviceConfig(self.device, self.gpu_id),
|
694
765
|
)
|
695
766
|
monkey_patch_vllm_parallel_state(reverse=True)
|
696
767
|
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
@@ -822,6 +893,103 @@ class ModelRunner:
|
|
822
893
|
logger.info("Update weights end.")
|
823
894
|
return True, "Succeeded to update model weights."
|
824
895
|
|
896
|
+
def init_weights_send_group_for_remote_instance(
|
897
|
+
self,
|
898
|
+
master_address,
|
899
|
+
ports,
|
900
|
+
group_rank,
|
901
|
+
world_size,
|
902
|
+
group_name,
|
903
|
+
backend="nccl",
|
904
|
+
):
|
905
|
+
assert (
|
906
|
+
torch.distributed.is_initialized()
|
907
|
+
), "Default torch process group must be initialized"
|
908
|
+
assert group_name != "", "Group name cannot be empty"
|
909
|
+
|
910
|
+
ports_list = ports.split(",")
|
911
|
+
assert (
|
912
|
+
len(ports_list) == self.tp_size
|
913
|
+
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
914
|
+
group_port = ports_list[self.tp_rank]
|
915
|
+
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
916
|
+
|
917
|
+
logger.info(
|
918
|
+
f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
|
919
|
+
f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
920
|
+
)
|
921
|
+
|
922
|
+
torch.cuda.empty_cache()
|
923
|
+
success = False
|
924
|
+
message = ""
|
925
|
+
try:
|
926
|
+
self._weights_send_group[group_name] = init_custom_process_group(
|
927
|
+
backend=backend,
|
928
|
+
init_method=f"tcp://{master_address}:{group_port}",
|
929
|
+
world_size=world_size,
|
930
|
+
rank=group_rank,
|
931
|
+
group_name=group_name,
|
932
|
+
device_id=torch.device("cuda", self.gpu_id),
|
933
|
+
)
|
934
|
+
dist.barrier(group=self._weights_send_group[group_name])
|
935
|
+
success = True
|
936
|
+
message = (
|
937
|
+
f"Succeeded to init group through {master_address}:{group_port} group."
|
938
|
+
)
|
939
|
+
except Exception as e:
|
940
|
+
message = f"Failed to init group: {e}."
|
941
|
+
logger.error(message)
|
942
|
+
|
943
|
+
torch.cuda.empty_cache()
|
944
|
+
return success, message
|
945
|
+
|
946
|
+
def send_weights_to_remote_instance(
|
947
|
+
self,
|
948
|
+
master_address,
|
949
|
+
ports,
|
950
|
+
group_name,
|
951
|
+
):
|
952
|
+
assert (
|
953
|
+
torch.distributed.is_initialized()
|
954
|
+
), "Default torch process group must be initialized"
|
955
|
+
assert group_name != "", "Group name cannot be empty"
|
956
|
+
|
957
|
+
ports_list = ports.split(",")
|
958
|
+
assert (
|
959
|
+
len(ports_list) == self.tp_size
|
960
|
+
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
961
|
+
group_port = ports_list[self.tp_rank]
|
962
|
+
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
963
|
+
|
964
|
+
if self._weights_send_group[group_name] is not None:
|
965
|
+
send_group = self._weights_send_group[group_name]
|
966
|
+
else:
|
967
|
+
message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
|
968
|
+
logger.error(message)
|
969
|
+
return False, message
|
970
|
+
|
971
|
+
torch.cuda.empty_cache()
|
972
|
+
success = False
|
973
|
+
message = ""
|
974
|
+
try:
|
975
|
+
for _, weights in self.model.named_parameters():
|
976
|
+
torch.distributed.broadcast(
|
977
|
+
weights,
|
978
|
+
src=0,
|
979
|
+
group=send_group,
|
980
|
+
)
|
981
|
+
success = True
|
982
|
+
message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
|
983
|
+
except Exception as e:
|
984
|
+
message = f"Failed to send weights: {e}."
|
985
|
+
logger.error(message)
|
986
|
+
|
987
|
+
# destroy the process group after sending weights
|
988
|
+
del self._weights_send_group[group_name]
|
989
|
+
torch.distributed.distributed_c10d.destroy_process_group(send_group)
|
990
|
+
torch.cuda.empty_cache()
|
991
|
+
return success, message
|
992
|
+
|
825
993
|
def init_weights_update_group(
|
826
994
|
self,
|
827
995
|
master_address,
|
@@ -1057,6 +1225,8 @@ class ModelRunner:
|
|
1057
1225
|
"num_nextn_predict_layers",
|
1058
1226
|
self.num_effective_layers,
|
1059
1227
|
)
|
1228
|
+
elif self.is_hybrid_gdn:
|
1229
|
+
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
|
1060
1230
|
else:
|
1061
1231
|
num_layers = self.num_effective_layers
|
1062
1232
|
if self.use_mla_backend:
|
@@ -1076,9 +1246,22 @@ class ModelRunner:
|
|
1076
1246
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
1077
1247
|
1 - self.mem_fraction_static
|
1078
1248
|
)
|
1249
|
+
if self.is_hybrid_gdn:
|
1250
|
+
rest_memory -= (
|
1251
|
+
self.server_args.max_mamba_cache_size
|
1252
|
+
* self.model_config.hf_config.mamba_cache_per_req
|
1253
|
+
/ (1 << 30)
|
1254
|
+
)
|
1079
1255
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
1080
1256
|
return max_num_token
|
1081
1257
|
|
1258
|
+
@property
|
1259
|
+
def is_hybrid_gdn(self):
|
1260
|
+
return self.model_config.hf_config.architectures[0] in [
|
1261
|
+
"Qwen3NextForCausalLM",
|
1262
|
+
"Qwen3NextForCausalLMMTP",
|
1263
|
+
]
|
1264
|
+
|
1082
1265
|
def set_num_token_hybrid(self):
|
1083
1266
|
if (
|
1084
1267
|
"Llama4ForConditionalGeneration"
|
@@ -1199,6 +1382,8 @@ class ModelRunner:
|
|
1199
1382
|
),
|
1200
1383
|
4096,
|
1201
1384
|
)
|
1385
|
+
if self.is_hybrid_gdn:
|
1386
|
+
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
1202
1387
|
|
1203
1388
|
if not self.spec_algorithm.is_none():
|
1204
1389
|
if self.is_draft_worker:
|
@@ -1237,6 +1422,16 @@ class ModelRunner:
|
|
1237
1422
|
// self.server_args.page_size
|
1238
1423
|
* self.server_args.page_size
|
1239
1424
|
)
|
1425
|
+
# different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
|
1426
|
+
if self.pp_size > 1:
|
1427
|
+
tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
|
1428
|
+
torch.distributed.all_reduce(
|
1429
|
+
tensor,
|
1430
|
+
op=torch.distributed.ReduceOp.MIN,
|
1431
|
+
group=get_world_group().cpu_group,
|
1432
|
+
)
|
1433
|
+
self.max_total_num_tokens = tensor.item()
|
1434
|
+
|
1240
1435
|
# create token size for hybrid cache
|
1241
1436
|
if self.is_hybrid:
|
1242
1437
|
self.set_num_token_hybrid()
|
@@ -1267,6 +1462,28 @@ class ModelRunner:
|
|
1267
1462
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1268
1463
|
pre_alloc_size=pre_alloc_size,
|
1269
1464
|
)
|
1465
|
+
elif self.is_hybrid_gdn:
|
1466
|
+
config = self.model_config.hf_config
|
1467
|
+
(
|
1468
|
+
conv_state_shape,
|
1469
|
+
temporal_state_shape,
|
1470
|
+
conv_dtype,
|
1471
|
+
ssm_dtype,
|
1472
|
+
mamba_layers,
|
1473
|
+
) = config.hybrid_gdn_params
|
1474
|
+
self.req_to_token_pool = HybridReqToTokenPool(
|
1475
|
+
size=max_num_reqs,
|
1476
|
+
max_context_len=self.model_config.context_len
|
1477
|
+
+ extra_max_context_len,
|
1478
|
+
device=self.device,
|
1479
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1480
|
+
conv_state_shape=conv_state_shape,
|
1481
|
+
temporal_state_shape=temporal_state_shape,
|
1482
|
+
conv_dtype=conv_dtype,
|
1483
|
+
ssm_dtype=ssm_dtype,
|
1484
|
+
mamba_layers=mamba_layers,
|
1485
|
+
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1486
|
+
)
|
1270
1487
|
else:
|
1271
1488
|
self.req_to_token_pool = ReqToTokenPool(
|
1272
1489
|
size=max_num_reqs,
|
@@ -1349,6 +1566,24 @@ class ModelRunner:
|
|
1349
1566
|
enable_kvcache_transpose=False,
|
1350
1567
|
device=self.device,
|
1351
1568
|
)
|
1569
|
+
elif self.is_hybrid_gdn:
|
1570
|
+
self.token_to_kv_pool = HybridLinearKVPool(
|
1571
|
+
page_size=self.page_size if _is_npu else 1,
|
1572
|
+
size=self.max_total_num_tokens,
|
1573
|
+
dtype=self.kv_cache_dtype,
|
1574
|
+
head_num=self.model_config.get_num_kv_heads(
|
1575
|
+
get_attention_tp_size()
|
1576
|
+
),
|
1577
|
+
head_dim=self.model_config.head_dim,
|
1578
|
+
# if draft worker, we only need 1 attention layer's kv pool
|
1579
|
+
full_attention_layer_ids=(
|
1580
|
+
[0]
|
1581
|
+
if self.is_draft_worker
|
1582
|
+
else self.model_config.hf_config.full_attention_layer_ids
|
1583
|
+
),
|
1584
|
+
enable_kvcache_transpose=False,
|
1585
|
+
device=self.device,
|
1586
|
+
)
|
1352
1587
|
else:
|
1353
1588
|
self.token_to_kv_pool = MHATokenToKVPool(
|
1354
1589
|
self.max_total_num_tokens,
|
@@ -1368,7 +1603,10 @@ class ModelRunner:
|
|
1368
1603
|
# Initialize token_to_kv_pool_allocator
|
1369
1604
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1370
1605
|
if self.token_to_kv_pool_allocator is None:
|
1371
|
-
if self.server_args.attention_backend
|
1606
|
+
if _is_npu and self.server_args.attention_backend in [
|
1607
|
+
"ascend",
|
1608
|
+
"hybrid_linear_attn",
|
1609
|
+
]:
|
1372
1610
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1373
1611
|
self.max_total_num_tokens,
|
1374
1612
|
page_size=self.page_size,
|
@@ -1582,6 +1820,35 @@ class ModelRunner:
|
|
1582
1820
|
)
|
1583
1821
|
|
1584
1822
|
return DualChunkFlashAttentionBackend(self)
|
1823
|
+
elif backend_str == "hybrid_linear_attn":
|
1824
|
+
assert (
|
1825
|
+
self.is_hybrid_gdn
|
1826
|
+
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
|
1827
|
+
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
1828
|
+
HybridLinearAttnBackend,
|
1829
|
+
MambaAttnBackend,
|
1830
|
+
)
|
1831
|
+
|
1832
|
+
if _is_npu:
|
1833
|
+
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1834
|
+
|
1835
|
+
full_attn_backend = AscendAttnBackend(self)
|
1836
|
+
elif is_blackwell():
|
1837
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
1838
|
+
|
1839
|
+
full_attn_backend = TritonAttnBackend(self)
|
1840
|
+
else:
|
1841
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
1842
|
+
FlashAttentionBackend,
|
1843
|
+
)
|
1844
|
+
|
1845
|
+
full_attn_backend = FlashAttentionBackend(self)
|
1846
|
+
|
1847
|
+
linear_attn_backend = MambaAttnBackend(self)
|
1848
|
+
full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
|
1849
|
+
return HybridLinearAttnBackend(
|
1850
|
+
full_attn_backend, linear_attn_backend, full_attn_layers
|
1851
|
+
)
|
1585
1852
|
else:
|
1586
1853
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1587
1854
|
|
@@ -1603,38 +1870,46 @@ class ModelRunner:
|
|
1603
1870
|
)
|
1604
1871
|
|
1605
1872
|
def init_device_graphs(self):
|
1606
|
-
"""Capture
|
1873
|
+
"""Capture device graphs."""
|
1607
1874
|
self.graph_runner = None
|
1608
|
-
self.
|
1875
|
+
self.graph_mem_usage = 0
|
1609
1876
|
|
1610
1877
|
if not self.is_generation:
|
1611
1878
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
1612
1879
|
return
|
1613
1880
|
|
1614
|
-
if self.server_args.disable_cuda_graph:
|
1881
|
+
if self.device != "cpu" and self.server_args.disable_cuda_graph:
|
1882
|
+
return
|
1883
|
+
|
1884
|
+
if self.device == "cpu" and not self.server_args.enable_torch_compile:
|
1615
1885
|
return
|
1616
1886
|
|
1617
1887
|
tic = time.perf_counter()
|
1618
1888
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1619
1889
|
logger.info(
|
1620
|
-
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1890
|
+
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1621
1891
|
)
|
1622
|
-
|
1623
|
-
CudaGraphRunner
|
1892
|
+
graph_runners = defaultdict(
|
1893
|
+
lambda: CudaGraphRunner,
|
1894
|
+
{
|
1895
|
+
"cpu": CPUGraphRunner,
|
1896
|
+
"npu": NPUGraphRunner,
|
1897
|
+
},
|
1624
1898
|
)
|
1899
|
+
self.graph_runner = graph_runners[self.device](self)
|
1900
|
+
|
1625
1901
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1626
|
-
self.
|
1902
|
+
self.graph_mem_usage = before_mem - after_mem
|
1627
1903
|
logger.info(
|
1628
|
-
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1629
|
-
f"mem usage={self.
|
1904
|
+
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1905
|
+
f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
1630
1906
|
)
|
1631
1907
|
|
1632
1908
|
def init_threads_binding(self):
|
1633
1909
|
omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
|
1910
|
+
cpu_ids_by_node = get_cpu_ids_by_node()
|
1911
|
+
n_numa_node = len(cpu_ids_by_node)
|
1634
1912
|
if omp_cpuids == "all":
|
1635
|
-
cpu_ids_by_node = get_cpu_ids_by_node()
|
1636
|
-
n_numa_node = len(cpu_ids_by_node)
|
1637
|
-
|
1638
1913
|
assert self.tp_size <= n_numa_node, (
|
1639
1914
|
f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
|
1640
1915
|
f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
|
@@ -1651,7 +1926,18 @@ class ModelRunner:
|
|
1651
1926
|
)
|
1652
1927
|
self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
|
1653
1928
|
else:
|
1654
|
-
|
1929
|
+
threads_bind_list = omp_cpuids.split("|")
|
1930
|
+
assert self.tp_size == len(threads_bind_list), (
|
1931
|
+
f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
|
1932
|
+
f"Please double check your settings."
|
1933
|
+
)
|
1934
|
+
self.local_omp_cpuid = threads_bind_list[self.tp_rank]
|
1935
|
+
if self.tp_size > n_numa_node:
|
1936
|
+
logger.warning(
|
1937
|
+
f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
|
1938
|
+
f"in this case the available memory amount of each rank cannot be determined in prior. "
|
1939
|
+
f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
|
1940
|
+
)
|
1655
1941
|
|
1656
1942
|
def apply_torch_tp(self):
|
1657
1943
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
@@ -1771,18 +2057,24 @@ class ModelRunner:
|
|
1771
2057
|
reinit_attn_backend: bool = False,
|
1772
2058
|
split_forward_count: int = 1,
|
1773
2059
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1774
|
-
|
1775
|
-
forward_batch.forward_mode.
|
2060
|
+
mode_check = (
|
2061
|
+
forward_batch.forward_mode.is_cpu_graph
|
2062
|
+
if self.device == "cpu"
|
2063
|
+
else forward_batch.forward_mode.is_cuda_graph
|
2064
|
+
)
|
2065
|
+
can_run_graph = bool(
|
2066
|
+
mode_check()
|
1776
2067
|
and self.graph_runner
|
1777
2068
|
and self.graph_runner.can_run(forward_batch)
|
1778
2069
|
)
|
1779
|
-
|
2070
|
+
|
2071
|
+
if can_run_graph:
|
1780
2072
|
ret = self.graph_runner.replay(
|
1781
2073
|
forward_batch,
|
1782
2074
|
skip_attn_backend_init=skip_attn_backend_init,
|
1783
2075
|
pp_proxy_tensors=pp_proxy_tensors,
|
1784
2076
|
)
|
1785
|
-
return ret,
|
2077
|
+
return ret, can_run_graph
|
1786
2078
|
|
1787
2079
|
# For MLP sync
|
1788
2080
|
if forward_batch.global_num_tokens_cpu is not None:
|
@@ -1811,10 +2103,13 @@ class ModelRunner:
|
|
1811
2103
|
else:
|
1812
2104
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1813
2105
|
|
1814
|
-
if
|
2106
|
+
if (
|
2107
|
+
forward_batch.global_num_tokens_cpu is not None
|
2108
|
+
and self.pp_group.is_last_rank
|
2109
|
+
):
|
1815
2110
|
forward_batch.post_forward_mlp_sync_batch(ret)
|
1816
2111
|
|
1817
|
-
return ret,
|
2112
|
+
return ret, can_run_graph
|
1818
2113
|
|
1819
2114
|
def _preprocess_logits(
|
1820
2115
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
@@ -1863,6 +2158,38 @@ class ModelRunner:
|
|
1863
2158
|
)
|
1864
2159
|
return next_token_ids
|
1865
2160
|
|
2161
|
+
def compute_logprobs_only(
|
2162
|
+
self,
|
2163
|
+
logits_output: LogitsProcessorOutput,
|
2164
|
+
forward_batch: ForwardBatch,
|
2165
|
+
) -> None:
|
2166
|
+
"""
|
2167
|
+
Compute token_ids_logprobs without performing sampling.
|
2168
|
+
|
2169
|
+
Optimized path for prefill-only requests that need token_ids_logprobs but don't
|
2170
|
+
require next token generation. Skips expensive sampling operations
|
2171
|
+
while still providing requested probability information.
|
2172
|
+
|
2173
|
+
Args:
|
2174
|
+
logits_output: The logits output from the model forward
|
2175
|
+
forward_batch: The forward batch that generates logits_output
|
2176
|
+
"""
|
2177
|
+
if not forward_batch.token_ids_logprobs:
|
2178
|
+
return
|
2179
|
+
|
2180
|
+
# Preprocess logits (same as in sample method)
|
2181
|
+
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
2182
|
+
|
2183
|
+
# Delegate to sampler for logprob-only computation
|
2184
|
+
# This populates logits_output with requested token probabilities
|
2185
|
+
self.sampler.compute_logprobs_only(
|
2186
|
+
logits_output,
|
2187
|
+
forward_batch.sampling_info,
|
2188
|
+
forward_batch.return_logprob,
|
2189
|
+
forward_batch.top_logprobs_nums,
|
2190
|
+
forward_batch.token_ids_logprobs,
|
2191
|
+
)
|
2192
|
+
|
1866
2193
|
@property
|
1867
2194
|
def model_is_mrope(self) -> bool:
|
1868
2195
|
"""Detect if the model has "mrope" rope_scaling type.
|
@@ -1,16 +1,22 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
3
7
|
from torch import nn
|
4
8
|
|
5
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
6
|
-
from sglang.srt.configs.load_config import LoadConfig
|
7
|
-
from sglang.srt.configs.model_config import ModelConfig
|
8
9
|
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
|
9
10
|
from sglang.srt.model_loader.utils import (
|
10
11
|
get_architecture_class_name,
|
11
12
|
get_model_architecture,
|
12
13
|
)
|
13
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
17
|
+
from sglang.srt.configs.load_config import LoadConfig
|
18
|
+
from sglang.srt.configs.model_config import ModelConfig
|
19
|
+
|
14
20
|
|
15
21
|
def get_model(
|
16
22
|
*,
|