sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,12 @@
13
13
  # ==============================================================================
14
14
  """Radix attention."""
15
15
 
16
+ from typing import Optional
17
+
16
18
  from torch import nn
17
19
 
20
+ from sglang.srt.layers.linear import UnquantizedLinearMethod
21
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
23
 
20
24
 
@@ -34,7 +38,9 @@ class RadixAttention(nn.Module):
34
38
  v_head_dim: int = -1,
35
39
  sliding_window_size: int = -1,
36
40
  is_cross_attention: bool = False,
41
+ quant_config: Optional[QuantizationConfig] = None,
37
42
  prefix: str = "",
43
+ use_irope: bool = False,
38
44
  ):
39
45
  super().__init__()
40
46
  self.tp_q_head_num = num_heads
@@ -48,8 +54,16 @@ class RadixAttention(nn.Module):
48
54
  self.logit_cap = logit_cap
49
55
  self.sliding_window_size = sliding_window_size or -1
50
56
  self.is_cross_attention = is_cross_attention
57
+ self.use_irope = use_irope
51
58
  self.k_scale = None
52
59
  self.v_scale = None
60
+ self.k_scale_float = None
61
+ self.v_scale_float = None
62
+ self.quant_method = None
63
+ if quant_config is not None:
64
+ self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
65
+ if self.quant_method is not None:
66
+ self.quant_method.create_weights(self)
53
67
 
54
68
  def forward(
55
69
  self,
@@ -645,7 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
645
645
  cache = torch.cat((cos, sin), dim=-1)
646
646
  return cache
647
647
 
648
- def forward(
648
+ def forward_hip(self, *args, **kwargs):
649
+ return self.forward_native(*args, **kwargs)
650
+
651
+ def forward(self, *args, **kwargs):
652
+ if torch.compiler.is_compiling():
653
+ return self.forward_native(*args, **kwargs)
654
+ if _is_cuda_available:
655
+ return self.forward_cuda(*args, **kwargs)
656
+ else:
657
+ return self.forward_native(*args, **kwargs)
658
+
659
+ def forward_native(
649
660
  self,
650
661
  positions: torch.Tensor,
651
662
  query: torch.Tensor,
@@ -733,6 +744,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
733
744
  return new_freqs
734
745
 
735
746
 
747
+ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
748
+
749
+ def __init__(
750
+ self,
751
+ head_size: int,
752
+ rotary_dim: int,
753
+ max_position_embeddings: int,
754
+ base: int,
755
+ is_neox_style: bool,
756
+ dtype: torch.dtype,
757
+ ):
758
+ super().__init__(
759
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
760
+ )
761
+
762
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
763
+ inv_freqs = super()._compute_inv_freq(base)
764
+ inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
765
+ return inv_freqs
766
+
767
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
768
+ inv_freq = self._compute_inv_freq(self.base)
769
+
770
+ # self.max_position_embeddings here is number of image patches
771
+ # i.e. (image_size // patch_size) ** 2
772
+ num_patches = self.max_position_embeddings
773
+ img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
774
+ img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
775
+ img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
776
+ num_patches_single_dim = int(math.sqrt(num_patches))
777
+ frequencies_x = img_idx % num_patches_single_dim
778
+ frequencies_y = img_idx // num_patches_single_dim
779
+ freqs_x = (
780
+ (frequencies_x + 1)[..., None] * inv_freq[None, None, :]
781
+ ).repeat_interleave(2, dim=-1)
782
+ freqs_y = (
783
+ (frequencies_y + 1)[..., None] * inv_freq[None, None, :]
784
+ ).repeat_interleave(2, dim=-1)
785
+ freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
786
+ freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
787
+ cache = torch.view_as_complex(
788
+ torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
789
+ )
790
+ return cache
791
+
792
+ def forward(
793
+ self,
794
+ query: torch.Tensor,
795
+ key: torch.Tensor,
796
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
797
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
798
+ query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
799
+ key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
800
+ broadcast_shape = [
801
+ d if i == 1 or i == (query_.ndim - 1) else 1
802
+ for i, d in enumerate(query_.shape)
803
+ ]
804
+ freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
805
+ query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
806
+ key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
807
+ return query_out.type_as(query), key_out.type_as(key)
808
+
809
+
736
810
  class MRotaryEmbedding(RotaryEmbedding):
737
811
  """Rotary Embedding with Multimodal Sections."""
738
812
 
@@ -20,7 +20,13 @@ import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import Any, Dict, List, Literal, Optional, Union
23
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
24
+
25
+ # handle serialization of Image for pydantic
26
+ if TYPE_CHECKING:
27
+ from PIL.Image import Image
28
+ else:
29
+ Image = Any
24
30
 
25
31
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
32
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -42,10 +48,16 @@ class GenerateReqInput:
42
48
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
43
49
  # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
44
50
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
45
- # The image input. It can be a file name, a url, or base64 encoded string.
46
- # See also python/sglang/srt/utils.py:load_image.
47
- image_data: Optional[Union[List[str], str]] = None
48
- # The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
51
+ # The image input. It can be an image instance, file name, URL, or base64 encoded string.
52
+ # Can be formatted as:
53
+ # - Single image for a single request
54
+ # - List of images (one per request in a batch)
55
+ # - List of lists of images (multiple images per request)
56
+ # See also python/sglang/srt/utils.py:load_image for more details.
57
+ image_data: Optional[
58
+ Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
59
+ ] = None
60
+ # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
49
61
  audio_data: Optional[Union[List[str], str]] = None
50
62
  # The sampling_params. See descriptions below.
51
63
  sampling_params: Optional[Union[List[Dict], Dict]] = None
@@ -83,7 +95,36 @@ class GenerateReqInput:
83
95
  # Whether to return hidden states
84
96
  return_hidden_states: bool = False
85
97
 
98
+ # For disaggregated inference
99
+ bootstrap_host: Optional[str] = None
100
+ bootstrap_room: Optional[int] = None
101
+
86
102
  def normalize_batch_and_arguments(self):
103
+ """
104
+ Normalize the batch size and arguments for the request.
105
+
106
+ This method resolves various input formats and ensures all parameters
107
+ are properly formatted as either single values or batches depending on the input.
108
+ It also handles parallel sampling expansion and sets default values for
109
+ unspecified parameters.
110
+
111
+ Raises:
112
+ ValueError: If inputs are not properly specified (e.g., none or all of
113
+ text, input_ids, input_embeds are provided)
114
+ """
115
+ self._validate_inputs()
116
+ self._determine_batch_size()
117
+ self._handle_parallel_sampling()
118
+
119
+ if self.is_single:
120
+ self._normalize_single_inputs()
121
+ else:
122
+ self._normalize_batch_inputs()
123
+
124
+ self._validate_session_params()
125
+
126
+ def _validate_inputs(self):
127
+ """Validate that the input configuration is valid."""
87
128
  if (
88
129
  self.text is None and self.input_ids is None and self.input_embeds is None
89
130
  ) or (
@@ -95,7 +136,8 @@ class GenerateReqInput:
95
136
  "Either text, input_ids or input_embeds should be provided."
96
137
  )
97
138
 
98
- # Derive the batch size
139
+ def _determine_batch_size(self):
140
+ """Determine if this is a single example or a batch and the batch size."""
99
141
  if self.text is not None:
100
142
  if isinstance(self.text, str):
101
143
  self.is_single = True
@@ -119,21 +161,25 @@ class GenerateReqInput:
119
161
  self.is_single = True
120
162
  self.batch_size = 1
121
163
  else:
164
+ self.is_single = False
122
165
  self.batch_size = len(self.input_embeds)
123
166
 
124
- # Handle parallel sampling
125
- # When parallel sampling is used, we always treat the input as a batch.
167
+ def _handle_parallel_sampling(self):
168
+ """Handle parallel sampling parameters and adjust batch size if needed."""
169
+ # Determine parallel sample count
126
170
  if self.sampling_params is None:
127
171
  self.parallel_sample_num = 1
128
172
  elif isinstance(self.sampling_params, dict):
129
173
  self.parallel_sample_num = self.sampling_params.get("n", 1)
130
174
  else: # isinstance(self.sampling_params, list):
131
175
  self.parallel_sample_num = self.sampling_params[0].get("n", 1)
132
- assert all(
133
- self.parallel_sample_num == sampling_params.get("n", 1)
134
- for sampling_params in self.sampling_params
135
- ), "The parallel_sample_num should be the same for all samples in sample params."
176
+ for sampling_params in self.sampling_params:
177
+ if self.parallel_sample_num != sampling_params.get("n", 1):
178
+ raise ValueError(
179
+ "The parallel_sample_num should be the same for all samples in sample params."
180
+ )
136
181
 
182
+ # If using parallel sampling with a single example, convert to batch
137
183
  if self.parallel_sample_num > 1 and self.is_single:
138
184
  self.is_single = False
139
185
  if self.text is not None:
@@ -141,97 +187,190 @@ class GenerateReqInput:
141
187
  if self.input_ids is not None:
142
188
  self.input_ids = [self.input_ids]
143
189
 
144
- # Fill in default arguments
145
- if self.is_single:
146
- if self.sampling_params is None:
147
- self.sampling_params = {}
148
- if self.rid is None:
149
- self.rid = uuid.uuid4().hex
150
- if self.return_logprob is None:
151
- self.return_logprob = False
152
- if self.logprob_start_len is None:
153
- self.logprob_start_len = -1
154
- if self.top_logprobs_num is None:
155
- self.top_logprobs_num = 0
156
- if not self.token_ids_logprob: # covers both None and []
157
- self.token_ids_logprob = None
190
+ def _normalize_single_inputs(self):
191
+ """Normalize inputs for a single example."""
192
+ if self.sampling_params is None:
193
+ self.sampling_params = {}
194
+ if self.rid is None:
195
+ self.rid = uuid.uuid4().hex
196
+ if self.return_logprob is None:
197
+ self.return_logprob = False
198
+ if self.logprob_start_len is None:
199
+ self.logprob_start_len = -1
200
+ if self.top_logprobs_num is None:
201
+ self.top_logprobs_num = 0
202
+ if not self.token_ids_logprob: # covers both None and []
203
+ self.token_ids_logprob = None
204
+
205
+ def _normalize_batch_inputs(self):
206
+ """Normalize inputs for a batch of examples, including parallel sampling expansion."""
207
+ # Calculate expanded batch size
208
+ if self.parallel_sample_num == 1:
209
+ num = self.batch_size
158
210
  else:
159
- if self.parallel_sample_num == 1:
160
- num = self.batch_size
211
+ # Expand parallel_sample_num
212
+ num = self.batch_size * self.parallel_sample_num
213
+
214
+ # Expand input based on type
215
+ self._expand_inputs(num)
216
+ self._normalize_lora_paths(num)
217
+ self._normalize_image_data(num)
218
+ self._normalize_audio_data(num)
219
+ self._normalize_sampling_params(num)
220
+ self._normalize_rid(num)
221
+ self._normalize_logprob_params(num)
222
+ self._normalize_custom_logit_processor(num)
223
+
224
+ def _expand_inputs(self, num):
225
+ """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
226
+ if self.text is not None:
227
+ if not isinstance(self.text, list):
228
+ raise ValueError("Text should be a list for batch processing.")
229
+ self.text = self.text * self.parallel_sample_num
230
+ elif self.input_ids is not None:
231
+ if not isinstance(self.input_ids, list) or not isinstance(
232
+ self.input_ids[0], list
233
+ ):
234
+ raise ValueError(
235
+ "input_ids should be a list of lists for batch processing."
236
+ )
237
+ self.input_ids = self.input_ids * self.parallel_sample_num
238
+ elif self.input_embeds is not None:
239
+ if not isinstance(self.input_embeds, list):
240
+ raise ValueError("input_embeds should be a list for batch processing.")
241
+ self.input_embeds = self.input_embeds * self.parallel_sample_num
242
+
243
+ def _normalize_lora_paths(self, num):
244
+ """Normalize LoRA paths for batch processing."""
245
+ if self.lora_path is not None:
246
+ if isinstance(self.lora_path, str):
247
+ self.lora_path = [self.lora_path] * num
248
+ elif isinstance(self.lora_path, list):
249
+ self.lora_path = self.lora_path * self.parallel_sample_num
161
250
  else:
251
+ raise ValueError("lora_path should be a list or a string.")
252
+
253
+ def _normalize_image_data(self, num):
254
+ """Normalize image data for batch processing."""
255
+ if self.image_data is None:
256
+ self.image_data = [None] * num
257
+ elif not isinstance(self.image_data, list):
258
+ # Single image, convert to list of single-image lists
259
+ self.image_data = [[self.image_data]] * num
260
+ self.modalities = ["image"] * num
261
+ elif isinstance(self.image_data, list):
262
+ if len(self.image_data) != self.batch_size:
263
+ raise ValueError(
264
+ "The length of image_data should be equal to the batch size."
265
+ )
266
+
267
+ self.modalities = []
268
+ if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
269
+ # Already a list of lists, keep as is
270
+ for i in range(len(self.image_data)):
271
+ if self.image_data[i] is None or self.image_data[i] == [None]:
272
+ self.modalities.append(None)
273
+ elif len(self.image_data[i]) == 1:
274
+ self.modalities.append("image")
275
+ elif len(self.image_data[i]) > 1:
276
+ self.modalities.append("multi-images")
162
277
  # Expand parallel_sample_num
163
- num = self.batch_size * self.parallel_sample_num
164
-
165
- if not self.image_data:
166
- self.image_data = [None] * num
167
- elif not isinstance(self.image_data, list):
168
- self.image_data = [self.image_data] * num
169
- elif isinstance(self.image_data, list):
170
- pass
171
-
172
- if self.audio_data is None:
173
- self.audio_data = [None] * num
174
- elif not isinstance(self.audio_data, list):
175
- self.audio_data = [self.audio_data] * num
176
- elif isinstance(self.audio_data, list):
177
- pass
178
-
179
- if self.sampling_params is None:
180
- self.sampling_params = [{}] * num
181
- elif not isinstance(self.sampling_params, list):
182
- self.sampling_params = [self.sampling_params] * num
183
-
184
- if self.rid is None:
185
- self.rid = [uuid.uuid4().hex for _ in range(num)]
278
+ self.image_data = self.image_data * self.parallel_sample_num
279
+ self.modalities = self.modalities * self.parallel_sample_num
186
280
  else:
187
- assert isinstance(self.rid, list), "The rid should be a list."
188
-
189
- if self.return_logprob is None:
190
- self.return_logprob = [False] * num
191
- elif not isinstance(self.return_logprob, list):
192
- self.return_logprob = [self.return_logprob] * num
193
- else:
194
- assert self.parallel_sample_num == 1
195
-
196
- if self.logprob_start_len is None:
197
- self.logprob_start_len = [-1] * num
198
- elif not isinstance(self.logprob_start_len, list):
199
- self.logprob_start_len = [self.logprob_start_len] * num
281
+ # List of images for a batch, wrap each in a list
282
+ wrapped_images = [[img] for img in self.image_data]
283
+ # Expand for parallel sampling
284
+ self.image_data = wrapped_images * self.parallel_sample_num
285
+ self.modalities = ["image"] * num
286
+
287
+ def _normalize_audio_data(self, num):
288
+ """Normalize audio data for batch processing."""
289
+ if self.audio_data is None:
290
+ self.audio_data = [None] * num
291
+ elif not isinstance(self.audio_data, list):
292
+ self.audio_data = [self.audio_data] * num
293
+ elif isinstance(self.audio_data, list):
294
+ self.audio_data = self.audio_data * self.parallel_sample_num
295
+
296
+ def _normalize_sampling_params(self, num):
297
+ """Normalize sampling parameters for batch processing."""
298
+ if self.sampling_params is None:
299
+ self.sampling_params = [{}] * num
300
+ elif isinstance(self.sampling_params, dict):
301
+ self.sampling_params = [self.sampling_params] * num
302
+ else: # Already a list
303
+ self.sampling_params = self.sampling_params * self.parallel_sample_num
304
+
305
+ def _normalize_rid(self, num):
306
+ """Normalize request IDs for batch processing."""
307
+ if self.rid is None:
308
+ self.rid = [uuid.uuid4().hex for _ in range(num)]
309
+ elif not isinstance(self.rid, list):
310
+ raise ValueError("The rid should be a list for batch processing.")
311
+
312
+ def _normalize_logprob_params(self, num):
313
+ """Normalize logprob-related parameters for batch processing."""
314
+
315
+ # Helper function to normalize a parameter
316
+ def normalize_param(param, default_value, param_name):
317
+ if param is None:
318
+ return [default_value] * num
319
+ elif not isinstance(param, list):
320
+ return [param] * num
200
321
  else:
201
- assert self.parallel_sample_num == 1
322
+ if self.parallel_sample_num > 1:
323
+ raise ValueError(
324
+ f"Cannot use list {param_name} with parallel_sample_num > 1"
325
+ )
326
+ return param
327
+
328
+ # Normalize each logprob parameter
329
+ self.return_logprob = normalize_param(
330
+ self.return_logprob, False, "return_logprob"
331
+ )
332
+ self.logprob_start_len = normalize_param(
333
+ self.logprob_start_len, -1, "logprob_start_len"
334
+ )
335
+ self.top_logprobs_num = normalize_param(
336
+ self.top_logprobs_num, 0, "top_logprobs_num"
337
+ )
202
338
 
203
- if self.top_logprobs_num is None:
204
- self.top_logprobs_num = [0] * num
205
- elif not isinstance(self.top_logprobs_num, list):
206
- self.top_logprobs_num = [self.top_logprobs_num] * num
207
- else:
208
- assert self.parallel_sample_num == 1
209
-
210
- if not self.token_ids_logprob: # covers both None and []
211
- self.token_ids_logprob = [None] * num
212
- elif not isinstance(self.token_ids_logprob, list):
213
- self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
214
- elif not isinstance(self.token_ids_logprob[0], list):
215
- self.token_ids_logprob = [
216
- copy.deepcopy(self.token_ids_logprob) for _ in range(num)
217
- ]
218
- else:
219
- assert self.parallel_sample_num == 1
339
+ # Handle token_ids_logprob specially due to its nested structure
340
+ if not self.token_ids_logprob: # covers both None and []
341
+ self.token_ids_logprob = [None] * num
342
+ elif not isinstance(self.token_ids_logprob, list):
343
+ self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
344
+ elif not isinstance(self.token_ids_logprob[0], list):
345
+ self.token_ids_logprob = [
346
+ copy.deepcopy(self.token_ids_logprob) for _ in range(num)
347
+ ]
348
+ elif self.parallel_sample_num > 1:
349
+ raise ValueError(
350
+ "Cannot use list token_ids_logprob with parallel_sample_num > 1"
351
+ )
220
352
 
221
- if self.custom_logit_processor is None:
222
- self.custom_logit_processor = [None] * num
223
- elif not isinstance(self.custom_logit_processor, list):
224
- self.custom_logit_processor = [self.custom_logit_processor] * num
225
- else:
226
- assert self.parallel_sample_num == 1
353
+ def _normalize_custom_logit_processor(self, num):
354
+ """Normalize custom logit processor for batch processing."""
355
+ if self.custom_logit_processor is None:
356
+ self.custom_logit_processor = [None] * num
357
+ elif not isinstance(self.custom_logit_processor, list):
358
+ self.custom_logit_processor = [self.custom_logit_processor] * num
359
+ elif self.parallel_sample_num > 1:
360
+ raise ValueError(
361
+ "Cannot use list custom_logit_processor with parallel_sample_num > 1"
362
+ )
227
363
 
228
- # Other checks
364
+ def _validate_session_params(self):
365
+ """Validate that session parameters are properly formatted."""
229
366
  if self.session_params is not None:
230
- assert isinstance(self.session_params, dict) or isinstance(
367
+ if not isinstance(self.session_params, dict) and not isinstance(
231
368
  self.session_params[0], dict
232
- )
369
+ ):
370
+ raise ValueError("Session params must be a dict or a list of dicts.")
233
371
 
234
372
  def regenerate_rid(self):
373
+ """Generate a new request ID and return it."""
235
374
  self.rid = uuid.uuid4().hex
236
375
  return self.rid
237
376
 
@@ -300,13 +439,24 @@ class TokenizedGenerateReqInput:
300
439
  # Whether to return hidden states
301
440
  return_hidden_states: bool = False
302
441
 
442
+ # For disaggregated inference
443
+ bootstrap_host: Optional[str] = None
444
+ bootstrap_room: Optional[int] = None
445
+
303
446
 
304
447
  @dataclass
305
448
  class EmbeddingReqInput:
306
449
  # The input prompt. It can be a single prompt or a batch of prompts.
307
450
  text: Optional[Union[List[str], str]] = None
308
- # The image input. It can be a file name, a url, or base64 encoded string.
309
- image_data: Optional[Union[List[str], str]] = None
451
+ # The image input. It can be an image instance, file name, URL, or base64 encoded string.
452
+ # Can be formatted as:
453
+ # - Single image for a single request
454
+ # - List of images (one per request in a batch)
455
+ # - List of lists of images (multiple images per request)
456
+ # See also python/sglang/srt/utils.py:load_image for more details.
457
+ image_data: Optional[
458
+ Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
459
+ ] = None
310
460
  # The token ids for text; one can either specify text or input_ids.
311
461
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
312
462
  # The request id.
@@ -550,10 +700,17 @@ class UpdateWeightsFromDistributedReqOutput:
550
700
 
551
701
  @dataclass
552
702
  class UpdateWeightsFromTensorReqInput:
553
- # List containing one serialized Dict[str, torch.Tensor] per TP worker
554
- serialized_named_tensors: List[bytes]
555
- load_format: Optional[str]
556
- flush_cache: bool
703
+ """Update model weights from tensor input.
704
+
705
+ - Tensors are serialized for transmission
706
+ - Data is structured in JSON for easy transmission over HTTP
707
+ """
708
+
709
+ serialized_named_tensors: List[Union[str, bytes]]
710
+ # Optional format specification for loading
711
+ load_format: Optional[str] = None
712
+ # Whether to flush the cache after updating weights
713
+ flush_cache: bool = True
557
714
 
558
715
 
559
716
  @dataclass
@@ -148,7 +148,8 @@ def get_embedding_and_mask(
148
148
  placeholder_tensor,
149
149
  ).unsqueeze(-1)
150
150
 
151
- num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
151
+ num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
152
+
152
153
  if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
153
154
  logger.warning(
154
155
  f"Number of tokens in multimodal embedding does not match those in the input text."
@@ -172,7 +173,7 @@ def get_embedding_and_mask(
172
173
  embedding = embedding[-num_multimodal:, :]
173
174
  else:
174
175
  raise RuntimeError(
175
- "Insufficient multimodal embedding length. This is an internal error"
176
+ f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
176
177
  )
177
178
 
178
179
  return embedding, special_multimodal_mask