sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
11
11
  from sglang.srt.utils import is_cuda_available
12
12
 
13
13
  _is_cuda_available = is_cuda_available()
14
+
14
15
  if _is_cuda_available:
15
16
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
16
17
  else:
17
- from vllm import _custom_ops as ops
18
+ from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
18
19
 
19
20
 
20
21
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
159
160
  )
160
161
  else:
161
162
  self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
162
- ops.rotary_embedding(
163
+ vllm_rotary_embedding(
163
164
  positions,
164
165
  query,
165
166
  key,
@@ -645,7 +646,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
645
646
  cache = torch.cat((cos, sin), dim=-1)
646
647
  return cache
647
648
 
648
- def forward(
649
+ def forward_hip(self, *args, **kwargs):
650
+ return self.forward_native(*args, **kwargs)
651
+
652
+ def forward(self, *args, **kwargs):
653
+ if torch.compiler.is_compiling():
654
+ return self.forward_native(*args, **kwargs)
655
+ if _is_cuda_available:
656
+ return self.forward_cuda(*args, **kwargs)
657
+ else:
658
+ return self.forward_native(*args, **kwargs)
659
+
660
+ def forward_native(
649
661
  self,
650
662
  positions: torch.Tensor,
651
663
  query: torch.Tensor,
@@ -93,28 +93,23 @@ class Sampler(nn.Module):
93
93
  ).clamp(min=torch.finfo(probs.dtype).min)
94
94
 
95
95
  max_top_k_round, batch_size = 32, probs.shape[0]
96
- uniform_samples = torch.rand(
97
- (max_top_k_round, batch_size), device=probs.device
98
- )
99
96
  if sampling_info.need_min_p_sampling:
100
97
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
101
98
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
102
99
  batch_next_token_ids = min_p_sampling_from_probs(
103
- probs, uniform_samples, sampling_info.min_ps
100
+ probs, sampling_info.min_ps
104
101
  )
105
102
  else:
106
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
103
+ # Check Nan will throw exception, only check when crash_on_warnings is True
104
+ check_nan = self.use_nan_detection and crash_on_warnings()
105
+ batch_next_token_ids = top_k_top_p_sampling_from_probs(
107
106
  probs,
108
- uniform_samples,
109
107
  sampling_info.top_ks,
110
108
  sampling_info.top_ps,
111
109
  filter_apply_order="joint",
110
+ check_nan=check_nan,
112
111
  )
113
112
 
114
- if self.use_nan_detection and not torch.all(success):
115
- logger.warning("Detected errors during sampling!")
116
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
117
-
118
113
  elif global_server_args_dict["sampling_backend"] == "pytorch":
119
114
  # A slower fallback implementation with torch native operations.
120
115
  batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
@@ -75,7 +75,7 @@ class BaseLoRABackend:
75
75
  qkv_lora_a: torch.Tensor,
76
76
  qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
77
77
  *args,
78
- **kwargs
78
+ **kwargs,
79
79
  ) -> torch.Tensor:
80
80
  """Run the lora pass for QKV Layer.
81
81
 
@@ -98,7 +98,7 @@ class BaseLoRABackend:
98
98
  gate_up_lora_a: torch.Tensor,
99
99
  gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
100
100
  *args,
101
- **kwargs
101
+ **kwargs,
102
102
  ) -> torch.Tensor:
103
103
  """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
104
104
 
@@ -115,3 +115,19 @@ class BaseLoRABackend:
115
115
 
116
116
  def set_batch_info(self, batch_info: LoRABatchInfo):
117
117
  self.batch_info = batch_info
118
+
119
+
120
+ def get_backend_from_name(name: str) -> BaseLoRABackend:
121
+ """
122
+ Get corresponding backend class from backend's name
123
+ """
124
+ if name == "triton":
125
+ from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
126
+
127
+ return TritonLoRABackend
128
+ elif name == "flashinfer":
129
+ from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
130
+
131
+ return FlashInferLoRABackend
132
+ else:
133
+ raise ValueError(f"Invalid backend: {name}")
@@ -2,7 +2,7 @@ from typing import Tuple
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt.lora.backend import BaseLoRABackend
5
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
6
6
  from sglang.srt.lora.utils import LoRABatchInfo
7
7
  from sglang.srt.utils import is_flashinfer_available
8
8
 
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.lora.backend import BaseLoRABackend
3
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
4
4
  from sglang.srt.lora.triton_ops import (
5
5
  gate_up_lora_b_fwd,
6
6
  qkv_lora_b_fwd,
sglang/srt/lora/layers.py CHANGED
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
16
16
  RowParallelLinear,
17
17
  )
18
18
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
19
- from sglang.srt.lora.backend import BaseLoRABackend
19
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
20
20
 
21
21
 
22
22
  class BaseLayerWithLoRA(nn.Module):
sglang/srt/lora/lora.py CHANGED
@@ -27,7 +27,7 @@ from torch import nn
27
27
 
28
28
  from sglang.srt.configs.load_config import LoadConfig
29
29
  from sglang.srt.hf_transformers_utils import AutoConfig
30
- from sglang.srt.lora.backend import BaseLoRABackend
30
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
31
31
  from sglang.srt.lora.lora_config import LoRAConfig
32
32
  from sglang.srt.model_loader.loader import DefaultModelLoader
33
33
 
@@ -22,7 +22,7 @@ import torch
22
22
 
23
23
  from sglang.srt.configs.load_config import LoadConfig
24
24
  from sglang.srt.hf_transformers_utils import AutoConfig
25
- from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
25
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
26
  from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
27
  from sglang.srt.lora.lora import LoRAAdapter
28
28
  from sglang.srt.lora.lora_config import LoRAConfig
@@ -14,7 +14,6 @@
14
14
  """DetokenizerManager is a process that detokenizes the token ids."""
15
15
 
16
16
  import dataclasses
17
- import json
18
17
  import logging
19
18
  import os
20
19
  import signal
@@ -20,7 +20,13 @@ import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import Any, Dict, List, Literal, Optional, Union
23
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
24
+
25
+ # handle serialization of Image for pydantic
26
+ if TYPE_CHECKING:
27
+ from PIL.Image import Image
28
+ else:
29
+ Image = Any
24
30
 
25
31
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
32
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -42,10 +48,16 @@ class GenerateReqInput:
42
48
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
43
49
  # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
44
50
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
45
- # The image input. It can be a file name, a url, or base64 encoded string.
46
- # See also python/sglang/srt/utils.py:load_image.
47
- image_data: Optional[Union[List[str], str]] = None
48
- # The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
51
+ # The image input. It can be an image instance, file name, URL, or base64 encoded string.
52
+ # Can be formatted as:
53
+ # - Single image for a single request
54
+ # - List of images (one per request in a batch)
55
+ # - List of lists of images (multiple images per request)
56
+ # See also python/sglang/srt/utils.py:load_image for more details.
57
+ image_data: Optional[
58
+ Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
59
+ ] = None
60
+ # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
49
61
  audio_data: Optional[Union[List[str], str]] = None
50
62
  # The sampling_params. See descriptions below.
51
63
  sampling_params: Optional[Union[List[Dict], Dict]] = None
@@ -83,7 +95,36 @@ class GenerateReqInput:
83
95
  # Whether to return hidden states
84
96
  return_hidden_states: bool = False
85
97
 
98
+ # For disaggregated inference
99
+ bootstrap_host: Optional[str] = None
100
+ bootstrap_room: Optional[int] = None
101
+
86
102
  def normalize_batch_and_arguments(self):
103
+ """
104
+ Normalize the batch size and arguments for the request.
105
+
106
+ This method resolves various input formats and ensures all parameters
107
+ are properly formatted as either single values or batches depending on the input.
108
+ It also handles parallel sampling expansion and sets default values for
109
+ unspecified parameters.
110
+
111
+ Raises:
112
+ ValueError: If inputs are not properly specified (e.g., none or all of
113
+ text, input_ids, input_embeds are provided)
114
+ """
115
+ self._validate_inputs()
116
+ self._determine_batch_size()
117
+ self._handle_parallel_sampling()
118
+
119
+ if self.is_single:
120
+ self._normalize_single_inputs()
121
+ else:
122
+ self._normalize_batch_inputs()
123
+
124
+ self._validate_session_params()
125
+
126
+ def _validate_inputs(self):
127
+ """Validate that the input configuration is valid."""
87
128
  if (
88
129
  self.text is None and self.input_ids is None and self.input_embeds is None
89
130
  ) or (
@@ -95,7 +136,8 @@ class GenerateReqInput:
95
136
  "Either text, input_ids or input_embeds should be provided."
96
137
  )
97
138
 
98
- # Derive the batch size
139
+ def _determine_batch_size(self):
140
+ """Determine if this is a single example or a batch and the batch size."""
99
141
  if self.text is not None:
100
142
  if isinstance(self.text, str):
101
143
  self.is_single = True
@@ -119,21 +161,25 @@ class GenerateReqInput:
119
161
  self.is_single = True
120
162
  self.batch_size = 1
121
163
  else:
164
+ self.is_single = False
122
165
  self.batch_size = len(self.input_embeds)
123
166
 
124
- # Handle parallel sampling
125
- # When parallel sampling is used, we always treat the input as a batch.
167
+ def _handle_parallel_sampling(self):
168
+ """Handle parallel sampling parameters and adjust batch size if needed."""
169
+ # Determine parallel sample count
126
170
  if self.sampling_params is None:
127
171
  self.parallel_sample_num = 1
128
172
  elif isinstance(self.sampling_params, dict):
129
173
  self.parallel_sample_num = self.sampling_params.get("n", 1)
130
174
  else: # isinstance(self.sampling_params, list):
131
175
  self.parallel_sample_num = self.sampling_params[0].get("n", 1)
132
- assert all(
133
- self.parallel_sample_num == sampling_params.get("n", 1)
134
- for sampling_params in self.sampling_params
135
- ), "The parallel_sample_num should be the same for all samples in sample params."
176
+ for sampling_params in self.sampling_params:
177
+ if self.parallel_sample_num != sampling_params.get("n", 1):
178
+ raise ValueError(
179
+ "The parallel_sample_num should be the same for all samples in sample params."
180
+ )
136
181
 
182
+ # If using parallel sampling with a single example, convert to batch
137
183
  if self.parallel_sample_num > 1 and self.is_single:
138
184
  self.is_single = False
139
185
  if self.text is not None:
@@ -141,97 +187,190 @@ class GenerateReqInput:
141
187
  if self.input_ids is not None:
142
188
  self.input_ids = [self.input_ids]
143
189
 
144
- # Fill in default arguments
145
- if self.is_single:
146
- if self.sampling_params is None:
147
- self.sampling_params = {}
148
- if self.rid is None:
149
- self.rid = uuid.uuid4().hex
150
- if self.return_logprob is None:
151
- self.return_logprob = False
152
- if self.logprob_start_len is None:
153
- self.logprob_start_len = -1
154
- if self.top_logprobs_num is None:
155
- self.top_logprobs_num = 0
156
- if not self.token_ids_logprob: # covers both None and []
157
- self.token_ids_logprob = None
190
+ def _normalize_single_inputs(self):
191
+ """Normalize inputs for a single example."""
192
+ if self.sampling_params is None:
193
+ self.sampling_params = {}
194
+ if self.rid is None:
195
+ self.rid = uuid.uuid4().hex
196
+ if self.return_logprob is None:
197
+ self.return_logprob = False
198
+ if self.logprob_start_len is None:
199
+ self.logprob_start_len = -1
200
+ if self.top_logprobs_num is None:
201
+ self.top_logprobs_num = 0
202
+ if not self.token_ids_logprob: # covers both None and []
203
+ self.token_ids_logprob = None
204
+
205
+ def _normalize_batch_inputs(self):
206
+ """Normalize inputs for a batch of examples, including parallel sampling expansion."""
207
+ # Calculate expanded batch size
208
+ if self.parallel_sample_num == 1:
209
+ num = self.batch_size
158
210
  else:
159
- if self.parallel_sample_num == 1:
160
- num = self.batch_size
211
+ # Expand parallel_sample_num
212
+ num = self.batch_size * self.parallel_sample_num
213
+
214
+ # Expand input based on type
215
+ self._expand_inputs(num)
216
+ self._normalize_lora_paths(num)
217
+ self._normalize_image_data(num)
218
+ self._normalize_audio_data(num)
219
+ self._normalize_sampling_params(num)
220
+ self._normalize_rid(num)
221
+ self._normalize_logprob_params(num)
222
+ self._normalize_custom_logit_processor(num)
223
+
224
+ def _expand_inputs(self, num):
225
+ """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
226
+ if self.text is not None:
227
+ if not isinstance(self.text, list):
228
+ raise ValueError("Text should be a list for batch processing.")
229
+ self.text = self.text * self.parallel_sample_num
230
+ elif self.input_ids is not None:
231
+ if not isinstance(self.input_ids, list) or not isinstance(
232
+ self.input_ids[0], list
233
+ ):
234
+ raise ValueError(
235
+ "input_ids should be a list of lists for batch processing."
236
+ )
237
+ self.input_ids = self.input_ids * self.parallel_sample_num
238
+ elif self.input_embeds is not None:
239
+ if not isinstance(self.input_embeds, list):
240
+ raise ValueError("input_embeds should be a list for batch processing.")
241
+ self.input_embeds = self.input_embeds * self.parallel_sample_num
242
+
243
+ def _normalize_lora_paths(self, num):
244
+ """Normalize LoRA paths for batch processing."""
245
+ if self.lora_path is not None:
246
+ if isinstance(self.lora_path, str):
247
+ self.lora_path = [self.lora_path] * num
248
+ elif isinstance(self.lora_path, list):
249
+ self.lora_path = self.lora_path * self.parallel_sample_num
161
250
  else:
251
+ raise ValueError("lora_path should be a list or a string.")
252
+
253
+ def _normalize_image_data(self, num):
254
+ """Normalize image data for batch processing."""
255
+ if self.image_data is None:
256
+ self.image_data = [None] * num
257
+ elif not isinstance(self.image_data, list):
258
+ # Single image, convert to list of single-image lists
259
+ self.image_data = [[self.image_data]] * num
260
+ self.modalities = ["image"] * num
261
+ elif isinstance(self.image_data, list):
262
+ if len(self.image_data) != self.batch_size:
263
+ raise ValueError(
264
+ "The length of image_data should be equal to the batch size."
265
+ )
266
+
267
+ self.modalities = []
268
+ if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
269
+ # Already a list of lists, keep as is
270
+ for i in range(len(self.image_data)):
271
+ if self.image_data[i] is None or self.image_data[i] == [None]:
272
+ self.modalities.append(None)
273
+ elif len(self.image_data[i]) == 1:
274
+ self.modalities.append("image")
275
+ elif len(self.image_data[i]) > 1:
276
+ self.modalities.append("multi-images")
162
277
  # Expand parallel_sample_num
163
- num = self.batch_size * self.parallel_sample_num
164
-
165
- if not self.image_data:
166
- self.image_data = [None] * num
167
- elif not isinstance(self.image_data, list):
168
- self.image_data = [self.image_data] * num
169
- elif isinstance(self.image_data, list):
170
- pass
171
-
172
- if self.audio_data is None:
173
- self.audio_data = [None] * num
174
- elif not isinstance(self.audio_data, list):
175
- self.audio_data = [self.audio_data] * num
176
- elif isinstance(self.audio_data, list):
177
- pass
178
-
179
- if self.sampling_params is None:
180
- self.sampling_params = [{}] * num
181
- elif not isinstance(self.sampling_params, list):
182
- self.sampling_params = [self.sampling_params] * num
183
-
184
- if self.rid is None:
185
- self.rid = [uuid.uuid4().hex for _ in range(num)]
278
+ self.image_data = self.image_data * self.parallel_sample_num
279
+ self.modalities = self.modalities * self.parallel_sample_num
186
280
  else:
187
- assert isinstance(self.rid, list), "The rid should be a list."
188
-
189
- if self.return_logprob is None:
190
- self.return_logprob = [False] * num
191
- elif not isinstance(self.return_logprob, list):
192
- self.return_logprob = [self.return_logprob] * num
193
- else:
194
- assert self.parallel_sample_num == 1
195
-
196
- if self.logprob_start_len is None:
197
- self.logprob_start_len = [-1] * num
198
- elif not isinstance(self.logprob_start_len, list):
199
- self.logprob_start_len = [self.logprob_start_len] * num
281
+ # List of images for a batch, wrap each in a list
282
+ wrapped_images = [[img] for img in self.image_data]
283
+ # Expand for parallel sampling
284
+ self.image_data = wrapped_images * self.parallel_sample_num
285
+ self.modalities = ["image"] * num
286
+
287
+ def _normalize_audio_data(self, num):
288
+ """Normalize audio data for batch processing."""
289
+ if self.audio_data is None:
290
+ self.audio_data = [None] * num
291
+ elif not isinstance(self.audio_data, list):
292
+ self.audio_data = [self.audio_data] * num
293
+ elif isinstance(self.audio_data, list):
294
+ self.audio_data = self.audio_data * self.parallel_sample_num
295
+
296
+ def _normalize_sampling_params(self, num):
297
+ """Normalize sampling parameters for batch processing."""
298
+ if self.sampling_params is None:
299
+ self.sampling_params = [{}] * num
300
+ elif isinstance(self.sampling_params, dict):
301
+ self.sampling_params = [self.sampling_params] * num
302
+ else: # Already a list
303
+ self.sampling_params = self.sampling_params * self.parallel_sample_num
304
+
305
+ def _normalize_rid(self, num):
306
+ """Normalize request IDs for batch processing."""
307
+ if self.rid is None:
308
+ self.rid = [uuid.uuid4().hex for _ in range(num)]
309
+ elif not isinstance(self.rid, list):
310
+ raise ValueError("The rid should be a list for batch processing.")
311
+
312
+ def _normalize_logprob_params(self, num):
313
+ """Normalize logprob-related parameters for batch processing."""
314
+
315
+ # Helper function to normalize a parameter
316
+ def normalize_param(param, default_value, param_name):
317
+ if param is None:
318
+ return [default_value] * num
319
+ elif not isinstance(param, list):
320
+ return [param] * num
200
321
  else:
201
- assert self.parallel_sample_num == 1
322
+ if self.parallel_sample_num > 1:
323
+ raise ValueError(
324
+ f"Cannot use list {param_name} with parallel_sample_num > 1"
325
+ )
326
+ return param
327
+
328
+ # Normalize each logprob parameter
329
+ self.return_logprob = normalize_param(
330
+ self.return_logprob, False, "return_logprob"
331
+ )
332
+ self.logprob_start_len = normalize_param(
333
+ self.logprob_start_len, -1, "logprob_start_len"
334
+ )
335
+ self.top_logprobs_num = normalize_param(
336
+ self.top_logprobs_num, 0, "top_logprobs_num"
337
+ )
202
338
 
203
- if self.top_logprobs_num is None:
204
- self.top_logprobs_num = [0] * num
205
- elif not isinstance(self.top_logprobs_num, list):
206
- self.top_logprobs_num = [self.top_logprobs_num] * num
207
- else:
208
- assert self.parallel_sample_num == 1
209
-
210
- if not self.token_ids_logprob: # covers both None and []
211
- self.token_ids_logprob = [None] * num
212
- elif not isinstance(self.token_ids_logprob, list):
213
- self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
214
- elif not isinstance(self.token_ids_logprob[0], list):
215
- self.token_ids_logprob = [
216
- copy.deepcopy(self.token_ids_logprob) for _ in range(num)
217
- ]
218
- else:
219
- assert self.parallel_sample_num == 1
339
+ # Handle token_ids_logprob specially due to its nested structure
340
+ if not self.token_ids_logprob: # covers both None and []
341
+ self.token_ids_logprob = [None] * num
342
+ elif not isinstance(self.token_ids_logprob, list):
343
+ self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
344
+ elif not isinstance(self.token_ids_logprob[0], list):
345
+ self.token_ids_logprob = [
346
+ copy.deepcopy(self.token_ids_logprob) for _ in range(num)
347
+ ]
348
+ elif self.parallel_sample_num > 1:
349
+ raise ValueError(
350
+ "Cannot use list token_ids_logprob with parallel_sample_num > 1"
351
+ )
220
352
 
221
- if self.custom_logit_processor is None:
222
- self.custom_logit_processor = [None] * num
223
- elif not isinstance(self.custom_logit_processor, list):
224
- self.custom_logit_processor = [self.custom_logit_processor] * num
225
- else:
226
- assert self.parallel_sample_num == 1
353
+ def _normalize_custom_logit_processor(self, num):
354
+ """Normalize custom logit processor for batch processing."""
355
+ if self.custom_logit_processor is None:
356
+ self.custom_logit_processor = [None] * num
357
+ elif not isinstance(self.custom_logit_processor, list):
358
+ self.custom_logit_processor = [self.custom_logit_processor] * num
359
+ elif self.parallel_sample_num > 1:
360
+ raise ValueError(
361
+ "Cannot use list custom_logit_processor with parallel_sample_num > 1"
362
+ )
227
363
 
228
- # Other checks
364
+ def _validate_session_params(self):
365
+ """Validate that session parameters are properly formatted."""
229
366
  if self.session_params is not None:
230
- assert isinstance(self.session_params, dict) or isinstance(
367
+ if not isinstance(self.session_params, dict) and not isinstance(
231
368
  self.session_params[0], dict
232
- )
369
+ ):
370
+ raise ValueError("Session params must be a dict or a list of dicts.")
233
371
 
234
372
  def regenerate_rid(self):
373
+ """Generate a new request ID and return it."""
235
374
  self.rid = uuid.uuid4().hex
236
375
  return self.rid
237
376
 
@@ -300,13 +439,24 @@ class TokenizedGenerateReqInput:
300
439
  # Whether to return hidden states
301
440
  return_hidden_states: bool = False
302
441
 
442
+ # For disaggregated inference
443
+ bootstrap_host: Optional[str] = None
444
+ bootstrap_room: Optional[int] = None
445
+
303
446
 
304
447
  @dataclass
305
448
  class EmbeddingReqInput:
306
449
  # The input prompt. It can be a single prompt or a batch of prompts.
307
450
  text: Optional[Union[List[str], str]] = None
308
- # The image input. It can be a file name, a url, or base64 encoded string.
309
- image_data: Optional[Union[List[str], str]] = None
451
+ # The image input. It can be an image instance, file name, URL, or base64 encoded string.
452
+ # Can be formatted as:
453
+ # - Single image for a single request
454
+ # - List of images (one per request in a batch)
455
+ # - List of lists of images (multiple images per request)
456
+ # See also python/sglang/srt/utils.py:load_image for more details.
457
+ image_data: Optional[
458
+ Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
459
+ ] = None
310
460
  # The token ids for text; one can either specify text or input_ids.
311
461
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
312
462
  # The request id.
@@ -550,10 +700,17 @@ class UpdateWeightsFromDistributedReqOutput:
550
700
 
551
701
  @dataclass
552
702
  class UpdateWeightsFromTensorReqInput:
553
- # List containing one serialized Dict[str, torch.Tensor] per TP worker
554
- serialized_named_tensors: List[bytes]
555
- load_format: Optional[str]
556
- flush_cache: bool
703
+ """Update model weights from tensor input.
704
+
705
+ - Tensors are serialized for transmission
706
+ - Data is structured in JSON for easy transmission over HTTP
707
+ """
708
+
709
+ serialized_named_tensors: List[Union[str, bytes]]
710
+ # Optional format specification for loading
711
+ load_format: Optional[str] = None
712
+ # Whether to flush the cache after updating weights
713
+ flush_cache: bool = True
557
714
 
558
715
 
559
716
  @dataclass
@@ -677,6 +834,7 @@ class ProfileReq:
677
834
  activities: Optional[List[str]] = None
678
835
  with_stack: Optional[bool] = None
679
836
  record_shapes: Optional[bool] = None
837
+ profile_id: Optional[str] = None
680
838
 
681
839
 
682
840
  @dataclass
@@ -1,7 +1,8 @@
1
1
  """
2
- Multi-modality utils
2
+ Multi-modality utils
3
3
  """
4
4
 
5
+ import logging
5
6
  from abc import abstractmethod
6
7
  from typing import Callable, List, Optional, Tuple
7
8
 
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
12
13
  MultimodalDataItem,
13
14
  MultimodalInputs,
14
15
  global_server_args_dict,
15
- logger,
16
16
  )
17
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
18
  from sglang.srt.utils import print_warning_once
19
- from sglang.utils import logger
19
+
20
+ logger = logging.getLogger(__name__)
20
21
 
21
22
 
22
23
  class MultiModalityDataPaddingPattern:
@@ -148,7 +149,8 @@ def get_embedding_and_mask(
148
149
  placeholder_tensor,
149
150
  ).unsqueeze(-1)
150
151
 
151
- num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
152
+ num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
153
+
152
154
  if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
153
155
  logger.warning(
154
156
  f"Number of tokens in multimodal embedding does not match those in the input text."
@@ -172,7 +174,7 @@ def get_embedding_and_mask(
172
174
  embedding = embedding[-num_multimodal:, :]
173
175
  else:
174
176
  raise RuntimeError(
175
- "Insufficient multimodal embedding length. This is an internal error"
177
+ f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
176
178
  )
177
179
 
178
180
  return embedding, special_multimodal_mask
@@ -5,8 +5,6 @@ import logging
5
5
  import pkgutil
6
6
  from functools import lru_cache
7
7
 
8
- from transformers import PROCESSOR_MAPPING
9
-
10
8
  from sglang.srt.managers.multimodal_processors.base_processor import (
11
9
  BaseMultimodalProcessor,
12
10
  )