sglang 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +145 -36
  4. sglang/check_env.py +24 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -29
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/layers/logits_processor.py +1 -1
  13. sglang/srt/layers/radix_attention.py +2 -5
  14. sglang/srt/managers/schedule_batch.py +95 -324
  15. sglang/srt/managers/tokenizer_manager.py +6 -3
  16. sglang/srt/managers/tp_worker.py +20 -22
  17. sglang/srt/mem_cache/memory_pool.py +9 -14
  18. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  19. sglang/srt/model_executor/forward_batch_info.py +256 -0
  20. sglang/srt/model_executor/model_runner.py +6 -10
  21. sglang/srt/models/chatglm.py +1 -1
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/dbrx.py +1 -1
  24. sglang/srt/models/deepseek.py +1 -1
  25. sglang/srt/models/deepseek_v2.py +1 -1
  26. sglang/srt/models/gemma.py +1 -1
  27. sglang/srt/models/gemma2.py +1 -1
  28. sglang/srt/models/gpt_bigcode.py +1 -1
  29. sglang/srt/models/grok.py +1 -1
  30. sglang/srt/models/internlm2.py +1 -1
  31. sglang/srt/models/llama2.py +1 -1
  32. sglang/srt/models/llama_classification.py +1 -1
  33. sglang/srt/models/llava.py +1 -2
  34. sglang/srt/models/llavavid.py +1 -2
  35. sglang/srt/models/minicpm.py +1 -1
  36. sglang/srt/models/mixtral.py +1 -1
  37. sglang/srt/models/mixtral_quant.py +1 -1
  38. sglang/srt/models/qwen.py +1 -1
  39. sglang/srt/models/qwen2.py +1 -1
  40. sglang/srt/models/qwen2_moe.py +1 -1
  41. sglang/srt/models/stablelm.py +1 -1
  42. sglang/srt/openai_api/adapter.py +34 -12
  43. sglang/srt/openai_api/protocol.py +6 -0
  44. sglang/srt/server.py +24 -6
  45. sglang/srt/server_args.py +4 -0
  46. sglang/test/test_utils.py +1 -1
  47. sglang/version.py +1 -1
  48. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/METADATA +34 -24
  49. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/RECORD +52 -50
  50. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  51. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  52. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ limitations under the License.
18
18
  import logging
19
19
  import warnings
20
20
  from dataclasses import dataclass
21
- from enum import IntEnum, auto
22
21
  from typing import List, Union
23
22
 
24
23
  import numpy as np
@@ -46,15 +45,6 @@ global_server_args_dict = {
46
45
  logger = logging.getLogger(__name__)
47
46
 
48
47
 
49
- class ForwardMode(IntEnum):
50
- # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
51
- PREFILL = auto()
52
- # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
53
- EXTEND = auto()
54
- # Decode one token.
55
- DECODE = auto()
56
-
57
-
58
48
  class BaseFinishReason:
59
49
  def __init__(self, is_error: bool = False):
60
50
  self.is_error = is_error
@@ -110,6 +100,9 @@ class Req:
110
100
  self.output_ids = [] # Each decode stage's output ids
111
101
  self.input_ids = None # input_ids = origin_input_ids + output_ids
112
102
 
103
+ # Memory info
104
+ self.req_pool_idx = None
105
+
113
106
  # For incremental decoding
114
107
  # ----- | --------- read_ids -------|
115
108
  # ----- | surr_ids |
@@ -284,7 +277,7 @@ class Req:
284
277
 
285
278
 
286
279
  @dataclass
287
- class Batch:
280
+ class ScheduleBatch:
288
281
  """Store all inforamtion of a batch."""
289
282
 
290
283
  # Request, memory pool, and cache
@@ -331,6 +324,9 @@ class Batch:
331
324
  return_logprob=return_logprob,
332
325
  )
333
326
 
327
+ def batch_size(self):
328
+ return len(self.reqs) if self.reqs is not None else 0
329
+
334
330
  def is_empty(self):
335
331
  return len(self.reqs) == 0
336
332
 
@@ -338,118 +334,127 @@ class Batch:
338
334
  # Return whether batch has at least 1 streaming request
339
335
  return any(r.stream for r in self.reqs)
340
336
 
337
+ def alloc_req_slots(self, num_reqs):
338
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
339
+ if req_pool_indices is None:
340
+ raise RuntimeError(
341
+ "Out of memory. "
342
+ "Please set a smaller number for `--max-running-requests`."
343
+ )
344
+ return req_pool_indices
345
+
346
+ def alloc_token_slots(self, num_tokens: int):
347
+ out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
348
+
349
+ if out_cache_loc is None:
350
+ if self.tree_cache is not None:
351
+ self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
352
+ out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
353
+
354
+ if out_cache_loc is None:
355
+ logger.error("Prefill out of memory. Try to lower your batch size.")
356
+ if self.tree_cache is not None:
357
+ self.tree_cache.pretty_print()
358
+ exit(1)
359
+
360
+ return out_cache_loc
361
+
362
+ def batch_sampling_params(self, vocab_size, int_token_logit_bias):
363
+ device = "cuda"
364
+ bs, reqs = self.batch_size(), self.reqs
365
+ self.temperatures = torch.tensor(
366
+ [r.sampling_params.temperature for r in reqs],
367
+ dtype=torch.float,
368
+ device=device,
369
+ ).view(-1, 1)
370
+ self.top_ps = torch.tensor(
371
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
372
+ )
373
+ self.top_ks = torch.tensor(
374
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
375
+ )
376
+ self.frequency_penalties = torch.tensor(
377
+ [r.sampling_params.frequency_penalty for r in reqs],
378
+ dtype=torch.float,
379
+ device=device,
380
+ )
381
+ self.presence_penalties = torch.tensor(
382
+ [r.sampling_params.presence_penalty for r in reqs],
383
+ dtype=torch.float,
384
+ device=device,
385
+ )
386
+
387
+ # Handle logit bias but only allocate when needed
388
+ self.logit_bias = None
389
+ for i in range(bs):
390
+ if reqs[i].sampling_params.dtype == "int":
391
+ if self.logit_bias is None:
392
+ self.logit_bias = torch.zeros(
393
+ (bs, vocab_size), dtype=torch.float32, device=device
394
+ )
395
+ self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
396
+
341
397
  def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
342
398
  device = "cuda"
343
- bs = len(self.reqs)
399
+ bs = self.batch_size()
344
400
  reqs = self.reqs
345
401
  input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
346
402
  prefix_indices = [r.prefix_indices for r in reqs]
347
403
 
348
404
  # Handle prefix
349
- flatten_input_ids = []
350
405
  extend_lens = []
351
406
  prefix_lens = []
352
407
  seq_lens = []
353
408
 
354
- req_pool_indices = self.req_to_token_pool.alloc(bs)
409
+ req_pool_indices_cpu = self.alloc_req_slots(bs)
355
410
 
356
- if req_pool_indices is None:
357
- raise RuntimeError(
358
- "Out of memory. "
359
- "Please set a smaller number for `--max-running-requests`."
360
- )
361
-
362
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
363
- for i in range(bs):
364
- flatten_input_ids.extend(input_ids[i])
411
+ for i, req in enumerate(reqs):
412
+ req.req_pool_idx = req_pool_indices_cpu[i]
365
413
  extend_lens.append(len(input_ids[i]))
366
414
 
367
415
  if len(prefix_indices[i]) == 0:
368
416
  prefix_lens.append(0)
369
417
  else:
370
418
  prefix_lens.append(len(prefix_indices[i]))
371
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
419
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][
372
420
  : len(prefix_indices[i])
373
421
  ] = prefix_indices[i]
374
422
 
375
423
  seq_lens.append(prefix_lens[-1] + extend_lens[-1])
376
424
 
377
- position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
378
-
379
425
  # Allocate memory
380
426
  seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
381
427
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
382
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
383
- if out_cache_loc is None:
384
- if self.tree_cache is not None:
385
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
386
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
387
-
388
- if out_cache_loc is None:
389
- logger.error("Prefill out of memory. Try to lower your batch size.")
390
- if self.tree_cache is not None:
391
- self.tree_cache.pretty_print()
392
- exit(1)
428
+ out_cache_loc = self.alloc_token_slots(extend_num_tokens)
393
429
 
394
430
  pt = 0
395
- for i in range(bs):
396
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
431
+ for i, req in enumerate(reqs):
432
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][
397
433
  prefix_lens[i] : prefix_lens[i] + extend_lens[i]
398
434
  ] = out_cache_loc[pt : pt + extend_lens[i]]
399
435
  pt += extend_lens[i]
400
436
 
401
- # Handle logit bias but only allocate when needed
402
- logit_bias = None
403
- for i in range(bs):
404
- if reqs[i].sampling_params.dtype == "int":
405
- if logit_bias is None:
406
- logit_bias = torch.zeros(
407
- (bs, vocab_size), dtype=torch.float32, device=device
408
- )
409
- logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
410
-
411
437
  # Set fields
412
- self.input_ids = torch.tensor(
413
- flatten_input_ids, dtype=torch.int32, device=device
414
- )
438
+ with torch.device("cuda"):
439
+ self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
440
+ self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
441
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
442
+ self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
443
+
415
444
  self.pixel_values = [r.pixel_values for r in reqs]
416
445
  self.image_sizes = [r.image_size for r in reqs]
417
446
  self.image_offsets = [
418
447
  r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
419
448
  ]
420
- self.req_pool_indices = req_pool_indices
421
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
422
449
  self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
423
- self.position_ids_offsets = position_ids_offsets
424
450
  self.extend_num_tokens = extend_num_tokens
425
451
  self.out_cache_loc = out_cache_loc
426
452
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
427
453
 
428
- self.temperatures = torch.tensor(
429
- [r.sampling_params.temperature for r in reqs],
430
- dtype=torch.float,
431
- device=device,
432
- ).view(-1, 1)
433
- self.top_ps = torch.tensor(
434
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
435
- )
436
- self.top_ks = torch.tensor(
437
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
438
- )
439
- self.frequency_penalties = torch.tensor(
440
- [r.sampling_params.frequency_penalty for r in reqs],
441
- dtype=torch.float,
442
- device=device,
443
- )
444
- self.presence_penalties = torch.tensor(
445
- [r.sampling_params.presence_penalty for r in reqs],
446
- dtype=torch.float,
447
- device=device,
448
- )
449
- self.logit_bias = logit_bias
454
+ self.batch_sampling_params(vocab_size, int_token_logit_bias)
450
455
 
451
456
  def check_decode_mem(self):
452
- bs = len(self.reqs)
457
+ bs = self.batch_size()
453
458
  if self.token_to_kv_pool.available_size() >= bs:
454
459
  return True
455
460
 
@@ -474,7 +479,6 @@ class Batch:
474
479
 
475
480
  retracted_reqs = []
476
481
  seq_lens_cpu = self.seq_lens.cpu().numpy()
477
- req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
478
482
  while (
479
483
  self.token_to_kv_pool.available_size()
480
484
  < len(sorted_indices) * global_config.retract_decode_steps
@@ -492,20 +496,20 @@ class Batch:
492
496
 
493
497
  if isinstance(self.tree_cache, ChunkCache):
494
498
  # ChunkCache does not have eviction
495
- token_indices = self.req_to_token_pool.req_to_token[
496
- req_pool_indices_cpu[idx]
497
- ][: seq_lens_cpu[idx]]
499
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
500
+ : seq_lens_cpu[idx]
501
+ ]
498
502
  self.token_to_kv_pool.free(token_indices)
499
- self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
503
+ self.req_to_token_pool.free(req.req_pool_idx)
500
504
  del self.tree_cache.entries[req.rid]
501
505
  else:
502
506
  # TODO: apply more fine-grained retraction
503
507
  last_uncached_pos = len(req.prefix_indices)
504
- token_indices = self.req_to_token_pool.req_to_token[
505
- req_pool_indices_cpu[idx]
506
- ][last_uncached_pos : seq_lens_cpu[idx]]
508
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
509
+ last_uncached_pos : seq_lens_cpu[idx]
510
+ ]
507
511
  self.token_to_kv_pool.free(token_indices)
508
- self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
512
+ self.req_to_token_pool.free(req.req_pool_idx)
509
513
 
510
514
  # release the last node
511
515
  self.tree_cache.dec_lock_ref(req.last_node)
@@ -543,8 +547,6 @@ class Batch:
543
547
  jump_forward_reqs = []
544
548
  filter_indices = [i for i in range(len(self.reqs))]
545
549
 
546
- req_pool_indices_cpu = None
547
-
548
550
  for i, req in enumerate(self.reqs):
549
551
  if req.jump_forward_map is not None:
550
552
  jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
@@ -594,13 +596,11 @@ class Batch:
594
596
  req.vid += 1
595
597
 
596
598
  # insert the old request into tree_cache
597
- if req_pool_indices_cpu is None:
598
- req_pool_indices_cpu = self.req_pool_indices.tolist()
599
599
  self.tree_cache.cache_req(
600
600
  rid=req.rid,
601
601
  token_ids=cur_all_ids,
602
602
  last_uncached_pos=len(req.prefix_indices),
603
- req_pool_idx=req_pool_indices_cpu[i],
603
+ req_pool_idx=req.req_pool_idx,
604
604
  )
605
605
 
606
606
  # unlock the last node
@@ -636,14 +636,8 @@ class Batch:
636
636
  self.prefix_lens = None
637
637
 
638
638
  # Alloc mem
639
- bs = len(self.reqs)
640
- self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
641
-
642
- if self.out_cache_loc is None:
643
- logger.error("Decode out of memory. Try to lower your batch size.")
644
- if self.tree_cache is not None:
645
- self.tree_cache.pretty_print()
646
- exit(1)
639
+ bs = self.batch_size()
640
+ self.out_cache_loc = self.alloc_token_slots(bs)
647
641
 
648
642
  self.req_to_token_pool.req_to_token[
649
643
  self.req_pool_indices, self.seq_lens - 1
@@ -673,7 +667,7 @@ class Batch:
673
667
  if self_val is not None: # logit_bias can be None
674
668
  setattr(self, item, self_val[new_indices])
675
669
 
676
- def merge(self, other: "Batch"):
670
+ def merge(self, other: "ScheduleBatch"):
677
671
  self.reqs.extend(other.reqs)
678
672
 
679
673
  self.req_pool_indices = torch.concat(
@@ -770,229 +764,6 @@ class Batch:
770
764
  return batch_next_token_ids
771
765
 
772
766
 
773
- @dataclass
774
- class InputMetadata:
775
- """Store all inforamtion of a forward pass."""
776
-
777
- forward_mode: ForwardMode
778
- batch_size: int
779
- total_num_tokens: int
780
- req_pool_indices: torch.Tensor
781
- seq_lens: torch.Tensor
782
- positions: torch.Tensor
783
- req_to_token_pool: ReqToTokenPool
784
- token_to_kv_pool: BaseTokenToKVPool
785
-
786
- # For extend
787
- extend_seq_lens: torch.Tensor
788
- extend_start_loc: torch.Tensor
789
- extend_no_prefix: bool
790
-
791
- # Output location of the KV cache
792
- out_cache_loc: torch.Tensor = None
793
-
794
- # Output options
795
- return_logprob: bool = False
796
- top_logprobs_nums: List[int] = None
797
-
798
- # Trition attention backend
799
- triton_max_seq_len: int = 0
800
- triton_max_extend_len: int = 0
801
- triton_start_loc: torch.Tensor = None
802
- triton_prefix_lens: torch.Tensor = None
803
-
804
- # FlashInfer attention backend
805
- flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
806
- flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
807
- flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
808
- flashinfer_use_ragged: bool = False
809
-
810
- @classmethod
811
- def create(
812
- cls,
813
- model_runner,
814
- forward_mode,
815
- req_pool_indices,
816
- seq_lens,
817
- prefix_lens,
818
- position_ids_offsets,
819
- out_cache_loc,
820
- top_logprobs_nums=None,
821
- return_logprob=False,
822
- skip_flashinfer_init=False,
823
- ):
824
- flashinfer_use_ragged = False
825
- if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
826
- if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
827
- flashinfer_use_ragged = True
828
- init_flashinfer_args(
829
- forward_mode,
830
- model_runner,
831
- req_pool_indices,
832
- seq_lens,
833
- prefix_lens,
834
- model_runner.flashinfer_decode_wrapper,
835
- flashinfer_use_ragged,
836
- )
837
-
838
- batch_size = len(req_pool_indices)
839
-
840
- if forward_mode == ForwardMode.DECODE:
841
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
842
- extend_seq_lens = extend_start_loc = extend_no_prefix = None
843
- if not model_runner.server_args.disable_flashinfer:
844
- # This variable is not needed in this case,
845
- # we do not compute it to make it compatbile with cuda graph.
846
- total_num_tokens = None
847
- else:
848
- total_num_tokens = int(torch.sum(seq_lens))
849
- else:
850
- seq_lens_cpu = seq_lens.cpu().numpy()
851
- prefix_lens_cpu = prefix_lens.cpu().numpy()
852
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
853
- positions = torch.tensor(
854
- np.concatenate(
855
- [
856
- np.arange(
857
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
858
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
859
- )
860
- for i in range(batch_size)
861
- ],
862
- axis=0,
863
- ),
864
- device="cuda",
865
- )
866
- extend_seq_lens = seq_lens - prefix_lens
867
- extend_start_loc = torch.zeros_like(seq_lens)
868
- extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
869
- extend_no_prefix = torch.all(prefix_lens == 0)
870
- total_num_tokens = int(torch.sum(seq_lens))
871
-
872
- ret = cls(
873
- forward_mode=forward_mode,
874
- batch_size=batch_size,
875
- total_num_tokens=total_num_tokens,
876
- req_pool_indices=req_pool_indices,
877
- seq_lens=seq_lens,
878
- positions=positions,
879
- req_to_token_pool=model_runner.req_to_token_pool,
880
- token_to_kv_pool=model_runner.token_to_kv_pool,
881
- out_cache_loc=out_cache_loc,
882
- extend_seq_lens=extend_seq_lens,
883
- extend_start_loc=extend_start_loc,
884
- extend_no_prefix=extend_no_prefix,
885
- return_logprob=return_logprob,
886
- top_logprobs_nums=top_logprobs_nums,
887
- flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
888
- flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
889
- flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
890
- flashinfer_use_ragged=flashinfer_use_ragged,
891
- )
892
-
893
- if model_runner.server_args.disable_flashinfer:
894
- (
895
- ret.triton_max_seq_len,
896
- ret.triton_max_extend_len,
897
- ret.triton_start_loc,
898
- ret.triton_prefix_lens,
899
- ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
900
-
901
- return ret
902
-
903
-
904
- def init_flashinfer_args(
905
- forward_mode,
906
- model_runner,
907
- req_pool_indices,
908
- seq_lens,
909
- prefix_lens,
910
- flashinfer_decode_wrapper,
911
- flashinfer_use_ragged=False,
912
- ):
913
- """Init auxiliary variables for FlashInfer attention backend."""
914
- num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
915
- num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
916
- head_dim = model_runner.model_config.head_dim
917
- batch_size = len(req_pool_indices)
918
- total_num_tokens = int(torch.sum(seq_lens))
919
-
920
- if flashinfer_use_ragged:
921
- paged_kernel_lens = prefix_lens
922
- else:
923
- paged_kernel_lens = seq_lens
924
-
925
- kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
926
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
927
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
928
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
929
- kv_indices = torch.cat(
930
- [
931
- model_runner.req_to_token_pool.req_to_token[
932
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
933
- ]
934
- for i in range(batch_size)
935
- ],
936
- dim=0,
937
- ).contiguous()
938
- kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
939
-
940
- if forward_mode == ForwardMode.DECODE:
941
- flashinfer_decode_wrapper.end_forward()
942
- flashinfer_decode_wrapper.begin_forward(
943
- kv_indptr,
944
- kv_indices,
945
- kv_last_page_len,
946
- num_qo_heads,
947
- num_kv_heads,
948
- head_dim,
949
- 1,
950
- )
951
- else:
952
- # extend part
953
- qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
954
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
955
-
956
- if flashinfer_use_ragged:
957
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
958
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
959
- qo_indptr,
960
- qo_indptr,
961
- num_qo_heads,
962
- num_kv_heads,
963
- head_dim,
964
- )
965
-
966
- # cached part
967
- model_runner.flashinfer_prefill_wrapper_paged.end_forward()
968
- model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
969
- qo_indptr,
970
- kv_indptr,
971
- kv_indices,
972
- kv_last_page_len,
973
- num_qo_heads,
974
- num_kv_heads,
975
- head_dim,
976
- 1,
977
- )
978
-
979
-
980
- def init_triton_args(forward_mode, seq_lens, prefix_lens):
981
- """Init auxiliary variables for triton attention backend."""
982
- batch_size = len(seq_lens)
983
- max_seq_len = int(torch.max(seq_lens))
984
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
985
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
986
-
987
- if forward_mode == ForwardMode.DECODE:
988
- max_extend_len = None
989
- else:
990
- extend_seq_lens = seq_lens - prefix_lens
991
- max_extend_len = int(torch.max(extend_seq_lens))
992
-
993
- return max_seq_len, max_extend_len, start_loc, prefix_lens
994
-
995
-
996
767
  def top_k_top_p_sampling_from_probs_torch(
997
768
  probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
998
769
  ):
@@ -308,7 +308,6 @@ class TokenizerManager:
308
308
  event = asyncio.Event()
309
309
  state = ReqState([], False, event)
310
310
  self.rid_to_state[rid] = state
311
-
312
311
  # Then wait for all responses
313
312
  output_list = []
314
313
  for i in range(batch_size):
@@ -341,7 +340,6 @@ class TokenizerManager:
341
340
  )
342
341
  assert state.finished
343
342
  del self.rid_to_state[rid]
344
-
345
343
  yield output_list
346
344
 
347
345
  def _validate_input_length(self, input_ids: List[int]):
@@ -390,8 +388,13 @@ class TokenizerManager:
390
388
  obj.return_text_in_logprobs,
391
389
  )
392
390
 
391
+ # Log requests
393
392
  if self.server_args.log_requests and state.finished:
394
- logger.info(f"in={obj.text}, out={out}")
393
+ if obj.text is None:
394
+ in_obj = {"text": self.tokenizer.decode(obj.input_ids)}
395
+ else:
396
+ in_obj = {"text": obj.text}
397
+ logger.info(f"in={in_obj}, out={out}")
395
398
 
396
399
  state.out_list = []
397
400
  if state.finished: