sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
|
-
import collections
|
17
16
|
import datetime
|
18
17
|
import gc
|
19
18
|
import inspect
|
@@ -32,6 +31,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
|
32
31
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
33
32
|
from sglang.srt.distributed import (
|
34
33
|
get_tp_group,
|
34
|
+
get_world_group,
|
35
35
|
init_distributed_environment,
|
36
36
|
initialize_model_parallel,
|
37
37
|
set_custom_all_reduce,
|
@@ -51,6 +51,18 @@ from sglang.srt.layers.quantization.deep_gemm import (
|
|
51
51
|
from sglang.srt.layers.sampler import Sampler
|
52
52
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
53
53
|
from sglang.srt.lora.lora_manager import LoRAManager
|
54
|
+
from sglang.srt.managers.eplb_manager import EPLBManager
|
55
|
+
from sglang.srt.managers.expert_distribution import (
|
56
|
+
ExpertDistributionRecorder,
|
57
|
+
get_global_expert_distribution_recorder,
|
58
|
+
set_global_expert_distribution_recorder,
|
59
|
+
)
|
60
|
+
from sglang.srt.managers.expert_location import (
|
61
|
+
ExpertLocationMetadata,
|
62
|
+
compute_initial_expert_location_metadata,
|
63
|
+
get_global_expert_location_metadata,
|
64
|
+
set_global_expert_location_metadata,
|
65
|
+
)
|
54
66
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
55
67
|
from sglang.srt.mem_cache.memory_pool import (
|
56
68
|
DoubleSparseTokenToKVPool,
|
@@ -60,6 +72,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
60
72
|
TokenToKVPoolAllocator,
|
61
73
|
)
|
62
74
|
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
75
|
+
from sglang.srt.model_executor import expert_location_updater
|
63
76
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
64
77
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
65
78
|
from sglang.srt.model_loader import get_model
|
@@ -93,6 +106,8 @@ from sglang.srt.utils import (
|
|
93
106
|
set_cuda_arch,
|
94
107
|
)
|
95
108
|
|
109
|
+
_is_hip = is_hip()
|
110
|
+
|
96
111
|
# Use a small KV cache pool size for tests in CI
|
97
112
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
98
113
|
|
@@ -102,6 +117,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
|
102
117
|
logger = logging.getLogger(__name__)
|
103
118
|
|
104
119
|
|
120
|
+
class RankZeroFilter(logging.Filter):
|
121
|
+
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
122
|
+
|
123
|
+
def __init__(self, is_rank_zero):
|
124
|
+
super().__init__()
|
125
|
+
self.is_rank_zero = is_rank_zero
|
126
|
+
|
127
|
+
def filter(self, record):
|
128
|
+
if record.levelno == logging.INFO:
|
129
|
+
return self.is_rank_zero
|
130
|
+
return True
|
131
|
+
|
132
|
+
|
105
133
|
class ModelRunner:
|
106
134
|
"""ModelRunner runs the forward passes of the models."""
|
107
135
|
|
@@ -125,6 +153,10 @@ class ModelRunner:
|
|
125
153
|
self.mem_fraction_static = mem_fraction_static
|
126
154
|
self.device = server_args.device
|
127
155
|
self.gpu_id = gpu_id
|
156
|
+
|
157
|
+
# Apply the rank zero filter to logger
|
158
|
+
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
159
|
+
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
128
160
|
self.tp_rank = tp_rank
|
129
161
|
self.tp_size = tp_size
|
130
162
|
self.pp_rank = pp_rank
|
@@ -134,7 +166,9 @@ class ModelRunner:
|
|
134
166
|
self.is_draft_worker = is_draft_worker
|
135
167
|
self.is_generation = model_config.is_generation
|
136
168
|
self.is_multimodal = model_config.is_multimodal
|
137
|
-
self.
|
169
|
+
self.is_multimodal_chunked_prefill_supported = (
|
170
|
+
model_config.is_multimodal_chunked_prefill_supported
|
171
|
+
)
|
138
172
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
139
173
|
server_args.speculative_algorithm
|
140
174
|
)
|
@@ -144,6 +178,8 @@ class ModelRunner:
|
|
144
178
|
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
145
179
|
self.attention_chunk_size = model_config.attention_chunk_size
|
146
180
|
|
181
|
+
self.forward_pass_id = 0
|
182
|
+
|
147
183
|
# Model-specific adjustment
|
148
184
|
self.model_specific_adjustment()
|
149
185
|
|
@@ -162,10 +198,13 @@ class ModelRunner:
|
|
162
198
|
"disable_radix_cache": server_args.disable_radix_cache,
|
163
199
|
"enable_nan_detection": server_args.enable_nan_detection,
|
164
200
|
"enable_dp_attention": server_args.enable_dp_attention,
|
201
|
+
"enable_dp_lm_head": server_args.enable_dp_lm_head,
|
165
202
|
"enable_ep_moe": server_args.enable_ep_moe,
|
166
203
|
"enable_deepep_moe": server_args.enable_deepep_moe,
|
204
|
+
"deepep_config": server_args.deepep_config,
|
167
205
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
168
206
|
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
207
|
+
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
|
169
208
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
170
209
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
171
210
|
"torchao_config": server_args.torchao_config,
|
@@ -174,6 +213,7 @@ class ModelRunner:
|
|
174
213
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
175
214
|
"use_mla_backend": self.use_mla_backend,
|
176
215
|
"mm_attention_backend": server_args.mm_attention_backend,
|
216
|
+
"ep_num_redundant_experts": server_args.ep_num_redundant_experts,
|
177
217
|
}
|
178
218
|
)
|
179
219
|
|
@@ -201,6 +241,31 @@ class ModelRunner:
|
|
201
241
|
enable=self.server_args.enable_memory_saver
|
202
242
|
)
|
203
243
|
|
244
|
+
if not self.is_draft_worker:
|
245
|
+
set_global_expert_location_metadata(
|
246
|
+
compute_initial_expert_location_metadata(server_args, self.model_config)
|
247
|
+
)
|
248
|
+
if self.tp_rank == 0 and get_bool_env_var(
|
249
|
+
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
|
250
|
+
):
|
251
|
+
logger.info(
|
252
|
+
f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
|
253
|
+
)
|
254
|
+
|
255
|
+
set_global_expert_distribution_recorder(
|
256
|
+
ExpertDistributionRecorder.init_new(
|
257
|
+
server_args,
|
258
|
+
get_global_expert_location_metadata(),
|
259
|
+
rank=self.tp_rank,
|
260
|
+
)
|
261
|
+
)
|
262
|
+
|
263
|
+
self.eplb_manager = (
|
264
|
+
EPLBManager(self)
|
265
|
+
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
266
|
+
else None
|
267
|
+
)
|
268
|
+
|
204
269
|
# Load the model
|
205
270
|
self.sampler = Sampler()
|
206
271
|
self.load_model()
|
@@ -269,6 +334,8 @@ class ModelRunner:
|
|
269
334
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
270
335
|
):
|
271
336
|
server_args.attention_backend = "fa3"
|
337
|
+
elif _is_hip:
|
338
|
+
server_args.attention_backend = "aiter"
|
272
339
|
else:
|
273
340
|
server_args.attention_backend = (
|
274
341
|
"flashinfer" if is_flashinfer_available() else "triton"
|
@@ -279,10 +346,9 @@ class ModelRunner:
|
|
279
346
|
server_args.attention_backend = "fa3"
|
280
347
|
else:
|
281
348
|
server_args.attention_backend = "triton"
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
)
|
349
|
+
logger.info(
|
350
|
+
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
351
|
+
)
|
286
352
|
elif self.use_mla_backend:
|
287
353
|
if server_args.device != "cpu":
|
288
354
|
if server_args.attention_backend in [
|
@@ -292,10 +358,9 @@ class ModelRunner:
|
|
292
358
|
"flashmla",
|
293
359
|
"cutlass_mla",
|
294
360
|
]:
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
)
|
361
|
+
logger.info(
|
362
|
+
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
363
|
+
)
|
299
364
|
else:
|
300
365
|
raise ValueError(
|
301
366
|
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
@@ -314,10 +379,9 @@ class ModelRunner:
|
|
314
379
|
server_args.attention_backend = "triton"
|
315
380
|
|
316
381
|
if server_args.enable_double_sparsity:
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
)
|
382
|
+
logger.info(
|
383
|
+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
384
|
+
)
|
321
385
|
server_args.attention_backend = "triton"
|
322
386
|
server_args.disable_cuda_graph = True
|
323
387
|
if server_args.ds_heavy_channel_type is None:
|
@@ -328,26 +392,25 @@ class ModelRunner:
|
|
328
392
|
|
329
393
|
if self.is_multimodal:
|
330
394
|
self.mem_fraction_static *= 0.90
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
395
|
+
logger.info(
|
396
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
397
|
+
f"because this is a multimodal model."
|
398
|
+
)
|
399
|
+
if not self.is_multimodal_chunked_prefill_supported:
|
400
|
+
server_args.chunked_prefill_size = -1
|
336
401
|
logger.info(
|
337
|
-
"Automatically turn
|
402
|
+
f"Automatically turn of --chunked-prefill-size as it is not supported for "
|
403
|
+
f"{self.model_config.hf_config.model_type}"
|
338
404
|
)
|
339
|
-
server_args.chunked_prefill_size = -1
|
340
405
|
|
341
406
|
if not self.use_mla_backend:
|
342
407
|
server_args.disable_chunked_prefix_cache = True
|
343
408
|
elif self.page_size > 1:
|
344
|
-
|
345
|
-
logger.info("Disable chunked prefix cache when page size > 1.")
|
409
|
+
logger.info("Disable chunked prefix cache when page size > 1.")
|
346
410
|
server_args.disable_chunked_prefix_cache = True
|
347
411
|
|
348
412
|
if not server_args.disable_chunked_prefix_cache:
|
349
|
-
|
350
|
-
logger.info("Chunked prefix cache is turned on.")
|
413
|
+
logger.info("Chunked prefix cache is turned on.")
|
351
414
|
|
352
415
|
def init_torch_distributed(self):
|
353
416
|
logger.info("Init torch distributed begin.")
|
@@ -400,11 +463,15 @@ class ModelRunner:
|
|
400
463
|
tp_rank=self.tp_rank,
|
401
464
|
tp_size=self.tp_size,
|
402
465
|
dp_size=self.server_args.dp_size,
|
466
|
+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
403
467
|
pp_size=self.server_args.pp_size,
|
404
468
|
)
|
405
469
|
|
406
470
|
min_per_gpu_memory = get_available_gpu_memory(
|
407
|
-
self.device,
|
471
|
+
self.device,
|
472
|
+
self.gpu_id,
|
473
|
+
distributed=get_world_group().world_size > 1,
|
474
|
+
cpu_group=get_world_group().cpu_group,
|
408
475
|
)
|
409
476
|
self.tp_group = get_tp_group()
|
410
477
|
self.attention_tp_group = get_attention_tp_group()
|
@@ -440,10 +507,9 @@ class ModelRunner:
|
|
440
507
|
torch.set_num_threads(1)
|
441
508
|
if self.device == "cuda":
|
442
509
|
if torch.cuda.get_device_capability()[0] < 8:
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
)
|
510
|
+
logger.info(
|
511
|
+
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
512
|
+
)
|
447
513
|
self.server_args.dtype = "float16"
|
448
514
|
self.model_config.dtype = torch.float16
|
449
515
|
if torch.cuda.get_device_capability()[1] < 5:
|
@@ -479,11 +545,10 @@ class ModelRunner:
|
|
479
545
|
self.model.load_kv_cache_scales(
|
480
546
|
self.server_args.quantization_param_path
|
481
547
|
)
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
)
|
548
|
+
logger.info(
|
549
|
+
"Loaded KV cache scaling factors from %s",
|
550
|
+
self.server_args.quantization_param_path,
|
551
|
+
)
|
487
552
|
else:
|
488
553
|
raise RuntimeError(
|
489
554
|
"Using FP8 KV cache and scaling factors provided but "
|
@@ -526,6 +591,16 @@ class ModelRunner:
|
|
526
591
|
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
527
592
|
) from None
|
528
593
|
|
594
|
+
def update_expert_location(
|
595
|
+
self, new_expert_location_metadata: ExpertLocationMetadata
|
596
|
+
):
|
597
|
+
expert_location_updater.update_expert_location(
|
598
|
+
self.model.routed_experts_weights_of_layer,
|
599
|
+
new_expert_location_metadata,
|
600
|
+
nnodes=self.server_args.nnodes,
|
601
|
+
rank=self.tp_rank,
|
602
|
+
)
|
603
|
+
|
529
604
|
def update_weights_from_disk(
|
530
605
|
self, model_path: str, load_format: str
|
531
606
|
) -> tuple[bool, str]:
|
@@ -547,13 +622,7 @@ class ModelRunner:
|
|
547
622
|
|
548
623
|
def get_weight_iter(config):
|
549
624
|
iter = loader._get_weights_iterator(
|
550
|
-
DefaultModelLoader.Source(
|
551
|
-
config.model_path,
|
552
|
-
revision=config.revision,
|
553
|
-
fall_back_to_pt=getattr(
|
554
|
-
self.model, "fall_back_to_pt_during_load", True
|
555
|
-
),
|
556
|
-
)
|
625
|
+
DefaultModelLoader.Source.init_new(config, self.model)
|
557
626
|
)
|
558
627
|
return iter
|
559
628
|
|
@@ -626,7 +695,6 @@ class ModelRunner:
|
|
626
695
|
rank=rank,
|
627
696
|
group_name=group_name,
|
628
697
|
)
|
629
|
-
dist.barrier(group=self._model_update_group, device_ids=[rank])
|
630
698
|
return True, "Succeeded to initialize custom process group."
|
631
699
|
except Exception as e:
|
632
700
|
message = f"Failed to initialize custom process group: {e}."
|
@@ -716,14 +784,20 @@ class ModelRunner:
|
|
716
784
|
|
717
785
|
def profile_max_num_token(self, total_gpu_memory: int):
|
718
786
|
available_gpu_memory = get_available_gpu_memory(
|
719
|
-
self.device,
|
787
|
+
self.device,
|
788
|
+
self.gpu_id,
|
789
|
+
distributed=get_world_group().world_size > 1,
|
790
|
+
cpu_group=get_world_group().cpu_group,
|
720
791
|
)
|
721
|
-
if self.
|
722
|
-
num_layers = (
|
723
|
-
self.model_config.
|
724
|
-
|
725
|
-
|
792
|
+
if self.is_draft_worker:
|
793
|
+
num_layers = getattr(
|
794
|
+
self.model_config.hf_config,
|
795
|
+
"num_nextn_predict_layers",
|
796
|
+
self.num_effective_layers,
|
726
797
|
)
|
798
|
+
else:
|
799
|
+
num_layers = self.num_effective_layers
|
800
|
+
if self.use_mla_backend:
|
727
801
|
# FIXME: pipeline parallelism is not compatible with mla backend
|
728
802
|
assert self.pp_size == 1
|
729
803
|
cell_size = (
|
@@ -735,7 +809,7 @@ class ModelRunner:
|
|
735
809
|
cell_size = (
|
736
810
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
737
811
|
* self.model_config.head_dim
|
738
|
-
*
|
812
|
+
* num_layers
|
739
813
|
* 2
|
740
814
|
* torch._utils._element_size(self.kv_cache_dtype)
|
741
815
|
)
|
@@ -754,7 +828,7 @@ class ModelRunner:
|
|
754
828
|
if self.server_args.kv_cache_dtype == "auto":
|
755
829
|
self.kv_cache_dtype = self.dtype
|
756
830
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
757
|
-
if
|
831
|
+
if _is_hip: # Using natively supported format
|
758
832
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
759
833
|
else:
|
760
834
|
self.kv_cache_dtype = torch.float8_e5m2
|
@@ -932,6 +1006,10 @@ class ModelRunner:
|
|
932
1006
|
)
|
933
1007
|
|
934
1008
|
self.attn_backend = FlashInferMLAAttnBackend(self)
|
1009
|
+
elif self.server_args.attention_backend == "aiter":
|
1010
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1011
|
+
|
1012
|
+
self.attn_backend = AiterAttnBackend(self)
|
935
1013
|
elif self.server_args.attention_backend == "triton":
|
936
1014
|
assert self.sliding_window_size is None, (
|
937
1015
|
"Window attention is not supported in the triton attention backend. "
|
@@ -1012,7 +1090,7 @@ class ModelRunner:
|
|
1012
1090
|
if self.server_args.disable_cuda_graph:
|
1013
1091
|
return
|
1014
1092
|
|
1015
|
-
tic = time.
|
1093
|
+
tic = time.perf_counter()
|
1016
1094
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1017
1095
|
logger.info(
|
1018
1096
|
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
@@ -1020,13 +1098,12 @@ class ModelRunner:
|
|
1020
1098
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
1021
1099
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1022
1100
|
logger.info(
|
1023
|
-
f"Capture cuda graph end. Time elapsed: {time.
|
1101
|
+
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1024
1102
|
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
1025
1103
|
)
|
1026
1104
|
|
1027
1105
|
def apply_torch_tp(self):
|
1028
|
-
|
1029
|
-
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1106
|
+
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1030
1107
|
from sglang.srt.model_parallel import tensor_parallel
|
1031
1108
|
|
1032
1109
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
@@ -1085,32 +1162,54 @@ class ModelRunner:
|
|
1085
1162
|
forward_batch: ForwardBatch,
|
1086
1163
|
skip_attn_backend_init: bool = False,
|
1087
1164
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
1088
|
-
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
1165
|
+
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1166
|
+
self.forward_pass_id += 1
|
1167
|
+
|
1168
|
+
with get_global_expert_distribution_recorder().with_forward_pass(
|
1169
|
+
self.forward_pass_id,
|
1170
|
+
forward_batch,
|
1171
|
+
):
|
1172
|
+
output = self._forward_raw(
|
1173
|
+
forward_batch, skip_attn_backend_init, pp_proxy_tensors
|
1174
|
+
)
|
1175
|
+
|
1176
|
+
if self.eplb_manager is not None:
|
1177
|
+
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
|
1178
|
+
|
1179
|
+
return output
|
1180
|
+
|
1181
|
+
def _forward_raw(
|
1182
|
+
self,
|
1183
|
+
forward_batch: ForwardBatch,
|
1184
|
+
skip_attn_backend_init: bool,
|
1185
|
+
pp_proxy_tensors: Optional[PPProxyTensors],
|
1186
|
+
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1089
1187
|
can_run_cuda_graph = bool(
|
1090
1188
|
forward_batch.forward_mode.is_cuda_graph()
|
1091
1189
|
and self.cuda_graph_runner
|
1092
1190
|
and self.cuda_graph_runner.can_run(forward_batch)
|
1093
1191
|
)
|
1094
1192
|
if can_run_cuda_graph:
|
1095
|
-
|
1193
|
+
ret = self.cuda_graph_runner.replay(
|
1096
1194
|
forward_batch,
|
1097
1195
|
skip_attn_backend_init=skip_attn_backend_init,
|
1098
1196
|
pp_proxy_tensors=pp_proxy_tensors,
|
1099
1197
|
)
|
1100
|
-
|
1101
|
-
|
1102
|
-
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1198
|
+
elif forward_batch.forward_mode.is_decode():
|
1199
|
+
ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1103
1200
|
elif forward_batch.forward_mode.is_extend():
|
1104
|
-
|
1201
|
+
ret = self.forward_extend(
|
1105
1202
|
forward_batch,
|
1106
1203
|
skip_attn_backend_init=skip_attn_backend_init,
|
1107
1204
|
pp_proxy_tensors=pp_proxy_tensors,
|
1108
1205
|
)
|
1109
1206
|
elif forward_batch.forward_mode.is_idle():
|
1110
|
-
|
1207
|
+
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1111
1208
|
else:
|
1112
1209
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1113
1210
|
|
1211
|
+
return ret, can_run_cuda_graph
|
1212
|
+
|
1114
1213
|
def _preprocess_logits(
|
1115
1214
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
1116
1215
|
):
|
@@ -1145,9 +1244,7 @@ class ModelRunner:
|
|
1145
1244
|
[self.sample(values, forward_batch) for values in logits_output],
|
1146
1245
|
axis=-1,
|
1147
1246
|
)
|
1148
|
-
|
1149
|
-
if sampling_info.thinking_budgets is not None:
|
1150
|
-
sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
|
1247
|
+
|
1151
1248
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
1152
1249
|
|
1153
1250
|
# Sample the next tokens
|
@@ -1158,15 +1255,13 @@ class ModelRunner:
|
|
1158
1255
|
forward_batch.top_logprobs_nums,
|
1159
1256
|
forward_batch.token_ids_logprobs,
|
1160
1257
|
)
|
1161
|
-
if sampling_info.thinking_budgets is not None:
|
1162
|
-
sampling_info.update_thinking_budgets(next_token_ids)
|
1163
1258
|
return next_token_ids
|
1164
1259
|
|
1165
1260
|
@property
|
1166
1261
|
def model_is_mrope(self) -> bool:
|
1167
1262
|
"""Detect if the model has "mrope" rope_scaling type.
|
1168
1263
|
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
1169
|
-
rope_scaling = getattr(self.model_config.
|
1264
|
+
rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
|
1170
1265
|
if rope_scaling is None:
|
1171
1266
|
return False
|
1172
1267
|
is_mrope_enabled = "mrope_section" in rope_scaling
|
@@ -197,6 +197,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|
197
197
|
fall_back_to_pt: bool = True
|
198
198
|
"""Whether .pt weights can be used."""
|
199
199
|
|
200
|
+
@classmethod
|
201
|
+
def init_new(cls, model_config: ModelConfig, model):
|
202
|
+
return cls(
|
203
|
+
model_config.model_path,
|
204
|
+
model_config.revision,
|
205
|
+
prefix="",
|
206
|
+
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
207
|
+
)
|
208
|
+
|
200
209
|
def __init__(self, load_config: LoadConfig):
|
201
210
|
super().__init__(load_config)
|
202
211
|
if load_config.model_loader_extra_config:
|
@@ -341,12 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
341
350
|
model: nn.Module,
|
342
351
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
343
352
|
|
344
|
-
primary_weights = DefaultModelLoader.Source(
|
345
|
-
model_config.model_path,
|
346
|
-
model_config.revision,
|
347
|
-
prefix="",
|
348
|
-
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
349
|
-
)
|
353
|
+
primary_weights = DefaultModelLoader.Source.init_new(model_config, model)
|
350
354
|
yield from self._get_weights_iterator(primary_weights)
|
351
355
|
|
352
356
|
secondary_weights = cast(
|
sglang/srt/models/clip.py
CHANGED
@@ -168,7 +168,7 @@ class CLIPEncoderLayer(nn.Module):
|
|
168
168
|
softmax_in_single_precision=softmax_in_single_precision,
|
169
169
|
flatten_batch=True,
|
170
170
|
quant_config=quant_config,
|
171
|
-
prefix=add_prefix("
|
171
|
+
prefix=add_prefix("self_attn", prefix),
|
172
172
|
)
|
173
173
|
self.mlp = CLIPMLP(
|
174
174
|
config,
|
@@ -395,6 +395,10 @@ class CLIPVisionModel(nn.Module):
|
|
395
395
|
config, quant_config, prefix=add_prefix("vision_model", prefix)
|
396
396
|
)
|
397
397
|
|
398
|
+
@property
|
399
|
+
def device(self) -> torch.device:
|
400
|
+
return self.vision_model.device
|
401
|
+
|
398
402
|
def forward(self, pixel_values: torch.Tensor):
|
399
403
|
return self.vision_model(pixel_values)
|
400
404
|
|
@@ -188,7 +188,7 @@ def trunc_normal_tf_(
|
|
188
188
|
best when :math:`a \\leq \text{mean} \\leq b`.
|
189
189
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
190
190
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
191
|
-
and the result is
|
191
|
+
and the result is subsequently scaled and shifted by the mean and std args.
|
192
192
|
Args:
|
193
193
|
tensor: an n-dimensional `torch.Tensor`
|
194
194
|
mean: the mean of the normal distribution
|
@@ -735,7 +735,7 @@ class VisionTransformer(nn.Module):
|
|
735
735
|
img_size: Input image size.
|
736
736
|
patch_size: Patch size.
|
737
737
|
in_chans: Number of image input channels.
|
738
|
-
num_classes:
|
738
|
+
num_classes: Number of classes for classification head.
|
739
739
|
global_pool: Type of global pooling for final sequence (default: 'token').
|
740
740
|
embed_dim: Transformer embedding dimension.
|
741
741
|
depth: Depth of transformer.
|