sglang 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -98,17 +98,21 @@ class ModelTpServer:
98
98
  if server_args.max_prefill_tokens is None
99
99
  else server_args.max_prefill_tokens
100
100
  )
101
- self.max_running_requests = (
102
- self.max_total_num_tokens // 2
103
- if server_args.max_running_requests is None
104
- else server_args.max_running_requests
105
- )
106
101
  self.max_running_requests = min(
107
- self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
102
+ (
103
+ self.max_total_num_tokens // 2
104
+ if server_args.max_running_requests is None
105
+ else server_args.max_running_requests
106
+ ),
107
+ self.model_runner.req_to_token_pool.size - 1,
108
108
  )
109
109
  self.int_token_logit_bias = torch.tensor(
110
110
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
111
111
  )
112
+ self.max_req_input_len = min(
113
+ self.model_config.context_len - 1,
114
+ self.max_total_num_tokens - 1,
115
+ )
112
116
  set_random_seed(server_args.random_seed)
113
117
 
114
118
  # Print info
@@ -295,18 +299,20 @@ class ModelTpServer:
295
299
  )
296
300
 
297
301
  # Truncate prompts that are too long
298
- req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
302
+ if len(req.origin_input_ids) >= self.max_req_input_len:
303
+ logger.warn(
304
+ "Request length is longer than the KV cache pool size or "
305
+ "the max context length. Truncated!!!"
306
+ )
307
+ req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
299
308
  req.sampling_params.max_new_tokens = min(
300
- req.sampling_params.max_new_tokens,
301
- self.model_config.context_len - 1 - len(req.origin_input_ids),
302
- self.max_total_num_tokens - 128 - len(req.origin_input_ids),
309
+ (
310
+ req.sampling_params.max_new_tokens
311
+ if req.sampling_params.max_new_tokens is not None
312
+ else 1 << 30
313
+ ),
314
+ self.max_req_input_len - 1 - len(req.origin_input_ids),
303
315
  )
304
- if req.sampling_params.max_new_tokens < 0:
305
- req.origin_input_ids = req.origin_input_ids[
306
- : self.max_total_num_tokens - 128
307
- ]
308
- logger.error("Request longer than memory pool size, truncated!!!")
309
-
310
316
  self.forward_queue.append(req)
311
317
 
312
318
  def get_new_prefill_batch(self) -> Optional[Batch]:
@@ -449,7 +455,7 @@ class ModelTpServer:
449
455
  torch.arange(len(next_token_ids), device=next_token_ids.device),
450
456
  next_token_ids,
451
457
  ].tolist()
452
- output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
458
+ output.input_token_logprobs = output.input_token_logprobs.tolist()
453
459
  output.normalized_prompt_logprobs = (
454
460
  output.normalized_prompt_logprobs.tolist()
455
461
  )
@@ -475,24 +481,24 @@ class ModelTpServer:
475
481
  if req.normalized_prompt_logprob is None:
476
482
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
477
483
 
478
- if req.prefill_token_logprobs is None:
484
+ if req.input_token_logprobs is None:
479
485
  # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
480
- req.prefill_token_logprobs = list(
486
+ req.input_token_logprobs = list(
481
487
  zip(
482
- output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
488
+ output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
483
489
  req.input_ids[-req.extend_input_len + 1 :],
484
490
  )
485
491
  )
486
492
  if req.logprob_start_len == 0:
487
- req.prefill_token_logprobs = [
493
+ req.input_token_logprobs = [
488
494
  (None, req.input_ids[0])
489
- ] + req.prefill_token_logprobs
495
+ ] + req.input_token_logprobs
490
496
 
491
497
  if req.last_update_decode_tokens != 0:
492
- req.decode_token_logprobs.extend(
498
+ req.output_token_logprobs.extend(
493
499
  list(
494
500
  zip(
495
- output.prefill_token_logprobs[
501
+ output.input_token_logprobs[
496
502
  pt
497
503
  + req.extend_input_len
498
504
  - req.last_update_decode_tokens : pt
@@ -504,21 +510,21 @@ class ModelTpServer:
504
510
  )
505
511
  )
506
512
 
507
- req.decode_token_logprobs.append(
513
+ req.output_token_logprobs.append(
508
514
  (output.next_token_logprobs[i], next_token_ids[i])
509
515
  )
510
516
 
511
517
  if req.top_logprobs_num > 0:
512
- if req.prefill_top_logprobs is None:
513
- req.prefill_top_logprobs = output.prefill_top_logprobs[i]
518
+ if req.input_top_logprobs is None:
519
+ req.input_top_logprobs = output.input_top_logprobs[i]
514
520
  if req.logprob_start_len == 0:
515
- req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
521
+ req.input_top_logprobs = [None] + req.input_top_logprobs
516
522
 
517
523
  if req.last_update_decode_tokens != 0:
518
- req.decode_top_logprobs.extend(
519
- output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
524
+ req.output_top_logprobs.extend(
525
+ output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
520
526
  )
521
- req.decode_top_logprobs.append(output.decode_top_logprobs[i])
527
+ req.output_top_logprobs.append(output.output_top_logprobs[i])
522
528
 
523
529
  def cache_filled_batch(self, batch: Batch):
524
530
  req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
@@ -583,11 +589,11 @@ class ModelTpServer:
583
589
  req.check_finished()
584
590
 
585
591
  if req.return_logprob:
586
- req.decode_token_logprobs.append(
592
+ req.output_token_logprobs.append(
587
593
  (next_token_logprobs[i], next_token_id)
588
594
  )
589
595
  if req.top_logprobs_num > 0:
590
- req.decode_top_logprobs.append(output.decode_top_logprobs[i])
596
+ req.output_top_logprobs.append(output.output_top_logprobs[i])
591
597
 
592
598
  self.handle_finished_requests(batch)
593
599
 
@@ -639,16 +645,16 @@ class ModelTpServer:
639
645
  }
640
646
  if req.return_logprob:
641
647
  (
642
- meta_info["prefill_token_logprobs"],
643
- meta_info["decode_token_logprobs"],
644
- meta_info["prefill_top_logprobs"],
645
- meta_info["decode_top_logprobs"],
648
+ meta_info["input_token_logprobs"],
649
+ meta_info["output_token_logprobs"],
650
+ meta_info["input_top_logprobs"],
651
+ meta_info["output_top_logprobs"],
646
652
  meta_info["normalized_prompt_logprob"],
647
653
  ) = (
648
- req.prefill_token_logprobs,
649
- req.decode_token_logprobs,
650
- req.prefill_top_logprobs,
651
- req.decode_top_logprobs,
654
+ req.input_token_logprobs,
655
+ req.output_token_logprobs,
656
+ req.input_top_logprobs,
657
+ req.output_top_logprobs,
652
658
  req.normalized_prompt_logprob,
653
659
  )
654
660
  output_meta_info.append(meta_info)
@@ -20,7 +20,7 @@ class GenerateReqInput:
20
20
  # The image input. It can be a file name, a url, or base64 encoded string.
21
21
  # See also python/sglang/srt/utils.py:load_image.
22
22
  image_data: Optional[Union[List[str], str]] = None
23
- # The sampling_params.
23
+ # The sampling_params. See descriptions below.
24
24
  sampling_params: Union[List[Dict], Dict] = None
25
25
  # The request id.
26
26
  rid: Optional[Union[List[str], str]] = None
@@ -30,7 +30,7 @@ class GenerateReqInput:
30
30
  logprob_start_len: Optional[Union[List[int], int]] = None
31
31
  # The number of top logprobs to return.
32
32
  top_logprobs_num: Optional[Union[List[int], int]] = None
33
- # Whether to detokenize tokens in logprobs.
33
+ # Whether to detokenize tokens in text in the returned logprobs.
34
34
  return_text_in_logprobs: bool = False
35
35
  # Whether to stream output.
36
36
  stream: bool = False
@@ -133,24 +133,10 @@ class TokenizerManager:
133
133
  async for response in self._handle_batch_request(obj, request):
134
134
  yield response
135
135
 
136
- async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
137
- if is_prefill:
138
- if isinstance(obj.text, list):
139
- input_text = obj.text[index]
140
- rid = obj.rid[index]
141
- else:
142
- input_text = obj.text
143
- rid = obj.rid[0]
144
- input_ids = self.tokenizer.encode(input_text)
145
- sampling_params = SamplingParams(**obj.sampling_params[0])
146
- sampling_params.max_new_tokens = 0
147
- pixel_values, image_hash, image_size = await self._get_pixel_values(
148
- obj.image_data[0]
149
- )
150
- return_logprob = obj.return_logprob[0]
151
- logprob_start_len = obj.logprob_start_len[0]
152
- top_logprobs_num = obj.top_logprobs_num[0]
153
- else:
136
+ async def _handle_single_request(
137
+ self, obj, request, index=None, is_cache_for_prefill=False
138
+ ):
139
+ if not is_cache_for_prefill:
154
140
  rid = obj.rid if index is None else obj.rid[index]
155
141
  input_text = obj.text if index is None else obj.text[index]
156
142
  input_ids = (
@@ -177,6 +163,22 @@ class TokenizerManager:
177
163
  top_logprobs_num = (
178
164
  obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
179
165
  )
166
+ else:
167
+ if isinstance(obj.text, list):
168
+ input_text = obj.text[index]
169
+ rid = obj.rid[index]
170
+ else:
171
+ input_text = obj.text
172
+ rid = obj.rid[0]
173
+ input_ids = self.tokenizer.encode(input_text)
174
+ sampling_params = SamplingParams(**obj.sampling_params[0])
175
+ sampling_params.max_new_tokens = 0
176
+ pixel_values, image_hash, image_size = await self._get_pixel_values(
177
+ obj.image_data[0]
178
+ )
179
+ return_logprob = obj.return_logprob[0]
180
+ logprob_start_len = obj.logprob_start_len[0]
181
+ top_logprobs_num = obj.top_logprobs_num[0]
180
182
 
181
183
  tokenized_obj = TokenizedGenerateReqInput(
182
184
  rid,
@@ -196,26 +198,26 @@ class TokenizerManager:
196
198
  event = asyncio.Event()
197
199
  state = ReqState([], False, event)
198
200
  self.rid_to_state[rid] = state
199
- if is_prefill:
200
- await self._wait_for_prefill_response(event, state, obj, request, rid)
201
- yield input_ids
202
- else:
201
+ if not is_cache_for_prefill:
203
202
  async for response in self._wait_for_response(
204
203
  event, state, obj, rid, request
205
204
  ):
206
205
  yield response
206
+ else:
207
+ await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
208
+ yield input_ids
207
209
 
208
- async def _handle_batch_request(self, obj, request):
210
+ async def _handle_batch_request(self, obj: GenerateReqInput, request):
209
211
  batch_size = obj.batch_size
210
212
  parallel_sample_num = obj.sampling_params[0].get("n", 1)
211
213
 
212
214
  if parallel_sample_num != 1:
213
- ## send prefill requests
215
+ # Send prefill requests to cache the common input
214
216
  parallel_sample_num += 1
215
217
  input_id_result = [] if obj.input_ids is None else None
216
218
  for i in range(batch_size):
217
219
  async for input_id in self._handle_single_request(
218
- obj, request, index=i, is_prefill=True
220
+ obj, request, index=i, is_cache_for_prefill=True
219
221
  ):
220
222
  if input_id_result is not None:
221
223
  input_id_result.append(input_id)
@@ -224,6 +226,7 @@ class TokenizerManager:
224
226
  obj.input_ids = input_id_result
225
227
  elif input_id_result is not None:
226
228
  obj.input_ids = input_id_result[0]
229
+
227
230
  # First send out all requests
228
231
  for i in range(batch_size):
229
232
  for j in range(parallel_sample_num):
@@ -308,17 +311,15 @@ class TokenizerManager:
308
311
 
309
312
  yield output_list
310
313
 
311
- def _validate_input_length(self, input_ids):
314
+ def _validate_input_length(self, input_ids: List[int]):
312
315
  if len(input_ids) >= self.context_len:
313
316
  raise ValueError(
314
317
  f"The input ({len(input_ids)} tokens) is longer than the "
315
318
  f"model's context length ({self.context_len} tokens)."
316
319
  )
317
320
 
318
- def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
321
+ def _get_sampling_params(self, sampling_params_data: dict):
319
322
  sampling_params = SamplingParams(**sampling_params_data)
320
- if max_new_tokens is not None:
321
- sampling_params.max_new_tokens = max_new_tokens
322
323
  if sampling_params.max_new_tokens != 0:
323
324
  sampling_params.normalize(self.tokenizer)
324
325
  sampling_params.verify()
@@ -332,7 +333,14 @@ class TokenizerManager:
332
333
  else:
333
334
  return None, None, None
334
335
 
335
- async def _wait_for_response(self, event, state, obj, rid, request):
336
+ async def _wait_for_response(
337
+ self,
338
+ event: asyncio.Event,
339
+ state: ReqState,
340
+ obj: GenerateReqInput,
341
+ rid: str,
342
+ request,
343
+ ):
336
344
  while True:
337
345
  try:
338
346
  await asyncio.wait_for(event.wait(), timeout=4)
@@ -361,7 +369,14 @@ class TokenizerManager:
361
369
  event.clear()
362
370
  yield out
363
371
 
364
- async def _wait_for_prefill_response(self, event, state, obj, request, rid):
372
+ async def _wait_for_cache_prefill_response(
373
+ self,
374
+ event: asyncio.Event,
375
+ state: ReqState,
376
+ obj: GenerateReqInput,
377
+ rid: str,
378
+ request,
379
+ ):
365
380
  while True:
366
381
  try:
367
382
  await asyncio.wait_for(state.event.wait(), timeout=4)
@@ -380,7 +395,7 @@ class TokenizerManager:
380
395
  req = FlushCacheReq()
381
396
  self.send_to_router.send_pyobj(req)
382
397
 
383
- def abort_request(self, rid):
398
+ def abort_request(self, rid: str):
384
399
  if rid not in self.rid_to_state:
385
400
  return
386
401
  del self.rid_to_state[rid]
@@ -426,31 +441,35 @@ class TokenizerManager:
426
441
  state.event.set()
427
442
 
428
443
  def convert_logprob_style(
429
- self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
444
+ self,
445
+ ret: dict,
446
+ return_logprob: bool,
447
+ top_logprobs_num: int,
448
+ return_text_in_logprobs: bool,
430
449
  ):
431
450
  if return_logprob:
432
- ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
433
- ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
451
+ ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
452
+ ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
434
453
  )
435
- ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
436
- ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
454
+ ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
455
+ ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
437
456
  )
438
457
 
439
458
  if top_logprobs_num > 0:
440
- ret["meta_info"]["prefill_top_logprobs"] = (
459
+ ret["meta_info"]["input_top_logprobs"] = (
441
460
  self.detokenize_top_logprobs_tokens(
442
- ret["meta_info"]["prefill_top_logprobs"],
461
+ ret["meta_info"]["input_top_logprobs"],
443
462
  return_text_in_logprobs,
444
463
  )
445
464
  )
446
- ret["meta_info"]["decode_top_logprobs"] = (
465
+ ret["meta_info"]["output_top_logprobs"] = (
447
466
  self.detokenize_top_logprobs_tokens(
448
- ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
467
+ ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
449
468
  )
450
469
  )
451
470
  return ret
452
471
 
453
- def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
472
+ def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool):
454
473
  if not decode_to_text:
455
474
  return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
456
475
 
@@ -461,7 +480,7 @@ class TokenizerManager:
461
480
  for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
462
481
  ]
463
482
 
464
- def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
483
+ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
465
484
  for i, t in enumerate(top_logprobs):
466
485
  if t:
467
486
  top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
@@ -36,6 +36,11 @@ class ModelConfig:
36
36
  "head_dim",
37
37
  self.hf_config.hidden_size // self.hf_config.num_attention_heads,
38
38
  )
39
+
40
+ # FIXME: temporary special judge for deepseek v2 MLA architecture
41
+ if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
42
+ self.head_dim = 256
43
+
39
44
  self.num_attention_heads = self.hf_config.num_attention_heads
40
45
  self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
41
46