sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
|
|
68
68
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
69
69
|
from sglang.srt.layers.utils import is_sm100_supported
|
70
70
|
from sglang.srt.lora.lora_manager import LoRAManager
|
71
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
71
72
|
from sglang.srt.managers.schedule_batch import (
|
72
73
|
GLOBAL_SERVER_ARGS_KEYS,
|
73
74
|
global_server_args_dict,
|
@@ -108,7 +109,6 @@ from sglang.srt.utils import (
|
|
108
109
|
get_bool_env_var,
|
109
110
|
get_cpu_ids_by_node,
|
110
111
|
init_custom_process_group,
|
111
|
-
is_cuda,
|
112
112
|
is_fa3_default_architecture,
|
113
113
|
is_flashinfer_available,
|
114
114
|
is_hip,
|
@@ -275,6 +275,16 @@ class ModelRunner:
|
|
275
275
|
self.sampler = Sampler()
|
276
276
|
self.load_model()
|
277
277
|
|
278
|
+
# Check if the model is using hybrid SWA
|
279
|
+
if (
|
280
|
+
not self.server_args.disable_hybrid_swa_memory
|
281
|
+
and self.sliding_window_size is not None
|
282
|
+
and self.sliding_window_size > 0
|
283
|
+
):
|
284
|
+
architectures = self.model_config.hf_config.architectures
|
285
|
+
if architectures and not any("Llama4" in arch for arch in architectures):
|
286
|
+
self.is_hybrid = self.model_config.is_hybrid = True
|
287
|
+
|
278
288
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
279
289
|
self.end_layer = getattr(
|
280
290
|
self.model, "end_layer", self.model_config.num_hidden_layers
|
@@ -295,11 +305,7 @@ class ModelRunner:
|
|
295
305
|
self.apply_torch_tp()
|
296
306
|
|
297
307
|
# Init lora
|
298
|
-
|
299
|
-
# a new server arg `enable_lora` to control whether to init LoRA manager to be more
|
300
|
-
# explicit, as it is perfectly valid to start a server with an empty lora_paths and
|
301
|
-
# load LoRA adapters dynamically later.
|
302
|
-
if server_args.lora_paths is not None:
|
308
|
+
if server_args.enable_lora:
|
303
309
|
self.init_lora_manager()
|
304
310
|
|
305
311
|
# Init memory pool and attention backends
|
@@ -372,6 +378,7 @@ class ModelRunner:
|
|
372
378
|
is_hopper_with_cuda_12_3()
|
373
379
|
and is_no_spec_infer_or_topk_one(server_args)
|
374
380
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
381
|
+
and (not server_args.enable_hierarchical_cache)
|
375
382
|
):
|
376
383
|
server_args.attention_backend = "fa3"
|
377
384
|
elif _is_hip:
|
@@ -384,7 +391,9 @@ class ModelRunner:
|
|
384
391
|
)
|
385
392
|
else:
|
386
393
|
# MLA architecture
|
387
|
-
if is_hopper_with_cuda_12_3()
|
394
|
+
if is_hopper_with_cuda_12_3() and (
|
395
|
+
not server_args.enable_hierarchical_cache
|
396
|
+
):
|
388
397
|
server_args.attention_backend = "fa3"
|
389
398
|
elif is_sm100_supported():
|
390
399
|
server_args.attention_backend = "flashinfer"
|
@@ -402,7 +411,7 @@ class ModelRunner:
|
|
402
411
|
else:
|
403
412
|
server_args.attention_backend = "triton"
|
404
413
|
logger.info(
|
405
|
-
f"Attention backend not
|
414
|
+
f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
|
406
415
|
)
|
407
416
|
elif self.use_mla_backend:
|
408
417
|
if server_args.device != "cpu":
|
@@ -454,7 +463,7 @@ class ModelRunner:
|
|
454
463
|
if not self.is_multimodal_chunked_prefill_supported:
|
455
464
|
server_args.chunked_prefill_size = -1
|
456
465
|
logger.info(
|
457
|
-
f"Automatically turn
|
466
|
+
f"Automatically turn off --chunked-prefill-size as it is not supported for "
|
458
467
|
f"{self.model_config.hf_config.model_type}"
|
459
468
|
)
|
460
469
|
|
@@ -471,10 +480,6 @@ class ModelRunner:
|
|
471
480
|
if self.model_config.context_len > 8192:
|
472
481
|
self.mem_fraction_static *= 0.85
|
473
482
|
|
474
|
-
if self.is_hybrid and not server_args.disable_radix_cache:
|
475
|
-
logger.info("Automatically disable radix cache for hybrid cache.")
|
476
|
-
server_args.disable_radix_cache = True
|
477
|
-
|
478
483
|
def init_torch_distributed(self):
|
479
484
|
logger.info("Init torch distributed begin.")
|
480
485
|
|
@@ -534,6 +539,7 @@ class ModelRunner:
|
|
534
539
|
initialize_model_parallel(
|
535
540
|
tensor_model_parallel_size=self.tp_size,
|
536
541
|
pipeline_model_parallel_size=self.pp_size,
|
542
|
+
duplicate_tp_group=self.server_args.enable_pdmux,
|
537
543
|
)
|
538
544
|
initialize_dp_attention(
|
539
545
|
enable_dp_attention=self.server_args.enable_dp_attention,
|
@@ -555,7 +561,7 @@ class ModelRunner:
|
|
555
561
|
|
556
562
|
# Check memory for tensor parallelism
|
557
563
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
558
|
-
if self.tp_size > 1:
|
564
|
+
if self.tp_size > 1 and not self.is_draft_worker:
|
559
565
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
560
566
|
if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
|
561
567
|
logger.warning(
|
@@ -645,11 +651,15 @@ class ModelRunner:
|
|
645
651
|
)
|
646
652
|
|
647
653
|
# Parse other args
|
648
|
-
self.sliding_window_size =
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
654
|
+
self.sliding_window_size = None
|
655
|
+
if hasattr(self.model, "get_attention_sliding_window_size"):
|
656
|
+
self.sliding_window_size = self.model.get_attention_sliding_window_size()
|
657
|
+
elif self.model_config.attention_chunk_size is not None:
|
658
|
+
self.sliding_window_size = self.model_config.attention_chunk_size
|
659
|
+
print(
|
660
|
+
f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
|
661
|
+
)
|
662
|
+
|
653
663
|
self.dtype = self.model_config.dtype
|
654
664
|
|
655
665
|
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
@@ -882,44 +892,40 @@ class ModelRunner:
|
|
882
892
|
lora_backend=self.server_args.lora_backend,
|
883
893
|
tp_size=self.tp_size,
|
884
894
|
tp_rank=self.tp_rank,
|
895
|
+
max_lora_rank=self.server_args.max_lora_rank,
|
896
|
+
target_modules=self.server_args.lora_target_modules,
|
897
|
+
lora_paths=self.server_args.lora_paths,
|
885
898
|
)
|
886
|
-
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
887
|
-
if result.success:
|
888
|
-
logger.info(
|
889
|
-
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
|
890
|
-
)
|
891
|
-
else:
|
892
|
-
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
|
893
899
|
|
894
|
-
def load_lora_adapter(self,
|
900
|
+
def load_lora_adapter(self, lora_ref: LoRARef):
|
895
901
|
"""Load a new lora adapter from disk or huggingface."""
|
896
902
|
|
897
903
|
logger.info(
|
898
|
-
f"LoRA adapter loading starts:
|
904
|
+
f"LoRA adapter loading starts: {lora_ref}. "
|
899
905
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
900
906
|
)
|
901
907
|
|
902
|
-
result = self.lora_manager.load_lora_adapter(
|
908
|
+
result = self.lora_manager.load_lora_adapter(lora_ref)
|
903
909
|
|
904
910
|
logger.info(
|
905
|
-
f"LoRA adapter loading completes:
|
911
|
+
f"LoRA adapter loading completes: {lora_ref}. "
|
906
912
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
907
913
|
)
|
908
914
|
|
909
915
|
return result
|
910
916
|
|
911
|
-
def unload_lora_adapter(self,
|
917
|
+
def unload_lora_adapter(self, lora_ref: LoRARef):
|
912
918
|
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
|
913
919
|
|
914
920
|
logger.info(
|
915
|
-
f"LoRA adapter unloading starts:
|
921
|
+
f"LoRA adapter unloading starts: {lora_ref}. "
|
916
922
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
917
923
|
)
|
918
924
|
|
919
|
-
result = self.lora_manager.unload_lora_adapter(
|
925
|
+
result = self.lora_manager.unload_lora_adapter(lora_ref)
|
920
926
|
|
921
927
|
logger.info(
|
922
|
-
f"LoRA adapter unloading completes:
|
928
|
+
f"LoRA adapter unloading completes: {lora_ref}. "
|
923
929
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
924
930
|
)
|
925
931
|
|
@@ -992,8 +998,56 @@ class ModelRunner:
|
|
992
998
|
)
|
993
999
|
self.max_total_num_tokens = self.full_max_total_num_tokens
|
994
1000
|
else:
|
995
|
-
|
996
|
-
|
1001
|
+
assert self.sliding_window_size is not None and self.sliding_window_size > 0
|
1002
|
+
full_attention_layer_ids = []
|
1003
|
+
swa_attention_layer_ids = []
|
1004
|
+
|
1005
|
+
try:
|
1006
|
+
layers = self.model.model.layers
|
1007
|
+
except:
|
1008
|
+
try:
|
1009
|
+
layers = self.model.language_model.model.layers
|
1010
|
+
except:
|
1011
|
+
try:
|
1012
|
+
layers = self.model.language_model.layers
|
1013
|
+
except:
|
1014
|
+
self.is_hybrid = False
|
1015
|
+
return
|
1016
|
+
|
1017
|
+
for layer in layers:
|
1018
|
+
if (
|
1019
|
+
layer.self_attn.attn.sliding_window_size is None
|
1020
|
+
or layer.self_attn.attn.sliding_window_size == -1
|
1021
|
+
):
|
1022
|
+
full_attention_layer_ids.append(layer.layer_id)
|
1023
|
+
else:
|
1024
|
+
swa_attention_layer_ids.append(layer.layer_id)
|
1025
|
+
self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
|
1026
|
+
self.model_config.full_attention_layer_ids = full_attention_layer_ids
|
1027
|
+
|
1028
|
+
# Algorithm:
|
1029
|
+
# Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
|
1030
|
+
# - Find total # of tokens available across layers.
|
1031
|
+
# - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
|
1032
|
+
total_tokens = (
|
1033
|
+
self.max_total_num_tokens * self.model_config.num_hidden_layers
|
1034
|
+
)
|
1035
|
+
full_layers_num = len(full_attention_layer_ids)
|
1036
|
+
swa_layers_num = len(swa_attention_layer_ids)
|
1037
|
+
swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio
|
1038
|
+
|
1039
|
+
# Solve the equations:
|
1040
|
+
# 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
|
1041
|
+
# 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
|
1042
|
+
denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
|
1043
|
+
self.full_max_total_num_tokens = int(total_tokens / denominator)
|
1044
|
+
self.swa_max_total_num_tokens = int(
|
1045
|
+
self.full_max_total_num_tokens * swa_full_tokens_ratio
|
1046
|
+
)
|
1047
|
+
self.max_total_num_tokens = self.full_max_total_num_tokens
|
1048
|
+
|
1049
|
+
logger.info(
|
1050
|
+
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
|
997
1051
|
)
|
998
1052
|
|
999
1053
|
def init_memory_pool(
|
@@ -1072,7 +1126,6 @@ class ModelRunner:
|
|
1072
1126
|
// self.server_args.page_size
|
1073
1127
|
* self.server_args.page_size
|
1074
1128
|
)
|
1075
|
-
|
1076
1129
|
# create token size for hybrid cache
|
1077
1130
|
if self.is_hybrid:
|
1078
1131
|
self.set_num_token_hybrid()
|
@@ -1410,9 +1463,13 @@ class ModelRunner:
|
|
1410
1463
|
tensor_parallel(self.model, device_mesh)
|
1411
1464
|
|
1412
1465
|
def forward_decode(
|
1413
|
-
self,
|
1466
|
+
self,
|
1467
|
+
forward_batch: ForwardBatch,
|
1468
|
+
skip_attn_backend_init: bool = False,
|
1469
|
+
pp_proxy_tensors=None,
|
1414
1470
|
) -> LogitsProcessorOutput:
|
1415
|
-
|
1471
|
+
if not skip_attn_backend_init:
|
1472
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
1416
1473
|
# FIXME: add pp_proxy_tensors arg to all models
|
1417
1474
|
kwargs = {}
|
1418
1475
|
if self.support_pp:
|
@@ -1457,11 +1514,34 @@ class ModelRunner:
|
|
1457
1514
|
**kwargs,
|
1458
1515
|
)
|
1459
1516
|
|
1517
|
+
def forward_split_prefill(
|
1518
|
+
self,
|
1519
|
+
forward_batch: ForwardBatch,
|
1520
|
+
reinit_attn_backend: bool = False,
|
1521
|
+
forward_count: int = 1,
|
1522
|
+
) -> LogitsProcessorOutput:
|
1523
|
+
if forward_batch.split_index == 0 or reinit_attn_backend:
|
1524
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
1525
|
+
next_split_index = min(
|
1526
|
+
forward_batch.split_index + forward_count,
|
1527
|
+
self.model_config.num_hidden_layers,
|
1528
|
+
)
|
1529
|
+
ret = self.model.forward_split_prefill(
|
1530
|
+
forward_batch.input_ids,
|
1531
|
+
forward_batch.positions,
|
1532
|
+
forward_batch,
|
1533
|
+
(forward_batch.split_index, next_split_index),
|
1534
|
+
)
|
1535
|
+
forward_batch.split_index = next_split_index
|
1536
|
+
return ret
|
1537
|
+
|
1460
1538
|
def forward(
|
1461
1539
|
self,
|
1462
1540
|
forward_batch: ForwardBatch,
|
1463
1541
|
skip_attn_backend_init: bool = False,
|
1464
1542
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
1543
|
+
reinit_attn_backend: bool = False,
|
1544
|
+
split_forward_count: int = 1,
|
1465
1545
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1466
1546
|
self.forward_pass_id += 1
|
1467
1547
|
|
@@ -1470,7 +1550,11 @@ class ModelRunner:
|
|
1470
1550
|
forward_batch,
|
1471
1551
|
):
|
1472
1552
|
output = self._forward_raw(
|
1473
|
-
forward_batch,
|
1553
|
+
forward_batch,
|
1554
|
+
skip_attn_backend_init,
|
1555
|
+
pp_proxy_tensors,
|
1556
|
+
reinit_attn_backend,
|
1557
|
+
split_forward_count,
|
1474
1558
|
)
|
1475
1559
|
|
1476
1560
|
if self.eplb_manager is not None:
|
@@ -1483,6 +1567,8 @@ class ModelRunner:
|
|
1483
1567
|
forward_batch: ForwardBatch,
|
1484
1568
|
skip_attn_backend_init: bool,
|
1485
1569
|
pp_proxy_tensors: Optional[PPProxyTensors],
|
1570
|
+
reinit_attn_backend: bool = False,
|
1571
|
+
split_forward_count: int = 1,
|
1486
1572
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1487
1573
|
can_run_cuda_graph = bool(
|
1488
1574
|
forward_batch.forward_mode.is_cuda_graph()
|
@@ -1495,19 +1581,38 @@ class ModelRunner:
|
|
1495
1581
|
skip_attn_backend_init=skip_attn_backend_init,
|
1496
1582
|
pp_proxy_tensors=pp_proxy_tensors,
|
1497
1583
|
)
|
1498
|
-
|
1499
|
-
|
1584
|
+
return ret, can_run_cuda_graph
|
1585
|
+
|
1586
|
+
# For MLP sync
|
1587
|
+
if forward_batch.global_num_tokens_cpu is not None:
|
1588
|
+
forward_batch.prepare_mlp_sync_batch(self)
|
1589
|
+
|
1590
|
+
if forward_batch.forward_mode.is_decode():
|
1591
|
+
ret = self.forward_decode(
|
1592
|
+
forward_batch,
|
1593
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
1594
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
1595
|
+
)
|
1500
1596
|
elif forward_batch.forward_mode.is_extend():
|
1501
1597
|
ret = self.forward_extend(
|
1502
1598
|
forward_batch,
|
1503
1599
|
skip_attn_backend_init=skip_attn_backend_init,
|
1504
1600
|
pp_proxy_tensors=pp_proxy_tensors,
|
1505
1601
|
)
|
1602
|
+
elif forward_batch.forward_mode.is_split_prefill():
|
1603
|
+
ret = self.forward_split_prefill(
|
1604
|
+
forward_batch,
|
1605
|
+
reinit_attn_backend=reinit_attn_backend,
|
1606
|
+
forward_count=split_forward_count,
|
1607
|
+
)
|
1506
1608
|
elif forward_batch.forward_mode.is_idle():
|
1507
1609
|
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1508
1610
|
else:
|
1509
1611
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1510
1612
|
|
1613
|
+
if forward_batch.global_num_tokens_cpu is not None:
|
1614
|
+
forward_batch.post_forward_mlp_sync_batch(ret)
|
1615
|
+
|
1511
1616
|
return ret, can_run_cuda_graph
|
1512
1617
|
|
1513
1618
|
def _preprocess_logits(
|
@@ -575,7 +575,13 @@ class DummyModelLoader(BaseModelLoader):
|
|
575
575
|
# 2. Post-processing of weights, including assigning specific member variables.
|
576
576
|
# For `dummy_init`, only the second stage is required.
|
577
577
|
if hasattr(model, "post_load_weights"):
|
578
|
-
|
578
|
+
if (
|
579
|
+
model_config.hf_config.architectures[0]
|
580
|
+
== "DeepseekV3ForCausalLMNextN"
|
581
|
+
):
|
582
|
+
model.post_load_weights(is_nextn=True)
|
583
|
+
else:
|
584
|
+
model.post_load_weights()
|
579
585
|
|
580
586
|
return model.eval()
|
581
587
|
|
sglang/srt/model_loader/utils.py
CHANGED
@@ -56,14 +56,14 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str
|
|
56
56
|
"if the model is custom)."
|
57
57
|
)
|
58
58
|
model_module = auto_modules["AutoModel"]
|
59
|
-
if model_config.
|
59
|
+
if model_config.model_impl == ModelImpl.TRANSFORMERS:
|
60
60
|
if not model_module.is_backend_compatible():
|
61
61
|
raise ValueError(
|
62
62
|
f"The Transformers implementation of {arch} is not "
|
63
|
-
"compatible with
|
63
|
+
"compatible with SGLang."
|
64
64
|
)
|
65
65
|
architectures[i] = "TransformersForCausalLM"
|
66
|
-
if model_config.
|
66
|
+
if model_config.model_impl == ModelImpl.AUTO:
|
67
67
|
if not model_module.is_backend_compatible():
|
68
68
|
raise ValueError(
|
69
69
|
f"{arch} has no SGlang implementation and the Transformers "
|
@@ -97,7 +97,7 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
|
|
97
97
|
supported_archs = ModelRegistry.get_supported_archs()
|
98
98
|
is_native_supported = any(arch in supported_archs for arch in architectures)
|
99
99
|
|
100
|
-
if not is_native_supported or model_config.
|
100
|
+
if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS:
|
101
101
|
architectures = resolve_transformers_arch(model_config, architectures)
|
102
102
|
|
103
103
|
return ModelRegistry.resolve_model_cls(architectures)
|
sglang/srt/models/clip.py
CHANGED
@@ -463,7 +463,7 @@ class CLIPModel(nn.Module):
|
|
463
463
|
if forward_batch.mm_inputs is not None:
|
464
464
|
mm_inputs = forward_batch.mm_inputs
|
465
465
|
pixel_values_list = [
|
466
|
-
item.
|
466
|
+
item.feature
|
467
467
|
for item in flatten_nested_list(
|
468
468
|
[mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
|
469
469
|
)
|
sglang/srt/models/deepseek.py
CHANGED
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
40
|
+
from sglang.srt.layers.moe.topk import TopK
|
40
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
43
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -109,7 +110,10 @@ class DeepseekMoE(nn.Module):
|
|
109
110
|
f"Tensor parallel size {self.tp_size} is greater than "
|
110
111
|
f"the number of experts {self.n_routed_experts}."
|
111
112
|
)
|
112
|
-
|
113
|
+
self.topk = TopK(
|
114
|
+
top_k=self.top_k,
|
115
|
+
renormalize=config.norm_topk_prob,
|
116
|
+
)
|
113
117
|
self.experts = nn.ModuleList(
|
114
118
|
[
|
115
119
|
DeepseekMLP(
|
@@ -170,13 +174,12 @@ class DeepseekMoE(nn.Module):
|
|
170
174
|
shared_output = self.shared_experts(hidden_states)
|
171
175
|
# router_logits: (num_tokens, n_experts)
|
172
176
|
router_logits, _ = self.gate(hidden_states)
|
177
|
+
topk_output = self.topk(hidden_states, router_logits)
|
173
178
|
final_hidden_states = fused_moe.fused_moe(
|
174
179
|
hidden_states,
|
175
|
-
self.w1,
|
176
|
-
self.w2,
|
177
|
-
|
178
|
-
self.top_k,
|
179
|
-
renormalize=self.config.norm_topk_prob,
|
180
|
+
w1=self.w1,
|
181
|
+
w2=self.w2,
|
182
|
+
topk_output=topk_output,
|
180
183
|
inplace=True,
|
181
184
|
)
|
182
185
|
|
@@ -1960,7 +1960,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1960
1960
|
self.logits_processor = LogitsProcessor(config)
|
1961
1961
|
|
1962
1962
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
1963
|
-
pixel_values = torch.concat([item.
|
1963
|
+
pixel_values = torch.concat([item.feature for item in items], dim=0)
|
1964
1964
|
bs, n = pixel_values.shape[0:2]
|
1965
1965
|
pixel_values = pixel_values.to(
|
1966
1966
|
device=self.vision_model.device, dtype=self.vision_model.dtype
|