sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ The definition of objects transfered between different
18
18
  processes (TokenizerManager, DetokenizerManager, Controller).
19
19
  """
20
20
 
21
- import copy
22
21
  import uuid
23
22
  from dataclasses import dataclass
24
23
  from typing import Dict, List, Optional, Union
@@ -37,7 +36,7 @@ class GenerateReqInput:
37
36
  # See also python/sglang/srt/utils.py:load_image.
38
37
  image_data: Optional[Union[List[str], str]] = None
39
38
  # The sampling_params. See descriptions below.
40
- sampling_params: Union[List[Dict], Dict] = None
39
+ sampling_params: Optional[Union[List[Dict], Dict]] = None
41
40
  # The request id.
42
41
  rid: Optional[Union[List[str], str]] = None
43
42
  # Whether to return logprobs.
@@ -53,9 +52,6 @@ class GenerateReqInput:
53
52
  stream: bool = False
54
53
  # The modalities of the image data [image, multi-images, video]
55
54
  modalities: Optional[List[str]] = None
56
-
57
- is_single: bool = True
58
-
59
55
  # LoRA related
60
56
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
61
57
 
@@ -65,19 +61,41 @@ class GenerateReqInput:
65
61
  ):
66
62
  raise ValueError("Either text or input_ids should be provided.")
67
63
 
68
- if (
69
- isinstance(self.sampling_params, dict)
70
- and self.sampling_params.get("n", 1) != 1
71
- ):
72
- is_single = False
64
+ self.is_single = False
65
+ if self.text is not None:
66
+ if isinstance(self.text, str):
67
+ self.is_single = True
68
+ self.batch_size = 1
69
+ else:
70
+ self.batch_size = len(self.text)
73
71
  else:
74
- if self.text is not None:
75
- is_single = isinstance(self.text, str)
72
+ if isinstance(self.input_ids[0], int):
73
+ self.is_single = True
74
+ self.batch_size = 1
76
75
  else:
77
- is_single = isinstance(self.input_ids[0], int)
78
- self.is_single = is_single
79
-
80
- if is_single:
76
+ self.batch_size = len(self.input_ids)
77
+
78
+ if self.sampling_params is None:
79
+ self.parallel_sample_num = 1
80
+ elif isinstance(self.sampling_params, dict):
81
+ self.parallel_sample_num = self.sampling_params.get("n", 1)
82
+ else: # isinstance(self.sampling_params, list):
83
+ self.parallel_sample_num = self.sampling_params[0].get("n", 1)
84
+ for sp in self.sampling_params:
85
+ # TODO cope with the case that the parallel_sample_num is different for different samples
86
+ assert self.parallel_sample_num == sp.get(
87
+ "n", 1
88
+ ), "The parallel_sample_num should be the same for all samples in sample params."
89
+
90
+ if self.parallel_sample_num > 1:
91
+ if self.is_single:
92
+ self.is_single = False
93
+ if self.text is not None:
94
+ self.text = [self.text]
95
+ if self.input_ids is not None:
96
+ self.input_ids = [self.input_ids]
97
+
98
+ if self.is_single:
81
99
  if self.sampling_params is None:
82
100
  self.sampling_params = {}
83
101
  if self.rid is None:
@@ -89,79 +107,54 @@ class GenerateReqInput:
89
107
  if self.top_logprobs_num is None:
90
108
  self.top_logprobs_num = 0
91
109
  else:
92
- parallel_sample_num_list = []
93
- if isinstance(self.sampling_params, dict):
94
- parallel_sample_num = self.sampling_params.get("n", 1)
95
- elif isinstance(self.sampling_params, list):
96
- for sp in self.sampling_params:
97
- parallel_sample_num = sp.get("n", 1)
98
- parallel_sample_num_list.append(parallel_sample_num)
99
- parallel_sample_num = max(parallel_sample_num_list)
100
- all_equal = all(
101
- element == parallel_sample_num
102
- for element in parallel_sample_num_list
103
- )
104
- if parallel_sample_num > 1 and (not all_equal):
105
- # TODO cope with the case that the parallel_sample_num is different for different samples
106
- raise ValueError(
107
- "The parallel_sample_num should be the same for all samples in sample params."
108
- )
110
+ if self.parallel_sample_num == 1:
111
+ num = self.batch_size
109
112
  else:
110
- parallel_sample_num = 1
111
- self.parallel_sample_num = parallel_sample_num
112
-
113
- if parallel_sample_num != 1:
114
- # parallel sampling +1 represents the original prefill stage
115
- num = parallel_sample_num + 1
116
- if isinstance(self.text, list):
117
- # suppot batch operation
118
- self.batch_size = len(self.text)
119
- num = num * len(self.text)
120
- elif isinstance(self.input_ids, list) and isinstance(
121
- self.input_ids[0], list
122
- ):
123
- self.batch_size = len(self.input_ids)
124
- num = num * len(self.input_ids)
125
- else:
126
- self.batch_size = 1
127
- else:
128
- # support select operation
129
- num = len(self.text) if self.text is not None else len(self.input_ids)
130
- self.batch_size = num
113
+ # FIXME support cascade inference
114
+ # first bs samples are used for caching the prefix for parallel sampling
115
+ num = self.batch_size + self.parallel_sample_num * self.batch_size
131
116
 
132
117
  if self.image_data is None:
133
118
  self.image_data = [None] * num
134
119
  elif not isinstance(self.image_data, list):
135
120
  self.image_data = [self.image_data] * num
136
121
  elif isinstance(self.image_data, list):
137
- # multi-image with n > 1
122
+ # FIXME incorrect order for duplication
138
123
  self.image_data = self.image_data * num
139
124
 
140
125
  if self.sampling_params is None:
141
126
  self.sampling_params = [{}] * num
142
127
  elif not isinstance(self.sampling_params, list):
143
128
  self.sampling_params = [self.sampling_params] * num
129
+ else:
130
+ assert self.parallel_sample_num == 1
144
131
 
145
132
  if self.rid is None:
146
133
  self.rid = [uuid.uuid4().hex for _ in range(num)]
147
134
  else:
148
- if not isinstance(self.rid, list):
149
- raise ValueError("The rid should be a list.")
135
+ assert isinstance(self.rid, list), "The rid should be a list."
136
+ assert self.parallel_sample_num == 1
150
137
 
151
138
  if self.return_logprob is None:
152
139
  self.return_logprob = [False] * num
153
140
  elif not isinstance(self.return_logprob, list):
154
141
  self.return_logprob = [self.return_logprob] * num
142
+ else:
143
+ assert self.parallel_sample_num == 1
155
144
 
156
145
  if self.logprob_start_len is None:
157
146
  self.logprob_start_len = [-1] * num
158
147
  elif not isinstance(self.logprob_start_len, list):
159
148
  self.logprob_start_len = [self.logprob_start_len] * num
149
+ else:
150
+ assert self.parallel_sample_num == 1
160
151
 
161
152
  if self.top_logprobs_num is None:
162
153
  self.top_logprobs_num = [0] * num
163
154
  elif not isinstance(self.top_logprobs_num, list):
164
155
  self.top_logprobs_num = [self.top_logprobs_num] * num
156
+ else:
157
+ assert self.parallel_sample_num == 1
165
158
 
166
159
 
167
160
  @dataclass
@@ -172,12 +165,8 @@ class TokenizedGenerateReqInput:
172
165
  input_text: str
173
166
  # The input token ids
174
167
  input_ids: List[int]
175
- # The pixel values for input images
176
- pixel_values: List[float]
177
- # The hash values of input images
178
- image_hashes: List[int]
179
- # The image sizes
180
- image_sizes: List[List[int]]
168
+ # The image input
169
+ image_inputs: dict
181
170
  # The sampling parameters
182
171
  sampling_params: SamplingParams
183
172
  # Whether to return the logprobs
@@ -188,8 +177,6 @@ class TokenizedGenerateReqInput:
188
177
  top_logprobs_num: int
189
178
  # Whether to stream output
190
179
  stream: bool
191
- # Modalities of the input images
192
- modalites: Optional[List[str]] = None
193
180
 
194
181
  # LoRA related
195
182
  lora_path: Optional[str] = None # None means just use the base model
@@ -206,8 +193,6 @@ class EmbeddingReqInput:
206
193
  # Dummy sampling params for compatibility
207
194
  sampling_params: Union[List[Dict], Dict] = None
208
195
 
209
- is_single: bool = True
210
-
211
196
  def post_init(self):
212
197
  if (self.text is None and self.input_ids is None) or (
213
198
  self.text is not None and self.input_ids is not None
@@ -215,12 +200,11 @@ class EmbeddingReqInput:
215
200
  raise ValueError("Either text or input_ids should be provided.")
216
201
 
217
202
  if self.text is not None:
218
- is_single = isinstance(self.text, str)
203
+ self.is_single = isinstance(self.text, str)
219
204
  else:
220
- is_single = isinstance(self.input_ids[0], int)
221
- self.is_single = is_single
205
+ self.is_single = isinstance(self.input_ids[0], int)
222
206
 
223
- if is_single:
207
+ if self.is_single:
224
208
  if self.rid is None:
225
209
  self.rid = uuid.uuid4().hex
226
210
  if self.sampling_params is None:
@@ -254,6 +238,50 @@ class TokenizedEmbeddingReqInput:
254
238
  sampling_params: SamplingParams
255
239
 
256
240
 
241
+ @dataclass
242
+ class RewardReqInput:
243
+ # The input prompt in the chat format. It can be a single prompt or a batch of prompts.
244
+ conv: Union[List[List[Dict]], List[Dict]]
245
+ # The request id.
246
+ rid: Optional[Union[List[str], str]] = None
247
+ # Dummy sampling params for compatibility
248
+ sampling_params: Union[List[Dict], Dict] = None
249
+
250
+ def post_init(self):
251
+ self.is_single = isinstance(self.conv[0], dict)
252
+
253
+ if self.is_single:
254
+ if self.rid is None:
255
+ self.rid = uuid.uuid4().hex
256
+ if self.sampling_params is None:
257
+ self.sampling_params = {}
258
+ self.sampling_params["max_new_tokens"] = 1
259
+ else:
260
+ # support select operation
261
+ self.batch_size = len(self.conv)
262
+ if self.rid is None:
263
+ self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
264
+ else:
265
+ if not isinstance(self.rid, list):
266
+ raise ValueError("The rid should be a list.")
267
+ if self.sampling_params is None:
268
+ self.sampling_params = [{}] * self.batch_size
269
+ for i in range(self.batch_size):
270
+ self.sampling_params[i]["max_new_tokens"] = 1
271
+
272
+
273
+ @dataclass
274
+ class TokenizedRewardReqInput:
275
+ # The request id
276
+ rid: str
277
+ # The input text
278
+ input_text: str
279
+ # The input token ids
280
+ input_ids: List[int]
281
+ # Dummy sampling params for compatibility
282
+ sampling_params: SamplingParams
283
+
284
+
257
285
  @dataclass
258
286
  class BatchTokenIDOut:
259
287
  # The request id
@@ -268,10 +296,6 @@ class BatchTokenIDOut:
268
296
  meta_info: List[Dict]
269
297
  finished_reason: List[BaseFinishReason]
270
298
 
271
- def __post_init__(self):
272
- # deepcopy meta_info to avoid modification in place
273
- self.meta_info = copy.deepcopy(self.meta_info)
274
-
275
299
 
276
300
  @dataclass
277
301
  class BatchStrOut: