sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -226,11 +226,11 @@ class GenerateReqInput:
|
|
226
226
|
|
227
227
|
# Expand input based on type
|
228
228
|
self._expand_inputs(num)
|
229
|
+
self._normalize_rid(num)
|
229
230
|
self._normalize_lora_paths(num)
|
230
231
|
self._normalize_image_data(num)
|
231
232
|
self._normalize_audio_data(num)
|
232
233
|
self._normalize_sampling_params(num)
|
233
|
-
self._normalize_rid(num)
|
234
234
|
self._normalize_logprob_params(num)
|
235
235
|
self._normalize_custom_logit_processor(num)
|
236
236
|
|
@@ -319,8 +319,16 @@ class GenerateReqInput:
|
|
319
319
|
"""Normalize request IDs for batch processing."""
|
320
320
|
if self.rid is None:
|
321
321
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
322
|
-
elif
|
323
|
-
|
322
|
+
elif isinstance(self.rid, str):
|
323
|
+
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
324
|
+
self.rid = new_rids
|
325
|
+
elif isinstance(self.rid, list):
|
326
|
+
if len(self.rid) != num:
|
327
|
+
raise ValueError(
|
328
|
+
"The specified rids length mismatch with the batch_size for batch processing."
|
329
|
+
)
|
330
|
+
else:
|
331
|
+
raise ValueError("The rid should be a string or a list of strings.")
|
324
332
|
|
325
333
|
def _normalize_logprob_params(self, num):
|
326
334
|
"""Normalize logprob-related parameters for batch processing."""
|
@@ -530,6 +538,7 @@ class EmbeddingReqInput:
|
|
530
538
|
if self.text is not None:
|
531
539
|
if isinstance(self.text, list):
|
532
540
|
self.batch_size += len(self.text)
|
541
|
+
self.is_single = False
|
533
542
|
else:
|
534
543
|
self.batch_size += 1
|
535
544
|
|
@@ -537,12 +546,10 @@ class EmbeddingReqInput:
|
|
537
546
|
if self.input_ids is not None:
|
538
547
|
if isinstance(self.input_ids[0], list):
|
539
548
|
self.batch_size += len(self.input_ids)
|
549
|
+
self.is_single = False
|
540
550
|
else:
|
541
551
|
self.batch_size += 1
|
542
552
|
|
543
|
-
if self.batch_size > 1:
|
544
|
-
self.is_single = False
|
545
|
-
|
546
553
|
# Fill in default arguments
|
547
554
|
if self.is_single:
|
548
555
|
if self.rid is None:
|
@@ -812,7 +819,9 @@ class GetWeightsByNameReqOutput:
|
|
812
819
|
|
813
820
|
@dataclass
|
814
821
|
class ReleaseMemoryOccupationReqInput:
|
815
|
-
|
822
|
+
# Optional tags to identify the memory region, which is primarily used for RL
|
823
|
+
# Currently we only support `weights` and `kv_cache`
|
824
|
+
tags: Optional[List[str]] = None
|
816
825
|
|
817
826
|
|
818
827
|
@dataclass
|
@@ -822,7 +831,9 @@ class ReleaseMemoryOccupationReqOutput:
|
|
822
831
|
|
823
832
|
@dataclass
|
824
833
|
class ResumeMemoryOccupationReqInput:
|
825
|
-
|
834
|
+
# Optional tags to identify the memory region, which is primarily used for RL
|
835
|
+
# Currently we only support `weights` and `kv_cache`
|
836
|
+
tags: Optional[List[str]] = None
|
826
837
|
|
827
838
|
|
828
839
|
@dataclass
|
@@ -861,12 +872,6 @@ class SetInternalStateReq:
|
|
861
872
|
server_args: Dict[str, Any]
|
862
873
|
|
863
874
|
|
864
|
-
@dataclass
|
865
|
-
class V1RerankReqInput:
|
866
|
-
query: str
|
867
|
-
documents: List[str]
|
868
|
-
|
869
|
-
|
870
875
|
@dataclass
|
871
876
|
class SetInternalStateReqOutput:
|
872
877
|
updated: bool
|
@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum):
|
|
23
23
|
RAW_IMAGES = "raw_images"
|
24
24
|
PRECOMPUTED_FEATURES = "precomputed_features"
|
25
25
|
PIXEL_VALUES = "pixel_values"
|
26
|
+
AUDIO = "audio"
|
26
27
|
|
27
28
|
|
28
29
|
@dataclasses.dataclass
|
@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC):
|
|
441
442
|
has_image = False
|
442
443
|
has_pixel_values = False
|
443
444
|
has_precomputed_features = False
|
445
|
+
has_audio = False
|
444
446
|
|
445
447
|
for mm_input in mm_inputs:
|
446
448
|
if isinstance(mm_input, Image.Image):
|
447
449
|
has_image = True
|
450
|
+
elif isinstance(mm_input, np.ndarray):
|
451
|
+
has_audio = True
|
448
452
|
elif isinstance(mm_input, dict):
|
449
453
|
if mm_input.get("precomputed_features", None) is not None:
|
450
454
|
has_precomputed_features = True
|
@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC):
|
|
461
465
|
|
462
466
|
# Validate format consistency
|
463
467
|
format_count = sum(
|
464
|
-
[has_image, has_pixel_values, has_precomputed_features]
|
468
|
+
[has_image, has_pixel_values, has_precomputed_features, has_audio]
|
465
469
|
)
|
466
470
|
if format_count > 1:
|
467
471
|
raise ValueError(
|
468
472
|
"Unsupported: mixture of multimodal input formats. "
|
469
473
|
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
470
|
-
f"precomputed_features={has_precomputed_features}"
|
474
|
+
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
|
471
475
|
)
|
472
476
|
|
473
477
|
if has_image:
|
@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC):
|
|
476
480
|
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
477
481
|
elif has_pixel_values:
|
478
482
|
return MultimodalInputFormat.PIXEL_VALUES
|
483
|
+
elif has_audio:
|
484
|
+
return MultimodalInputFormat.AUDIO
|
479
485
|
else:
|
480
486
|
raise ValueError("No valid multimodal input format found")
|
481
487
|
except Exception as e:
|
@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC):
|
|
521
527
|
input_ids = tokenize_text(base_output.input_text)
|
522
528
|
return combined_mm_item, input_ids
|
523
529
|
|
530
|
+
def process_audio(
|
531
|
+
base_output: BaseMultiModalProcessorOutput,
|
532
|
+
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
533
|
+
"""Process inputs with audio."""
|
534
|
+
ret = self.process_mm_data(
|
535
|
+
input_text=base_output.input_text,
|
536
|
+
audio=base_output.audios, # Note: "audio" is for gemma3n only
|
537
|
+
)
|
538
|
+
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
|
539
|
+
for key, value in ret.items():
|
540
|
+
if key != "input_ids" and hasattr(combined_mm_item, key):
|
541
|
+
setattr(combined_mm_item, key, value)
|
542
|
+
input_ids = ret["input_ids"].flatten()
|
543
|
+
return combined_mm_item, input_ids
|
544
|
+
|
524
545
|
def finalize_mm_item(
|
525
546
|
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
526
547
|
) -> MultimodalDataItem:
|
527
548
|
"""Apply common post-processing to the multimodal item."""
|
528
|
-
combined_mm_item.
|
529
|
-
|
530
|
-
|
531
|
-
|
549
|
+
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
550
|
+
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
551
|
+
input_ids=input_ids,
|
552
|
+
mm_token_id=self.IM_TOKEN_ID,
|
553
|
+
)
|
554
|
+
elif combined_mm_item.modality == Modality.AUDIO:
|
555
|
+
combined_mm_item.audio_offsets = self.get_mm_items_offset(
|
556
|
+
input_ids=input_ids,
|
557
|
+
mm_token_id=self.AUDIO_TOKEN_ID,
|
558
|
+
)
|
559
|
+
elif combined_mm_item.modality == Modality.VIDEO:
|
560
|
+
combined_mm_item.video_offsets = self.get_mm_items_offset(
|
561
|
+
input_ids=input_ids,
|
562
|
+
mm_token_id=self.VIDEO_TOKEN_ID,
|
563
|
+
)
|
564
|
+
else:
|
565
|
+
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
|
532
566
|
return combined_mm_item
|
533
567
|
|
534
|
-
# Main logic
|
535
|
-
mm_inputs = base_output.images
|
568
|
+
# Main logic - determine input type and handle text-only case
|
569
|
+
mm_inputs = base_output.images or base_output.audios
|
536
570
|
if not mm_inputs:
|
537
|
-
# Return text-only case
|
538
571
|
input_ids = tokenize_text(base_output.input_text)
|
539
572
|
return None, input_ids
|
540
573
|
|
@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC):
|
|
548
581
|
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
549
582
|
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
550
583
|
combined_mm_item, input_ids = process_pixel_values(base_output)
|
584
|
+
elif input_format == MultimodalInputFormat.AUDIO:
|
585
|
+
combined_mm_item, input_ids = process_audio(base_output)
|
551
586
|
else:
|
552
587
|
raise ValueError(f"Unknown input format: {input_format}")
|
553
588
|
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
import re
|
16
|
+
from typing import Dict, List, Optional, Union
|
17
|
+
|
18
|
+
from sglang.srt.managers.multimodal_processor import (
|
19
|
+
BaseMultimodalProcessor as SGLangBaseProcessor,
|
20
|
+
)
|
21
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
22
|
+
MultimodalSpecialTokens,
|
23
|
+
)
|
24
|
+
from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
|
25
|
+
|
26
|
+
|
27
|
+
class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
28
|
+
"""Multimodal processor for Gemma3n supporting image and audio inputs."""
|
29
|
+
|
30
|
+
models = [Gemma3nForConditionalGeneration]
|
31
|
+
|
32
|
+
def __init__(self, hf_config, server_args, _processor):
|
33
|
+
super().__init__(hf_config, server_args, _processor)
|
34
|
+
|
35
|
+
self.IMAGE_TOKEN = "<image_soft_token>"
|
36
|
+
self.IMAGE_TOKEN_REGEX = re.compile(
|
37
|
+
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
38
|
+
)
|
39
|
+
|
40
|
+
self.AUDIO_TOKEN = "<audio_soft_token>"
|
41
|
+
self.AUDIO_TOKEN_REGEX = re.compile(
|
42
|
+
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
|
43
|
+
)
|
44
|
+
|
45
|
+
self.IM_TOKEN_ID = hf_config.image_token_id
|
46
|
+
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
47
|
+
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
48
|
+
|
49
|
+
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
|
50
|
+
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
|
51
|
+
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
|
52
|
+
|
53
|
+
async def process_mm_data_async(
|
54
|
+
self,
|
55
|
+
image_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
56
|
+
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
57
|
+
input_text: str = "",
|
58
|
+
request_obj=None,
|
59
|
+
max_req_input_len: int = 0,
|
60
|
+
*args,
|
61
|
+
**kwargs,
|
62
|
+
):
|
63
|
+
"""Process multimodal data including images and audio."""
|
64
|
+
|
65
|
+
audio_data = request_obj.audio_data
|
66
|
+
if not image_data and not audio_data:
|
67
|
+
return None
|
68
|
+
|
69
|
+
if isinstance(image_data, str):
|
70
|
+
image_data = [image_data]
|
71
|
+
|
72
|
+
if isinstance(audio_data, str):
|
73
|
+
audio_data = [audio_data]
|
74
|
+
|
75
|
+
base_output = self.load_mm_data(
|
76
|
+
prompt=input_text,
|
77
|
+
image_data=image_data,
|
78
|
+
audio_data=audio_data,
|
79
|
+
max_req_input_len=max_req_input_len,
|
80
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
81
|
+
image_token=self.IMAGE_TOKEN,
|
82
|
+
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
83
|
+
audio_token=self.AUDIO_TOKEN,
|
84
|
+
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
85
|
+
),
|
86
|
+
)
|
87
|
+
|
88
|
+
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
89
|
+
|
90
|
+
return {
|
91
|
+
"input_ids": input_ids.tolist(),
|
92
|
+
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
93
|
+
"im_start_id": self.IM_START_TOKEN_ID,
|
94
|
+
"im_end_id": self.IM_END_TOKEN_ID,
|
95
|
+
"audio_start_id": self.AUDIO_START_TOKEN_ID,
|
96
|
+
"audio_end_id": self.AUDIO_END_TOKEN_ID,
|
97
|
+
}
|
@@ -38,7 +38,7 @@ import logging
|
|
38
38
|
import threading
|
39
39
|
from enum import Enum, auto
|
40
40
|
from http import HTTPStatus
|
41
|
-
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
41
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
42
42
|
|
43
43
|
import numpy as np
|
44
44
|
import torch
|
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
54
54
|
)
|
55
55
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
56
56
|
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
57
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
57
58
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
58
59
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
59
|
-
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
60
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
60
61
|
from sglang.srt.metrics.collector import TimeStats
|
61
62
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
62
63
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -85,6 +86,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
85
86
|
"enable_deepep_moe",
|
86
87
|
"deepep_mode",
|
87
88
|
"enable_ep_moe",
|
89
|
+
"enable_flashinfer_moe",
|
88
90
|
"moe_dense_tp_size",
|
89
91
|
"ep_dispatch_algorithm",
|
90
92
|
"deepep_config",
|
@@ -99,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
99
101
|
"torchao_config",
|
100
102
|
"triton_attention_reduce_in_fp32",
|
101
103
|
"num_reserved_decode_tokens",
|
104
|
+
"weight_loader_disable_mmap",
|
102
105
|
]
|
103
106
|
|
104
107
|
# Put some global args for easy access
|
@@ -211,6 +214,10 @@ class MultimodalDataItem:
|
|
211
214
|
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
212
215
|
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
213
216
|
|
217
|
+
# gemma3n related
|
218
|
+
input_features: Optional[torch.Tensor] = None
|
219
|
+
input_features_mask: Optional[torch.Tensor] = None
|
220
|
+
|
214
221
|
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
215
222
|
|
216
223
|
@staticmethod
|
@@ -274,7 +281,10 @@ class MultimodalDataItem:
|
|
274
281
|
if self.precomputed_features is not None:
|
275
282
|
self.hash = hash_feature(self.precomputed_features)
|
276
283
|
elif self.is_audio():
|
277
|
-
self.
|
284
|
+
if self.audio_features is not None:
|
285
|
+
self.hash = hash_feature(self.audio_features)
|
286
|
+
elif self.input_features is not None:
|
287
|
+
self.hash = hash_feature(self.input_features)
|
278
288
|
else:
|
279
289
|
self.hash = hash_feature(self.pixel_values)
|
280
290
|
|
@@ -285,6 +295,7 @@ class MultimodalDataItem:
|
|
285
295
|
return (self.modality == Modality.AUDIO) and (
|
286
296
|
self.precomputed_features is not None
|
287
297
|
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
298
|
+
or not MultimodalDataItem.is_empty_list(self.input_features)
|
288
299
|
)
|
289
300
|
|
290
301
|
def is_image(self):
|
@@ -436,7 +447,7 @@ class Req:
|
|
436
447
|
self,
|
437
448
|
rid: str,
|
438
449
|
origin_input_text: str,
|
439
|
-
origin_input_ids:
|
450
|
+
origin_input_ids: List[int],
|
440
451
|
sampling_params: SamplingParams,
|
441
452
|
return_logprob: bool = False,
|
442
453
|
top_logprobs_num: int = 0,
|
@@ -467,7 +478,7 @@ class Req:
|
|
467
478
|
# Each decode stage's output ids
|
468
479
|
self.output_ids = []
|
469
480
|
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
470
|
-
self.fill_ids =
|
481
|
+
self.fill_ids = []
|
471
482
|
self.session_id = session_id
|
472
483
|
self.input_embeds = input_embeds
|
473
484
|
|
@@ -519,13 +530,14 @@ class Req:
|
|
519
530
|
|
520
531
|
# Prefix info
|
521
532
|
# The indices to kv cache for the shared prefix.
|
522
|
-
self.prefix_indices = []
|
533
|
+
self.prefix_indices: torch.Tensor = []
|
523
534
|
# Number of tokens to run prefill.
|
524
535
|
self.extend_input_len = 0
|
525
536
|
# The relative logprob_start_len in an extend batch
|
526
537
|
self.extend_logprob_start_len = 0
|
527
|
-
self.last_node = None
|
528
|
-
self.
|
538
|
+
self.last_node: Any = None
|
539
|
+
self.last_host_node: Any = None
|
540
|
+
self.host_hit_length = 0
|
529
541
|
|
530
542
|
# Whether or not if it is chunked. It increments whenever
|
531
543
|
# it is chunked, and decrement whenever chunked request is
|
@@ -583,6 +595,7 @@ class Req:
|
|
583
595
|
self.output_token_ids_logprobs_idx
|
584
596
|
) = None
|
585
597
|
self.hidden_states: List[List[float]] = []
|
598
|
+
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
586
599
|
|
587
600
|
# Embedding (return values)
|
588
601
|
self.embedding = None
|
@@ -644,29 +657,17 @@ class Req:
|
|
644
657
|
def init_next_round_input(
|
645
658
|
self,
|
646
659
|
tree_cache: Optional[BasePrefixCache] = None,
|
647
|
-
enable_hierarchical_cache=False,
|
648
660
|
):
|
649
661
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
650
662
|
if tree_cache is not None:
|
651
|
-
|
652
|
-
|
653
|
-
self.
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
)
|
658
|
-
|
659
|
-
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
660
|
-
rid=self.rid, key=self.adjust_max_prefix_ids()
|
661
|
-
)
|
662
|
-
elif enable_hierarchical_cache:
|
663
|
-
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
664
|
-
while self.last_node.evicted:
|
665
|
-
self.prefix_indices = self.prefix_indices[
|
666
|
-
: -len(self.last_node.host_value)
|
667
|
-
]
|
668
|
-
self.last_node = self.last_node.parent
|
669
|
-
|
663
|
+
(
|
664
|
+
self.prefix_indices,
|
665
|
+
self.last_node,
|
666
|
+
self.last_host_node,
|
667
|
+
self.host_hit_length,
|
668
|
+
) = tree_cache.match_prefix(
|
669
|
+
key=self.adjust_max_prefix_ids(),
|
670
|
+
)
|
670
671
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
671
672
|
|
672
673
|
def adjust_max_prefix_ids(self):
|
@@ -796,6 +797,7 @@ class Req:
|
|
796
797
|
self.multimodal_inputs = None
|
797
798
|
self.grammar = None
|
798
799
|
self.origin_input_ids = [0] # set it to one token to skip the long prefill
|
800
|
+
self.return_logprob = False
|
799
801
|
self.finished_reason = FINISH_ABORT(
|
800
802
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
801
803
|
)
|
@@ -820,7 +822,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
820
822
|
# Request, memory pool, and cache
|
821
823
|
reqs: List[Req]
|
822
824
|
req_to_token_pool: ReqToTokenPool = None
|
823
|
-
token_to_kv_pool_allocator:
|
825
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
|
824
826
|
tree_cache: BasePrefixCache = None
|
825
827
|
|
826
828
|
# Batch configs
|
@@ -862,6 +864,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
862
864
|
global_num_tokens: Optional[List[int]] = None
|
863
865
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
864
866
|
can_run_dp_cuda_graph: bool = False
|
867
|
+
is_extend_in_batch: bool = False
|
865
868
|
tbo_split_seq_index: Optional[int] = None
|
866
869
|
global_forward_mode: Optional[ForwardMode] = None
|
867
870
|
|
@@ -908,12 +911,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
908
911
|
# Whether to return hidden states
|
909
912
|
return_hidden_states: bool = False
|
910
913
|
|
914
|
+
# hicache pointer for synchronizing data loading from CPU to GPU
|
915
|
+
hicache_consumer_index: int = 0
|
916
|
+
|
911
917
|
@classmethod
|
912
918
|
def init_new(
|
913
919
|
cls,
|
914
920
|
reqs: List[Req],
|
915
921
|
req_to_token_pool: ReqToTokenPool,
|
916
|
-
token_to_kv_pool_allocator:
|
922
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
917
923
|
tree_cache: BasePrefixCache,
|
918
924
|
model_config: ModelConfig,
|
919
925
|
enable_overlap: bool,
|
@@ -1365,7 +1371,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1365
1371
|
return len(self.reqs)
|
1366
1372
|
# In the decoding phase, the length of a request's KV cache should be
|
1367
1373
|
# the total length of the request minus 1
|
1368
|
-
return
|
1374
|
+
return (
|
1375
|
+
sum(1 for req in self.reqs if req.seqlen % page_size == 0)
|
1376
|
+
if self.enable_overlap
|
1377
|
+
else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
|
1378
|
+
)
|
1369
1379
|
|
1370
1380
|
def check_decode_mem(self, buf_multiplier=1):
|
1371
1381
|
tokens_required = (
|
@@ -1734,6 +1744,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1734
1744
|
token_type_ids=self.token_type_ids,
|
1735
1745
|
spec_algorithm=self.spec_algorithm,
|
1736
1746
|
spec_info=self.spec_info,
|
1747
|
+
hicache_consumer_index=self.hicache_consumer_index,
|
1737
1748
|
capture_hidden_mode=(
|
1738
1749
|
CaptureHiddenMode.FULL
|
1739
1750
|
if self.return_hidden_states
|
@@ -1760,11 +1771,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1760
1771
|
decoding_reqs=self.decoding_reqs,
|
1761
1772
|
spec_algorithm=self.spec_algorithm,
|
1762
1773
|
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1774
|
+
global_num_tokens=self.global_num_tokens,
|
1775
|
+
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1776
|
+
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1777
|
+
is_extend_in_batch=self.is_extend_in_batch,
|
1763
1778
|
)
|
1764
1779
|
|
1765
1780
|
def __str__(self):
|
1766
1781
|
return (
|
1767
|
-
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
1782
|
+
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
1768
1783
|
f"#req={(len(self.reqs))})"
|
1769
1784
|
)
|
1770
1785
|
|
@@ -1833,6 +1848,8 @@ class ModelWorkerBatch:
|
|
1833
1848
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1834
1849
|
# If set, the output of the batch contains the hidden states of the run.
|
1835
1850
|
capture_hidden_mode: CaptureHiddenMode = None
|
1851
|
+
spec_num_draft_tokens: Optional[int] = None
|
1852
|
+
hicache_consumer_index: int = 0
|
1836
1853
|
|
1837
1854
|
# Overlap event
|
1838
1855
|
launch_done: Optional[threading.Event] = None
|