sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
|
|
11
11
|
from sglang.srt.utils import is_cuda_available
|
12
12
|
|
13
13
|
_is_cuda_available = is_cuda_available()
|
14
|
+
|
14
15
|
if _is_cuda_available:
|
15
16
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
16
17
|
else:
|
17
|
-
from vllm import
|
18
|
+
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
18
19
|
|
19
20
|
|
20
21
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
|
|
159
160
|
)
|
160
161
|
else:
|
161
162
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
162
|
-
|
163
|
+
vllm_rotary_embedding(
|
163
164
|
positions,
|
164
165
|
query,
|
165
166
|
key,
|
@@ -645,7 +646,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
645
646
|
cache = torch.cat((cos, sin), dim=-1)
|
646
647
|
return cache
|
647
648
|
|
648
|
-
def
|
649
|
+
def forward_hip(self, *args, **kwargs):
|
650
|
+
return self.forward_native(*args, **kwargs)
|
651
|
+
|
652
|
+
def forward(self, *args, **kwargs):
|
653
|
+
if torch.compiler.is_compiling():
|
654
|
+
return self.forward_native(*args, **kwargs)
|
655
|
+
if _is_cuda_available:
|
656
|
+
return self.forward_cuda(*args, **kwargs)
|
657
|
+
else:
|
658
|
+
return self.forward_native(*args, **kwargs)
|
659
|
+
|
660
|
+
def forward_native(
|
649
661
|
self,
|
650
662
|
positions: torch.Tensor,
|
651
663
|
query: torch.Tensor,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -93,28 +93,23 @@ class Sampler(nn.Module):
|
|
93
93
|
).clamp(min=torch.finfo(probs.dtype).min)
|
94
94
|
|
95
95
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
96
|
-
uniform_samples = torch.rand(
|
97
|
-
(max_top_k_round, batch_size), device=probs.device
|
98
|
-
)
|
99
96
|
if sampling_info.need_min_p_sampling:
|
100
97
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
101
98
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
102
99
|
batch_next_token_ids = min_p_sampling_from_probs(
|
103
|
-
probs,
|
100
|
+
probs, sampling_info.min_ps
|
104
101
|
)
|
105
102
|
else:
|
106
|
-
|
103
|
+
# Check Nan will throw exception, only check when crash_on_warnings is True
|
104
|
+
check_nan = self.use_nan_detection and crash_on_warnings()
|
105
|
+
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
107
106
|
probs,
|
108
|
-
uniform_samples,
|
109
107
|
sampling_info.top_ks,
|
110
108
|
sampling_info.top_ps,
|
111
109
|
filter_apply_order="joint",
|
110
|
+
check_nan=check_nan,
|
112
111
|
)
|
113
112
|
|
114
|
-
if self.use_nan_detection and not torch.all(success):
|
115
|
-
logger.warning("Detected errors during sampling!")
|
116
|
-
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
117
|
-
|
118
113
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
119
114
|
# A slower fallback implementation with torch native operations.
|
120
115
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -75,7 +75,7 @@ class BaseLoRABackend:
|
|
75
75
|
qkv_lora_a: torch.Tensor,
|
76
76
|
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
77
77
|
*args,
|
78
|
-
**kwargs
|
78
|
+
**kwargs,
|
79
79
|
) -> torch.Tensor:
|
80
80
|
"""Run the lora pass for QKV Layer.
|
81
81
|
|
@@ -98,7 +98,7 @@ class BaseLoRABackend:
|
|
98
98
|
gate_up_lora_a: torch.Tensor,
|
99
99
|
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
100
100
|
*args,
|
101
|
-
**kwargs
|
101
|
+
**kwargs,
|
102
102
|
) -> torch.Tensor:
|
103
103
|
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
104
104
|
|
@@ -115,3 +115,19 @@ class BaseLoRABackend:
|
|
115
115
|
|
116
116
|
def set_batch_info(self, batch_info: LoRABatchInfo):
|
117
117
|
self.batch_info = batch_info
|
118
|
+
|
119
|
+
|
120
|
+
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
121
|
+
"""
|
122
|
+
Get corresponding backend class from backend's name
|
123
|
+
"""
|
124
|
+
if name == "triton":
|
125
|
+
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
126
|
+
|
127
|
+
return TritonLoRABackend
|
128
|
+
elif name == "flashinfer":
|
129
|
+
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
130
|
+
|
131
|
+
return FlashInferLoRABackend
|
132
|
+
else:
|
133
|
+
raise ValueError(f"Invalid backend: {name}")
|
@@ -2,7 +2,7 @@ from typing import Tuple
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
5
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
6
6
|
from sglang.srt.lora.utils import LoRABatchInfo
|
7
7
|
from sglang.srt.utils import is_flashinfer_available
|
8
8
|
|
sglang/srt/lora/layers.py
CHANGED
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
|
|
16
16
|
RowParallelLinear,
|
17
17
|
)
|
18
18
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
19
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
19
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
20
20
|
|
21
21
|
|
22
22
|
class BaseLayerWithLoRA(nn.Module):
|
sglang/srt/lora/lora.py
CHANGED
@@ -27,7 +27,7 @@ from torch import nn
|
|
27
27
|
|
28
28
|
from sglang.srt.configs.load_config import LoadConfig
|
29
29
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
30
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
30
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
31
31
|
from sglang.srt.lora.lora_config import LoRAConfig
|
32
32
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
33
33
|
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -22,7 +22,7 @@ import torch
|
|
22
22
|
|
23
23
|
from sglang.srt.configs.load_config import LoadConfig
|
24
24
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
25
|
-
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
25
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
|
26
26
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
27
|
from sglang.srt.lora.lora import LoRAAdapter
|
28
28
|
from sglang.srt.lora.lora_config import LoRAConfig
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,7 +20,13 @@ import copy
|
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import Any, Dict, List, Literal, Optional, Union
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
24
|
+
|
25
|
+
# handle serialization of Image for pydantic
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from PIL.Image import Image
|
28
|
+
else:
|
29
|
+
Image = Any
|
24
30
|
|
25
31
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
32
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -42,10 +48,16 @@ class GenerateReqInput:
|
|
42
48
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
43
49
|
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
44
50
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
45
|
-
# The image input. It can be
|
46
|
-
#
|
47
|
-
|
48
|
-
#
|
51
|
+
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
52
|
+
# Can be formatted as:
|
53
|
+
# - Single image for a single request
|
54
|
+
# - List of images (one per request in a batch)
|
55
|
+
# - List of lists of images (multiple images per request)
|
56
|
+
# See also python/sglang/srt/utils.py:load_image for more details.
|
57
|
+
image_data: Optional[
|
58
|
+
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
59
|
+
] = None
|
60
|
+
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
49
61
|
audio_data: Optional[Union[List[str], str]] = None
|
50
62
|
# The sampling_params. See descriptions below.
|
51
63
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
@@ -83,7 +95,36 @@ class GenerateReqInput:
|
|
83
95
|
# Whether to return hidden states
|
84
96
|
return_hidden_states: bool = False
|
85
97
|
|
98
|
+
# For disaggregated inference
|
99
|
+
bootstrap_host: Optional[str] = None
|
100
|
+
bootstrap_room: Optional[int] = None
|
101
|
+
|
86
102
|
def normalize_batch_and_arguments(self):
|
103
|
+
"""
|
104
|
+
Normalize the batch size and arguments for the request.
|
105
|
+
|
106
|
+
This method resolves various input formats and ensures all parameters
|
107
|
+
are properly formatted as either single values or batches depending on the input.
|
108
|
+
It also handles parallel sampling expansion and sets default values for
|
109
|
+
unspecified parameters.
|
110
|
+
|
111
|
+
Raises:
|
112
|
+
ValueError: If inputs are not properly specified (e.g., none or all of
|
113
|
+
text, input_ids, input_embeds are provided)
|
114
|
+
"""
|
115
|
+
self._validate_inputs()
|
116
|
+
self._determine_batch_size()
|
117
|
+
self._handle_parallel_sampling()
|
118
|
+
|
119
|
+
if self.is_single:
|
120
|
+
self._normalize_single_inputs()
|
121
|
+
else:
|
122
|
+
self._normalize_batch_inputs()
|
123
|
+
|
124
|
+
self._validate_session_params()
|
125
|
+
|
126
|
+
def _validate_inputs(self):
|
127
|
+
"""Validate that the input configuration is valid."""
|
87
128
|
if (
|
88
129
|
self.text is None and self.input_ids is None and self.input_embeds is None
|
89
130
|
) or (
|
@@ -95,7 +136,8 @@ class GenerateReqInput:
|
|
95
136
|
"Either text, input_ids or input_embeds should be provided."
|
96
137
|
)
|
97
138
|
|
98
|
-
|
139
|
+
def _determine_batch_size(self):
|
140
|
+
"""Determine if this is a single example or a batch and the batch size."""
|
99
141
|
if self.text is not None:
|
100
142
|
if isinstance(self.text, str):
|
101
143
|
self.is_single = True
|
@@ -119,21 +161,25 @@ class GenerateReqInput:
|
|
119
161
|
self.is_single = True
|
120
162
|
self.batch_size = 1
|
121
163
|
else:
|
164
|
+
self.is_single = False
|
122
165
|
self.batch_size = len(self.input_embeds)
|
123
166
|
|
124
|
-
|
125
|
-
|
167
|
+
def _handle_parallel_sampling(self):
|
168
|
+
"""Handle parallel sampling parameters and adjust batch size if needed."""
|
169
|
+
# Determine parallel sample count
|
126
170
|
if self.sampling_params is None:
|
127
171
|
self.parallel_sample_num = 1
|
128
172
|
elif isinstance(self.sampling_params, dict):
|
129
173
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
130
174
|
else: # isinstance(self.sampling_params, list):
|
131
175
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
132
|
-
|
133
|
-
self.parallel_sample_num
|
134
|
-
|
135
|
-
|
176
|
+
for sampling_params in self.sampling_params:
|
177
|
+
if self.parallel_sample_num != sampling_params.get("n", 1):
|
178
|
+
raise ValueError(
|
179
|
+
"The parallel_sample_num should be the same for all samples in sample params."
|
180
|
+
)
|
136
181
|
|
182
|
+
# If using parallel sampling with a single example, convert to batch
|
137
183
|
if self.parallel_sample_num > 1 and self.is_single:
|
138
184
|
self.is_single = False
|
139
185
|
if self.text is not None:
|
@@ -141,97 +187,190 @@ class GenerateReqInput:
|
|
141
187
|
if self.input_ids is not None:
|
142
188
|
self.input_ids = [self.input_ids]
|
143
189
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
190
|
+
def _normalize_single_inputs(self):
|
191
|
+
"""Normalize inputs for a single example."""
|
192
|
+
if self.sampling_params is None:
|
193
|
+
self.sampling_params = {}
|
194
|
+
if self.rid is None:
|
195
|
+
self.rid = uuid.uuid4().hex
|
196
|
+
if self.return_logprob is None:
|
197
|
+
self.return_logprob = False
|
198
|
+
if self.logprob_start_len is None:
|
199
|
+
self.logprob_start_len = -1
|
200
|
+
if self.top_logprobs_num is None:
|
201
|
+
self.top_logprobs_num = 0
|
202
|
+
if not self.token_ids_logprob: # covers both None and []
|
203
|
+
self.token_ids_logprob = None
|
204
|
+
|
205
|
+
def _normalize_batch_inputs(self):
|
206
|
+
"""Normalize inputs for a batch of examples, including parallel sampling expansion."""
|
207
|
+
# Calculate expanded batch size
|
208
|
+
if self.parallel_sample_num == 1:
|
209
|
+
num = self.batch_size
|
158
210
|
else:
|
159
|
-
|
160
|
-
|
211
|
+
# Expand parallel_sample_num
|
212
|
+
num = self.batch_size * self.parallel_sample_num
|
213
|
+
|
214
|
+
# Expand input based on type
|
215
|
+
self._expand_inputs(num)
|
216
|
+
self._normalize_lora_paths(num)
|
217
|
+
self._normalize_image_data(num)
|
218
|
+
self._normalize_audio_data(num)
|
219
|
+
self._normalize_sampling_params(num)
|
220
|
+
self._normalize_rid(num)
|
221
|
+
self._normalize_logprob_params(num)
|
222
|
+
self._normalize_custom_logit_processor(num)
|
223
|
+
|
224
|
+
def _expand_inputs(self, num):
|
225
|
+
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
226
|
+
if self.text is not None:
|
227
|
+
if not isinstance(self.text, list):
|
228
|
+
raise ValueError("Text should be a list for batch processing.")
|
229
|
+
self.text = self.text * self.parallel_sample_num
|
230
|
+
elif self.input_ids is not None:
|
231
|
+
if not isinstance(self.input_ids, list) or not isinstance(
|
232
|
+
self.input_ids[0], list
|
233
|
+
):
|
234
|
+
raise ValueError(
|
235
|
+
"input_ids should be a list of lists for batch processing."
|
236
|
+
)
|
237
|
+
self.input_ids = self.input_ids * self.parallel_sample_num
|
238
|
+
elif self.input_embeds is not None:
|
239
|
+
if not isinstance(self.input_embeds, list):
|
240
|
+
raise ValueError("input_embeds should be a list for batch processing.")
|
241
|
+
self.input_embeds = self.input_embeds * self.parallel_sample_num
|
242
|
+
|
243
|
+
def _normalize_lora_paths(self, num):
|
244
|
+
"""Normalize LoRA paths for batch processing."""
|
245
|
+
if self.lora_path is not None:
|
246
|
+
if isinstance(self.lora_path, str):
|
247
|
+
self.lora_path = [self.lora_path] * num
|
248
|
+
elif isinstance(self.lora_path, list):
|
249
|
+
self.lora_path = self.lora_path * self.parallel_sample_num
|
161
250
|
else:
|
251
|
+
raise ValueError("lora_path should be a list or a string.")
|
252
|
+
|
253
|
+
def _normalize_image_data(self, num):
|
254
|
+
"""Normalize image data for batch processing."""
|
255
|
+
if self.image_data is None:
|
256
|
+
self.image_data = [None] * num
|
257
|
+
elif not isinstance(self.image_data, list):
|
258
|
+
# Single image, convert to list of single-image lists
|
259
|
+
self.image_data = [[self.image_data]] * num
|
260
|
+
self.modalities = ["image"] * num
|
261
|
+
elif isinstance(self.image_data, list):
|
262
|
+
if len(self.image_data) != self.batch_size:
|
263
|
+
raise ValueError(
|
264
|
+
"The length of image_data should be equal to the batch size."
|
265
|
+
)
|
266
|
+
|
267
|
+
self.modalities = []
|
268
|
+
if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
|
269
|
+
# Already a list of lists, keep as is
|
270
|
+
for i in range(len(self.image_data)):
|
271
|
+
if self.image_data[i] is None or self.image_data[i] == [None]:
|
272
|
+
self.modalities.append(None)
|
273
|
+
elif len(self.image_data[i]) == 1:
|
274
|
+
self.modalities.append("image")
|
275
|
+
elif len(self.image_data[i]) > 1:
|
276
|
+
self.modalities.append("multi-images")
|
162
277
|
# Expand parallel_sample_num
|
163
|
-
|
164
|
-
|
165
|
-
if not self.image_data:
|
166
|
-
self.image_data = [None] * num
|
167
|
-
elif not isinstance(self.image_data, list):
|
168
|
-
self.image_data = [self.image_data] * num
|
169
|
-
elif isinstance(self.image_data, list):
|
170
|
-
pass
|
171
|
-
|
172
|
-
if self.audio_data is None:
|
173
|
-
self.audio_data = [None] * num
|
174
|
-
elif not isinstance(self.audio_data, list):
|
175
|
-
self.audio_data = [self.audio_data] * num
|
176
|
-
elif isinstance(self.audio_data, list):
|
177
|
-
pass
|
178
|
-
|
179
|
-
if self.sampling_params is None:
|
180
|
-
self.sampling_params = [{}] * num
|
181
|
-
elif not isinstance(self.sampling_params, list):
|
182
|
-
self.sampling_params = [self.sampling_params] * num
|
183
|
-
|
184
|
-
if self.rid is None:
|
185
|
-
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
278
|
+
self.image_data = self.image_data * self.parallel_sample_num
|
279
|
+
self.modalities = self.modalities * self.parallel_sample_num
|
186
280
|
else:
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
self.
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
281
|
+
# List of images for a batch, wrap each in a list
|
282
|
+
wrapped_images = [[img] for img in self.image_data]
|
283
|
+
# Expand for parallel sampling
|
284
|
+
self.image_data = wrapped_images * self.parallel_sample_num
|
285
|
+
self.modalities = ["image"] * num
|
286
|
+
|
287
|
+
def _normalize_audio_data(self, num):
|
288
|
+
"""Normalize audio data for batch processing."""
|
289
|
+
if self.audio_data is None:
|
290
|
+
self.audio_data = [None] * num
|
291
|
+
elif not isinstance(self.audio_data, list):
|
292
|
+
self.audio_data = [self.audio_data] * num
|
293
|
+
elif isinstance(self.audio_data, list):
|
294
|
+
self.audio_data = self.audio_data * self.parallel_sample_num
|
295
|
+
|
296
|
+
def _normalize_sampling_params(self, num):
|
297
|
+
"""Normalize sampling parameters for batch processing."""
|
298
|
+
if self.sampling_params is None:
|
299
|
+
self.sampling_params = [{}] * num
|
300
|
+
elif isinstance(self.sampling_params, dict):
|
301
|
+
self.sampling_params = [self.sampling_params] * num
|
302
|
+
else: # Already a list
|
303
|
+
self.sampling_params = self.sampling_params * self.parallel_sample_num
|
304
|
+
|
305
|
+
def _normalize_rid(self, num):
|
306
|
+
"""Normalize request IDs for batch processing."""
|
307
|
+
if self.rid is None:
|
308
|
+
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
309
|
+
elif not isinstance(self.rid, list):
|
310
|
+
raise ValueError("The rid should be a list for batch processing.")
|
311
|
+
|
312
|
+
def _normalize_logprob_params(self, num):
|
313
|
+
"""Normalize logprob-related parameters for batch processing."""
|
314
|
+
|
315
|
+
# Helper function to normalize a parameter
|
316
|
+
def normalize_param(param, default_value, param_name):
|
317
|
+
if param is None:
|
318
|
+
return [default_value] * num
|
319
|
+
elif not isinstance(param, list):
|
320
|
+
return [param] * num
|
200
321
|
else:
|
201
|
-
|
322
|
+
if self.parallel_sample_num > 1:
|
323
|
+
raise ValueError(
|
324
|
+
f"Cannot use list {param_name} with parallel_sample_num > 1"
|
325
|
+
)
|
326
|
+
return param
|
327
|
+
|
328
|
+
# Normalize each logprob parameter
|
329
|
+
self.return_logprob = normalize_param(
|
330
|
+
self.return_logprob, False, "return_logprob"
|
331
|
+
)
|
332
|
+
self.logprob_start_len = normalize_param(
|
333
|
+
self.logprob_start_len, -1, "logprob_start_len"
|
334
|
+
)
|
335
|
+
self.top_logprobs_num = normalize_param(
|
336
|
+
self.top_logprobs_num, 0, "top_logprobs_num"
|
337
|
+
)
|
202
338
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
217
|
-
]
|
218
|
-
else:
|
219
|
-
assert self.parallel_sample_num == 1
|
339
|
+
# Handle token_ids_logprob specially due to its nested structure
|
340
|
+
if not self.token_ids_logprob: # covers both None and []
|
341
|
+
self.token_ids_logprob = [None] * num
|
342
|
+
elif not isinstance(self.token_ids_logprob, list):
|
343
|
+
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
344
|
+
elif not isinstance(self.token_ids_logprob[0], list):
|
345
|
+
self.token_ids_logprob = [
|
346
|
+
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
347
|
+
]
|
348
|
+
elif self.parallel_sample_num > 1:
|
349
|
+
raise ValueError(
|
350
|
+
"Cannot use list token_ids_logprob with parallel_sample_num > 1"
|
351
|
+
)
|
220
352
|
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
353
|
+
def _normalize_custom_logit_processor(self, num):
|
354
|
+
"""Normalize custom logit processor for batch processing."""
|
355
|
+
if self.custom_logit_processor is None:
|
356
|
+
self.custom_logit_processor = [None] * num
|
357
|
+
elif not isinstance(self.custom_logit_processor, list):
|
358
|
+
self.custom_logit_processor = [self.custom_logit_processor] * num
|
359
|
+
elif self.parallel_sample_num > 1:
|
360
|
+
raise ValueError(
|
361
|
+
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
362
|
+
)
|
227
363
|
|
228
|
-
|
364
|
+
def _validate_session_params(self):
|
365
|
+
"""Validate that session parameters are properly formatted."""
|
229
366
|
if self.session_params is not None:
|
230
|
-
|
367
|
+
if not isinstance(self.session_params, dict) and not isinstance(
|
231
368
|
self.session_params[0], dict
|
232
|
-
)
|
369
|
+
):
|
370
|
+
raise ValueError("Session params must be a dict or a list of dicts.")
|
233
371
|
|
234
372
|
def regenerate_rid(self):
|
373
|
+
"""Generate a new request ID and return it."""
|
235
374
|
self.rid = uuid.uuid4().hex
|
236
375
|
return self.rid
|
237
376
|
|
@@ -300,13 +439,24 @@ class TokenizedGenerateReqInput:
|
|
300
439
|
# Whether to return hidden states
|
301
440
|
return_hidden_states: bool = False
|
302
441
|
|
442
|
+
# For disaggregated inference
|
443
|
+
bootstrap_host: Optional[str] = None
|
444
|
+
bootstrap_room: Optional[int] = None
|
445
|
+
|
303
446
|
|
304
447
|
@dataclass
|
305
448
|
class EmbeddingReqInput:
|
306
449
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
307
450
|
text: Optional[Union[List[str], str]] = None
|
308
|
-
# The image input. It can be
|
309
|
-
|
451
|
+
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
452
|
+
# Can be formatted as:
|
453
|
+
# - Single image for a single request
|
454
|
+
# - List of images (one per request in a batch)
|
455
|
+
# - List of lists of images (multiple images per request)
|
456
|
+
# See also python/sglang/srt/utils.py:load_image for more details.
|
457
|
+
image_data: Optional[
|
458
|
+
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
459
|
+
] = None
|
310
460
|
# The token ids for text; one can either specify text or input_ids.
|
311
461
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
312
462
|
# The request id.
|
@@ -550,10 +700,17 @@ class UpdateWeightsFromDistributedReqOutput:
|
|
550
700
|
|
551
701
|
@dataclass
|
552
702
|
class UpdateWeightsFromTensorReqInput:
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
703
|
+
"""Update model weights from tensor input.
|
704
|
+
|
705
|
+
- Tensors are serialized for transmission
|
706
|
+
- Data is structured in JSON for easy transmission over HTTP
|
707
|
+
"""
|
708
|
+
|
709
|
+
serialized_named_tensors: List[Union[str, bytes]]
|
710
|
+
# Optional format specification for loading
|
711
|
+
load_format: Optional[str] = None
|
712
|
+
# Whether to flush the cache after updating weights
|
713
|
+
flush_cache: bool = True
|
557
714
|
|
558
715
|
|
559
716
|
@dataclass
|
@@ -677,6 +834,7 @@ class ProfileReq:
|
|
677
834
|
activities: Optional[List[str]] = None
|
678
835
|
with_stack: Optional[bool] = None
|
679
836
|
record_shapes: Optional[bool] = None
|
837
|
+
profile_id: Optional[str] = None
|
680
838
|
|
681
839
|
|
682
840
|
@dataclass
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
"""
|
2
|
-
|
2
|
+
Multi-modality utils
|
3
3
|
"""
|
4
4
|
|
5
|
+
import logging
|
5
6
|
from abc import abstractmethod
|
6
7
|
from typing import Callable, List, Optional, Tuple
|
7
8
|
|
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
|
|
12
13
|
MultimodalDataItem,
|
13
14
|
MultimodalInputs,
|
14
15
|
global_server_args_dict,
|
15
|
-
logger,
|
16
16
|
)
|
17
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
18
|
from sglang.srt.utils import print_warning_once
|
19
|
-
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
20
21
|
|
21
22
|
|
22
23
|
class MultiModalityDataPaddingPattern:
|
@@ -148,7 +149,8 @@ def get_embedding_and_mask(
|
|
148
149
|
placeholder_tensor,
|
149
150
|
).unsqueeze(-1)
|
150
151
|
|
151
|
-
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
|
152
|
+
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
153
|
+
|
152
154
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
153
155
|
logger.warning(
|
154
156
|
f"Number of tokens in multimodal embedding does not match those in the input text."
|
@@ -172,7 +174,7 @@ def get_embedding_and_mask(
|
|
172
174
|
embedding = embedding[-num_multimodal:, :]
|
173
175
|
else:
|
174
176
|
raise RuntimeError(
|
175
|
-
"Insufficient multimodal embedding length. This is an internal error"
|
177
|
+
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
|
176
178
|
)
|
177
179
|
|
178
180
|
return embedding, special_multimodal_mask
|