sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import dataclasses
21
21
  import logging
22
22
  import multiprocessing as mp
23
23
  import os
24
- from typing import Dict, List, Tuple
24
+ from typing import Dict, List, Tuple, Union
25
25
 
26
26
  import numpy as np
27
27
  import transformers
@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
38
38
  )
39
39
  from sglang.srt.managers.io_struct import (
40
40
  AbortReq,
41
+ BatchEmbeddingOut,
41
42
  BatchStrOut,
42
43
  BatchTokenIDOut,
44
+ EmbeddingReqInput,
43
45
  FlushCacheReq,
44
46
  GenerateReqInput,
47
+ TokenizedEmbeddingReqInput,
45
48
  TokenizedGenerateReqInput,
46
49
  )
47
50
  from sglang.srt.mm_utils import expand2square, process_anyres_image
48
51
  from sglang.srt.sampling_params import SamplingParams
49
52
  from sglang.srt.server_args import PortArgs, ServerArgs
50
- from sglang.srt.utils import is_multimodal_model, load_image
53
+ from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
51
54
  from sglang.utils import get_exception_traceback
52
55
 
53
56
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -85,31 +88,35 @@ class TokenizerManager:
85
88
  trust_remote_code=server_args.trust_remote_code,
86
89
  model_overide_args=model_overide_args,
87
90
  )
91
+ self.is_generation = is_generation_model(self.hf_config.architectures)
88
92
 
89
93
  if server_args.context_length is not None:
90
94
  self.context_len = server_args.context_length
91
95
  else:
92
96
  self.context_len = get_context_length(self.hf_config)
93
97
 
94
- if is_multimodal_model(self.model_path):
95
- self.processor = get_processor(
96
- server_args.tokenizer_path,
97
- tokenizer_mode=server_args.tokenizer_mode,
98
- trust_remote_code=server_args.trust_remote_code,
99
- )
100
- self.tokenizer = self.processor.tokenizer
101
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
102
- self.executor = concurrent.futures.ProcessPoolExecutor(
103
- initializer=init_global_processor,
104
- mp_context=mp.get_context("fork"),
105
- initargs=(server_args,),
106
- )
98
+ if server_args.skip_tokenizer_init:
99
+ self.tokenizer = self.processor = None
107
100
  else:
108
- self.tokenizer = get_tokenizer(
109
- server_args.tokenizer_path,
110
- tokenizer_mode=server_args.tokenizer_mode,
111
- trust_remote_code=server_args.trust_remote_code,
112
- )
101
+ if is_multimodal_model(self.model_path):
102
+ self.processor = get_processor(
103
+ server_args.tokenizer_path,
104
+ tokenizer_mode=server_args.tokenizer_mode,
105
+ trust_remote_code=server_args.trust_remote_code,
106
+ )
107
+ self.tokenizer = self.processor.tokenizer
108
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
109
+ self.executor = concurrent.futures.ProcessPoolExecutor(
110
+ initializer=init_global_processor,
111
+ mp_context=mp.get_context("fork"),
112
+ initargs=(server_args,),
113
+ )
114
+ else:
115
+ self.tokenizer = get_tokenizer(
116
+ server_args.tokenizer_path,
117
+ tokenizer_mode=server_args.tokenizer_mode,
118
+ trust_remote_code=server_args.trust_remote_code,
119
+ )
113
120
 
114
121
  self.to_create_loop = True
115
122
  self.rid_to_state: Dict[str, ReqState] = {}
@@ -133,7 +140,9 @@ class TokenizerManager:
133
140
  image_data, aspect_ratio, grid_pinpoints, self.processor
134
141
  )
135
142
 
136
- async def generate_request(self, obj: GenerateReqInput, request=None):
143
+ async def generate_request(
144
+ self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
145
+ ):
137
146
  if self.to_create_loop:
138
147
  self.create_handle_loop()
139
148
 
@@ -144,46 +153,55 @@ class TokenizerManager:
144
153
  async for response in self._handle_single_request(obj, request):
145
154
  yield response
146
155
  else:
147
- if obj.stream:
156
+ if hasattr(obj, "stream") and obj.stream:
148
157
  raise ValueError("Do not support stream for batch mode.")
149
158
 
150
159
  async for response in self._handle_batch_request(obj, request):
151
160
  yield response
152
161
 
153
162
  async def _handle_single_request(
154
- self, obj, request, index=None, is_cache_for_prefill=False
163
+ self,
164
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
165
+ request,
166
+ index=None,
167
+ is_cache_for_prefill=False,
155
168
  ):
156
169
  if not is_cache_for_prefill: # The normal case with a single prompt
157
170
  not_use_index = index is None
158
171
 
159
172
  rid = obj.rid if not_use_index else obj.rid[index]
160
173
  input_text = obj.text if not_use_index else obj.text[index]
161
- input_ids = (
162
- self.tokenizer.encode(input_text)
163
- if obj.input_ids is None
164
- else obj.input_ids
165
- )
166
- if not not_use_index and obj.input_ids:
167
- input_ids = obj.input_ids[index]
174
+ if obj.input_ids is None:
175
+ assert self.tokenizer is not None
176
+ input_ids = self.tokenizer.encode(input_text)
177
+ else:
178
+ input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
168
179
 
169
180
  self._validate_input_length(input_ids)
170
181
 
171
182
  sampling_params = self._get_sampling_params(
172
183
  obj.sampling_params if not_use_index else obj.sampling_params[index]
173
184
  )
174
- pixel_values, image_hash, image_size = await self._get_pixel_values(
175
- obj.image_data if not_use_index else obj.image_data[index]
176
- )
177
- return_logprob = (
178
- obj.return_logprob if not_use_index else obj.return_logprob[index]
179
- )
180
- logprob_start_len = (
181
- obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
182
- )
183
- top_logprobs_num = (
184
- obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
185
- )
185
+
186
+ if self.is_generation:
187
+ pixel_values, image_hash, image_size = await self._get_pixel_values(
188
+ obj.image_data if not_use_index else obj.image_data[index]
189
+ )
190
+ return_logprob = (
191
+ obj.return_logprob if not_use_index else obj.return_logprob[index]
192
+ )
193
+ logprob_start_len = (
194
+ obj.logprob_start_len
195
+ if not_use_index
196
+ else obj.logprob_start_len[index]
197
+ )
198
+ top_logprobs_num = (
199
+ obj.top_logprobs_num
200
+ if not_use_index
201
+ else obj.top_logprobs_num[index]
202
+ )
186
203
  else: # A prefill request to cache the common prompt for parallel sampling
204
+ assert self.is_generation
187
205
  if obj.text is not None:
188
206
  if isinstance(obj.text, list):
189
207
  input_text = obj.text[index]
@@ -191,7 +209,20 @@ class TokenizerManager:
191
209
  else:
192
210
  input_text = obj.text
193
211
  rid = obj.rid[0]
194
- input_ids = self.tokenizer.encode(input_text)
212
+ if self.tokenizer is not None:
213
+ input_ids = self.tokenizer.encode(input_text)
214
+ else:
215
+ assert obj.input_ids is not None
216
+ input_ids = obj.input_ids
217
+ if isinstance(obj.input_ids, list) and isinstance(
218
+ obj.input_ids[0], list
219
+ ):
220
+ # when obj["input_ids"] is List[List[int]]
221
+ input_ids = obj.input_ids[index]
222
+ rid = obj.rid[index]
223
+ else:
224
+ input_ids = obj.input_ids
225
+ rid = obj.rid[0]
195
226
  else:
196
227
  input_text = None
197
228
  if isinstance(obj.input_ids, list) and isinstance(
@@ -213,19 +244,28 @@ class TokenizerManager:
213
244
  logprob_start_len = obj.logprob_start_len[0]
214
245
  top_logprobs_num = obj.top_logprobs_num[0]
215
246
 
216
- tokenized_obj = TokenizedGenerateReqInput(
217
- rid,
218
- input_text,
219
- input_ids,
220
- pixel_values,
221
- image_hash,
222
- image_size,
223
- sampling_params,
224
- return_logprob,
225
- logprob_start_len,
226
- top_logprobs_num,
227
- obj.stream,
228
- )
247
+ if self.is_generation:
248
+ tokenized_obj = TokenizedGenerateReqInput(
249
+ rid,
250
+ input_text,
251
+ input_ids,
252
+ pixel_values,
253
+ image_hash,
254
+ image_size,
255
+ sampling_params,
256
+ return_logprob,
257
+ logprob_start_len,
258
+ top_logprobs_num,
259
+ obj.stream,
260
+ )
261
+ else: # is embedding
262
+ tokenized_obj = TokenizedEmbeddingReqInput(
263
+ rid,
264
+ input_text,
265
+ input_ids,
266
+ sampling_params,
267
+ )
268
+
229
269
  self.send_to_router.send_pyobj(tokenized_obj)
230
270
 
231
271
  event = asyncio.Event()
@@ -237,27 +277,33 @@ class TokenizerManager:
237
277
  ):
238
278
  yield response
239
279
  else:
280
+ assert self.is_generation
240
281
  await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
241
282
  yield input_ids
242
283
 
243
- async def _handle_batch_request(self, obj: GenerateReqInput, request):
284
+ async def _handle_batch_request(
285
+ self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
286
+ ):
244
287
  batch_size = obj.batch_size
245
- parallel_sample_num = obj.parallel_sample_num
246
-
247
- if parallel_sample_num != 1:
248
- # Send prefill requests to cache the common input
249
- parallel_sample_num += 1
250
- input_id_result = [] if obj.input_ids is None else None
251
- for i in range(batch_size):
252
- async for input_id in self._handle_single_request(
253
- obj, request, index=i, is_cache_for_prefill=True
254
- ):
255
- if input_id_result is not None:
256
- input_id_result.append(input_id)
257
- if input_id_result is not None and len(input_id_result) > 1:
258
- obj.input_ids = input_id_result
259
- elif input_id_result is not None:
260
- obj.input_ids = input_id_result[0]
288
+ if self.is_generation:
289
+ parallel_sample_num = obj.parallel_sample_num
290
+
291
+ if parallel_sample_num != 1:
292
+ # Send prefill requests to cache the common input
293
+ parallel_sample_num += 1
294
+ input_id_result = [] if obj.input_ids is None else None
295
+ for i in range(batch_size):
296
+ async for input_id in self._handle_single_request(
297
+ obj, request, index=i, is_cache_for_prefill=True
298
+ ):
299
+ if input_id_result is not None:
300
+ input_id_result.append(input_id)
301
+ if input_id_result is not None and len(input_id_result) > 1:
302
+ obj.input_ids = input_id_result
303
+ elif input_id_result is not None:
304
+ obj.input_ids = input_id_result[0]
305
+ else:
306
+ parallel_sample_num = 1
261
307
 
262
308
  # First send out all requests
263
309
  for i in range(batch_size):
@@ -286,28 +332,38 @@ class TokenizerManager:
286
332
  input_text = None
287
333
  input_ids = obj.input_ids[i]
288
334
  sampling_params = self._get_sampling_params(obj.sampling_params[index])
289
- pixel_values, image_hash, image_size = await self._get_pixel_values(
290
- obj.image_data[index]
291
- )
292
335
 
293
- tokenized_obj = TokenizedGenerateReqInput(
294
- rid,
295
- input_text,
296
- input_ids,
297
- pixel_values,
298
- image_hash,
299
- image_size,
300
- sampling_params,
301
- obj.return_logprob[index],
302
- obj.logprob_start_len[index],
303
- obj.top_logprobs_num[index],
304
- obj.stream,
305
- )
336
+ if self.is_generation:
337
+ pixel_values, image_hash, image_size = await self._get_pixel_values(
338
+ obj.image_data[index]
339
+ )
340
+
341
+ tokenized_obj = TokenizedGenerateReqInput(
342
+ rid,
343
+ input_text,
344
+ input_ids,
345
+ pixel_values,
346
+ image_hash,
347
+ image_size,
348
+ sampling_params,
349
+ obj.return_logprob[index],
350
+ obj.logprob_start_len[index],
351
+ obj.top_logprobs_num[index],
352
+ obj.stream,
353
+ )
354
+ else:
355
+ tokenized_obj = TokenizedEmbeddingReqInput(
356
+ rid,
357
+ input_text,
358
+ input_ids,
359
+ sampling_params,
360
+ )
306
361
  self.send_to_router.send_pyobj(tokenized_obj)
307
362
 
308
363
  event = asyncio.Event()
309
364
  state = ReqState([], False, event)
310
365
  self.rid_to_state[rid] = state
366
+
311
367
  # Then wait for all responses
312
368
  output_list = []
313
369
  for i in range(batch_size):
@@ -330,14 +386,17 @@ class TokenizerManager:
330
386
  self.abort_request(rid)
331
387
  raise ValueError(f"Abort request {rid}")
332
388
  continue
333
- output_list.append(
334
- self.convert_logprob_style(
335
- state.out_list[-1],
336
- obj.return_logprob[index],
337
- obj.top_logprobs_num[index],
338
- obj.return_text_in_logprobs,
389
+ if self.is_generation:
390
+ output_list.append(
391
+ self.convert_logprob_style(
392
+ state.out_list[-1],
393
+ obj.return_logprob[index],
394
+ obj.top_logprobs_num[index],
395
+ obj.return_text_in_logprobs,
396
+ )
339
397
  )
340
- )
398
+ else:
399
+ output_list.append(state.out_list[-1])
341
400
  assert state.finished
342
401
  del self.rid_to_state[rid]
343
402
  yield output_list
@@ -368,7 +427,7 @@ class TokenizerManager:
368
427
  self,
369
428
  event: asyncio.Event,
370
429
  state: ReqState,
371
- obj: GenerateReqInput,
430
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
372
431
  rid: str,
373
432
  request,
374
433
  ):
@@ -381,17 +440,20 @@ class TokenizerManager:
381
440
  raise ValueError(f"Abort request {rid}")
382
441
  continue
383
442
 
384
- out = self.convert_logprob_style(
385
- state.out_list[-1],
386
- obj.return_logprob,
387
- obj.top_logprobs_num,
388
- obj.return_text_in_logprobs,
389
- )
443
+ if self.is_generation:
444
+ out = self.convert_logprob_style(
445
+ state.out_list[-1],
446
+ obj.return_logprob,
447
+ obj.top_logprobs_num,
448
+ obj.return_text_in_logprobs,
449
+ )
450
+ else: # isinstance(obj, EmbeddingReqInput)
451
+ out = state.out_list[-1]
390
452
 
391
453
  # Log requests
392
454
  if self.server_args.log_requests and state.finished:
393
455
  if obj.text is None:
394
- in_obj = {"text": self.tokenizer.decode(obj.input_ids)}
456
+ in_obj = {"input_ids": obj.input_ids}
395
457
  else:
396
458
  in_obj = {"text": obj.text}
397
459
  logger.info(f"in={in_obj}, out={out}")
@@ -459,19 +521,38 @@ class TokenizerManager:
459
521
 
460
522
  async def handle_loop(self):
461
523
  while True:
462
- recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
463
- assert isinstance(recv_obj, BatchStrOut)
464
-
524
+ recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
525
+ await self.recv_from_detokenizer.recv_pyobj()
526
+ )
527
+ assert isinstance(
528
+ recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
529
+ ), f"Unexpected obj received: {type(recv_obj)}"
465
530
  for i, rid in enumerate(recv_obj.rids):
466
531
  state = self.rid_to_state.get(rid, None)
467
532
  if state is None:
468
533
  continue
469
534
 
470
535
  recv_obj.meta_info[i]["id"] = rid
471
- out_dict = {
472
- "text": recv_obj.output_strs[i],
473
- "meta_info": recv_obj.meta_info[i],
474
- }
536
+ if isinstance(recv_obj, BatchStrOut):
537
+ out_dict = {
538
+ "text": recv_obj.output_strs[i],
539
+ "meta_info": recv_obj.meta_info[i],
540
+ }
541
+ elif isinstance(recv_obj, BatchTokenIDOut):
542
+ read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
543
+ out_dict = {
544
+ "token_ids": recv_obj.decode_ids[
545
+ read_start : recv_obj.read_offsets[i]
546
+ ],
547
+ "meta_info": recv_obj.meta_info[i],
548
+ }
549
+
550
+ else:
551
+ assert isinstance(recv_obj, BatchEmbeddingOut)
552
+ out_dict = {
553
+ "embedding": recv_obj.embeddings[i],
554
+ "meta_info": recv_obj.meta_info[i],
555
+ }
475
556
  state.out_list.append(out_dict)
476
557
  state.finished = recv_obj.finished_reason[i] is not None
477
558
  state.event.set()
@@ -511,6 +592,7 @@ class TokenizerManager:
511
592
  if not decode_to_text:
512
593
  return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
513
594
 
595
+ assert self.tokenizer is not None
514
596
  token_ids = [tid for _, tid in token_logprobs]
515
597
  token_texts = self.tokenizer.batch_decode(token_ids)
516
598
  return [