sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,12 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Radix attention."""
|
15
15
|
|
16
|
+
from typing import Optional
|
17
|
+
|
16
18
|
from torch import nn
|
17
19
|
|
20
|
+
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
21
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
18
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
23
|
|
20
24
|
|
@@ -34,7 +38,9 @@ class RadixAttention(nn.Module):
|
|
34
38
|
v_head_dim: int = -1,
|
35
39
|
sliding_window_size: int = -1,
|
36
40
|
is_cross_attention: bool = False,
|
41
|
+
quant_config: Optional[QuantizationConfig] = None,
|
37
42
|
prefix: str = "",
|
43
|
+
use_irope: bool = False,
|
38
44
|
):
|
39
45
|
super().__init__()
|
40
46
|
self.tp_q_head_num = num_heads
|
@@ -48,8 +54,16 @@ class RadixAttention(nn.Module):
|
|
48
54
|
self.logit_cap = logit_cap
|
49
55
|
self.sliding_window_size = sliding_window_size or -1
|
50
56
|
self.is_cross_attention = is_cross_attention
|
57
|
+
self.use_irope = use_irope
|
51
58
|
self.k_scale = None
|
52
59
|
self.v_scale = None
|
60
|
+
self.k_scale_float = None
|
61
|
+
self.v_scale_float = None
|
62
|
+
self.quant_method = None
|
63
|
+
if quant_config is not None:
|
64
|
+
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
65
|
+
if self.quant_method is not None:
|
66
|
+
self.quant_method.create_weights(self)
|
53
67
|
|
54
68
|
def forward(
|
55
69
|
self,
|
@@ -645,7 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
645
645
|
cache = torch.cat((cos, sin), dim=-1)
|
646
646
|
return cache
|
647
647
|
|
648
|
-
def
|
648
|
+
def forward_hip(self, *args, **kwargs):
|
649
|
+
return self.forward_native(*args, **kwargs)
|
650
|
+
|
651
|
+
def forward(self, *args, **kwargs):
|
652
|
+
if torch.compiler.is_compiling():
|
653
|
+
return self.forward_native(*args, **kwargs)
|
654
|
+
if _is_cuda_available:
|
655
|
+
return self.forward_cuda(*args, **kwargs)
|
656
|
+
else:
|
657
|
+
return self.forward_native(*args, **kwargs)
|
658
|
+
|
659
|
+
def forward_native(
|
649
660
|
self,
|
650
661
|
positions: torch.Tensor,
|
651
662
|
query: torch.Tensor,
|
@@ -733,6 +744,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
|
733
744
|
return new_freqs
|
734
745
|
|
735
746
|
|
747
|
+
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
748
|
+
|
749
|
+
def __init__(
|
750
|
+
self,
|
751
|
+
head_size: int,
|
752
|
+
rotary_dim: int,
|
753
|
+
max_position_embeddings: int,
|
754
|
+
base: int,
|
755
|
+
is_neox_style: bool,
|
756
|
+
dtype: torch.dtype,
|
757
|
+
):
|
758
|
+
super().__init__(
|
759
|
+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
760
|
+
)
|
761
|
+
|
762
|
+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
763
|
+
inv_freqs = super()._compute_inv_freq(base)
|
764
|
+
inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
|
765
|
+
return inv_freqs
|
766
|
+
|
767
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
768
|
+
inv_freq = self._compute_inv_freq(self.base)
|
769
|
+
|
770
|
+
# self.max_position_embeddings here is number of image patches
|
771
|
+
# i.e. (image_size // patch_size) ** 2
|
772
|
+
num_patches = self.max_position_embeddings
|
773
|
+
img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
|
774
|
+
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
775
|
+
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
|
776
|
+
num_patches_single_dim = int(math.sqrt(num_patches))
|
777
|
+
frequencies_x = img_idx % num_patches_single_dim
|
778
|
+
frequencies_y = img_idx // num_patches_single_dim
|
779
|
+
freqs_x = (
|
780
|
+
(frequencies_x + 1)[..., None] * inv_freq[None, None, :]
|
781
|
+
).repeat_interleave(2, dim=-1)
|
782
|
+
freqs_y = (
|
783
|
+
(frequencies_y + 1)[..., None] * inv_freq[None, None, :]
|
784
|
+
).repeat_interleave(2, dim=-1)
|
785
|
+
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
786
|
+
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
787
|
+
cache = torch.view_as_complex(
|
788
|
+
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
789
|
+
)
|
790
|
+
return cache
|
791
|
+
|
792
|
+
def forward(
|
793
|
+
self,
|
794
|
+
query: torch.Tensor,
|
795
|
+
key: torch.Tensor,
|
796
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
797
|
+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
798
|
+
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
|
799
|
+
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
|
800
|
+
broadcast_shape = [
|
801
|
+
d if i == 1 or i == (query_.ndim - 1) else 1
|
802
|
+
for i, d in enumerate(query_.shape)
|
803
|
+
]
|
804
|
+
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
|
805
|
+
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
806
|
+
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
807
|
+
return query_out.type_as(query), key_out.type_as(key)
|
808
|
+
|
809
|
+
|
736
810
|
class MRotaryEmbedding(RotaryEmbedding):
|
737
811
|
"""Rotary Embedding with Multimodal Sections."""
|
738
812
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,7 +20,13 @@ import copy
|
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import Any, Dict, List, Literal, Optional, Union
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
24
|
+
|
25
|
+
# handle serialization of Image for pydantic
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from PIL.Image import Image
|
28
|
+
else:
|
29
|
+
Image = Any
|
24
30
|
|
25
31
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
32
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -42,10 +48,16 @@ class GenerateReqInput:
|
|
42
48
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
43
49
|
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
44
50
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
45
|
-
# The image input. It can be
|
46
|
-
#
|
47
|
-
|
48
|
-
#
|
51
|
+
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
52
|
+
# Can be formatted as:
|
53
|
+
# - Single image for a single request
|
54
|
+
# - List of images (one per request in a batch)
|
55
|
+
# - List of lists of images (multiple images per request)
|
56
|
+
# See also python/sglang/srt/utils.py:load_image for more details.
|
57
|
+
image_data: Optional[
|
58
|
+
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
59
|
+
] = None
|
60
|
+
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
49
61
|
audio_data: Optional[Union[List[str], str]] = None
|
50
62
|
# The sampling_params. See descriptions below.
|
51
63
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
@@ -83,7 +95,36 @@ class GenerateReqInput:
|
|
83
95
|
# Whether to return hidden states
|
84
96
|
return_hidden_states: bool = False
|
85
97
|
|
98
|
+
# For disaggregated inference
|
99
|
+
bootstrap_host: Optional[str] = None
|
100
|
+
bootstrap_room: Optional[int] = None
|
101
|
+
|
86
102
|
def normalize_batch_and_arguments(self):
|
103
|
+
"""
|
104
|
+
Normalize the batch size and arguments for the request.
|
105
|
+
|
106
|
+
This method resolves various input formats and ensures all parameters
|
107
|
+
are properly formatted as either single values or batches depending on the input.
|
108
|
+
It also handles parallel sampling expansion and sets default values for
|
109
|
+
unspecified parameters.
|
110
|
+
|
111
|
+
Raises:
|
112
|
+
ValueError: If inputs are not properly specified (e.g., none or all of
|
113
|
+
text, input_ids, input_embeds are provided)
|
114
|
+
"""
|
115
|
+
self._validate_inputs()
|
116
|
+
self._determine_batch_size()
|
117
|
+
self._handle_parallel_sampling()
|
118
|
+
|
119
|
+
if self.is_single:
|
120
|
+
self._normalize_single_inputs()
|
121
|
+
else:
|
122
|
+
self._normalize_batch_inputs()
|
123
|
+
|
124
|
+
self._validate_session_params()
|
125
|
+
|
126
|
+
def _validate_inputs(self):
|
127
|
+
"""Validate that the input configuration is valid."""
|
87
128
|
if (
|
88
129
|
self.text is None and self.input_ids is None and self.input_embeds is None
|
89
130
|
) or (
|
@@ -95,7 +136,8 @@ class GenerateReqInput:
|
|
95
136
|
"Either text, input_ids or input_embeds should be provided."
|
96
137
|
)
|
97
138
|
|
98
|
-
|
139
|
+
def _determine_batch_size(self):
|
140
|
+
"""Determine if this is a single example or a batch and the batch size."""
|
99
141
|
if self.text is not None:
|
100
142
|
if isinstance(self.text, str):
|
101
143
|
self.is_single = True
|
@@ -119,21 +161,25 @@ class GenerateReqInput:
|
|
119
161
|
self.is_single = True
|
120
162
|
self.batch_size = 1
|
121
163
|
else:
|
164
|
+
self.is_single = False
|
122
165
|
self.batch_size = len(self.input_embeds)
|
123
166
|
|
124
|
-
|
125
|
-
|
167
|
+
def _handle_parallel_sampling(self):
|
168
|
+
"""Handle parallel sampling parameters and adjust batch size if needed."""
|
169
|
+
# Determine parallel sample count
|
126
170
|
if self.sampling_params is None:
|
127
171
|
self.parallel_sample_num = 1
|
128
172
|
elif isinstance(self.sampling_params, dict):
|
129
173
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
130
174
|
else: # isinstance(self.sampling_params, list):
|
131
175
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
132
|
-
|
133
|
-
self.parallel_sample_num
|
134
|
-
|
135
|
-
|
176
|
+
for sampling_params in self.sampling_params:
|
177
|
+
if self.parallel_sample_num != sampling_params.get("n", 1):
|
178
|
+
raise ValueError(
|
179
|
+
"The parallel_sample_num should be the same for all samples in sample params."
|
180
|
+
)
|
136
181
|
|
182
|
+
# If using parallel sampling with a single example, convert to batch
|
137
183
|
if self.parallel_sample_num > 1 and self.is_single:
|
138
184
|
self.is_single = False
|
139
185
|
if self.text is not None:
|
@@ -141,97 +187,190 @@ class GenerateReqInput:
|
|
141
187
|
if self.input_ids is not None:
|
142
188
|
self.input_ids = [self.input_ids]
|
143
189
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
190
|
+
def _normalize_single_inputs(self):
|
191
|
+
"""Normalize inputs for a single example."""
|
192
|
+
if self.sampling_params is None:
|
193
|
+
self.sampling_params = {}
|
194
|
+
if self.rid is None:
|
195
|
+
self.rid = uuid.uuid4().hex
|
196
|
+
if self.return_logprob is None:
|
197
|
+
self.return_logprob = False
|
198
|
+
if self.logprob_start_len is None:
|
199
|
+
self.logprob_start_len = -1
|
200
|
+
if self.top_logprobs_num is None:
|
201
|
+
self.top_logprobs_num = 0
|
202
|
+
if not self.token_ids_logprob: # covers both None and []
|
203
|
+
self.token_ids_logprob = None
|
204
|
+
|
205
|
+
def _normalize_batch_inputs(self):
|
206
|
+
"""Normalize inputs for a batch of examples, including parallel sampling expansion."""
|
207
|
+
# Calculate expanded batch size
|
208
|
+
if self.parallel_sample_num == 1:
|
209
|
+
num = self.batch_size
|
158
210
|
else:
|
159
|
-
|
160
|
-
|
211
|
+
# Expand parallel_sample_num
|
212
|
+
num = self.batch_size * self.parallel_sample_num
|
213
|
+
|
214
|
+
# Expand input based on type
|
215
|
+
self._expand_inputs(num)
|
216
|
+
self._normalize_lora_paths(num)
|
217
|
+
self._normalize_image_data(num)
|
218
|
+
self._normalize_audio_data(num)
|
219
|
+
self._normalize_sampling_params(num)
|
220
|
+
self._normalize_rid(num)
|
221
|
+
self._normalize_logprob_params(num)
|
222
|
+
self._normalize_custom_logit_processor(num)
|
223
|
+
|
224
|
+
def _expand_inputs(self, num):
|
225
|
+
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
226
|
+
if self.text is not None:
|
227
|
+
if not isinstance(self.text, list):
|
228
|
+
raise ValueError("Text should be a list for batch processing.")
|
229
|
+
self.text = self.text * self.parallel_sample_num
|
230
|
+
elif self.input_ids is not None:
|
231
|
+
if not isinstance(self.input_ids, list) or not isinstance(
|
232
|
+
self.input_ids[0], list
|
233
|
+
):
|
234
|
+
raise ValueError(
|
235
|
+
"input_ids should be a list of lists for batch processing."
|
236
|
+
)
|
237
|
+
self.input_ids = self.input_ids * self.parallel_sample_num
|
238
|
+
elif self.input_embeds is not None:
|
239
|
+
if not isinstance(self.input_embeds, list):
|
240
|
+
raise ValueError("input_embeds should be a list for batch processing.")
|
241
|
+
self.input_embeds = self.input_embeds * self.parallel_sample_num
|
242
|
+
|
243
|
+
def _normalize_lora_paths(self, num):
|
244
|
+
"""Normalize LoRA paths for batch processing."""
|
245
|
+
if self.lora_path is not None:
|
246
|
+
if isinstance(self.lora_path, str):
|
247
|
+
self.lora_path = [self.lora_path] * num
|
248
|
+
elif isinstance(self.lora_path, list):
|
249
|
+
self.lora_path = self.lora_path * self.parallel_sample_num
|
161
250
|
else:
|
251
|
+
raise ValueError("lora_path should be a list or a string.")
|
252
|
+
|
253
|
+
def _normalize_image_data(self, num):
|
254
|
+
"""Normalize image data for batch processing."""
|
255
|
+
if self.image_data is None:
|
256
|
+
self.image_data = [None] * num
|
257
|
+
elif not isinstance(self.image_data, list):
|
258
|
+
# Single image, convert to list of single-image lists
|
259
|
+
self.image_data = [[self.image_data]] * num
|
260
|
+
self.modalities = ["image"] * num
|
261
|
+
elif isinstance(self.image_data, list):
|
262
|
+
if len(self.image_data) != self.batch_size:
|
263
|
+
raise ValueError(
|
264
|
+
"The length of image_data should be equal to the batch size."
|
265
|
+
)
|
266
|
+
|
267
|
+
self.modalities = []
|
268
|
+
if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
|
269
|
+
# Already a list of lists, keep as is
|
270
|
+
for i in range(len(self.image_data)):
|
271
|
+
if self.image_data[i] is None or self.image_data[i] == [None]:
|
272
|
+
self.modalities.append(None)
|
273
|
+
elif len(self.image_data[i]) == 1:
|
274
|
+
self.modalities.append("image")
|
275
|
+
elif len(self.image_data[i]) > 1:
|
276
|
+
self.modalities.append("multi-images")
|
162
277
|
# Expand parallel_sample_num
|
163
|
-
|
164
|
-
|
165
|
-
if not self.image_data:
|
166
|
-
self.image_data = [None] * num
|
167
|
-
elif not isinstance(self.image_data, list):
|
168
|
-
self.image_data = [self.image_data] * num
|
169
|
-
elif isinstance(self.image_data, list):
|
170
|
-
pass
|
171
|
-
|
172
|
-
if self.audio_data is None:
|
173
|
-
self.audio_data = [None] * num
|
174
|
-
elif not isinstance(self.audio_data, list):
|
175
|
-
self.audio_data = [self.audio_data] * num
|
176
|
-
elif isinstance(self.audio_data, list):
|
177
|
-
pass
|
178
|
-
|
179
|
-
if self.sampling_params is None:
|
180
|
-
self.sampling_params = [{}] * num
|
181
|
-
elif not isinstance(self.sampling_params, list):
|
182
|
-
self.sampling_params = [self.sampling_params] * num
|
183
|
-
|
184
|
-
if self.rid is None:
|
185
|
-
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
278
|
+
self.image_data = self.image_data * self.parallel_sample_num
|
279
|
+
self.modalities = self.modalities * self.parallel_sample_num
|
186
280
|
else:
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
self.
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
281
|
+
# List of images for a batch, wrap each in a list
|
282
|
+
wrapped_images = [[img] for img in self.image_data]
|
283
|
+
# Expand for parallel sampling
|
284
|
+
self.image_data = wrapped_images * self.parallel_sample_num
|
285
|
+
self.modalities = ["image"] * num
|
286
|
+
|
287
|
+
def _normalize_audio_data(self, num):
|
288
|
+
"""Normalize audio data for batch processing."""
|
289
|
+
if self.audio_data is None:
|
290
|
+
self.audio_data = [None] * num
|
291
|
+
elif not isinstance(self.audio_data, list):
|
292
|
+
self.audio_data = [self.audio_data] * num
|
293
|
+
elif isinstance(self.audio_data, list):
|
294
|
+
self.audio_data = self.audio_data * self.parallel_sample_num
|
295
|
+
|
296
|
+
def _normalize_sampling_params(self, num):
|
297
|
+
"""Normalize sampling parameters for batch processing."""
|
298
|
+
if self.sampling_params is None:
|
299
|
+
self.sampling_params = [{}] * num
|
300
|
+
elif isinstance(self.sampling_params, dict):
|
301
|
+
self.sampling_params = [self.sampling_params] * num
|
302
|
+
else: # Already a list
|
303
|
+
self.sampling_params = self.sampling_params * self.parallel_sample_num
|
304
|
+
|
305
|
+
def _normalize_rid(self, num):
|
306
|
+
"""Normalize request IDs for batch processing."""
|
307
|
+
if self.rid is None:
|
308
|
+
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
309
|
+
elif not isinstance(self.rid, list):
|
310
|
+
raise ValueError("The rid should be a list for batch processing.")
|
311
|
+
|
312
|
+
def _normalize_logprob_params(self, num):
|
313
|
+
"""Normalize logprob-related parameters for batch processing."""
|
314
|
+
|
315
|
+
# Helper function to normalize a parameter
|
316
|
+
def normalize_param(param, default_value, param_name):
|
317
|
+
if param is None:
|
318
|
+
return [default_value] * num
|
319
|
+
elif not isinstance(param, list):
|
320
|
+
return [param] * num
|
200
321
|
else:
|
201
|
-
|
322
|
+
if self.parallel_sample_num > 1:
|
323
|
+
raise ValueError(
|
324
|
+
f"Cannot use list {param_name} with parallel_sample_num > 1"
|
325
|
+
)
|
326
|
+
return param
|
327
|
+
|
328
|
+
# Normalize each logprob parameter
|
329
|
+
self.return_logprob = normalize_param(
|
330
|
+
self.return_logprob, False, "return_logprob"
|
331
|
+
)
|
332
|
+
self.logprob_start_len = normalize_param(
|
333
|
+
self.logprob_start_len, -1, "logprob_start_len"
|
334
|
+
)
|
335
|
+
self.top_logprobs_num = normalize_param(
|
336
|
+
self.top_logprobs_num, 0, "top_logprobs_num"
|
337
|
+
)
|
202
338
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
217
|
-
]
|
218
|
-
else:
|
219
|
-
assert self.parallel_sample_num == 1
|
339
|
+
# Handle token_ids_logprob specially due to its nested structure
|
340
|
+
if not self.token_ids_logprob: # covers both None and []
|
341
|
+
self.token_ids_logprob = [None] * num
|
342
|
+
elif not isinstance(self.token_ids_logprob, list):
|
343
|
+
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
344
|
+
elif not isinstance(self.token_ids_logprob[0], list):
|
345
|
+
self.token_ids_logprob = [
|
346
|
+
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
347
|
+
]
|
348
|
+
elif self.parallel_sample_num > 1:
|
349
|
+
raise ValueError(
|
350
|
+
"Cannot use list token_ids_logprob with parallel_sample_num > 1"
|
351
|
+
)
|
220
352
|
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
353
|
+
def _normalize_custom_logit_processor(self, num):
|
354
|
+
"""Normalize custom logit processor for batch processing."""
|
355
|
+
if self.custom_logit_processor is None:
|
356
|
+
self.custom_logit_processor = [None] * num
|
357
|
+
elif not isinstance(self.custom_logit_processor, list):
|
358
|
+
self.custom_logit_processor = [self.custom_logit_processor] * num
|
359
|
+
elif self.parallel_sample_num > 1:
|
360
|
+
raise ValueError(
|
361
|
+
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
362
|
+
)
|
227
363
|
|
228
|
-
|
364
|
+
def _validate_session_params(self):
|
365
|
+
"""Validate that session parameters are properly formatted."""
|
229
366
|
if self.session_params is not None:
|
230
|
-
|
367
|
+
if not isinstance(self.session_params, dict) and not isinstance(
|
231
368
|
self.session_params[0], dict
|
232
|
-
)
|
369
|
+
):
|
370
|
+
raise ValueError("Session params must be a dict or a list of dicts.")
|
233
371
|
|
234
372
|
def regenerate_rid(self):
|
373
|
+
"""Generate a new request ID and return it."""
|
235
374
|
self.rid = uuid.uuid4().hex
|
236
375
|
return self.rid
|
237
376
|
|
@@ -300,13 +439,24 @@ class TokenizedGenerateReqInput:
|
|
300
439
|
# Whether to return hidden states
|
301
440
|
return_hidden_states: bool = False
|
302
441
|
|
442
|
+
# For disaggregated inference
|
443
|
+
bootstrap_host: Optional[str] = None
|
444
|
+
bootstrap_room: Optional[int] = None
|
445
|
+
|
303
446
|
|
304
447
|
@dataclass
|
305
448
|
class EmbeddingReqInput:
|
306
449
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
307
450
|
text: Optional[Union[List[str], str]] = None
|
308
|
-
# The image input. It can be
|
309
|
-
|
451
|
+
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
452
|
+
# Can be formatted as:
|
453
|
+
# - Single image for a single request
|
454
|
+
# - List of images (one per request in a batch)
|
455
|
+
# - List of lists of images (multiple images per request)
|
456
|
+
# See also python/sglang/srt/utils.py:load_image for more details.
|
457
|
+
image_data: Optional[
|
458
|
+
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
459
|
+
] = None
|
310
460
|
# The token ids for text; one can either specify text or input_ids.
|
311
461
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
312
462
|
# The request id.
|
@@ -550,10 +700,17 @@ class UpdateWeightsFromDistributedReqOutput:
|
|
550
700
|
|
551
701
|
@dataclass
|
552
702
|
class UpdateWeightsFromTensorReqInput:
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
703
|
+
"""Update model weights from tensor input.
|
704
|
+
|
705
|
+
- Tensors are serialized for transmission
|
706
|
+
- Data is structured in JSON for easy transmission over HTTP
|
707
|
+
"""
|
708
|
+
|
709
|
+
serialized_named_tensors: List[Union[str, bytes]]
|
710
|
+
# Optional format specification for loading
|
711
|
+
load_format: Optional[str] = None
|
712
|
+
# Whether to flush the cache after updating weights
|
713
|
+
flush_cache: bool = True
|
557
714
|
|
558
715
|
|
559
716
|
@dataclass
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -148,7 +148,8 @@ def get_embedding_and_mask(
|
|
148
148
|
placeholder_tensor,
|
149
149
|
).unsqueeze(-1)
|
150
150
|
|
151
|
-
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
|
151
|
+
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
152
|
+
|
152
153
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
153
154
|
logger.warning(
|
154
155
|
f"Number of tokens in multimodal embedding does not match those in the input text."
|
@@ -172,7 +173,7 @@ def get_embedding_and_mask(
|
|
172
173
|
embedding = embedding[-num_multimodal:, :]
|
173
174
|
else:
|
174
175
|
raise RuntimeError(
|
175
|
-
"Insufficient multimodal embedding length. This is an internal error"
|
176
|
+
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
|
176
177
|
)
|
177
178
|
|
178
179
|
return embedding, special_multimodal_mask
|